Haitam03 commited on
Commit
4d7afb5
·
verified ·
1 Parent(s): 741c6b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -189
app.py CHANGED
@@ -1,17 +1,14 @@
1
- # Local version with additional development features
2
  import gradio as gr
3
  import torch
4
  import torch.nn as nn
5
  import json
6
- import csv
7
  import os
8
  from datetime import datetime
9
  from torch.nn.utils.rnn import pad_sequence
 
 
10
 
11
- # Enable better error messages for local development
12
- import traceback
13
-
14
- # Define the model architecture (same as your training code)
15
  class CTCTransliterator(nn.Module):
16
  def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.3):
17
  super().__init__()
@@ -31,150 +28,178 @@ class CTCTransliterator(nn.Module):
31
  x = x.log_softmax(dim=2)
32
  return x
33
 
34
- # Cache system for data collection
35
- class TransliterationCache:
36
- def __init__(self, cache_file="transliteration_cache.csv"):
37
- self.cache_file = cache_file
38
- self.cache = {}
39
- self.load_cache()
40
 
41
- def load_cache(self):
42
- """Load existing cache from file"""
43
- if os.path.exists(self.cache_file):
44
- try:
45
- with open(self.cache_file, 'r', encoding='utf-8') as f:
46
- reader = csv.DictReader(f)
47
- for row in reader:
48
- key = f"{row['input']}_{row['direction']}"
49
- self.cache[key] = {
50
- 'output': row['output'],
51
- 'corrected_output': row.get('corrected_output', ''),
52
- 'timestamp': row['timestamp'],
53
- 'usage_count': int(row.get('usage_count', 1))
54
- }
55
- print(f"Loaded {len(self.cache)} cached translations")
56
- except Exception as e:
57
- print(f"Error loading cache: {e}")
58
-
59
- def save_cache(self):
60
- """Save cache to file"""
61
  try:
62
- with open(self.cache_file, 'w', encoding='utf-8', newline='') as f:
63
- fieldnames = ['input', 'direction', 'output', 'corrected_output', 'timestamp', 'usage_count']
64
- writer = csv.DictWriter(f, fieldnames=fieldnames)
65
- writer.writeheader()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- for key, data in self.cache.items():
68
- input_text, direction = key.rsplit('_', 1)
69
- writer.writerow({
70
- 'input': input_text,
71
- 'direction': direction,
72
- 'output': data['output'],
73
- 'corrected_output': data.get('corrected_output', ''),
74
- 'timestamp': data['timestamp'],
75
- 'usage_count': data['usage_count']
76
- })
77
- print(f"Cache saved with {len(self.cache)} entries")
78
  except Exception as e:
79
- print(f"Error saving cache: {e}")
 
 
80
 
81
- def get(self, input_text, direction):
82
- """Get cached result if exists"""
 
 
83
  key = f"{input_text}_{direction}"
84
- if key in self.cache:
85
- self.cache[key]['usage_count'] += 1
86
- print(f"Using cached result for: {input_text}")
87
- self.save_cache()
88
- return self.cache[key]['output']
89
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  def set(self, input_text, direction, output):
92
- """Cache a new result"""
93
- key = f"{input_text}_{direction}"
94
- self.cache[key] = {
95
- 'output': output,
96
- 'corrected_output': '',
97
- 'timestamp': datetime.now().isoformat(),
98
- 'usage_count': 1
99
- }
100
- print(f"Cached new translation: {input_text} → {output}")
101
- self.save_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  def update_correction(self, input_text, direction, corrected_output):
104
- """Update with user correction"""
105
- key = f"{input_text}_{direction}"
106
- if key in self.cache:
107
- self.cache[key]['corrected_output'] = corrected_output
 
 
 
 
 
 
108
  print(f"Correction saved: {input_text} → {corrected_output}")
109
- self.save_cache()
110
  return True
111
- return False
 
 
 
112
 
113
  def get_stats(self):
114
- """Get cache statistics for development"""
115
- total = len(self.cache)
116
- corrected = sum(1 for item in self.cache.values() if item.get('corrected_output'))
117
- most_used = max(self.cache.values(), key=lambda x: x['usage_count'], default={'usage_count': 0})
118
-
119
- return {
120
- 'total_translations': total,
121
- 'corrected_translations': corrected,
122
- 'most_used_count': most_used['usage_count']
123
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  # Load vocabularies and model
126
  def load_model_and_vocabs():
127
- print("Loading model and vocabularies...")
128
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
129
- print(f"Using device: {device}")
130
 
131
- try:
132
- # Load vocabularies
133
- with open('latin_stoi.json', 'r', encoding='utf-8') as f:
134
- latin_stoi = json.load(f)
135
- with open('latin_itos.json', 'r', encoding='utf-8') as f:
136
- latin_itos = json.load(f)
137
- with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
138
- arabic_stoi = json.load(f)
139
- with open('arabic_itos.json', 'r', encoding='utf-8') as f:
140
- arabic_itos = json.load(f)
141
-
142
- print(f"Latin vocab size: {len(latin_stoi)}")
143
- print(f"Arabic vocab size: {len(arabic_stoi)}")
144
-
145
- # Initialize model
146
- model = CTCTransliterator(
147
- input_dim=len(latin_stoi),
148
- hidden_dim=256,
149
- output_dim=len(arabic_stoi),
150
- num_layers=3,
151
- dropout=0.4
152
- ).to(device)
153
-
154
- # Load trained weights
155
- model.load_state_dict(torch.load('CER_0.091_BLEU_0.85_transliterator.pth', map_location=device))
156
- model.eval()
157
- print("Model loaded successfully!")
158
-
159
- # Find blank ID (assuming it's 0)
160
- blank_id = 0
161
-
162
- return model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device
163
 
164
- except Exception as e:
165
- print(f"Error loading model: {e}")
166
- print("Full error details:")
167
- traceback.print_exc()
168
- raise
 
 
 
 
 
 
 
 
 
 
169
 
170
  # Load everything at startup
171
- try:
172
- model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device = load_model_and_vocabs()
173
- cache_system = TransliterationCache()
174
- print("App ready to launch!")
175
- except Exception as e:
176
- print(f"Startup failed: {e}")
177
- exit(1)
178
 
179
  def encode_text(text, vocab):
180
  """Encode text using vocabulary"""
@@ -197,14 +222,12 @@ def greedy_decode(log_probs, arabic_itos, blank_id):
197
  return results
198
 
199
  def transliterate_latin_to_arabic(text):
200
- """Transliterate Latin script to Arabic script with caching"""
201
  if not text.strip():
202
  return ""
203
 
204
- print(f"Processing: {text}")
205
-
206
- # Check cache first
207
- cached_result = cache_system.get(text, "Latin → Arabic")
208
  if cached_result:
209
  return cached_result
210
 
@@ -220,18 +243,13 @@ def transliterate_latin_to_arabic(text):
220
  decoded = greedy_decode(out, arabic_itos, blank_id)
221
  result = decoded[0] if decoded else ""
222
 
223
- print(f"Generated: {result}")
224
-
225
- # Cache the result
226
- cache_system.set(text, "Latin → Arabic", result)
227
 
228
  return result
229
 
230
  except Exception as e:
231
- error_msg = f"Error: {str(e)}"
232
- print(f"Translation failed: {error_msg}")
233
- traceback.print_exc()
234
- return error_msg
235
 
236
  def transliterate_arabic_to_latin(text):
237
  """Transliterate Arabic script to Latin script (placeholder)"""
@@ -245,21 +263,11 @@ def transliterate(text, direction):
245
  return transliterate_arabic_to_latin(text)
246
 
247
  def save_correction(input_text, direction, corrected_output):
248
- """Save user correction to cache"""
249
- if cache_system.update_correction(input_text, direction, corrected_output):
250
- return "Correction saved! Thank you for improving the model."
251
  else:
252
- return "Could not save correction."
253
-
254
- def get_cache_stats():
255
- """Get cache statistics for development dashboard"""
256
- stats = cache_system.get_stats()
257
- return f"""
258
- Cache Statistics:
259
- • Total translations: {stats['total_translations']}
260
- • Corrected translations: {stats['corrected_translations']}
261
- • Most used translation: {stats['most_used_count']} times
262
- """
263
 
264
  # Arabic keyboard layout
265
  arabic_keys = [
@@ -269,25 +277,29 @@ arabic_keys = [
269
  ['ذ', '١', '٢', '٣', '٤', '٥', '٦', '٧', '٨', '٩', '٠']
270
  ]
271
 
272
- # Create Gradio interface with development features
273
  def create_interface():
274
- with gr.Blocks(title="Darija Transliterator - Local Dev", theme=gr.themes.Soft()) as demo:
275
  gr.Markdown(
276
  """
277
- # Darija Transliterator (Local Development)
278
  Convert between Latin script and Arabic script for Moroccan Darija
279
 
280
- **Local Development Mode**
281
- **Smart Caching**: Results cached for faster responses
282
- **Arabic Keyboard**: Built-in Arabic keyboard for corrections
283
- **Debug Info**: Detailed logging in console
284
  """
285
  )
286
 
287
- # Development stats
288
  with gr.Row():
289
- stats_btn = gr.Button("Show Cache Stats", variant="secondary")
290
- stats_display = gr.Textbox(label="Statistics", interactive=False, visible=False)
 
 
 
 
 
291
 
292
  with gr.Row():
293
  with gr.Column(scale=1):
@@ -337,7 +349,7 @@ def create_interface():
337
  with gr.Row():
338
  space_btn = gr.Button("Space", size="sm", scale=2)
339
  backspace_btn = gr.Button("⌫ Backspace", size="sm", scale=2)
340
- clear_output_btn = gr.Button("🗑️ Clear Output", size="sm", scale=2)
341
 
342
  # Correction system
343
  with gr.Group():
@@ -378,7 +390,7 @@ def create_interface():
378
 
379
  # Stats button
380
  stats_btn.click(
381
- fn=get_cache_stats,
382
  outputs=[stats_display]
383
  ).then(
384
  fn=lambda: gr.update(visible=True),
@@ -389,7 +401,7 @@ def create_interface():
389
  gr.Markdown("### Examples")
390
  examples = [
391
  ["kifash nta?", "Latin → Arabic"],
392
- ["salam alikoum", "Latin → Arabic"],
393
  ["ana bem", "Latin → Arabic"],
394
  ["wach nta mjit?", "Latin → Arabic"],
395
  ["شكون نتا؟", "Arabic → Latin"],
@@ -433,20 +445,23 @@ def create_interface():
433
  outputs=[correction_status]
434
  )
435
 
436
- # Add information
437
  gr.Markdown(
438
  """
439
- ### Local Development Features
 
440
 
441
- **Debug Console**: Check your terminal for detailed logs
442
- **Cache Statistics**: Click "Show Cache Stats" to see usage data
443
- **Hot Reload**: Restart the script to see code changes
444
- **Error Details**: Full stack traces for easier debugging
 
 
445
 
446
- **File Locations:**
447
- - Cache: `transliteration_cache.csv`
448
- - Model: `CER_0.091_BLEU_0.85_transliterator.pth`
449
- - Vocabularies: `*_stoi.json` and `*_itos.json` files
450
  """
451
  )
452
 
@@ -454,15 +469,5 @@ def create_interface():
454
 
455
  # Launch the app
456
  if __name__ == "__main__":
457
- print("Starting Darija Transliterator (Local Development)")
458
  demo = create_interface()
459
-
460
- # Local development settings
461
- demo.launch(
462
- share=True, # Creates public URL for sharing
463
- debug=True, # Enable debug mode
464
- server_port=7860, # Fixed port
465
- server_name="0.0.0.0" # Allow external access
466
- )
467
-
468
- print("Thanks for using Darija Transliterator!")
 
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
  import json
 
5
  import os
6
  from datetime import datetime
7
  from torch.nn.utils.rnn import pad_sequence
8
+ import firebase_admin
9
+ from firebase_admin import credentials, firestore
10
 
11
+ # Define the model architecture
 
 
 
12
  class CTCTransliterator(nn.Module):
13
  def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.3):
14
  super().__init__()
 
28
  x = x.log_softmax(dim=2)
29
  return x
30
 
31
+ # Firebase Cache System
32
+ class FirebaseCache:
33
+ def __init__(self):
34
+ self.db = None
35
+ self.init_firebase()
 
36
 
37
+ def init_firebase(self):
38
+ """Initialize Firebase connection"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  try:
40
+ # Try to initialize Firebase
41
+ if not firebase_admin._apps:
42
+ # For HuggingFace Spaces, use environment variables
43
+ if os.getenv('FIREBASE_CREDENTIALS'):
44
+ # Parse credentials from environment variable
45
+ import base64
46
+ cred_data = json.loads(base64.b64decode(os.getenv('FIREBASE_CREDENTIALS')).decode())
47
+ cred = credentials.Certificate(cred_data)
48
+ elif os.path.exists('firebase-credentials.json'):
49
+ # For local development
50
+ cred = credentials.Certificate('firebase-credentials.json')
51
+ else:
52
+ print("No Firebase credentials found. Using local cache fallback.")
53
+ return
54
+
55
+ firebase_admin.initialize_app(cred)
56
+ self.db = firestore.client()
57
+ print("Firebase initialized successfully!")
58
+ else:
59
+ self.db = firestore.client()
60
 
 
 
 
 
 
 
 
 
 
 
 
61
  except Exception as e:
62
+ print(f"Firebase initialization failed: {e}")
63
+ print("Falling back to local cache mode")
64
+ self.db = None
65
 
66
+ def _create_cache_key(self, input_text, direction):
67
+ """Create a safe document key for Firestore"""
68
+ import hashlib
69
+ # Create hash to handle special characters and length limits
70
  key = f"{input_text}_{direction}"
71
+ return hashlib.md5(key.encode()).hexdigest()
72
+
73
+ def get(self, input_text, direction):
74
+ """Get cached translation from Firebase"""
75
+ if not self.db:
76
+ return None
77
+
78
+ try:
79
+ doc_key = self._create_cache_key(input_text, direction)
80
+ doc = self.db.collection('translations').document(doc_key).get()
81
+
82
+ if doc.exists:
83
+ data = doc.to_dict()
84
+ # Update usage count
85
+ self.db.collection('translations').document(doc_key).update({
86
+ 'usage_count': data.get('usage_count', 0) + 1,
87
+ 'last_used': datetime.now()
88
+ })
89
+ print(f"Cache hit: {input_text}")
90
+ return data.get('output', '')
91
+
92
+ return None
93
+
94
+ except Exception as e:
95
+ print(f"Cache read error: {e}")
96
+ return None
97
 
98
  def set(self, input_text, direction, output):
99
+ """Store translation in Firebase"""
100
+ if not self.db:
101
+ return False
102
+
103
+ try:
104
+ doc_key = self._create_cache_key(input_text, direction)
105
+ doc_data = {
106
+ 'input': input_text,
107
+ 'direction': direction,
108
+ 'output': output,
109
+ 'corrected_output': '',
110
+ 'timestamp': datetime.now(),
111
+ 'last_used': datetime.now(),
112
+ 'usage_count': 1
113
+ }
114
+
115
+ self.db.collection('translations').document(doc_key).set(doc_data)
116
+ print(f"Cached: {input_text} → {output}")
117
+ return True
118
+
119
+ except Exception as e:
120
+ print(f"Cache write error: {e}")
121
+ return False
122
 
123
  def update_correction(self, input_text, direction, corrected_output):
124
+ """Update translation with user correction"""
125
+ if not self.db:
126
+ return False
127
+
128
+ try:
129
+ doc_key = self._create_cache_key(input_text, direction)
130
+ self.db.collection('translations').document(doc_key).update({
131
+ 'corrected_output': corrected_output,
132
+ 'correction_timestamp': datetime.now()
133
+ })
134
  print(f"Correction saved: {input_text} → {corrected_output}")
 
135
  return True
136
+
137
+ except Exception as e:
138
+ print(f"Correction save error: {e}")
139
+ return False
140
 
141
  def get_stats(self):
142
+ """Get cache statistics"""
143
+ if not self.db:
144
+ return "Firebase not connected"
145
+
146
+ try:
147
+ docs = self.db.collection('translations').get()
148
+ total = len(docs)
149
+
150
+ corrected = 0
151
+ total_usage = 0
152
+
153
+ for doc in docs:
154
+ data = doc.to_dict()
155
+ if data.get('corrected_output'):
156
+ corrected += 1
157
+ total_usage += data.get('usage_count', 0)
158
+
159
+ return f"""
160
+ Firebase Cache Statistics:
161
+ • Total translations: {total}
162
+ • With corrections: {corrected}
163
+ • Total usage count: {total_usage}
164
+ • Average usage: {total_usage/total if total > 0 else 0:.1f} per translation
165
+ """.strip()
166
+
167
+ except Exception as e:
168
+ return f"Error getting stats: {e}"
169
 
170
  # Load vocabularies and model
171
  def load_model_and_vocabs():
 
172
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
173
 
174
+ # Load vocabularies
175
+ with open('latin_stoi.json', 'r', encoding='utf-8') as f:
176
+ latin_stoi = json.load(f)
177
+ with open('latin_itos.json', 'r', encoding='utf-8') as f:
178
+ latin_itos = json.load(f)
179
+ with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
180
+ arabic_stoi = json.load(f)
181
+ with open('arabic_itos.json', 'r', encoding='utf-8') as f:
182
+ arabic_itos = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ # Initialize model
185
+ model = CTCTransliterator(
186
+ input_dim=len(latin_stoi),
187
+ hidden_dim=256,
188
+ output_dim=len(arabic_stoi),
189
+ num_layers=3,
190
+ dropout=0.4
191
+ ).to(device)
192
+
193
+ # Load trained weights
194
+ model.load_state_dict(torch.load('CER_0.091_BLEU_0.85_transliterator.pth', map_location=device))
195
+ model.eval()
196
+
197
+ blank_id = 0
198
+ return model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device
199
 
200
  # Load everything at startup
201
+ model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device = load_model_and_vocabs()
202
+ firebase_cache = FirebaseCache()
 
 
 
 
 
203
 
204
  def encode_text(text, vocab):
205
  """Encode text using vocabulary"""
 
222
  return results
223
 
224
  def transliterate_latin_to_arabic(text):
225
+ """Transliterate Latin script to Arabic script with Firebase caching"""
226
  if not text.strip():
227
  return ""
228
 
229
+ # Check Firebase cache first
230
+ cached_result = firebase_cache.get(text, "Latin → Arabic")
 
 
231
  if cached_result:
232
  return cached_result
233
 
 
243
  decoded = greedy_decode(out, arabic_itos, blank_id)
244
  result = decoded[0] if decoded else ""
245
 
246
+ # Cache the result in Firebase
247
+ firebase_cache.set(text, "Latin → Arabic", result)
 
 
248
 
249
  return result
250
 
251
  except Exception as e:
252
+ return f"Error: {str(e)}"
 
 
 
253
 
254
  def transliterate_arabic_to_latin(text):
255
  """Transliterate Arabic script to Latin script (placeholder)"""
 
263
  return transliterate_arabic_to_latin(text)
264
 
265
  def save_correction(input_text, direction, corrected_output):
266
+ """Save user correction to Firebase"""
267
+ if firebase_cache.update_correction(input_text, direction, corrected_output):
268
+ return "Correction saved to Firebase! Thank you for improving the model."
269
  else:
270
+ return "Could not save correction to Firebase."
 
 
 
 
 
 
 
 
 
 
271
 
272
  # Arabic keyboard layout
273
  arabic_keys = [
 
277
  ['ذ', '١', '٢', '٣', '٤', '٥', '٦', '٧', '٨', '٩', '٠']
278
  ]
279
 
280
+ # Create Gradio interface
281
  def create_interface():
282
+ with gr.Blocks(title="Darija Transliterator", theme=gr.themes.Soft()) as demo:
283
  gr.Markdown(
284
  """
285
+ # Darija Transliterator
286
  Convert between Latin script and Arabic script for Moroccan Darija
287
 
288
+ **Firebase-Powered**: Persistent caching across sessions
289
+ **Arabic Keyboard**: Built-in Arabic keyboard for corrections
290
+ **Real-time Stats**: Live usage analytics
 
291
  """
292
  )
293
 
294
+ # Stats section
295
  with gr.Row():
296
+ stats_btn = gr.Button("Show Statistics", variant="secondary")
297
+ stats_display = gr.Textbox(
298
+ label="Firebase Statistics",
299
+ interactive=False,
300
+ visible=False,
301
+ lines=5
302
+ )
303
 
304
  with gr.Row():
305
  with gr.Column(scale=1):
 
349
  with gr.Row():
350
  space_btn = gr.Button("Space", size="sm", scale=2)
351
  backspace_btn = gr.Button("⌫ Backspace", size="sm", scale=2)
352
+ clear_output_btn = gr.Button("Clear Output", size="sm", scale=2)
353
 
354
  # Correction system
355
  with gr.Group():
 
390
 
391
  # Stats button
392
  stats_btn.click(
393
+ fn=firebase_cache.get_stats,
394
  outputs=[stats_display]
395
  ).then(
396
  fn=lambda: gr.update(visible=True),
 
401
  gr.Markdown("### Examples")
402
  examples = [
403
  ["kifash nta?", "Latin → Arabic"],
404
+ ["salam alikoum", "Latin → Arabic"],
405
  ["ana bem", "Latin → Arabic"],
406
  ["wach nta mjit?", "Latin → Arabic"],
407
  ["شكون نتا؟", "Arabic → Latin"],
 
445
  outputs=[correction_status]
446
  )
447
 
448
+ # Information
449
  gr.Markdown(
450
  """
451
+ ### About
452
+ This model transliterates Moroccan Darija between Latin and Arabic scripts using a CTC-based neural network.
453
 
454
+ **Firebase Features:**
455
+ - **Persistent Storage**: All translations are saved permanently
456
+ - **Analytics**: Track usage patterns and popular translations
457
+ - **Fast Responses**: Cached results load instantly
458
+ - **Global Access**: Data synced across all users
459
+ - **Corrections**: Help improve the model by fixing outputs
460
 
461
+ **How to help improve the model:**
462
+ 1. Use the Arabic keyboard to correct any wrong translations
463
+ 2. Click "Save Correction" to store your improvement
464
+ 3. Your corrections help train better models for everyone!
465
  """
466
  )
467
 
 
469
 
470
  # Launch the app
471
  if __name__ == "__main__":
 
472
  demo = create_interface()
473
+ demo.launch(share=True)