Haitam03 commited on
Commit
e5eda94
·
verified ·
1 Parent(s): 5ef1d32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -9
app.py CHANGED
@@ -185,16 +185,12 @@ def load_model_and_vocabs():
185
  with open('latin_stoi.json', 'r', encoding='utf-8') as f:
186
  latin_stoi = json.load(f)
187
  with open('latin_itos.json', 'r', encoding='utf-8') as f:
188
- latin_itos_raw = json.load(f)
189
- # Convert string keys to integers
190
- latin_itos = {int(k): v for k, v in latin_itos_raw.items()}
191
 
192
  with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
193
  arabic_stoi = json.load(f)
194
  with open('arabic_itos.json', 'r', encoding='utf-8') as f:
195
- arabic_itos_raw = json.load(f)
196
- # Convert string keys to integers
197
- arabic_itos = {int(k): v for k, v in arabic_itos_raw.items()}
198
 
199
  # Initialize model
200
  model = CTCTransliterator(
@@ -229,17 +225,17 @@ def greedy_decode(log_probs, blank_id, itos, stoi):
229
  preds = log_probs.argmax(2).T.cpu().numpy() # (B, T)
230
  results = []
231
  raw_results = []
232
-
233
  for i, pred in enumerate(preds):
234
  prev = None
235
  decoded = []
236
  raw_result = []
237
 
238
  for p in pred:
239
- if p == eos_id: # Stop at EOS!
240
  break
241
  # CTC collapse: skip blanks and repeated characters
242
- if p != blank_id and p != prev:
243
  decoded.append(itos[str(p)]) # Convert to string if needed
244
  prev = p
245
  raw_result.append(itos[str(p)])
 
185
  with open('latin_stoi.json', 'r', encoding='utf-8') as f:
186
  latin_stoi = json.load(f)
187
  with open('latin_itos.json', 'r', encoding='utf-8') as f:
188
+ latin_itos = json.load(f)
 
 
189
 
190
  with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
191
  arabic_stoi = json.load(f)
192
  with open('arabic_itos.json', 'r', encoding='utf-8') as f:
193
+ arabic_itos= json.load(f)
 
 
194
 
195
  # Initialize model
196
  model = CTCTransliterator(
 
225
  preds = log_probs.argmax(2).T.cpu().numpy() # (B, T)
226
  results = []
227
  raw_results = []
228
+ print(eos_id, blank_id)
229
  for i, pred in enumerate(preds):
230
  prev = None
231
  decoded = []
232
  raw_result = []
233
 
234
  for p in pred:
235
+ if str(p) == eos_id: # Stop at EOS!
236
  break
237
  # CTC collapse: skip blanks and repeated characters
238
+ if str(p) != blank_id and p != prev:
239
  decoded.append(itos[str(p)]) # Convert to string if needed
240
  prev = p
241
  raw_result.append(itos[str(p)])