Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -185,16 +185,12 @@ def load_model_and_vocabs():
|
|
| 185 |
with open('latin_stoi.json', 'r', encoding='utf-8') as f:
|
| 186 |
latin_stoi = json.load(f)
|
| 187 |
with open('latin_itos.json', 'r', encoding='utf-8') as f:
|
| 188 |
-
|
| 189 |
-
# Convert string keys to integers
|
| 190 |
-
latin_itos = {int(k): v for k, v in latin_itos_raw.items()}
|
| 191 |
|
| 192 |
with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
|
| 193 |
arabic_stoi = json.load(f)
|
| 194 |
with open('arabic_itos.json', 'r', encoding='utf-8') as f:
|
| 195 |
-
|
| 196 |
-
# Convert string keys to integers
|
| 197 |
-
arabic_itos = {int(k): v for k, v in arabic_itos_raw.items()}
|
| 198 |
|
| 199 |
# Initialize model
|
| 200 |
model = CTCTransliterator(
|
|
@@ -229,17 +225,17 @@ def greedy_decode(log_probs, blank_id, itos, stoi):
|
|
| 229 |
preds = log_probs.argmax(2).T.cpu().numpy() # (B, T)
|
| 230 |
results = []
|
| 231 |
raw_results = []
|
| 232 |
-
|
| 233 |
for i, pred in enumerate(preds):
|
| 234 |
prev = None
|
| 235 |
decoded = []
|
| 236 |
raw_result = []
|
| 237 |
|
| 238 |
for p in pred:
|
| 239 |
-
if p == eos_id: # Stop at EOS!
|
| 240 |
break
|
| 241 |
# CTC collapse: skip blanks and repeated characters
|
| 242 |
-
if p != blank_id and p != prev:
|
| 243 |
decoded.append(itos[str(p)]) # Convert to string if needed
|
| 244 |
prev = p
|
| 245 |
raw_result.append(itos[str(p)])
|
|
|
|
| 185 |
with open('latin_stoi.json', 'r', encoding='utf-8') as f:
|
| 186 |
latin_stoi = json.load(f)
|
| 187 |
with open('latin_itos.json', 'r', encoding='utf-8') as f:
|
| 188 |
+
latin_itos = json.load(f)
|
|
|
|
|
|
|
| 189 |
|
| 190 |
with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
|
| 191 |
arabic_stoi = json.load(f)
|
| 192 |
with open('arabic_itos.json', 'r', encoding='utf-8') as f:
|
| 193 |
+
arabic_itos= json.load(f)
|
|
|
|
|
|
|
| 194 |
|
| 195 |
# Initialize model
|
| 196 |
model = CTCTransliterator(
|
|
|
|
| 225 |
preds = log_probs.argmax(2).T.cpu().numpy() # (B, T)
|
| 226 |
results = []
|
| 227 |
raw_results = []
|
| 228 |
+
print(eos_id, blank_id)
|
| 229 |
for i, pred in enumerate(preds):
|
| 230 |
prev = None
|
| 231 |
decoded = []
|
| 232 |
raw_result = []
|
| 233 |
|
| 234 |
for p in pred:
|
| 235 |
+
if str(p) == eos_id: # Stop at EOS!
|
| 236 |
break
|
| 237 |
# CTC collapse: skip blanks and repeated characters
|
| 238 |
+
if str(p) != blank_id and p != prev:
|
| 239 |
decoded.append(itos[str(p)]) # Convert to string if needed
|
| 240 |
prev = p
|
| 241 |
raw_result.append(itos[str(p)])
|