|
|
"""Gradio UI interface for Caribbean Voices OWSM platform.""" |
|
|
import gradio as gr |
|
|
import time |
|
|
import os |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
from utils.status import get_status_display, get_data_loading_status |
|
|
from utils.entities import extract_entities_progress |
|
|
from training.espnet_trainer import run_espnet_training_progress |
|
|
from training.whisper_trainer import run_whisper_training_progress |
|
|
from models.inference import transcribe_audio, run_inference_owsm |
|
|
from models.loader import get_available_models, get_available_checkpoints |
|
|
from data.loader import load_data_from_hf_dataset |
|
|
from utils.logging import get_latest_log_file, get_all_log_files, get_log_directory |
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
"""Create and return the Gradio interface""" |
|
|
interface_start = time.time() |
|
|
|
|
|
with gr.Blocks(title="Caribbean Voices - OWSM Platform") as demo: |
|
|
gr.Markdown(""" |
|
|
<div class="main-header"> |
|
|
<h1>π€ Caribbean Voices Hackathon</h1> |
|
|
<p>OWSM v3.1 Training & Inference Platform</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Tabs() as tabs: |
|
|
|
|
|
with gr.Tab("π Home", id=0): |
|
|
status_display = gr.HTML(value=get_status_display()) |
|
|
refresh_status_btn = gr.Button("π Refresh Status", variant="secondary", size="lg") |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
<div class="nav-buttons-grid"> |
|
|
""") |
|
|
|
|
|
nav_buttons_row1 = gr.Row() |
|
|
with nav_buttons_row1: |
|
|
nav_load_data = gr.Button("π₯ Load Data", variant="primary", size="lg", scale=1) |
|
|
nav_entity_extraction = gr.Button("π Entity Extraction", variant="primary", size="lg", scale=1) |
|
|
nav_training = gr.Button("ποΈ Training", variant="primary", size="lg", scale=1) |
|
|
|
|
|
nav_buttons_row2 = gr.Row() |
|
|
with nav_buttons_row2: |
|
|
nav_inference = gr.Button("π Inference", variant="primary", size="lg", scale=1) |
|
|
nav_single_file = gr.Button("π― Single File", variant="primary", size="lg", scale=1) |
|
|
nav_about = gr.Button("π About", variant="secondary", size="lg", scale=1) |
|
|
|
|
|
gr.Markdown("</div>") |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
<div style="margin-top: 40px; padding: 20px; background: #fff; border-radius: 10px; border: 1px solid #e0e0e0;"> |
|
|
<h2 style="color: #667eea; margin-top: 0;">π About This Project</h2> |
|
|
<p style="font-size: 1.05em; line-height: 1.6; color: #555;"> |
|
|
The <strong>Caribbean Voices Hackathon</strong> project focuses on building an advanced Automatic Speech Recognition (ASR) |
|
|
system using OWSM v3.1 (Open Whisper-Style Model). This platform enables fine-tuning on Caribbean-accented speech |
|
|
with specialized entity extraction and contextual biasing for improved recognition of Caribbean proper nouns, |
|
|
locations, and organizations. |
|
|
</p> |
|
|
<h3 style="color: #667eea;">Key Features</h3> |
|
|
<ul style="font-size: 1.05em; line-height: 1.8; color: #555;"> |
|
|
<li><strong>Entity Extraction:</strong> Automatically identifies Caribbean-specific entities from training transcripts</li> |
|
|
<li><strong>OWSM Fine-tuning:</strong> Fine-tune the OWSM v3.1 model with entity-weighted loss</li> |
|
|
<li><strong>Batch Inference:</strong> Process entire test sets efficiently</li> |
|
|
<li><strong>Single File Testing:</strong> Quick transcription with multiple model options</li> |
|
|
<li><strong>ESPnet Integration:</strong> Full support for ESPnet training recipes</li> |
|
|
</ul> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
def refresh_status(): |
|
|
return get_status_display() |
|
|
|
|
|
refresh_status_btn.click( |
|
|
fn=refresh_status, |
|
|
outputs=[status_display] |
|
|
) |
|
|
|
|
|
|
|
|
nav_load_data.click( |
|
|
None, None, None, |
|
|
js="() => { setTimeout(() => { const tabs = document.querySelectorAll('button[role=\\'tab\\']'); if(tabs[1]) tabs[1].click(); }, 100); }" |
|
|
) |
|
|
nav_entity_extraction.click( |
|
|
None, None, None, |
|
|
js="() => { setTimeout(() => { const tabs = document.querySelectorAll('button[role=\\'tab\\']'); if(tabs[2]) tabs[2].click(); }, 100); }" |
|
|
) |
|
|
nav_training.click( |
|
|
None, None, None, |
|
|
js="() => { setTimeout(() => { const tabs = document.querySelectorAll('button[role=\\'tab\\']'); if(tabs[3]) tabs[3].click(); }, 100); }" |
|
|
) |
|
|
nav_inference.click( |
|
|
None, None, None, |
|
|
js="() => { setTimeout(() => { const tabs = document.querySelectorAll('button[role=\\'tab\\']'); if(tabs[4]) tabs[4].click(); }, 100); }" |
|
|
) |
|
|
nav_single_file.click( |
|
|
None, None, None, |
|
|
js="() => { setTimeout(() => { const tabs = document.querySelectorAll('button[role=\\'tab\\']'); if(tabs[5]) tabs[5].click(); }, 100); }" |
|
|
) |
|
|
nav_about.click( |
|
|
None, None, None, |
|
|
js="() => { setTimeout(() => { const tabs = document.querySelectorAll('button[role=\\'tab\\']'); if(tabs[6]) tabs[6].click(); }, 100); }" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("π₯ Load Data"): |
|
|
gr.Markdown("### Load Dataset into HF Space") |
|
|
|
|
|
|
|
|
data_status_display = gr.Markdown(value=get_data_loading_status()) |
|
|
refresh_data_status_btn = gr.Button("π Refresh Status", variant="secondary", size="sm") |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
|
|
|
### Load Dataset |
|
|
|
|
|
Data is automatically loaded from the Hugging Face dataset on startup. |
|
|
You can manually load a different dataset below if needed. |
|
|
""") |
|
|
|
|
|
hf_dataset_name = gr.Textbox( |
|
|
label="Hugging Face Dataset Name", |
|
|
placeholder="username/dataset-name", |
|
|
value="" |
|
|
) |
|
|
hf_load_btn = gr.Button("Load from HF Dataset", variant="primary") |
|
|
hf_load_output = gr.Markdown() |
|
|
|
|
|
|
|
|
def refresh_data_status(): |
|
|
return get_data_loading_status() |
|
|
|
|
|
refresh_data_status_btn.click( |
|
|
fn=refresh_data_status, |
|
|
outputs=[data_status_display] |
|
|
) |
|
|
|
|
|
def load_hf_and_refresh(dataset_name, progress=gr.Progress()): |
|
|
result = load_data_from_hf_dataset(dataset_name, progress) |
|
|
return result, get_data_loading_status() |
|
|
|
|
|
hf_load_btn.click( |
|
|
fn=load_hf_and_refresh, |
|
|
inputs=[hf_dataset_name], |
|
|
outputs=[hf_load_output, data_status_display] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("π Entity Extraction"): |
|
|
gr.Markdown("### Extract Caribbean Entities from Training Data") |
|
|
gr.Markdown(""" |
|
|
This extracts high-value Caribbean entities (proper nouns, locations, organizations) |
|
|
from the training transcripts. These entities will be used for: |
|
|
- Entity-weighted loss during training |
|
|
- Contextual biasing during inference |
|
|
""") |
|
|
|
|
|
extract_btn = gr.Button("Extract Entities", variant="primary") |
|
|
extract_output = gr.Markdown() |
|
|
extract_json = gr.JSON(label="Entities JSON") |
|
|
|
|
|
extract_btn.click( |
|
|
fn=extract_entities_progress, |
|
|
outputs=[extract_output, extract_json] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("ποΈ Training"): |
|
|
gr.Markdown("### Model Training") |
|
|
gr.Markdown(""" |
|
|
Choose your training framework: |
|
|
- **ESPnet Training**: For ESPnet OWSM models (requires ESPnet recipes) |
|
|
- **Whisper Training**: For Whisper models (full HuggingFace integration) |
|
|
""") |
|
|
|
|
|
with gr.Tabs() as training_tabs: |
|
|
|
|
|
with gr.Tab("π§ ESPnet Training"): |
|
|
gr.Markdown("### ESPnet OWSM Model Training") |
|
|
gr.Markdown(""" |
|
|
**ESPnet Training** - Uses ESPnet's native framework. |
|
|
|
|
|
This loads ESPnet models and prepares them for training with ESPnet recipes. |
|
|
Full fine-tuning requires ESPnet training recipes. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
espnet_train_epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs (for ESPnet recipes)") |
|
|
espnet_train_batch_size = gr.Slider(1, 32, value=4, step=1, label="Batch Size (for ESPnet recipes)") |
|
|
espnet_train_lr = gr.Slider(1e-6, 1e-3, value=3e-5, step=1e-6, label="Learning Rate (for ESPnet recipes)") |
|
|
espnet_train_btn = gr.Button("Load ESPnet Model", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
espnet_train_output = gr.Markdown() |
|
|
espnet_train_metrics = gr.JSON(label="Model Info") |
|
|
|
|
|
espnet_train_btn.click( |
|
|
fn=run_espnet_training_progress, |
|
|
inputs=[espnet_train_epochs, espnet_train_batch_size, espnet_train_lr], |
|
|
outputs=[espnet_train_output, espnet_train_metrics] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("π€ Whisper Training"): |
|
|
gr.Markdown("### Whisper Model Training") |
|
|
gr.Markdown(""" |
|
|
**Whisper Training** - Full HuggingFace transformers integration. |
|
|
|
|
|
Fine-tune Whisper models with entity-weighted loss using HuggingFace's training framework. |
|
|
Includes full support for HuggingFace features like early stopping, WER metrics, etc. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Training Hyperparameters") |
|
|
whisper_train_epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs") |
|
|
whisper_train_batch_size = gr.Slider(1, 32, value=4, step=1, label="Batch Size") |
|
|
whisper_train_lr = gr.Slider(1e-6, 1e-3, value=3e-5, step=1e-6, label="Learning Rate") |
|
|
|
|
|
gr.Markdown("#### Speed Augmentation") |
|
|
gr.Markdown("Speed factors for dataset expansion (creates multiple versions of each sample)") |
|
|
speed_aug_enabled = gr.Checkbox(value=True, label="Enable Speed Augmentation") |
|
|
speed_factor_min = gr.Slider(0.8, 1.0, value=0.9, step=0.05, label="Min Speed Factor") |
|
|
speed_factor_max = gr.Slider(1.0, 1.2, value=1.1, step=0.05, label="Max Speed Factor") |
|
|
speed_factor_count = gr.Slider(2, 5, value=3, step=1, label="Number of Speed Variants") |
|
|
|
|
|
gr.Markdown("#### SpecAugment Parameters") |
|
|
gr.Markdown("Spectrogram augmentation settings (applied during training)") |
|
|
specaug_enabled = gr.Checkbox(value=True, label="Enable SpecAugment") |
|
|
specaug_time_mask = gr.Slider(0, 50, value=27, step=1, label="Time Mask Parameter") |
|
|
specaug_freq_mask = gr.Slider(0, 20, value=10, step=1, label="Frequency Mask Parameter") |
|
|
specaug_time_warp = gr.Checkbox(value=True, label="Enable Time Warping") |
|
|
specaug_warp_param = gr.Slider(0, 80, value=40, step=5, label="Time Warp Parameter") |
|
|
|
|
|
whisper_train_btn = gr.Button("Start Whisper Training", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
whisper_train_output = gr.Markdown() |
|
|
whisper_train_metrics = gr.JSON(label="Training Metrics") |
|
|
|
|
|
gr.Markdown("#### Training Logs") |
|
|
log_info = gr.Markdown(f"Log directory: `{get_log_directory()}`") |
|
|
latest_log_file = gr.File( |
|
|
label="Download Latest Training Log", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
def update_log_download(): |
|
|
latest = get_latest_log_file("whisper_training") |
|
|
if latest and os.path.exists(latest): |
|
|
return gr.File(value=latest, visible=True) |
|
|
return gr.File(visible=False) |
|
|
|
|
|
refresh_log_btn = gr.Button("π Refresh Logs", variant="secondary", size="sm") |
|
|
refresh_log_btn.click( |
|
|
fn=update_log_download, |
|
|
outputs=[latest_log_file] |
|
|
) |
|
|
|
|
|
def run_training_with_log_refresh( |
|
|
epochs, batch_size, lr, |
|
|
speed_aug_enabled, speed_factor_min, speed_factor_max, speed_factor_count, |
|
|
specaug_enabled, specaug_time_mask, specaug_freq_mask, specaug_time_warp, specaug_warp_param, |
|
|
progress=gr.Progress() |
|
|
): |
|
|
"""Run training and refresh log download after completion.""" |
|
|
result = run_whisper_training_progress( |
|
|
epochs, batch_size, lr, |
|
|
speed_aug_enabled, speed_factor_min, speed_factor_max, speed_factor_count, |
|
|
specaug_enabled, specaug_time_mask, specaug_freq_mask, specaug_time_warp, specaug_warp_param, |
|
|
progress |
|
|
) |
|
|
latest_log = update_log_download() |
|
|
return result[0], result[1], latest_log |
|
|
|
|
|
whisper_train_btn.click( |
|
|
fn=run_training_with_log_refresh, |
|
|
inputs=[ |
|
|
whisper_train_epochs, |
|
|
whisper_train_batch_size, |
|
|
whisper_train_lr, |
|
|
speed_aug_enabled, |
|
|
speed_factor_min, |
|
|
speed_factor_max, |
|
|
speed_factor_count, |
|
|
specaug_enabled, |
|
|
specaug_time_mask, |
|
|
specaug_freq_mask, |
|
|
specaug_time_warp, |
|
|
specaug_warp_param, |
|
|
], |
|
|
outputs=[whisper_train_output, whisper_train_metrics, latest_log_file] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("π Inference"): |
|
|
gr.Markdown("### Run Inference on Test Set") |
|
|
gr.Markdown("Generate transcriptions for all test files using a trained checkpoint or base model") |
|
|
|
|
|
|
|
|
checkpoint_choices = get_available_checkpoints() |
|
|
if not checkpoint_choices: |
|
|
checkpoint_choices = ["No checkpoints available - train a model first"] |
|
|
checkpoint_default = checkpoint_choices[0] |
|
|
else: |
|
|
checkpoint_default = checkpoint_choices[0] if checkpoint_choices else None |
|
|
|
|
|
checkpoint_dropdown = gr.Dropdown( |
|
|
choices=checkpoint_choices, |
|
|
value=checkpoint_default, |
|
|
label="Select Checkpoint/Model", |
|
|
info="Choose a trained checkpoint or base model for inference" |
|
|
) |
|
|
|
|
|
def refresh_checkpoints(): |
|
|
"""Refresh checkpoint list""" |
|
|
checkpoints = get_available_checkpoints() |
|
|
if not checkpoints: |
|
|
return gr.Dropdown(choices=["No checkpoints available - train a model first"], value="No checkpoints available - train a model first") |
|
|
return gr.Dropdown(choices=checkpoints, value=checkpoints[0]) |
|
|
|
|
|
refresh_checkpoints_btn = gr.Button("π Refresh Checkpoint List", variant="secondary", size="sm") |
|
|
refresh_checkpoints_btn.click( |
|
|
fn=refresh_checkpoints, |
|
|
outputs=[checkpoint_dropdown] |
|
|
) |
|
|
|
|
|
infer_btn = gr.Button("Run Inference", variant="primary") |
|
|
infer_output = gr.Markdown() |
|
|
infer_download = gr.File(label="Download Submission CSV") |
|
|
|
|
|
infer_btn.click( |
|
|
fn=run_inference_owsm, |
|
|
inputs=[checkpoint_dropdown], |
|
|
outputs=[infer_output, infer_download] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("π― Single File"): |
|
|
gr.Markdown("### Transcribe a Single Audio File") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
audio_input = gr.Audio( |
|
|
label="Upload Audio File", |
|
|
type="filepath", |
|
|
sources=["upload", "microphone"] |
|
|
) |
|
|
model_choice = gr.Dropdown( |
|
|
choices=get_available_models(), |
|
|
value=get_available_models()[0], |
|
|
label="Select Model" |
|
|
) |
|
|
max_seconds = gr.Slider(5, 60, value=30, step=5, label="Max Audio Length (seconds)") |
|
|
transcribe_btn = gr.Button("Transcribe", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
transcription_output = gr.Textbox( |
|
|
label="Transcription", |
|
|
lines=5, |
|
|
placeholder="Transcription will appear here..." |
|
|
) |
|
|
info_output = gr.Markdown(label="Processing Info") |
|
|
|
|
|
transcribe_btn.click( |
|
|
fn=transcribe_audio, |
|
|
inputs=[audio_input, model_choice, max_seconds], |
|
|
outputs=[transcription_output, info_output] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("π About"): |
|
|
gr.Markdown(""" |
|
|
## Caribbean Voices Hackathon - OWSM v3.1 Platform |
|
|
|
|
|
### Features |
|
|
- **Entity Extraction**: Extract Caribbean entities from training data |
|
|
- **Model Training**: Fine-tune OWSM v3.1 with entity-weighted loss |
|
|
- **Batch Inference**: Generate transcriptions for test set |
|
|
- **Single File Transcription**: Quick transcription with multiple models |
|
|
|
|
|
### OWSM v3.1 Features |
|
|
- **Emergent Contextual Biasing**: Improves proper noun recognition |
|
|
- **Entity-Weighted Loss**: Prioritizes Caribbean entities during training |
|
|
- **Competition Compliant**: Single model, no external data |
|
|
|
|
|
### Available Models |
|
|
- **Wav2Vec2 Models**: Fast baseline models |
|
|
- **OWSM v3.1 Small**: Open Whisper-style model with ESPnet |
|
|
|
|
|
### Workflow |
|
|
1. **Extract Entities**: Run entity extraction on training data |
|
|
2. **Train Model**: |
|
|
- **ESPnet Training**: Load ESPnet models (requires ESPnet recipes for fine-tuning) |
|
|
- **Whisper Training**: Full HuggingFace fine-tuning with entity-weighted loss |
|
|
3. **Run Inference**: Generate test set transcriptions |
|
|
4. **Download Results**: Get submission CSV file |
|
|
|
|
|
### Technical Details |
|
|
- **ESPnet Framework**: ESPnet + PyTorch for ESPnet OWSM models |
|
|
- **Whisper Framework**: HuggingFace transformers for Whisper models |
|
|
- **Model**: OWSM v3.1 E-Branchformer (ESPnet) or Whisper (HuggingFace) |
|
|
- **Entity Extraction**: Frequency + capitalization analysis |
|
|
- **Training**: Entity-weighted cross-entropy loss |
|
|
|
|
|
### Documentation |
|
|
See `ESPNET_OWSM_SETUP.md` and `IMPLEMENTATION_SUMMARY.md` for details. |
|
|
""") |
|
|
|
|
|
interface_time = time.time() - interface_start |
|
|
timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] |
|
|
print(f"[{timestamp}] β±οΈ Total interface creation: {interface_time:.3f}s") |
|
|
|
|
|
|
|
|
css_path = Path(__file__).parent / "styles.css" |
|
|
return demo, css_path |
|
|
|
|
|
|