import os from typing import Dict, Optional, Tuple, cast import gradio as gr import spaces import torch from datasets import load_dataset from peft import ( LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, ) from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, PreTrainedTokenizerBase, Trainer, TrainingArguments, ) MODEL_ID = "tasal9/ZamAI-mT5-Pashto" BASE_MODEL = "google/mt5-base" DEFAULT_DATASET = "tasal9/ZamAi-Pashto-Datasets-V2" MAX_SEQ_LENGTH = 512 _MODEL_CACHE: Dict[str, Optional[object]] = {"model": None, "tokenizer": None} def _device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" def load_model() -> Tuple[AutoModelForSeq2SeqLM, PreTrainedTokenizerBase]: """Lazy-load the translation model for inference.""" if _MODEL_CACHE["model"] is None or _MODEL_CACHE["tokenizer"] is None: # The uploaded adapter repo does not include spiece.model, so reuse the base tokenizer tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False) model = AutoModelForSeq2SeqLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ) model.to(_device()) _MODEL_CACHE["model"] = model _MODEL_CACHE["tokenizer"] = tokenizer return ( cast(AutoModelForSeq2SeqLM, _MODEL_CACHE["model"]), cast(PreTrainedTokenizerBase, _MODEL_CACHE["tokenizer"]), ) def _direction_prefix(direction: str) -> str: return "translate Pashto to English: " if direction == "ps-en" else "translate English to Pashto: " def _extract_pair(example: dict, direction: str): source_candidates = ["input", "en", "english", "source", "prompt"] target_candidates = ["output", "ps", "pashto", "target", "completion", "answer"] if direction == "ps-en": source_candidates, target_candidates = target_candidates, source_candidates source = next((example.get(key) for key in source_candidates if example.get(key)), None) target = next((example.get(key) for key in target_candidates if example.get(key)), None) if not source or not target: return None, None return str(source).strip(), str(target).strip() @spaces.GPU def translate_text(text: str, direction: str) -> str: text = (text or "").strip() if not text: return "Please provide text to translate." model, translation_tokenizer = cast( Tuple[AutoModelForSeq2SeqLM, PreTrainedTokenizerBase], load_model() ) inputs = translation_tokenizer( # pylint: disable=not-callable _direction_prefix(direction) + text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH, ) if torch.cuda.is_available(): inputs = {k: v.to(_device()) for k, v in inputs.items()} outputs = model.generate( **inputs, max_length=MAX_SEQ_LENGTH, num_beams=4, early_stopping=True, ) return translation_tokenizer.decode(outputs[0], skip_special_tokens=True) def _prepare_dataset(dataset, direction: str): def convert(example): src, tgt = _extract_pair(example, direction) return {"source_text": src, "target_text": tgt} converted = dataset.map(convert) converted = converted.filter(lambda ex: ex["source_text"] and ex["target_text"]) extra_cols = [col for col in converted.column_names if col not in {"source_text", "target_text"}] return converted.remove_columns(extra_cols) @spaces.GPU def start_training( dataset_name: str, direction: str, epochs: int, learning_rate: float, max_train_samples: int, push_to_hub: bool, repo_id: str, ) -> str: try: epochs = int(epochs) max_train_samples = int(max_train_samples) if max_train_samples else None training_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False) model = AutoModelForSeq2SeqLM.from_pretrained( BASE_MODEL, load_in_8bit=True, device_map="auto", ) model = prepare_model_for_kbit_training(model) lora_config = LoraConfig( task_type=TaskType.SEQ_2_SEQ_LM, r=16, lora_alpha=32, lora_dropout=0.05, target_modules=["q", "v"], bias="none", ) model = get_peft_model(model, lora_config) dataset = load_dataset(dataset_name, split="train", verification_mode="no_checks") if max_train_samples and len(dataset) > max_train_samples: dataset = dataset.shuffle(seed=42).select(range(max_train_samples)) dataset = _prepare_dataset(dataset, direction) if len(dataset) == 0: return "โŒ Could not find translation pairs in the dataset." def tokenize_batch(batch): model_inputs = training_tokenizer( [_direction_prefix(direction) + text for text in batch["source_text"]], max_length=MAX_SEQ_LENGTH, truncation=True, padding="max_length", ) labels = training_tokenizer( batch["target_text"], max_length=MAX_SEQ_LENGTH, truncation=True, padding="max_length", ) model_inputs["labels"] = labels["input_ids"] return model_inputs tokenized = dataset.map(tokenize_batch, batched=True, remove_columns=dataset.column_names) data_collator = DataCollatorForSeq2Seq(tokenizer=training_tokenizer, model=model) training_args = TrainingArguments( output_dir="./mt5_translation_outputs", num_train_epochs=epochs, per_device_train_batch_size=1, gradient_accumulation_steps=8, learning_rate=float(learning_rate), warmup_steps=100, logging_steps=10, save_steps=200, save_total_limit=2, fp16=torch.cuda.is_available(), push_to_hub=bool(push_to_hub), hub_model_id=repo_id or MODEL_ID, hub_token=os.getenv("HF_TOKEN"), report_to="none", ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized, data_collator=data_collator, ) trainer.train() trainer.save_model() if push_to_hub: trainer.push_to_hub(commit_message="Update ZamAI mT5 translator adapter") return f"โœ… Training complete! Samples used: {len(tokenized)}" except Exception as exc: # pragma: no cover - runtime feedback is shown in the UI return f"โŒ Training error: {exc}" with gr.Blocks(title="ZamAI mT5 Pashto Translator", theme=gr.themes.Soft()) as demo: gr.Markdown("## ๐Ÿ‡ฆ๐Ÿ‡ซ ZamAI mT5 Pashto Translation + Training") gr.Markdown( f"**Model:** `{MODEL_ID}` ยท **Base:** `{BASE_MODEL}` ยท **Dataset:** `{DEFAULT_DATASET}`" ) with gr.Tab("Translate"): with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Text", placeholder="Type English or Pashto text...", lines=4, ) direction = gr.Dropdown( choices=["en-ps", "ps-en"], value="en-ps", label="Direction", ) translate_btn = gr.Button("Translate", variant="primary") with gr.Column(): translation_output = gr.Textbox( label="Translation", lines=6, interactive=False, ) translate_btn.click( translate_text, inputs=[text_input, direction], outputs=translation_output, ) with gr.Tab("Training"): dataset_name = gr.Textbox(label="Dataset (HF repo)", value=DEFAULT_DATASET) repo_id = gr.Textbox(label="Push trained adapter to", value=MODEL_ID) train_direction = gr.Dropdown( choices=["en-ps", "ps-en"], value="en-ps", label="Training direction" ) epochs = gr.Slider(1, 5, value=1, step=1, label="Epochs") learning_rate = gr.Slider(1e-5, 5e-4, value=2e-4, label="Learning rate") max_samples = gr.Slider(200, 4000, value=1500, step=100, label="Max training samples") push_flag = gr.Checkbox(label="Push to Hugging Face Hub", value=True) train_btn = gr.Button("๐Ÿš€ Start Training", variant="primary") training_status = gr.Textbox(label="Training Status", lines=8, interactive=False) train_btn.click( start_training, inputs=[ dataset_name, train_direction, epochs, learning_rate, max_samples, push_flag, repo_id, ], outputs=training_status, ) with gr.Tab("Tips"): gr.Markdown( """ ### ๐Ÿ“Œ Tips - Works with datasets that expose `input/output`, `en/ps`, or `prompt/completion` columns. - Lower `Max training samples` for quick smoke tests. - Add a valid HF token (Settings โ†’ Tokens) to the space secrets for automatic pushes. """ ) if __name__ == "__main__": demo.launch()