File size: 17,668 Bytes
243285a
 
 
e5064af
243285a
ed1ecb7
 
243285a
4d7afb5
 
243285a
4d7afb5
243285a
e5064af
243285a
 
 
 
 
 
e5064af
243285a
 
 
e5064af
 
 
243285a
 
e5064af
 
 
 
 
 
 
243285a
 
 
 
4d7afb5
 
 
 
 
741c6b3
4d7afb5
 
ed1ecb7
4d7afb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741c6b3
ed1ecb7
4d7afb5
 
 
741c6b3
4d7afb5
 
 
 
ed1ecb7
4d7afb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741c6b3
ed1ecb7
4d7afb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741c6b3
ed1ecb7
4d7afb5
 
 
 
 
 
 
 
 
 
741c6b3
ed1ecb7
4d7afb5
 
 
 
741c6b3
 
4d7afb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c716b49
4d7afb5
 
 
 
 
 
 
 
ed1ecb7
243285a
 
 
 
 
 
 
 
e5eda94
5ef1d32
243285a
 
 
e5eda94
243285a
 
 
e5064af
 
 
 
 
5ef1d32
243285a
c6a824e
243285a
3a55a85
243285a
 
01e27b1
243285a
 
 
 
4d7afb5
243285a
 
 
 
 
5ef1d32
e5064af
 
 
5ef1d32
e5064af
243285a
e5064af
e5eda94
ee1c7a6
 
41e3f4e
e5064af
243285a
 
e5064af
 
01e27b1
3519daf
6434462
e5064af
 
6434462
3519daf
243285a
5ef1d32
e5064af
243285a
e5064af
3519daf
e5064af
243285a
 
 
4d7afb5
243285a
 
 
4d7afb5
 
ed1ecb7
 
 
243285a
 
ed1ecb7
243285a
 
 
ed1ecb7
243285a
 
5ef1d32
ed1ecb7
 
4d7afb5
 
243285a
ed1ecb7
243285a
 
 
 
 
ed1ecb7
 
243285a
 
 
 
7d960dc
243285a
 
 
ed1ecb7
4d7afb5
 
c716b49
ed1ecb7
c716b49
ed1ecb7
 
 
 
 
 
 
 
 
243285a
 
 
 
 
7a5f01d
243285a
ed1ecb7
4d7afb5
 
 
243285a
 
 
4d7afb5
741c6b3
4d7afb5
 
 
 
 
 
 
741c6b3
243285a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed1ecb7
243285a
ed1ecb7
 
7a5f01d
741c6b3
 
ed1ecb7
 
 
 
bc597e1
 
 
741c6b3
ed1ecb7
bc597e1
 
 
ed1ecb7
741c6b3
ed1ecb7
bc597e1
741c6b3
7a5f01d
741c6b3
ed1ecb7
 
7a5f01d
ed1ecb7
 
 
 
 
7a5f01d
243285a
741c6b3
 
 
 
 
 
 
 
243285a
 
741c6b3
 
 
 
 
 
 
243285a
 
741c6b3
 
 
 
 
 
 
 
 
 
4d7afb5
741c6b3
 
 
 
243285a
 
bc597e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243285a
 
 
 
 
ed1ecb7
 
 
243285a
 
 
 
 
 
 
 
 
 
 
 
 
ed1ecb7
 
 
 
 
 
 
 
 
4d7afb5
243285a
 
 
 
 
4d7afb5
 
 
 
 
 
ed1ecb7
4d7afb5
 
 
 
243285a
 
 
 
 
 
 
 
8f6b729
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import os
from datetime import datetime
from torch.nn.utils.rnn import pad_sequence
import firebase_admin
from firebase_admin import credentials, firestore

# Define the model architecture
class CTCTransliterator(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.3, upsample_factor=3):
        super().__init__()
        self.embed = nn.Embedding(input_dim, hidden_dim, padding_idx=0)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=num_layers,
                            bidirectional=True, dropout=dropout)
        self.layer_norm = nn.LayerNorm(hidden_dim * 2)
        self.dropout = nn.Dropout(dropout)
        self.upsample_factor = upsample_factor
        self.fc = nn.Linear(hidden_dim * 2, output_dim)

    def forward(self, x):
        # x: (seq_len, batch, ...)
        x = self.embed(x)
        x, _ = self.lstm(x)
        x = self.layer_norm(x)
        x = self.dropout(x)

        # (seq_len, batch, hidden) → (batch, hidden, seq_len)
        x = x.permute(1, 2, 0)
        x = F.interpolate(x, scale_factor=self.upsample_factor, mode='linear', align_corners=False)
        # → (batch, hidden, seq_len*upsample_factor)
        x = x.permute(2, 0, 1)  # back to (seq_len*upsample_factor, batch, hidden)

        x = self.fc(x)
        x = x.log_softmax(dim=2)
        return x

# Firebase Cache System
class FirebaseCache:
    def __init__(self):
        self.db = None
        self.init_firebase()
        
    def init_firebase(self):
        """Initialize Firebase connection"""
        try:
            # Try to initialize Firebase
            if not firebase_admin._apps:
                # For HuggingFace Spaces, use environment variables
                if os.getenv('FIREBASE_CREDENTIALS'):
                    # Parse credentials from environment variable
                    import base64
                    cred_data = json.loads(base64.b64decode(os.getenv('FIREBASE_CREDENTIALS')).decode())
                    cred = credentials.Certificate(cred_data)
                elif os.path.exists('firebase-credentials.json'):
                    # For local development
                    cred = credentials.Certificate('firebase-credentials.json')
                else:
                    print("No Firebase credentials found. Using local cache fallback.")
                    return
                
                firebase_admin.initialize_app(cred)
                self.db = firestore.client()
                print("Firebase initialized successfully!")
            else:
                self.db = firestore.client()
                
        except Exception as e:
            print(f"Firebase initialization failed: {e}")
            print("Falling back to local cache mode")
            self.db = None
    
    def _create_cache_key(self, input_text, direction):
        """Create a safe document key for Firestore"""
        import hashlib
        # Create hash to handle special characters and length limits
        key = f"{input_text}_{direction}"
        return hashlib.md5(key.encode()).hexdigest()
    
    def get(self, input_text, direction):
        """Get cached translation from Firebase"""
        if not self.db:
            return None
            
        try:
            doc_key = self._create_cache_key(input_text, direction)
            doc = self.db.collection('translations').document(doc_key).get()
            
            if doc.exists:
                data = doc.to_dict()
                # Update usage count
                self.db.collection('translations').document(doc_key).update({
                    'usage_count': data.get('usage_count', 0) + 1,
                    'last_used': datetime.now()
                })
                print(f"Cache hit: {input_text}")
                return data.get('output', '')
            
            return None
            
        except Exception as e:
            print(f"Cache read error: {e}")
            return None
    
    def set(self, input_text, direction, output):
        """Store translation in Firebase"""
        if not self.db:
            return False
            
        try:
            doc_key = self._create_cache_key(input_text, direction)
            doc_data = {
                'input': input_text,
                'direction': direction,
                'output': output,
                'corrected_output': '',
                'timestamp': datetime.now(),
                'last_used': datetime.now(),
                'usage_count': 1
            }
            
            self.db.collection('translations').document(doc_key).set(doc_data)
            print(f"Cached: {input_text}{output}")
            return True
            
        except Exception as e:
            print(f"Cache write error: {e}")
            return False
    
    def update_correction(self, input_text, direction, corrected_output):
        """Update translation with user correction"""
        if not self.db:
            return False
            
        try:
            doc_key = self._create_cache_key(input_text, direction)
            self.db.collection('translations').document(doc_key).update({
                'corrected_output': corrected_output,
                'correction_timestamp': datetime.now()
            })
            print(f"Correction saved: {input_text}{corrected_output}")
            return True
            
        except Exception as e:
            print(f"Correction save error: {e}")
            return False
    
    def get_stats(self):
        """Get cache statistics"""
        if not self.db:
            return "Firebase not connected"
            
        try:
            docs = self.db.collection('translations').get()
            total = len(docs)
            
            corrected = 0
            total_usage = 0
            
            for doc in docs:
                data = doc.to_dict()
                if data.get('corrected_output'):
                    corrected += 1
                total_usage += data.get('usage_count', 0)
            
            return f"""
Cache Statistics:
• Total translations: {total}
• With corrections: {corrected}
• Total usage count: {total_usage}
• Average usage: {total_usage/total if total > 0 else 0:.1f} per translation
            """.strip()
            
        except Exception as e:
            return f"Error getting stats: {e}"

# Load vocabularies and model
def load_model_and_vocabs():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load vocabularies
    with open('latin_stoi.json', 'r', encoding='utf-8') as f:
        latin_stoi = json.load(f)
    with open('latin_itos.json', 'r', encoding='utf-8') as f:
        latin_itos = json.load(f)
    
    with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
        arabic_stoi = json.load(f)
    with open('arabic_itos.json', 'r', encoding='utf-8') as f:
        arabic_itos= json.load(f)
    
    # Initialize model
    model = CTCTransliterator(
        len(latin_stoi),
        256,
        len(arabic_stoi),
        num_layers=3,
        dropout=0.3,
        upsample_factor=2
    ).to(device)
        
    # Load trained weights
    model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=False))
    model.eval()
    
    blank_id = arabic_stoi.get('<blank>', len(arabic_itos)-1)
    return model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device

# Load everything at startup
model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device = load_model_and_vocabs()
firebase_cache = FirebaseCache()

def encode_text(text, vocab):
    """Encode text using vocabulary"""
    return torch.tensor([vocab.get(ch, 0) for ch in text.strip()], dtype=torch.long)

def greedy_decode(log_probs, blank_id, itos, stoi):
    """
    Decode CTC outputs using greedy decoding.
    """
    eos_id = stoi.get('<eos>', len(stoi)-2)
    preds = log_probs.argmax(2).T.cpu().numpy()  # (B, T)
    results = []
    raw_results = []
    print(eos_id, blank_id)
    print(stoi)
    print(type(blank_id))
    print(stoi.get('<eos>',0))
    for i, pred in enumerate(preds):
        prev = None
        decoded = []
        raw_result = []

        for p in pred:
            print(p, itos[str(p)])
            if p == eos_id:  # Stop at EOS!
                break
            # CTC collapse: skip blanks and repeated characters
            if p != blank_id and p != prev:
                decoded.append(itos[str(p)])  
            prev = p
            raw_result.append(itos[str(p)])

        results.append("".join(decoded))
        raw_results.append("".join(raw_result))
        print(results, raw_results)

    return results

def transliterate_latin_to_arabic(text):
    """Transliterate Latin script to Arabic script with Firebase caching"""
    if not text.strip():
        return ""
    
    # Check Firebase cache first
    cached_result = firebase_cache.get(text, "Latin → Arabic")
    if cached_result:
        return cached_result
    
    try:
        # Encode input text
        src = encode_text(text, latin_stoi).unsqueeze(1).to(device)
        
        # Generate prediction
        with torch.no_grad():
            out = model(src)
        
        # Decode output
        decoded = greedy_decode(out, blank_id, arabic_itos, arabic_stoi)
        result = decoded[0] if decoded else ""
        
        # Cache the result in Firebase
        firebase_cache.set(text, "Latin → Arabic", result)
        
        return result
    
    except Exception as e:
        return f"Error: {str(e)}"

def transliterate_arabic_to_latin(text):
    """Transliterate Arabic script to Latin script (placeholder)"""
    return "Arabic to Latin transliteration not implemented yet."

def transliterate(text, direction):
    """Main transliteration function"""
    if direction == "Latin → Arabic":
        return transliterate_latin_to_arabic(text.lower())
    else:
        return transliterate_arabic_to_latin(text)

def save_correction(input_text, direction, corrected_output):
    """Save user correction to Firebase"""
    if firebase_cache.update_correction(input_text, direction, corrected_output):
        return "Correction saved to the database! Thank you for improving the model."
    else:
        return "Could not save correction to databse."

# Arabic keyboard layout
arabic_keys = [
    ['ض', 'ص', 'ث', 'ق', 'ف', 'غ', 'ع', 'ه', 'خ', 'ح', 'ج', 'د'],
    ['ش', 'س', 'ي', 'ب', 'ل', 'ا', 'ت', 'ن', 'م', 'ك', 'ط'],
    ['ئ', 'ء', 'ؤ', 'ر', 'لا', 'ى', 'ة', 'و', 'ز', 'ظ'],
    ['ذ', '١', '٢', '٣', '٤', '٥', '٦', '٧', '٨', '٩', '٠']
]

# Create Gradio interface
def create_interface():
    with gr.Blocks(title="Darija Transliterator", theme=gr.themes.Soft()) as demo:
        gr.Markdown(
            """
            # Darija Transliterator
            Convert between Latin script and Arabic script for Moroccan Darija
            
            **Firebase-Powered**: Persistent caching across sessions  
            **Arabic Keyboard**: Built-in Arabic keyboard for corrections  
            **Real-time Stats**: Live usage analytics
            """
        )
        
        # Stats section
        with gr.Row():
            stats_btn = gr.Button("Show Statistics", variant="secondary")
            stats_display = gr.Textbox(
                label="Firebase Statistics", 
                interactive=False, 
                visible=False,
                lines=5
            )
        
        with gr.Row():
            with gr.Column(scale=1):
                direction = gr.Radio(
                    choices=["Latin → Arabic", "Arabic → Latin"],
                    value="Latin → Arabic",
                    label="Translation Direction"
                )
                
                input_text = gr.Textbox(
                    placeholder="Enter text to transliterate...",
                    label="Input Text",
                    lines=4,
                    max_lines=10
                )
                
                with gr.Row():
                    clear_btn = gr.Button("Clear", variant="secondary")
                    translate_btn = gr.Button("Transliterate", variant="primary")
            
            with gr.Column(scale=1):
                output_text = gr.Textbox(
                    label="Output",
                    lines=4,
                    max_lines=10,
                    interactive=True
                )
                
                # Arabic Keyboard
                gr.Markdown("### Arabic Keyboard")
                gr.Markdown("*Click letters to edit the output text above*")
                
                with gr.Group():
                    for row in arabic_keys:
                        with gr.Row():
                            for char in row:
                                btn = gr.Button(char, size="sm", scale=1)
                                btn.click(
                                    fn=None,
                                    js=f"(output_text) => output_text + '{char}'",
                                    inputs=[output_text],
                                    outputs=[output_text],
                                    show_progress=False,
                                    queue=False
                                )
                    
                    with gr.Row():
                        space_btn = gr.Button("Space", size="sm", scale=2)
                        backspace_btn = gr.Button("⌫ Backspace", size="sm", scale=2)
                        clear_output_btn = gr.Button("Clear Output", size="sm", scale=2)
                
                # Correction system
                with gr.Group():
                    gr.Markdown("### Correction System")
                    correction_status = gr.Textbox(
                        label="Status",
                        interactive=False,
                        visible=False
                    )
                    save_correction_btn = gr.Button("Save Correction", variant="secondary")
        
        # Keyboard utility buttons
        space_btn.click(
            fn=None,
            js="(output_text) => output_text + ' '",
            inputs=[output_text],
            outputs=[output_text],
            show_progress=False,
            queue=False
        )
        
        backspace_btn.click(
            fn=None,
            js="(output_text) => output_text.slice(0, -1)",
            inputs=[output_text],
            outputs=[output_text],
            show_progress=False,
            queue=False
        )
        
        clear_output_btn.click(
            fn=None,
            js="() => ''",
            outputs=[output_text],
            show_progress=False,
            queue=False
        )
        
        # Stats button
        stats_btn.click(
            fn=firebase_cache.get_stats,
            outputs=[stats_display]
        ).then(
            fn=lambda: gr.update(visible=True),
            outputs=[stats_display]
        )
        
        # Example inputs
        gr.Markdown("### Examples")
        examples = [
            ["makay3nich bli katkhdam bzaf", "Latin → Arabic"],
            ["rah bayn dkchi li katdir kolchi 3ay9 bik", "Latin → Arabic"],
            ["wach na9dar nakhod caipirinha, 3afak", "Latin → Arabic"],
            ["ghadi temchi f lkhedma mzyan", "Latin → Arabic"]
        ]
        
        gr.Examples(
            examples=examples,
            inputs=[input_text, direction],
            outputs=output_text,
            fn=transliterate,
            cache_examples=False
        )
        
        # Event handlers
        translate_btn.click(
            fn=transliterate,
            inputs=[input_text, direction],
            outputs=output_text
        ).then(
            fn=lambda: gr.update(visible=True),
            outputs=[correction_status]
        )
        
        clear_btn.click(
            fn=lambda: ("", ""),
            outputs=[input_text, output_text]
        )
        
        input_text.submit(
            fn=transliterate,
            inputs=[input_text, direction],
            outputs=output_text
        )
        
        save_correction_btn.click(
            fn=save_correction,
            inputs=[input_text, direction, output_text],
            outputs=[correction_status]
        ).then(
            fn=lambda: gr.update(visible=True),
            outputs=[correction_status]
        )
        
        # Information
        gr.Markdown(
            """
            ### About
            This model transliterates Moroccan Darija between Latin and Arabic scripts using a CTC-based neural network.
            
            **Firebase Features:**
            - **Persistent Storage**: All translations are saved permanently
            - **Analytics**: Track usage patterns and popular translations  
            - **Fast Responses**: Cached results load instantly
            - **Global Access**: Data synced across all users
            - **Corrections**: Help improve the model by fixing outputs
            
            **How to help improve the model:**
            1. Use the Arabic keyboard to correct any wrong translations
            2. Click "Save Correction" to store your improvement
            3. Your corrections help train better models for everyone!
            """
        )
    
    return demo

# Launch the app
if __name__ == "__main__":
    demo = create_interface()
    demo.launch(share=True)