Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -185,11 +185,16 @@ def load_model_and_vocabs():
|
|
| 185 |
with open('latin_stoi.json', 'r', encoding='utf-8') as f:
|
| 186 |
latin_stoi = json.load(f)
|
| 187 |
with open('latin_itos.json', 'r', encoding='utf-8') as f:
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
| 189 |
with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
|
| 190 |
arabic_stoi = json.load(f)
|
| 191 |
with open('arabic_itos.json', 'r', encoding='utf-8') as f:
|
| 192 |
-
|
|
|
|
|
|
|
| 193 |
|
| 194 |
# Initialize model
|
| 195 |
model = CTCTransliterator(
|
|
@@ -198,7 +203,7 @@ def load_model_and_vocabs():
|
|
| 198 |
len(arabic_stoi),
|
| 199 |
num_layers=3,
|
| 200 |
dropout=0.3,
|
| 201 |
-
upsample_factor=2
|
| 202 |
).to(device)
|
| 203 |
|
| 204 |
# Load trained weights
|
|
@@ -216,22 +221,16 @@ def encode_text(text, vocab):
|
|
| 216 |
"""Encode text using vocabulary"""
|
| 217 |
return torch.tensor([vocab.get(ch, 0) for ch in text.strip()], dtype=torch.long)
|
| 218 |
|
| 219 |
-
def greedy_decode(log_probs, blank_id):
|
| 220 |
"""
|
| 221 |
Decode CTC outputs using greedy decoding.
|
| 222 |
-
|
| 223 |
-
Args:
|
| 224 |
-
log_probs: (T, B, C) - log probabilities from model
|
| 225 |
-
input_lengths: (B,) - actual lengths of each sequence (optional)
|
| 226 |
"""
|
| 227 |
-
|
| 228 |
-
eos_id = arabic_stoi.get('<eos>',len(arabic_stoi)-2)
|
| 229 |
preds = log_probs.argmax(2).T.cpu().numpy() # (B, T)
|
| 230 |
results = []
|
| 231 |
raw_results = []
|
| 232 |
|
| 233 |
for i, pred in enumerate(preds):
|
| 234 |
-
|
| 235 |
prev = None
|
| 236 |
decoded = []
|
| 237 |
raw_result = []
|
|
@@ -241,9 +240,9 @@ def greedy_decode(log_probs, blank_id):
|
|
| 241 |
break
|
| 242 |
# CTC collapse: skip blanks and repeated characters
|
| 243 |
if p != blank_id and p != prev:
|
| 244 |
-
decoded.append(
|
| 245 |
prev = p
|
| 246 |
-
raw_result.append(
|
| 247 |
|
| 248 |
results.append("".join(decoded))
|
| 249 |
raw_results.append("".join(raw_result))
|
|
@@ -269,7 +268,7 @@ def transliterate_latin_to_arabic(text):
|
|
| 269 |
out = model(src)
|
| 270 |
|
| 271 |
# Decode output
|
| 272 |
-
decoded = greedy_decode(out, blank_id)
|
| 273 |
result = decoded[0] if decoded else ""
|
| 274 |
|
| 275 |
# Cache the result in Firebase
|
|
|
|
| 185 |
with open('latin_stoi.json', 'r', encoding='utf-8') as f:
|
| 186 |
latin_stoi = json.load(f)
|
| 187 |
with open('latin_itos.json', 'r', encoding='utf-8') as f:
|
| 188 |
+
latin_itos_raw = json.load(f)
|
| 189 |
+
# Convert string keys to integers
|
| 190 |
+
latin_itos = {int(k): v for k, v in latin_itos_raw.items()}
|
| 191 |
+
|
| 192 |
with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
|
| 193 |
arabic_stoi = json.load(f)
|
| 194 |
with open('arabic_itos.json', 'r', encoding='utf-8') as f:
|
| 195 |
+
arabic_itos_raw = json.load(f)
|
| 196 |
+
# Convert string keys to integers
|
| 197 |
+
arabic_itos = {int(k): v for k, v in arabic_itos_raw.items()}
|
| 198 |
|
| 199 |
# Initialize model
|
| 200 |
model = CTCTransliterator(
|
|
|
|
| 203 |
len(arabic_stoi),
|
| 204 |
num_layers=3,
|
| 205 |
dropout=0.3,
|
| 206 |
+
upsample_factor=2
|
| 207 |
).to(device)
|
| 208 |
|
| 209 |
# Load trained weights
|
|
|
|
| 221 |
"""Encode text using vocabulary"""
|
| 222 |
return torch.tensor([vocab.get(ch, 0) for ch in text.strip()], dtype=torch.long)
|
| 223 |
|
| 224 |
+
def greedy_decode(log_probs, blank_id, itos, stoi):
|
| 225 |
"""
|
| 226 |
Decode CTC outputs using greedy decoding.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
"""
|
| 228 |
+
eos_id = stoi.get('<eos>', len(stoi)-2)
|
|
|
|
| 229 |
preds = log_probs.argmax(2).T.cpu().numpy() # (B, T)
|
| 230 |
results = []
|
| 231 |
raw_results = []
|
| 232 |
|
| 233 |
for i, pred in enumerate(preds):
|
|
|
|
| 234 |
prev = None
|
| 235 |
decoded = []
|
| 236 |
raw_result = []
|
|
|
|
| 240 |
break
|
| 241 |
# CTC collapse: skip blanks and repeated characters
|
| 242 |
if p != blank_id and p != prev:
|
| 243 |
+
decoded.append(itos[str(p)]) # Convert to string if needed
|
| 244 |
prev = p
|
| 245 |
+
raw_result.append(itos[str(p)])
|
| 246 |
|
| 247 |
results.append("".join(decoded))
|
| 248 |
raw_results.append("".join(raw_result))
|
|
|
|
| 268 |
out = model(src)
|
| 269 |
|
| 270 |
# Decode output
|
| 271 |
+
decoded = greedy_decode(out, blank_id, arabic_itos, arabic_stoi)
|
| 272 |
result = decoded[0] if decoded else ""
|
| 273 |
|
| 274 |
# Cache the result in Firebase
|