Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import json | |
| import os | |
| from datetime import datetime | |
| from torch.nn.utils.rnn import pad_sequence | |
| import firebase_admin | |
| from firebase_admin import credentials, firestore | |
| # Define the model architecture | |
| class CTCTransliterator(nn.Module): | |
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.3, upsample_factor=3): | |
| super().__init__() | |
| self.embed = nn.Embedding(input_dim, hidden_dim, padding_idx=0) | |
| self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=num_layers, | |
| bidirectional=True, dropout=dropout) | |
| self.layer_norm = nn.LayerNorm(hidden_dim * 2) | |
| self.dropout = nn.Dropout(dropout) | |
| self.upsample_factor = upsample_factor | |
| self.fc = nn.Linear(hidden_dim * 2, output_dim) | |
| def forward(self, x): | |
| # x: (seq_len, batch, ...) | |
| x = self.embed(x) | |
| x, _ = self.lstm(x) | |
| x = self.layer_norm(x) | |
| x = self.dropout(x) | |
| # (seq_len, batch, hidden) → (batch, hidden, seq_len) | |
| x = x.permute(1, 2, 0) | |
| x = F.interpolate(x, scale_factor=self.upsample_factor, mode='linear', align_corners=False) | |
| # → (batch, hidden, seq_len*upsample_factor) | |
| x = x.permute(2, 0, 1) # back to (seq_len*upsample_factor, batch, hidden) | |
| x = self.fc(x) | |
| x = x.log_softmax(dim=2) | |
| return x | |
| # Firebase Cache System | |
| class FirebaseCache: | |
| def __init__(self): | |
| self.db = None | |
| self.init_firebase() | |
| def init_firebase(self): | |
| """Initialize Firebase connection""" | |
| try: | |
| # Try to initialize Firebase | |
| if not firebase_admin._apps: | |
| # For HuggingFace Spaces, use environment variables | |
| if os.getenv('FIREBASE_CREDENTIALS'): | |
| # Parse credentials from environment variable | |
| import base64 | |
| cred_data = json.loads(base64.b64decode(os.getenv('FIREBASE_CREDENTIALS')).decode()) | |
| cred = credentials.Certificate(cred_data) | |
| elif os.path.exists('firebase-credentials.json'): | |
| # For local development | |
| cred = credentials.Certificate('firebase-credentials.json') | |
| else: | |
| print("No Firebase credentials found. Using local cache fallback.") | |
| return | |
| firebase_admin.initialize_app(cred) | |
| self.db = firestore.client() | |
| print("Firebase initialized successfully!") | |
| else: | |
| self.db = firestore.client() | |
| except Exception as e: | |
| print(f"Firebase initialization failed: {e}") | |
| print("Falling back to local cache mode") | |
| self.db = None | |
| def _create_cache_key(self, input_text, direction): | |
| """Create a safe document key for Firestore""" | |
| import hashlib | |
| # Create hash to handle special characters and length limits | |
| key = f"{input_text}_{direction}" | |
| return hashlib.md5(key.encode()).hexdigest() | |
| def get(self, input_text, direction): | |
| """Get cached translation from Firebase""" | |
| if not self.db: | |
| return None | |
| try: | |
| doc_key = self._create_cache_key(input_text, direction) | |
| doc = self.db.collection('translations').document(doc_key).get() | |
| if doc.exists: | |
| data = doc.to_dict() | |
| # Update usage count | |
| self.db.collection('translations').document(doc_key).update({ | |
| 'usage_count': data.get('usage_count', 0) + 1, | |
| 'last_used': datetime.now() | |
| }) | |
| print(f"Cache hit: {input_text}") | |
| return data.get('output', '') | |
| return None | |
| except Exception as e: | |
| print(f"Cache read error: {e}") | |
| return None | |
| def set(self, input_text, direction, output): | |
| """Store translation in Firebase""" | |
| if not self.db: | |
| return False | |
| try: | |
| doc_key = self._create_cache_key(input_text, direction) | |
| doc_data = { | |
| 'input': input_text, | |
| 'direction': direction, | |
| 'output': output, | |
| 'corrected_output': '', | |
| 'timestamp': datetime.now(), | |
| 'last_used': datetime.now(), | |
| 'usage_count': 1 | |
| } | |
| self.db.collection('translations').document(doc_key).set(doc_data) | |
| print(f"Cached: {input_text} → {output}") | |
| return True | |
| except Exception as e: | |
| print(f"Cache write error: {e}") | |
| return False | |
| def update_correction(self, input_text, direction, corrected_output): | |
| """Update translation with user correction""" | |
| if not self.db: | |
| return False | |
| try: | |
| doc_key = self._create_cache_key(input_text, direction) | |
| self.db.collection('translations').document(doc_key).update({ | |
| 'corrected_output': corrected_output, | |
| 'correction_timestamp': datetime.now() | |
| }) | |
| print(f"Correction saved: {input_text} → {corrected_output}") | |
| return True | |
| except Exception as e: | |
| print(f"Correction save error: {e}") | |
| return False | |
| def get_stats(self): | |
| """Get cache statistics""" | |
| if not self.db: | |
| return "Firebase not connected" | |
| try: | |
| docs = self.db.collection('translations').get() | |
| total = len(docs) | |
| corrected = 0 | |
| total_usage = 0 | |
| for doc in docs: | |
| data = doc.to_dict() | |
| if data.get('corrected_output'): | |
| corrected += 1 | |
| total_usage += data.get('usage_count', 0) | |
| return f""" | |
| Cache Statistics: | |
| • Total translations: {total} | |
| • With corrections: {corrected} | |
| • Total usage count: {total_usage} | |
| • Average usage: {total_usage/total if total > 0 else 0:.1f} per translation | |
| """.strip() | |
| except Exception as e: | |
| return f"Error getting stats: {e}" | |
| # Load vocabularies and model | |
| def load_model_and_vocabs(): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load vocabularies | |
| with open('latin_stoi.json', 'r', encoding='utf-8') as f: | |
| latin_stoi = json.load(f) | |
| with open('latin_itos.json', 'r', encoding='utf-8') as f: | |
| latin_itos = json.load(f) | |
| with open('arabic_stoi.json', 'r', encoding='utf-8') as f: | |
| arabic_stoi = json.load(f) | |
| with open('arabic_itos.json', 'r', encoding='utf-8') as f: | |
| arabic_itos= json.load(f) | |
| # Initialize model | |
| model = CTCTransliterator( | |
| len(latin_stoi), | |
| 256, | |
| len(arabic_stoi), | |
| num_layers=3, | |
| dropout=0.3, | |
| upsample_factor=2 | |
| ).to(device) | |
| # Load trained weights | |
| model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=False)) | |
| model.eval() | |
| blank_id = arabic_stoi.get('<blank>', len(arabic_itos)-1) | |
| return model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device | |
| # Load everything at startup | |
| model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device = load_model_and_vocabs() | |
| firebase_cache = FirebaseCache() | |
| def encode_text(text, vocab): | |
| """Encode text using vocabulary""" | |
| return torch.tensor([vocab.get(ch, 0) for ch in text.strip()], dtype=torch.long) | |
| def greedy_decode(log_probs, blank_id, itos, stoi): | |
| """ | |
| Decode CTC outputs using greedy decoding. | |
| """ | |
| eos_id = stoi.get('<eos>', len(stoi)-2) | |
| preds = log_probs.argmax(2).T.cpu().numpy() # (B, T) | |
| results = [] | |
| raw_results = [] | |
| print(eos_id, blank_id) | |
| print(stoi) | |
| print(type(blank_id)) | |
| print(stoi.get('<eos>',0)) | |
| for i, pred in enumerate(preds): | |
| prev = None | |
| decoded = [] | |
| raw_result = [] | |
| for p in pred: | |
| print(p, itos[str(p)]) | |
| if p == eos_id: # Stop at EOS! | |
| break | |
| # CTC collapse: skip blanks and repeated characters | |
| if p != blank_id and p != prev: | |
| decoded.append(itos[str(p)]) | |
| prev = p | |
| raw_result.append(itos[str(p)]) | |
| results.append("".join(decoded)) | |
| raw_results.append("".join(raw_result)) | |
| print(results, raw_results) | |
| return results | |
| def transliterate_latin_to_arabic(text): | |
| """Transliterate Latin script to Arabic script with Firebase caching""" | |
| if not text.strip(): | |
| return "" | |
| # Check Firebase cache first | |
| cached_result = firebase_cache.get(text, "Latin → Arabic") | |
| if cached_result: | |
| return cached_result | |
| try: | |
| # Encode input text | |
| src = encode_text(text, latin_stoi).unsqueeze(1).to(device) | |
| # Generate prediction | |
| with torch.no_grad(): | |
| out = model(src) | |
| # Decode output | |
| decoded = greedy_decode(out, blank_id, arabic_itos, arabic_stoi) | |
| result = decoded[0] if decoded else "" | |
| # Cache the result in Firebase | |
| firebase_cache.set(text, "Latin → Arabic", result) | |
| return result | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def transliterate_arabic_to_latin(text): | |
| """Transliterate Arabic script to Latin script (placeholder)""" | |
| return "Arabic to Latin transliteration not implemented yet." | |
| def transliterate(text, direction): | |
| """Main transliteration function""" | |
| if direction == "Latin → Arabic": | |
| return transliterate_latin_to_arabic(text.lower()) | |
| else: | |
| return transliterate_arabic_to_latin(text) | |
| def save_correction(input_text, direction, corrected_output): | |
| """Save user correction to Firebase""" | |
| if firebase_cache.update_correction(input_text, direction, corrected_output): | |
| return "Correction saved to the database! Thank you for improving the model." | |
| else: | |
| return "Could not save correction to databse." | |
| # Arabic keyboard layout | |
| arabic_keys = [ | |
| ['ض', 'ص', 'ث', 'ق', 'ف', 'غ', 'ع', 'ه', 'خ', 'ح', 'ج', 'د'], | |
| ['ش', 'س', 'ي', 'ب', 'ل', 'ا', 'ت', 'ن', 'م', 'ك', 'ط'], | |
| ['ئ', 'ء', 'ؤ', 'ر', 'لا', 'ى', 'ة', 'و', 'ز', 'ظ'], | |
| ['ذ', '١', '٢', '٣', '٤', '٥', '٦', '٧', '٨', '٩', '٠'] | |
| ] | |
| # Create Gradio interface | |
| def create_interface(): | |
| with gr.Blocks(title="Darija Transliterator", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Darija Transliterator | |
| Convert between Latin script and Arabic script for Moroccan Darija | |
| **Firebase-Powered**: Persistent caching across sessions | |
| **Arabic Keyboard**: Built-in Arabic keyboard for corrections | |
| **Real-time Stats**: Live usage analytics | |
| """ | |
| ) | |
| # Stats section | |
| with gr.Row(): | |
| stats_btn = gr.Button("Show Statistics", variant="secondary") | |
| stats_display = gr.Textbox( | |
| label="Firebase Statistics", | |
| interactive=False, | |
| visible=False, | |
| lines=5 | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| direction = gr.Radio( | |
| choices=["Latin → Arabic", "Arabic → Latin"], | |
| value="Latin → Arabic", | |
| label="Translation Direction" | |
| ) | |
| input_text = gr.Textbox( | |
| placeholder="Enter text to transliterate...", | |
| label="Input Text", | |
| lines=4, | |
| max_lines=10 | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| translate_btn = gr.Button("Transliterate", variant="primary") | |
| with gr.Column(scale=1): | |
| output_text = gr.Textbox( | |
| label="Output", | |
| lines=4, | |
| max_lines=10, | |
| interactive=True | |
| ) | |
| # Arabic Keyboard | |
| gr.Markdown("### Arabic Keyboard") | |
| gr.Markdown("*Click letters to edit the output text above*") | |
| with gr.Group(): | |
| for row in arabic_keys: | |
| with gr.Row(): | |
| for char in row: | |
| btn = gr.Button(char, size="sm", scale=1) | |
| btn.click( | |
| fn=None, | |
| js=f"(output_text) => output_text + '{char}'", | |
| inputs=[output_text], | |
| outputs=[output_text], | |
| show_progress=False, | |
| queue=False | |
| ) | |
| with gr.Row(): | |
| space_btn = gr.Button("Space", size="sm", scale=2) | |
| backspace_btn = gr.Button("⌫ Backspace", size="sm", scale=2) | |
| clear_output_btn = gr.Button("Clear Output", size="sm", scale=2) | |
| # Correction system | |
| with gr.Group(): | |
| gr.Markdown("### Correction System") | |
| correction_status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| visible=False | |
| ) | |
| save_correction_btn = gr.Button("Save Correction", variant="secondary") | |
| # Keyboard utility buttons | |
| space_btn.click( | |
| fn=None, | |
| js="(output_text) => output_text + ' '", | |
| inputs=[output_text], | |
| outputs=[output_text], | |
| show_progress=False, | |
| queue=False | |
| ) | |
| backspace_btn.click( | |
| fn=None, | |
| js="(output_text) => output_text.slice(0, -1)", | |
| inputs=[output_text], | |
| outputs=[output_text], | |
| show_progress=False, | |
| queue=False | |
| ) | |
| clear_output_btn.click( | |
| fn=None, | |
| js="() => ''", | |
| outputs=[output_text], | |
| show_progress=False, | |
| queue=False | |
| ) | |
| # Stats button | |
| stats_btn.click( | |
| fn=firebase_cache.get_stats, | |
| outputs=[stats_display] | |
| ).then( | |
| fn=lambda: gr.update(visible=True), | |
| outputs=[stats_display] | |
| ) | |
| # Example inputs | |
| gr.Markdown("### Examples") | |
| examples = [ | |
| ["makay3nich bli katkhdam bzaf", "Latin → Arabic"], | |
| ["rah bayn dkchi li katdir kolchi 3ay9 bik", "Latin → Arabic"], | |
| ["wach na9dar nakhod caipirinha, 3afak", "Latin → Arabic"], | |
| ["ghadi temchi f lkhedma mzyan", "Latin → Arabic"] | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[input_text, direction], | |
| outputs=output_text, | |
| fn=transliterate, | |
| cache_examples=False | |
| ) | |
| # Event handlers | |
| translate_btn.click( | |
| fn=transliterate, | |
| inputs=[input_text, direction], | |
| outputs=output_text | |
| ).then( | |
| fn=lambda: gr.update(visible=True), | |
| outputs=[correction_status] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", ""), | |
| outputs=[input_text, output_text] | |
| ) | |
| input_text.submit( | |
| fn=transliterate, | |
| inputs=[input_text, direction], | |
| outputs=output_text | |
| ) | |
| save_correction_btn.click( | |
| fn=save_correction, | |
| inputs=[input_text, direction, output_text], | |
| outputs=[correction_status] | |
| ).then( | |
| fn=lambda: gr.update(visible=True), | |
| outputs=[correction_status] | |
| ) | |
| # Information | |
| gr.Markdown( | |
| """ | |
| ### About | |
| This model transliterates Moroccan Darija between Latin and Arabic scripts using a CTC-based neural network. | |
| **Firebase Features:** | |
| - **Persistent Storage**: All translations are saved permanently | |
| - **Analytics**: Track usage patterns and popular translations | |
| - **Fast Responses**: Cached results load instantly | |
| - **Global Access**: Data synced across all users | |
| - **Corrections**: Help improve the model by fixing outputs | |
| **How to help improve the model:** | |
| 1. Use the Arabic keyboard to correct any wrong translations | |
| 2. Click "Save Correction" to store your improvement | |
| 3. Your corrections help train better models for everyone! | |
| """ | |
| ) | |
| return demo | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(share=True) |