Haitam03 commited on
Commit
5ef1d32
·
verified ·
1 Parent(s): 01e27b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -14
app.py CHANGED
@@ -185,11 +185,16 @@ def load_model_and_vocabs():
185
  with open('latin_stoi.json', 'r', encoding='utf-8') as f:
186
  latin_stoi = json.load(f)
187
  with open('latin_itos.json', 'r', encoding='utf-8') as f:
188
- latin_itos = json.load(f)
 
 
 
189
  with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
190
  arabic_stoi = json.load(f)
191
  with open('arabic_itos.json', 'r', encoding='utf-8') as f:
192
- arabic_itos = json.load(f)
 
 
193
 
194
  # Initialize model
195
  model = CTCTransliterator(
@@ -198,7 +203,7 @@ def load_model_and_vocabs():
198
  len(arabic_stoi),
199
  num_layers=3,
200
  dropout=0.3,
201
- upsample_factor=2 # ← ADD THIS
202
  ).to(device)
203
 
204
  # Load trained weights
@@ -216,22 +221,16 @@ def encode_text(text, vocab):
216
  """Encode text using vocabulary"""
217
  return torch.tensor([vocab.get(ch, 0) for ch in text.strip()], dtype=torch.long)
218
 
219
- def greedy_decode(log_probs, blank_id):
220
  """
221
  Decode CTC outputs using greedy decoding.
222
-
223
- Args:
224
- log_probs: (T, B, C) - log probabilities from model
225
- input_lengths: (B,) - actual lengths of each sequence (optional)
226
  """
227
- # log_probs: (T, B, C)
228
- eos_id = arabic_stoi.get('<eos>',len(arabic_stoi)-2)
229
  preds = log_probs.argmax(2).T.cpu().numpy() # (B, T)
230
  results = []
231
  raw_results = []
232
 
233
  for i, pred in enumerate(preds):
234
-
235
  prev = None
236
  decoded = []
237
  raw_result = []
@@ -241,9 +240,9 @@ def greedy_decode(log_probs, blank_id):
241
  break
242
  # CTC collapse: skip blanks and repeated characters
243
  if p != blank_id and p != prev:
244
- decoded.append(arabic_itos[p])
245
  prev = p
246
- raw_result.append(arabic_itos[p])
247
 
248
  results.append("".join(decoded))
249
  raw_results.append("".join(raw_result))
@@ -269,7 +268,7 @@ def transliterate_latin_to_arabic(text):
269
  out = model(src)
270
 
271
  # Decode output
272
- decoded = greedy_decode(out, blank_id)
273
  result = decoded[0] if decoded else ""
274
 
275
  # Cache the result in Firebase
 
185
  with open('latin_stoi.json', 'r', encoding='utf-8') as f:
186
  latin_stoi = json.load(f)
187
  with open('latin_itos.json', 'r', encoding='utf-8') as f:
188
+ latin_itos_raw = json.load(f)
189
+ # Convert string keys to integers
190
+ latin_itos = {int(k): v for k, v in latin_itos_raw.items()}
191
+
192
  with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
193
  arabic_stoi = json.load(f)
194
  with open('arabic_itos.json', 'r', encoding='utf-8') as f:
195
+ arabic_itos_raw = json.load(f)
196
+ # Convert string keys to integers
197
+ arabic_itos = {int(k): v for k, v in arabic_itos_raw.items()}
198
 
199
  # Initialize model
200
  model = CTCTransliterator(
 
203
  len(arabic_stoi),
204
  num_layers=3,
205
  dropout=0.3,
206
+ upsample_factor=2
207
  ).to(device)
208
 
209
  # Load trained weights
 
221
  """Encode text using vocabulary"""
222
  return torch.tensor([vocab.get(ch, 0) for ch in text.strip()], dtype=torch.long)
223
 
224
+ def greedy_decode(log_probs, blank_id, itos, stoi):
225
  """
226
  Decode CTC outputs using greedy decoding.
 
 
 
 
227
  """
228
+ eos_id = stoi.get('<eos>', len(stoi)-2)
 
229
  preds = log_probs.argmax(2).T.cpu().numpy() # (B, T)
230
  results = []
231
  raw_results = []
232
 
233
  for i, pred in enumerate(preds):
 
234
  prev = None
235
  decoded = []
236
  raw_result = []
 
240
  break
241
  # CTC collapse: skip blanks and repeated characters
242
  if p != blank_id and p != prev:
243
+ decoded.append(itos[str(p)]) # Convert to string if needed
244
  prev = p
245
+ raw_result.append(itos[str(p)])
246
 
247
  results.append("".join(decoded))
248
  raw_results.append("".join(raw_result))
 
268
  out = model(src)
269
 
270
  # Decode output
271
+ decoded = greedy_decode(out, blank_id, arabic_itos, arabic_stoi)
272
  result = decoded[0] if decoded else ""
273
 
274
  # Cache the result in Firebase