Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,17 +1,14 @@
|
|
| 1 |
-
# Local version with additional development features
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import json
|
| 6 |
-
import csv
|
| 7 |
import os
|
| 8 |
from datetime import datetime
|
| 9 |
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
#
|
| 12 |
-
import traceback
|
| 13 |
-
|
| 14 |
-
# Define the model architecture (same as your training code)
|
| 15 |
class CTCTransliterator(nn.Module):
|
| 16 |
def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.3):
|
| 17 |
super().__init__()
|
|
@@ -31,150 +28,178 @@ class CTCTransliterator(nn.Module):
|
|
| 31 |
x = x.log_softmax(dim=2)
|
| 32 |
return x
|
| 33 |
|
| 34 |
-
# Cache
|
| 35 |
-
class
|
| 36 |
-
def __init__(self
|
| 37 |
-
self.
|
| 38 |
-
self.
|
| 39 |
-
self.load_cache()
|
| 40 |
|
| 41 |
-
def
|
| 42 |
-
"""
|
| 43 |
-
if os.path.exists(self.cache_file):
|
| 44 |
-
try:
|
| 45 |
-
with open(self.cache_file, 'r', encoding='utf-8') as f:
|
| 46 |
-
reader = csv.DictReader(f)
|
| 47 |
-
for row in reader:
|
| 48 |
-
key = f"{row['input']}_{row['direction']}"
|
| 49 |
-
self.cache[key] = {
|
| 50 |
-
'output': row['output'],
|
| 51 |
-
'corrected_output': row.get('corrected_output', ''),
|
| 52 |
-
'timestamp': row['timestamp'],
|
| 53 |
-
'usage_count': int(row.get('usage_count', 1))
|
| 54 |
-
}
|
| 55 |
-
print(f"Loaded {len(self.cache)} cached translations")
|
| 56 |
-
except Exception as e:
|
| 57 |
-
print(f"Error loading cache: {e}")
|
| 58 |
-
|
| 59 |
-
def save_cache(self):
|
| 60 |
-
"""Save cache to file"""
|
| 61 |
try:
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
for key, data in self.cache.items():
|
| 68 |
-
input_text, direction = key.rsplit('_', 1)
|
| 69 |
-
writer.writerow({
|
| 70 |
-
'input': input_text,
|
| 71 |
-
'direction': direction,
|
| 72 |
-
'output': data['output'],
|
| 73 |
-
'corrected_output': data.get('corrected_output', ''),
|
| 74 |
-
'timestamp': data['timestamp'],
|
| 75 |
-
'usage_count': data['usage_count']
|
| 76 |
-
})
|
| 77 |
-
print(f"Cache saved with {len(self.cache)} entries")
|
| 78 |
except Exception as e:
|
| 79 |
-
print(f"
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
def
|
| 82 |
-
"""
|
|
|
|
|
|
|
| 83 |
key = f"{input_text}_{direction}"
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
def set(self, input_text, direction, output):
|
| 92 |
-
"""
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def update_correction(self, input_text, direction, corrected_output):
|
| 104 |
-
"""Update with user correction"""
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
print(f"Correction saved: {input_text} → {corrected_output}")
|
| 109 |
-
self.save_cache()
|
| 110 |
return True
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
def get_stats(self):
|
| 114 |
-
"""Get cache statistics
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
# Load vocabularies and model
|
| 126 |
def load_model_and_vocabs():
|
| 127 |
-
print("Loading model and vocabularies...")
|
| 128 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 129 |
-
print(f"Using device: {device}")
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
arabic_itos = json.load(f)
|
| 141 |
-
|
| 142 |
-
print(f"Latin vocab size: {len(latin_stoi)}")
|
| 143 |
-
print(f"Arabic vocab size: {len(arabic_stoi)}")
|
| 144 |
-
|
| 145 |
-
# Initialize model
|
| 146 |
-
model = CTCTransliterator(
|
| 147 |
-
input_dim=len(latin_stoi),
|
| 148 |
-
hidden_dim=256,
|
| 149 |
-
output_dim=len(arabic_stoi),
|
| 150 |
-
num_layers=3,
|
| 151 |
-
dropout=0.4
|
| 152 |
-
).to(device)
|
| 153 |
-
|
| 154 |
-
# Load trained weights
|
| 155 |
-
model.load_state_dict(torch.load('CER_0.091_BLEU_0.85_transliterator.pth', map_location=device))
|
| 156 |
-
model.eval()
|
| 157 |
-
print("Model loaded successfully!")
|
| 158 |
-
|
| 159 |
-
# Find blank ID (assuming it's 0)
|
| 160 |
-
blank_id = 0
|
| 161 |
-
|
| 162 |
-
return model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Load everything at startup
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
cache_system = TransliterationCache()
|
| 174 |
-
print("App ready to launch!")
|
| 175 |
-
except Exception as e:
|
| 176 |
-
print(f"Startup failed: {e}")
|
| 177 |
-
exit(1)
|
| 178 |
|
| 179 |
def encode_text(text, vocab):
|
| 180 |
"""Encode text using vocabulary"""
|
|
@@ -197,14 +222,12 @@ def greedy_decode(log_probs, arabic_itos, blank_id):
|
|
| 197 |
return results
|
| 198 |
|
| 199 |
def transliterate_latin_to_arabic(text):
|
| 200 |
-
"""Transliterate Latin script to Arabic script with caching"""
|
| 201 |
if not text.strip():
|
| 202 |
return ""
|
| 203 |
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
# Check cache first
|
| 207 |
-
cached_result = cache_system.get(text, "Latin → Arabic")
|
| 208 |
if cached_result:
|
| 209 |
return cached_result
|
| 210 |
|
|
@@ -220,18 +243,13 @@ def transliterate_latin_to_arabic(text):
|
|
| 220 |
decoded = greedy_decode(out, arabic_itos, blank_id)
|
| 221 |
result = decoded[0] if decoded else ""
|
| 222 |
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
# Cache the result
|
| 226 |
-
cache_system.set(text, "Latin → Arabic", result)
|
| 227 |
|
| 228 |
return result
|
| 229 |
|
| 230 |
except Exception as e:
|
| 231 |
-
|
| 232 |
-
print(f"Translation failed: {error_msg}")
|
| 233 |
-
traceback.print_exc()
|
| 234 |
-
return error_msg
|
| 235 |
|
| 236 |
def transliterate_arabic_to_latin(text):
|
| 237 |
"""Transliterate Arabic script to Latin script (placeholder)"""
|
|
@@ -245,21 +263,11 @@ def transliterate(text, direction):
|
|
| 245 |
return transliterate_arabic_to_latin(text)
|
| 246 |
|
| 247 |
def save_correction(input_text, direction, corrected_output):
|
| 248 |
-
"""Save user correction to
|
| 249 |
-
if
|
| 250 |
-
return "Correction saved! Thank you for improving the model."
|
| 251 |
else:
|
| 252 |
-
return "Could not save correction."
|
| 253 |
-
|
| 254 |
-
def get_cache_stats():
|
| 255 |
-
"""Get cache statistics for development dashboard"""
|
| 256 |
-
stats = cache_system.get_stats()
|
| 257 |
-
return f"""
|
| 258 |
-
Cache Statistics:
|
| 259 |
-
• Total translations: {stats['total_translations']}
|
| 260 |
-
• Corrected translations: {stats['corrected_translations']}
|
| 261 |
-
• Most used translation: {stats['most_used_count']} times
|
| 262 |
-
"""
|
| 263 |
|
| 264 |
# Arabic keyboard layout
|
| 265 |
arabic_keys = [
|
|
@@ -269,25 +277,29 @@ arabic_keys = [
|
|
| 269 |
['ذ', '١', '٢', '٣', '٤', '٥', '٦', '٧', '٨', '٩', '٠']
|
| 270 |
]
|
| 271 |
|
| 272 |
-
# Create Gradio interface
|
| 273 |
def create_interface():
|
| 274 |
-
with gr.Blocks(title="Darija Transliterator
|
| 275 |
gr.Markdown(
|
| 276 |
"""
|
| 277 |
-
# Darija Transliterator
|
| 278 |
Convert between Latin script and Arabic script for Moroccan Darija
|
| 279 |
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
**Debug Info**: Detailed logging in console
|
| 284 |
"""
|
| 285 |
)
|
| 286 |
|
| 287 |
-
#
|
| 288 |
with gr.Row():
|
| 289 |
-
stats_btn = gr.Button("Show
|
| 290 |
-
stats_display = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
with gr.Row():
|
| 293 |
with gr.Column(scale=1):
|
|
@@ -337,7 +349,7 @@ def create_interface():
|
|
| 337 |
with gr.Row():
|
| 338 |
space_btn = gr.Button("Space", size="sm", scale=2)
|
| 339 |
backspace_btn = gr.Button("⌫ Backspace", size="sm", scale=2)
|
| 340 |
-
clear_output_btn = gr.Button("
|
| 341 |
|
| 342 |
# Correction system
|
| 343 |
with gr.Group():
|
|
@@ -378,7 +390,7 @@ def create_interface():
|
|
| 378 |
|
| 379 |
# Stats button
|
| 380 |
stats_btn.click(
|
| 381 |
-
fn=
|
| 382 |
outputs=[stats_display]
|
| 383 |
).then(
|
| 384 |
fn=lambda: gr.update(visible=True),
|
|
@@ -389,7 +401,7 @@ def create_interface():
|
|
| 389 |
gr.Markdown("### Examples")
|
| 390 |
examples = [
|
| 391 |
["kifash nta?", "Latin → Arabic"],
|
| 392 |
-
["salam alikoum", "Latin → Arabic"],
|
| 393 |
["ana bem", "Latin → Arabic"],
|
| 394 |
["wach nta mjit?", "Latin → Arabic"],
|
| 395 |
["شكون نتا؟", "Arabic → Latin"],
|
|
@@ -433,20 +445,23 @@ def create_interface():
|
|
| 433 |
outputs=[correction_status]
|
| 434 |
)
|
| 435 |
|
| 436 |
-
#
|
| 437 |
gr.Markdown(
|
| 438 |
"""
|
| 439 |
-
###
|
|
|
|
| 440 |
|
| 441 |
-
**
|
| 442 |
-
**
|
| 443 |
-
**
|
| 444 |
-
**
|
|
|
|
|
|
|
| 445 |
|
| 446 |
-
**
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
"""
|
| 451 |
)
|
| 452 |
|
|
@@ -454,15 +469,5 @@ def create_interface():
|
|
| 454 |
|
| 455 |
# Launch the app
|
| 456 |
if __name__ == "__main__":
|
| 457 |
-
print("Starting Darija Transliterator (Local Development)")
|
| 458 |
demo = create_interface()
|
| 459 |
-
|
| 460 |
-
# Local development settings
|
| 461 |
-
demo.launch(
|
| 462 |
-
share=True, # Creates public URL for sharing
|
| 463 |
-
debug=True, # Enable debug mode
|
| 464 |
-
server_port=7860, # Fixed port
|
| 465 |
-
server_name="0.0.0.0" # Allow external access
|
| 466 |
-
)
|
| 467 |
-
|
| 468 |
-
print("Thanks for using Darija Transliterator!")
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import json
|
|
|
|
| 5 |
import os
|
| 6 |
from datetime import datetime
|
| 7 |
from torch.nn.utils.rnn import pad_sequence
|
| 8 |
+
import firebase_admin
|
| 9 |
+
from firebase_admin import credentials, firestore
|
| 10 |
|
| 11 |
+
# Define the model architecture
|
|
|
|
|
|
|
|
|
|
| 12 |
class CTCTransliterator(nn.Module):
|
| 13 |
def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.3):
|
| 14 |
super().__init__()
|
|
|
|
| 28 |
x = x.log_softmax(dim=2)
|
| 29 |
return x
|
| 30 |
|
| 31 |
+
# Firebase Cache System
|
| 32 |
+
class FirebaseCache:
|
| 33 |
+
def __init__(self):
|
| 34 |
+
self.db = None
|
| 35 |
+
self.init_firebase()
|
|
|
|
| 36 |
|
| 37 |
+
def init_firebase(self):
|
| 38 |
+
"""Initialize Firebase connection"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
try:
|
| 40 |
+
# Try to initialize Firebase
|
| 41 |
+
if not firebase_admin._apps:
|
| 42 |
+
# For HuggingFace Spaces, use environment variables
|
| 43 |
+
if os.getenv('FIREBASE_CREDENTIALS'):
|
| 44 |
+
# Parse credentials from environment variable
|
| 45 |
+
import base64
|
| 46 |
+
cred_data = json.loads(base64.b64decode(os.getenv('FIREBASE_CREDENTIALS')).decode())
|
| 47 |
+
cred = credentials.Certificate(cred_data)
|
| 48 |
+
elif os.path.exists('firebase-credentials.json'):
|
| 49 |
+
# For local development
|
| 50 |
+
cred = credentials.Certificate('firebase-credentials.json')
|
| 51 |
+
else:
|
| 52 |
+
print("No Firebase credentials found. Using local cache fallback.")
|
| 53 |
+
return
|
| 54 |
+
|
| 55 |
+
firebase_admin.initialize_app(cred)
|
| 56 |
+
self.db = firestore.client()
|
| 57 |
+
print("Firebase initialized successfully!")
|
| 58 |
+
else:
|
| 59 |
+
self.db = firestore.client()
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
except Exception as e:
|
| 62 |
+
print(f"Firebase initialization failed: {e}")
|
| 63 |
+
print("Falling back to local cache mode")
|
| 64 |
+
self.db = None
|
| 65 |
|
| 66 |
+
def _create_cache_key(self, input_text, direction):
|
| 67 |
+
"""Create a safe document key for Firestore"""
|
| 68 |
+
import hashlib
|
| 69 |
+
# Create hash to handle special characters and length limits
|
| 70 |
key = f"{input_text}_{direction}"
|
| 71 |
+
return hashlib.md5(key.encode()).hexdigest()
|
| 72 |
+
|
| 73 |
+
def get(self, input_text, direction):
|
| 74 |
+
"""Get cached translation from Firebase"""
|
| 75 |
+
if not self.db:
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
doc_key = self._create_cache_key(input_text, direction)
|
| 80 |
+
doc = self.db.collection('translations').document(doc_key).get()
|
| 81 |
+
|
| 82 |
+
if doc.exists:
|
| 83 |
+
data = doc.to_dict()
|
| 84 |
+
# Update usage count
|
| 85 |
+
self.db.collection('translations').document(doc_key).update({
|
| 86 |
+
'usage_count': data.get('usage_count', 0) + 1,
|
| 87 |
+
'last_used': datetime.now()
|
| 88 |
+
})
|
| 89 |
+
print(f"Cache hit: {input_text}")
|
| 90 |
+
return data.get('output', '')
|
| 91 |
+
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f"Cache read error: {e}")
|
| 96 |
+
return None
|
| 97 |
|
| 98 |
def set(self, input_text, direction, output):
|
| 99 |
+
"""Store translation in Firebase"""
|
| 100 |
+
if not self.db:
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
doc_key = self._create_cache_key(input_text, direction)
|
| 105 |
+
doc_data = {
|
| 106 |
+
'input': input_text,
|
| 107 |
+
'direction': direction,
|
| 108 |
+
'output': output,
|
| 109 |
+
'corrected_output': '',
|
| 110 |
+
'timestamp': datetime.now(),
|
| 111 |
+
'last_used': datetime.now(),
|
| 112 |
+
'usage_count': 1
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
self.db.collection('translations').document(doc_key).set(doc_data)
|
| 116 |
+
print(f"Cached: {input_text} → {output}")
|
| 117 |
+
return True
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"Cache write error: {e}")
|
| 121 |
+
return False
|
| 122 |
|
| 123 |
def update_correction(self, input_text, direction, corrected_output):
|
| 124 |
+
"""Update translation with user correction"""
|
| 125 |
+
if not self.db:
|
| 126 |
+
return False
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
doc_key = self._create_cache_key(input_text, direction)
|
| 130 |
+
self.db.collection('translations').document(doc_key).update({
|
| 131 |
+
'corrected_output': corrected_output,
|
| 132 |
+
'correction_timestamp': datetime.now()
|
| 133 |
+
})
|
| 134 |
print(f"Correction saved: {input_text} → {corrected_output}")
|
|
|
|
| 135 |
return True
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
print(f"Correction save error: {e}")
|
| 139 |
+
return False
|
| 140 |
|
| 141 |
def get_stats(self):
|
| 142 |
+
"""Get cache statistics"""
|
| 143 |
+
if not self.db:
|
| 144 |
+
return "Firebase not connected"
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
docs = self.db.collection('translations').get()
|
| 148 |
+
total = len(docs)
|
| 149 |
+
|
| 150 |
+
corrected = 0
|
| 151 |
+
total_usage = 0
|
| 152 |
+
|
| 153 |
+
for doc in docs:
|
| 154 |
+
data = doc.to_dict()
|
| 155 |
+
if data.get('corrected_output'):
|
| 156 |
+
corrected += 1
|
| 157 |
+
total_usage += data.get('usage_count', 0)
|
| 158 |
+
|
| 159 |
+
return f"""
|
| 160 |
+
Firebase Cache Statistics:
|
| 161 |
+
• Total translations: {total}
|
| 162 |
+
• With corrections: {corrected}
|
| 163 |
+
• Total usage count: {total_usage}
|
| 164 |
+
• Average usage: {total_usage/total if total > 0 else 0:.1f} per translation
|
| 165 |
+
""".strip()
|
| 166 |
+
|
| 167 |
+
except Exception as e:
|
| 168 |
+
return f"Error getting stats: {e}"
|
| 169 |
|
| 170 |
# Load vocabularies and model
|
| 171 |
def load_model_and_vocabs():
|
|
|
|
| 172 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
| 173 |
|
| 174 |
+
# Load vocabularies
|
| 175 |
+
with open('latin_stoi.json', 'r', encoding='utf-8') as f:
|
| 176 |
+
latin_stoi = json.load(f)
|
| 177 |
+
with open('latin_itos.json', 'r', encoding='utf-8') as f:
|
| 178 |
+
latin_itos = json.load(f)
|
| 179 |
+
with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
|
| 180 |
+
arabic_stoi = json.load(f)
|
| 181 |
+
with open('arabic_itos.json', 'r', encoding='utf-8') as f:
|
| 182 |
+
arabic_itos = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
+
# Initialize model
|
| 185 |
+
model = CTCTransliterator(
|
| 186 |
+
input_dim=len(latin_stoi),
|
| 187 |
+
hidden_dim=256,
|
| 188 |
+
output_dim=len(arabic_stoi),
|
| 189 |
+
num_layers=3,
|
| 190 |
+
dropout=0.4
|
| 191 |
+
).to(device)
|
| 192 |
+
|
| 193 |
+
# Load trained weights
|
| 194 |
+
model.load_state_dict(torch.load('CER_0.091_BLEU_0.85_transliterator.pth', map_location=device))
|
| 195 |
+
model.eval()
|
| 196 |
+
|
| 197 |
+
blank_id = 0
|
| 198 |
+
return model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device
|
| 199 |
|
| 200 |
# Load everything at startup
|
| 201 |
+
model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device = load_model_and_vocabs()
|
| 202 |
+
firebase_cache = FirebaseCache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
def encode_text(text, vocab):
|
| 205 |
"""Encode text using vocabulary"""
|
|
|
|
| 222 |
return results
|
| 223 |
|
| 224 |
def transliterate_latin_to_arabic(text):
|
| 225 |
+
"""Transliterate Latin script to Arabic script with Firebase caching"""
|
| 226 |
if not text.strip():
|
| 227 |
return ""
|
| 228 |
|
| 229 |
+
# Check Firebase cache first
|
| 230 |
+
cached_result = firebase_cache.get(text, "Latin → Arabic")
|
|
|
|
|
|
|
| 231 |
if cached_result:
|
| 232 |
return cached_result
|
| 233 |
|
|
|
|
| 243 |
decoded = greedy_decode(out, arabic_itos, blank_id)
|
| 244 |
result = decoded[0] if decoded else ""
|
| 245 |
|
| 246 |
+
# Cache the result in Firebase
|
| 247 |
+
firebase_cache.set(text, "Latin → Arabic", result)
|
|
|
|
|
|
|
| 248 |
|
| 249 |
return result
|
| 250 |
|
| 251 |
except Exception as e:
|
| 252 |
+
return f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
def transliterate_arabic_to_latin(text):
|
| 255 |
"""Transliterate Arabic script to Latin script (placeholder)"""
|
|
|
|
| 263 |
return transliterate_arabic_to_latin(text)
|
| 264 |
|
| 265 |
def save_correction(input_text, direction, corrected_output):
|
| 266 |
+
"""Save user correction to Firebase"""
|
| 267 |
+
if firebase_cache.update_correction(input_text, direction, corrected_output):
|
| 268 |
+
return "Correction saved to Firebase! Thank you for improving the model."
|
| 269 |
else:
|
| 270 |
+
return "Could not save correction to Firebase."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
# Arabic keyboard layout
|
| 273 |
arabic_keys = [
|
|
|
|
| 277 |
['ذ', '١', '٢', '٣', '٤', '٥', '٦', '٧', '٨', '٩', '٠']
|
| 278 |
]
|
| 279 |
|
| 280 |
+
# Create Gradio interface
|
| 281 |
def create_interface():
|
| 282 |
+
with gr.Blocks(title="Darija Transliterator", theme=gr.themes.Soft()) as demo:
|
| 283 |
gr.Markdown(
|
| 284 |
"""
|
| 285 |
+
# Darija Transliterator
|
| 286 |
Convert between Latin script and Arabic script for Moroccan Darija
|
| 287 |
|
| 288 |
+
**Firebase-Powered**: Persistent caching across sessions
|
| 289 |
+
**Arabic Keyboard**: Built-in Arabic keyboard for corrections
|
| 290 |
+
**Real-time Stats**: Live usage analytics
|
|
|
|
| 291 |
"""
|
| 292 |
)
|
| 293 |
|
| 294 |
+
# Stats section
|
| 295 |
with gr.Row():
|
| 296 |
+
stats_btn = gr.Button("Show Statistics", variant="secondary")
|
| 297 |
+
stats_display = gr.Textbox(
|
| 298 |
+
label="Firebase Statistics",
|
| 299 |
+
interactive=False,
|
| 300 |
+
visible=False,
|
| 301 |
+
lines=5
|
| 302 |
+
)
|
| 303 |
|
| 304 |
with gr.Row():
|
| 305 |
with gr.Column(scale=1):
|
|
|
|
| 349 |
with gr.Row():
|
| 350 |
space_btn = gr.Button("Space", size="sm", scale=2)
|
| 351 |
backspace_btn = gr.Button("⌫ Backspace", size="sm", scale=2)
|
| 352 |
+
clear_output_btn = gr.Button("Clear Output", size="sm", scale=2)
|
| 353 |
|
| 354 |
# Correction system
|
| 355 |
with gr.Group():
|
|
|
|
| 390 |
|
| 391 |
# Stats button
|
| 392 |
stats_btn.click(
|
| 393 |
+
fn=firebase_cache.get_stats,
|
| 394 |
outputs=[stats_display]
|
| 395 |
).then(
|
| 396 |
fn=lambda: gr.update(visible=True),
|
|
|
|
| 401 |
gr.Markdown("### Examples")
|
| 402 |
examples = [
|
| 403 |
["kifash nta?", "Latin → Arabic"],
|
| 404 |
+
["salam alikoum", "Latin → Arabic"],
|
| 405 |
["ana bem", "Latin → Arabic"],
|
| 406 |
["wach nta mjit?", "Latin → Arabic"],
|
| 407 |
["شكون نتا؟", "Arabic → Latin"],
|
|
|
|
| 445 |
outputs=[correction_status]
|
| 446 |
)
|
| 447 |
|
| 448 |
+
# Information
|
| 449 |
gr.Markdown(
|
| 450 |
"""
|
| 451 |
+
### About
|
| 452 |
+
This model transliterates Moroccan Darija between Latin and Arabic scripts using a CTC-based neural network.
|
| 453 |
|
| 454 |
+
**Firebase Features:**
|
| 455 |
+
- **Persistent Storage**: All translations are saved permanently
|
| 456 |
+
- **Analytics**: Track usage patterns and popular translations
|
| 457 |
+
- **Fast Responses**: Cached results load instantly
|
| 458 |
+
- **Global Access**: Data synced across all users
|
| 459 |
+
- **Corrections**: Help improve the model by fixing outputs
|
| 460 |
|
| 461 |
+
**How to help improve the model:**
|
| 462 |
+
1. Use the Arabic keyboard to correct any wrong translations
|
| 463 |
+
2. Click "Save Correction" to store your improvement
|
| 464 |
+
3. Your corrections help train better models for everyone!
|
| 465 |
"""
|
| 466 |
)
|
| 467 |
|
|
|
|
| 469 |
|
| 470 |
# Launch the app
|
| 471 |
if __name__ == "__main__":
|
|
|
|
| 472 |
demo = create_interface()
|
| 473 |
+
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|