Haitam03's picture
Update app.py
3a55a85 verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import os
from datetime import datetime
from torch.nn.utils.rnn import pad_sequence
import firebase_admin
from firebase_admin import credentials, firestore
# Define the model architecture
class CTCTransliterator(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.3, upsample_factor=3):
super().__init__()
self.embed = nn.Embedding(input_dim, hidden_dim, padding_idx=0)
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=num_layers,
bidirectional=True, dropout=dropout)
self.layer_norm = nn.LayerNorm(hidden_dim * 2)
self.dropout = nn.Dropout(dropout)
self.upsample_factor = upsample_factor
self.fc = nn.Linear(hidden_dim * 2, output_dim)
def forward(self, x):
# x: (seq_len, batch, ...)
x = self.embed(x)
x, _ = self.lstm(x)
x = self.layer_norm(x)
x = self.dropout(x)
# (seq_len, batch, hidden) → (batch, hidden, seq_len)
x = x.permute(1, 2, 0)
x = F.interpolate(x, scale_factor=self.upsample_factor, mode='linear', align_corners=False)
# → (batch, hidden, seq_len*upsample_factor)
x = x.permute(2, 0, 1) # back to (seq_len*upsample_factor, batch, hidden)
x = self.fc(x)
x = x.log_softmax(dim=2)
return x
# Firebase Cache System
class FirebaseCache:
def __init__(self):
self.db = None
self.init_firebase()
def init_firebase(self):
"""Initialize Firebase connection"""
try:
# Try to initialize Firebase
if not firebase_admin._apps:
# For HuggingFace Spaces, use environment variables
if os.getenv('FIREBASE_CREDENTIALS'):
# Parse credentials from environment variable
import base64
cred_data = json.loads(base64.b64decode(os.getenv('FIREBASE_CREDENTIALS')).decode())
cred = credentials.Certificate(cred_data)
elif os.path.exists('firebase-credentials.json'):
# For local development
cred = credentials.Certificate('firebase-credentials.json')
else:
print("No Firebase credentials found. Using local cache fallback.")
return
firebase_admin.initialize_app(cred)
self.db = firestore.client()
print("Firebase initialized successfully!")
else:
self.db = firestore.client()
except Exception as e:
print(f"Firebase initialization failed: {e}")
print("Falling back to local cache mode")
self.db = None
def _create_cache_key(self, input_text, direction):
"""Create a safe document key for Firestore"""
import hashlib
# Create hash to handle special characters and length limits
key = f"{input_text}_{direction}"
return hashlib.md5(key.encode()).hexdigest()
def get(self, input_text, direction):
"""Get cached translation from Firebase"""
if not self.db:
return None
try:
doc_key = self._create_cache_key(input_text, direction)
doc = self.db.collection('translations').document(doc_key).get()
if doc.exists:
data = doc.to_dict()
# Update usage count
self.db.collection('translations').document(doc_key).update({
'usage_count': data.get('usage_count', 0) + 1,
'last_used': datetime.now()
})
print(f"Cache hit: {input_text}")
return data.get('output', '')
return None
except Exception as e:
print(f"Cache read error: {e}")
return None
def set(self, input_text, direction, output):
"""Store translation in Firebase"""
if not self.db:
return False
try:
doc_key = self._create_cache_key(input_text, direction)
doc_data = {
'input': input_text,
'direction': direction,
'output': output,
'corrected_output': '',
'timestamp': datetime.now(),
'last_used': datetime.now(),
'usage_count': 1
}
self.db.collection('translations').document(doc_key).set(doc_data)
print(f"Cached: {input_text}{output}")
return True
except Exception as e:
print(f"Cache write error: {e}")
return False
def update_correction(self, input_text, direction, corrected_output):
"""Update translation with user correction"""
if not self.db:
return False
try:
doc_key = self._create_cache_key(input_text, direction)
self.db.collection('translations').document(doc_key).update({
'corrected_output': corrected_output,
'correction_timestamp': datetime.now()
})
print(f"Correction saved: {input_text}{corrected_output}")
return True
except Exception as e:
print(f"Correction save error: {e}")
return False
def get_stats(self):
"""Get cache statistics"""
if not self.db:
return "Firebase not connected"
try:
docs = self.db.collection('translations').get()
total = len(docs)
corrected = 0
total_usage = 0
for doc in docs:
data = doc.to_dict()
if data.get('corrected_output'):
corrected += 1
total_usage += data.get('usage_count', 0)
return f"""
Cache Statistics:
• Total translations: {total}
• With corrections: {corrected}
• Total usage count: {total_usage}
• Average usage: {total_usage/total if total > 0 else 0:.1f} per translation
""".strip()
except Exception as e:
return f"Error getting stats: {e}"
# Load vocabularies and model
def load_model_and_vocabs():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load vocabularies
with open('latin_stoi.json', 'r', encoding='utf-8') as f:
latin_stoi = json.load(f)
with open('latin_itos.json', 'r', encoding='utf-8') as f:
latin_itos = json.load(f)
with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
arabic_stoi = json.load(f)
with open('arabic_itos.json', 'r', encoding='utf-8') as f:
arabic_itos= json.load(f)
# Initialize model
model = CTCTransliterator(
len(latin_stoi),
256,
len(arabic_stoi),
num_layers=3,
dropout=0.3,
upsample_factor=2
).to(device)
# Load trained weights
model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=False))
model.eval()
blank_id = arabic_stoi.get('<blank>', len(arabic_itos)-1)
return model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device
# Load everything at startup
model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device = load_model_and_vocabs()
firebase_cache = FirebaseCache()
def encode_text(text, vocab):
"""Encode text using vocabulary"""
return torch.tensor([vocab.get(ch, 0) for ch in text.strip()], dtype=torch.long)
def greedy_decode(log_probs, blank_id, itos, stoi):
"""
Decode CTC outputs using greedy decoding.
"""
eos_id = stoi.get('<eos>', len(stoi)-2)
preds = log_probs.argmax(2).T.cpu().numpy() # (B, T)
results = []
raw_results = []
print(eos_id, blank_id)
print(stoi)
print(type(blank_id))
print(stoi.get('<eos>',0))
for i, pred in enumerate(preds):
prev = None
decoded = []
raw_result = []
for p in pred:
print(p, itos[str(p)])
if p == eos_id: # Stop at EOS!
break
# CTC collapse: skip blanks and repeated characters
if p != blank_id and p != prev:
decoded.append(itos[str(p)])
prev = p
raw_result.append(itos[str(p)])
results.append("".join(decoded))
raw_results.append("".join(raw_result))
print(results, raw_results)
return results
def transliterate_latin_to_arabic(text):
"""Transliterate Latin script to Arabic script with Firebase caching"""
if not text.strip():
return ""
# Check Firebase cache first
cached_result = firebase_cache.get(text, "Latin → Arabic")
if cached_result:
return cached_result
try:
# Encode input text
src = encode_text(text, latin_stoi).unsqueeze(1).to(device)
# Generate prediction
with torch.no_grad():
out = model(src)
# Decode output
decoded = greedy_decode(out, blank_id, arabic_itos, arabic_stoi)
result = decoded[0] if decoded else ""
# Cache the result in Firebase
firebase_cache.set(text, "Latin → Arabic", result)
return result
except Exception as e:
return f"Error: {str(e)}"
def transliterate_arabic_to_latin(text):
"""Transliterate Arabic script to Latin script (placeholder)"""
return "Arabic to Latin transliteration not implemented yet."
def transliterate(text, direction):
"""Main transliteration function"""
if direction == "Latin → Arabic":
return transliterate_latin_to_arabic(text.lower())
else:
return transliterate_arabic_to_latin(text)
def save_correction(input_text, direction, corrected_output):
"""Save user correction to Firebase"""
if firebase_cache.update_correction(input_text, direction, corrected_output):
return "Correction saved to the database! Thank you for improving the model."
else:
return "Could not save correction to databse."
# Arabic keyboard layout
arabic_keys = [
['ض', 'ص', 'ث', 'ق', 'ف', 'غ', 'ع', 'ه', 'خ', 'ح', 'ج', 'د'],
['ش', 'س', 'ي', 'ب', 'ل', 'ا', 'ت', 'ن', 'م', 'ك', 'ط'],
['ئ', 'ء', 'ؤ', 'ر', 'لا', 'ى', 'ة', 'و', 'ز', 'ظ'],
['ذ', '١', '٢', '٣', '٤', '٥', '٦', '٧', '٨', '٩', '٠']
]
# Create Gradio interface
def create_interface():
with gr.Blocks(title="Darija Transliterator", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Darija Transliterator
Convert between Latin script and Arabic script for Moroccan Darija
**Firebase-Powered**: Persistent caching across sessions
**Arabic Keyboard**: Built-in Arabic keyboard for corrections
**Real-time Stats**: Live usage analytics
"""
)
# Stats section
with gr.Row():
stats_btn = gr.Button("Show Statistics", variant="secondary")
stats_display = gr.Textbox(
label="Firebase Statistics",
interactive=False,
visible=False,
lines=5
)
with gr.Row():
with gr.Column(scale=1):
direction = gr.Radio(
choices=["Latin → Arabic", "Arabic → Latin"],
value="Latin → Arabic",
label="Translation Direction"
)
input_text = gr.Textbox(
placeholder="Enter text to transliterate...",
label="Input Text",
lines=4,
max_lines=10
)
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary")
translate_btn = gr.Button("Transliterate", variant="primary")
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Output",
lines=4,
max_lines=10,
interactive=True
)
# Arabic Keyboard
gr.Markdown("### Arabic Keyboard")
gr.Markdown("*Click letters to edit the output text above*")
with gr.Group():
for row in arabic_keys:
with gr.Row():
for char in row:
btn = gr.Button(char, size="sm", scale=1)
btn.click(
fn=None,
js=f"(output_text) => output_text + '{char}'",
inputs=[output_text],
outputs=[output_text],
show_progress=False,
queue=False
)
with gr.Row():
space_btn = gr.Button("Space", size="sm", scale=2)
backspace_btn = gr.Button("⌫ Backspace", size="sm", scale=2)
clear_output_btn = gr.Button("Clear Output", size="sm", scale=2)
# Correction system
with gr.Group():
gr.Markdown("### Correction System")
correction_status = gr.Textbox(
label="Status",
interactive=False,
visible=False
)
save_correction_btn = gr.Button("Save Correction", variant="secondary")
# Keyboard utility buttons
space_btn.click(
fn=None,
js="(output_text) => output_text + ' '",
inputs=[output_text],
outputs=[output_text],
show_progress=False,
queue=False
)
backspace_btn.click(
fn=None,
js="(output_text) => output_text.slice(0, -1)",
inputs=[output_text],
outputs=[output_text],
show_progress=False,
queue=False
)
clear_output_btn.click(
fn=None,
js="() => ''",
outputs=[output_text],
show_progress=False,
queue=False
)
# Stats button
stats_btn.click(
fn=firebase_cache.get_stats,
outputs=[stats_display]
).then(
fn=lambda: gr.update(visible=True),
outputs=[stats_display]
)
# Example inputs
gr.Markdown("### Examples")
examples = [
["makay3nich bli katkhdam bzaf", "Latin → Arabic"],
["rah bayn dkchi li katdir kolchi 3ay9 bik", "Latin → Arabic"],
["wach na9dar nakhod caipirinha, 3afak", "Latin → Arabic"],
["ghadi temchi f lkhedma mzyan", "Latin → Arabic"]
]
gr.Examples(
examples=examples,
inputs=[input_text, direction],
outputs=output_text,
fn=transliterate,
cache_examples=False
)
# Event handlers
translate_btn.click(
fn=transliterate,
inputs=[input_text, direction],
outputs=output_text
).then(
fn=lambda: gr.update(visible=True),
outputs=[correction_status]
)
clear_btn.click(
fn=lambda: ("", ""),
outputs=[input_text, output_text]
)
input_text.submit(
fn=transliterate,
inputs=[input_text, direction],
outputs=output_text
)
save_correction_btn.click(
fn=save_correction,
inputs=[input_text, direction, output_text],
outputs=[correction_status]
).then(
fn=lambda: gr.update(visible=True),
outputs=[correction_status]
)
# Information
gr.Markdown(
"""
### About
This model transliterates Moroccan Darija between Latin and Arabic scripts using a CTC-based neural network.
**Firebase Features:**
- **Persistent Storage**: All translations are saved permanently
- **Analytics**: Track usage patterns and popular translations
- **Fast Responses**: Cached results load instantly
- **Global Access**: Data synced across all users
- **Corrections**: Help improve the model by fixing outputs
**How to help improve the model:**
1. Use the Arabic keyboard to correct any wrong translations
2. Click "Save Correction" to store your improvement
3. Your corrections help train better models for everyone!
"""
)
return demo
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True)