"""Gradio UI interface for Caribbean Voices OWSM platform."""
import gradio as gr
import time
import os
from pathlib import Path
from datetime import datetime
# Import modules
from utils.status import get_status_display, get_data_loading_status
from utils.entities import extract_entities_progress
from training.espnet_trainer import run_espnet_training_progress
from training.whisper_trainer import run_whisper_training_progress
from models.inference import transcribe_audio, run_inference_owsm
from models.loader import get_available_models, get_available_checkpoints
from data.loader import load_data_from_hf_dataset
from utils.logging import get_latest_log_file, get_all_log_files, get_log_directory
def create_interface():
"""Create and return the Gradio interface"""
interface_start = time.time()
with gr.Blocks(title="Caribbean Voices - OWSM Platform") as demo:
gr.Markdown("""
🎤 Caribbean Voices Hackathon
OWSM v3.1 Training & Inference Platform
""")
with gr.Tabs() as tabs:
# Tab 1: Status & Setup (Homepage)
with gr.Tab("🏠 Home", id=0):
status_display = gr.HTML(value=get_status_display())
refresh_status_btn = gr.Button("🔄 Refresh Status", variant="secondary", size="lg")
# Navigation buttons
gr.Markdown("""
""")
nav_buttons_row1 = gr.Row()
with nav_buttons_row1:
nav_load_data = gr.Button("📥 Load Data", variant="primary", size="lg", scale=1)
nav_entity_extraction = gr.Button("🔍 Entity Extraction", variant="primary", size="lg", scale=1)
nav_training = gr.Button("🏋️ Training", variant="primary", size="lg", scale=1)
nav_buttons_row2 = gr.Row()
with nav_buttons_row2:
nav_inference = gr.Button("🚀 Inference", variant="primary", size="lg", scale=1)
nav_single_file = gr.Button("🎯 Single File", variant="primary", size="lg", scale=1)
nav_about = gr.Button("📚 About", variant="secondary", size="lg", scale=1)
gr.Markdown("
")
# Add project info section
gr.Markdown("""
📊 About This Project
The Caribbean Voices Hackathon project focuses on building an advanced Automatic Speech Recognition (ASR)
system using OWSM v3.1 (Open Whisper-Style Model). This platform enables fine-tuning on Caribbean-accented speech
with specialized entity extraction and contextual biasing for improved recognition of Caribbean proper nouns,
locations, and organizations.
Key Features
- Entity Extraction: Automatically identifies Caribbean-specific entities from training transcripts
- OWSM Fine-tuning: Fine-tune the OWSM v3.1 model with entity-weighted loss
- Batch Inference: Process entire test sets efficiently
- Single File Testing: Quick transcription with multiple model options
- ESPnet Integration: Full support for ESPnet training recipes
""")
def refresh_status():
return get_status_display()
refresh_status_btn.click(
fn=refresh_status,
outputs=[status_display]
)
# Navigation button handlers - use JavaScript to switch tabs
nav_load_data.click(
None, None, None,
js="() => { setTimeout(() => { const tabs = document.querySelectorAll('button[role=\\'tab\\']'); if(tabs[1]) tabs[1].click(); }, 100); }"
)
nav_entity_extraction.click(
None, None, None,
js="() => { setTimeout(() => { const tabs = document.querySelectorAll('button[role=\\'tab\\']'); if(tabs[2]) tabs[2].click(); }, 100); }"
)
nav_training.click(
None, None, None,
js="() => { setTimeout(() => { const tabs = document.querySelectorAll('button[role=\\'tab\\']'); if(tabs[3]) tabs[3].click(); }, 100); }"
)
nav_inference.click(
None, None, None,
js="() => { setTimeout(() => { const tabs = document.querySelectorAll('button[role=\\'tab\\']'); if(tabs[4]) tabs[4].click(); }, 100); }"
)
nav_single_file.click(
None, None, None,
js="() => { setTimeout(() => { const tabs = document.querySelectorAll('button[role=\\'tab\\']'); if(tabs[5]) tabs[5].click(); }, 100); }"
)
nav_about.click(
None, None, None,
js="() => { setTimeout(() => { const tabs = document.querySelectorAll('button[role=\\'tab\\']'); if(tabs[6]) tabs[6].click(); }, 100); }"
)
# Tab 2: Data Loading
with gr.Tab("📥 Load Data"):
gr.Markdown("### Load Dataset into HF Space")
# Show current data status
data_status_display = gr.Markdown(value=get_data_loading_status())
refresh_data_status_btn = gr.Button("🔄 Refresh Status", variant="secondary", size="sm")
gr.Markdown("""
---
### Load Dataset
Data is automatically loaded from the Hugging Face dataset on startup.
You can manually load a different dataset below if needed.
""")
hf_dataset_name = gr.Textbox(
label="Hugging Face Dataset Name",
placeholder="username/dataset-name",
value=""
)
hf_load_btn = gr.Button("Load from HF Dataset", variant="primary")
hf_load_output = gr.Markdown()
# Refresh data status when buttons are clicked
def refresh_data_status():
return get_data_loading_status()
refresh_data_status_btn.click(
fn=refresh_data_status,
outputs=[data_status_display]
)
def load_hf_and_refresh(dataset_name, progress=gr.Progress()):
result = load_data_from_hf_dataset(dataset_name, progress)
return result, get_data_loading_status()
hf_load_btn.click(
fn=load_hf_and_refresh,
inputs=[hf_dataset_name],
outputs=[hf_load_output, data_status_display]
)
# Tab 3: Entity Extraction
with gr.Tab("🔍 Entity Extraction"):
gr.Markdown("### Extract Caribbean Entities from Training Data")
gr.Markdown("""
This extracts high-value Caribbean entities (proper nouns, locations, organizations)
from the training transcripts. These entities will be used for:
- Entity-weighted loss during training
- Contextual biasing during inference
""")
extract_btn = gr.Button("Extract Entities", variant="primary")
extract_output = gr.Markdown()
extract_json = gr.JSON(label="Entities JSON")
extract_btn.click(
fn=extract_entities_progress,
outputs=[extract_output, extract_json]
)
# Tab 4: Training (with sub-tabs for ESPnet and Whisper)
with gr.Tab("🏋️ Training"):
gr.Markdown("### Model Training")
gr.Markdown("""
Choose your training framework:
- **ESPnet Training**: For ESPnet OWSM models (requires ESPnet recipes)
- **Whisper Training**: For Whisper models (full HuggingFace integration)
""")
with gr.Tabs() as training_tabs:
# ESPnet Training Tab
with gr.Tab("🔧 ESPnet Training"):
gr.Markdown("### ESPnet OWSM Model Training")
gr.Markdown("""
**ESPnet Training** - Uses ESPnet's native framework.
This loads ESPnet models and prepares them for training with ESPnet recipes.
Full fine-tuning requires ESPnet training recipes.
""")
with gr.Row():
with gr.Column():
espnet_train_epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs (for ESPnet recipes)")
espnet_train_batch_size = gr.Slider(1, 32, value=4, step=1, label="Batch Size (for ESPnet recipes)")
espnet_train_lr = gr.Slider(1e-6, 1e-3, value=3e-5, step=1e-6, label="Learning Rate (for ESPnet recipes)")
espnet_train_btn = gr.Button("Load ESPnet Model", variant="primary")
with gr.Column():
espnet_train_output = gr.Markdown()
espnet_train_metrics = gr.JSON(label="Model Info")
espnet_train_btn.click(
fn=run_espnet_training_progress,
inputs=[espnet_train_epochs, espnet_train_batch_size, espnet_train_lr],
outputs=[espnet_train_output, espnet_train_metrics]
)
# Whisper Training Tab
with gr.Tab("🎤 Whisper Training"):
gr.Markdown("### Whisper Model Training")
gr.Markdown("""
**Whisper Training** - Full HuggingFace transformers integration.
Fine-tune Whisper models with entity-weighted loss using HuggingFace's training framework.
Includes full support for HuggingFace features like early stopping, WER metrics, etc.
""")
with gr.Row():
with gr.Column():
gr.Markdown("#### Training Hyperparameters")
whisper_train_epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
whisper_train_batch_size = gr.Slider(1, 32, value=4, step=1, label="Batch Size")
whisper_train_lr = gr.Slider(1e-6, 1e-3, value=3e-5, step=1e-6, label="Learning Rate")
gr.Markdown("#### Speed Augmentation")
gr.Markdown("Speed factors for dataset expansion (creates multiple versions of each sample)")
speed_aug_enabled = gr.Checkbox(value=True, label="Enable Speed Augmentation")
speed_factor_min = gr.Slider(0.8, 1.0, value=0.9, step=0.05, label="Min Speed Factor")
speed_factor_max = gr.Slider(1.0, 1.2, value=1.1, step=0.05, label="Max Speed Factor")
speed_factor_count = gr.Slider(2, 5, value=3, step=1, label="Number of Speed Variants")
gr.Markdown("#### SpecAugment Parameters")
gr.Markdown("Spectrogram augmentation settings (applied during training)")
specaug_enabled = gr.Checkbox(value=True, label="Enable SpecAugment")
specaug_time_mask = gr.Slider(0, 50, value=27, step=1, label="Time Mask Parameter")
specaug_freq_mask = gr.Slider(0, 20, value=10, step=1, label="Frequency Mask Parameter")
specaug_time_warp = gr.Checkbox(value=True, label="Enable Time Warping")
specaug_warp_param = gr.Slider(0, 80, value=40, step=5, label="Time Warp Parameter")
whisper_train_btn = gr.Button("Start Whisper Training", variant="primary", size="lg")
with gr.Column():
whisper_train_output = gr.Markdown()
whisper_train_metrics = gr.JSON(label="Training Metrics")
gr.Markdown("#### Training Logs")
log_info = gr.Markdown(f"Log directory: `{get_log_directory()}`")
latest_log_file = gr.File(
label="Download Latest Training Log",
visible=False
)
def update_log_download():
latest = get_latest_log_file("whisper_training")
if latest and os.path.exists(latest):
return gr.File(value=latest, visible=True)
return gr.File(visible=False)
refresh_log_btn = gr.Button("🔄 Refresh Logs", variant="secondary", size="sm")
refresh_log_btn.click(
fn=update_log_download,
outputs=[latest_log_file]
)
def run_training_with_log_refresh(
epochs, batch_size, lr,
speed_aug_enabled, speed_factor_min, speed_factor_max, speed_factor_count,
specaug_enabled, specaug_time_mask, specaug_freq_mask, specaug_time_warp, specaug_warp_param,
progress=gr.Progress()
):
"""Run training and refresh log download after completion."""
result = run_whisper_training_progress(
epochs, batch_size, lr,
speed_aug_enabled, speed_factor_min, speed_factor_max, speed_factor_count,
specaug_enabled, specaug_time_mask, specaug_freq_mask, specaug_time_warp, specaug_warp_param,
progress
)
latest_log = update_log_download()
return result[0], result[1], latest_log
whisper_train_btn.click(
fn=run_training_with_log_refresh,
inputs=[
whisper_train_epochs,
whisper_train_batch_size,
whisper_train_lr,
speed_aug_enabled,
speed_factor_min,
speed_factor_max,
speed_factor_count,
specaug_enabled,
specaug_time_mask,
specaug_freq_mask,
specaug_time_warp,
specaug_warp_param,
],
outputs=[whisper_train_output, whisper_train_metrics, latest_log_file]
)
# Tab 5: Inference
with gr.Tab("🚀 Inference"):
gr.Markdown("### Run Inference on Test Set")
gr.Markdown("Generate transcriptions for all test files using a trained checkpoint or base model")
# Checkpoint selection
checkpoint_choices = get_available_checkpoints()
if not checkpoint_choices:
checkpoint_choices = ["No checkpoints available - train a model first"]
checkpoint_default = checkpoint_choices[0]
else:
checkpoint_default = checkpoint_choices[0] if checkpoint_choices else None
checkpoint_dropdown = gr.Dropdown(
choices=checkpoint_choices,
value=checkpoint_default,
label="Select Checkpoint/Model",
info="Choose a trained checkpoint or base model for inference"
)
def refresh_checkpoints():
"""Refresh checkpoint list"""
checkpoints = get_available_checkpoints()
if not checkpoints:
return gr.Dropdown(choices=["No checkpoints available - train a model first"], value="No checkpoints available - train a model first")
return gr.Dropdown(choices=checkpoints, value=checkpoints[0])
refresh_checkpoints_btn = gr.Button("🔄 Refresh Checkpoint List", variant="secondary", size="sm")
refresh_checkpoints_btn.click(
fn=refresh_checkpoints,
outputs=[checkpoint_dropdown]
)
infer_btn = gr.Button("Run Inference", variant="primary")
infer_output = gr.Markdown()
infer_download = gr.File(label="Download Submission CSV")
infer_btn.click(
fn=run_inference_owsm,
inputs=[checkpoint_dropdown],
outputs=[infer_output, infer_download]
)
# Tab 6: Single File Transcription
with gr.Tab("🎯 Single File"):
gr.Markdown("### Transcribe a Single Audio File")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="Upload Audio File",
type="filepath",
sources=["upload", "microphone"]
)
model_choice = gr.Dropdown(
choices=get_available_models(),
value=get_available_models()[0],
label="Select Model"
)
max_seconds = gr.Slider(5, 60, value=30, step=5, label="Max Audio Length (seconds)")
transcribe_btn = gr.Button("Transcribe", variant="primary")
with gr.Column():
transcription_output = gr.Textbox(
label="Transcription",
lines=5,
placeholder="Transcription will appear here..."
)
info_output = gr.Markdown(label="Processing Info")
transcribe_btn.click(
fn=transcribe_audio,
inputs=[audio_input, model_choice, max_seconds],
outputs=[transcription_output, info_output]
)
# Tab 7: About
with gr.Tab("📚 About"):
gr.Markdown("""
## Caribbean Voices Hackathon - OWSM v3.1 Platform
### Features
- **Entity Extraction**: Extract Caribbean entities from training data
- **Model Training**: Fine-tune OWSM v3.1 with entity-weighted loss
- **Batch Inference**: Generate transcriptions for test set
- **Single File Transcription**: Quick transcription with multiple models
### OWSM v3.1 Features
- **Emergent Contextual Biasing**: Improves proper noun recognition
- **Entity-Weighted Loss**: Prioritizes Caribbean entities during training
- **Competition Compliant**: Single model, no external data
### Available Models
- **Wav2Vec2 Models**: Fast baseline models
- **OWSM v3.1 Small**: Open Whisper-style model with ESPnet
### Workflow
1. **Extract Entities**: Run entity extraction on training data
2. **Train Model**:
- **ESPnet Training**: Load ESPnet models (requires ESPnet recipes for fine-tuning)
- **Whisper Training**: Full HuggingFace fine-tuning with entity-weighted loss
3. **Run Inference**: Generate test set transcriptions
4. **Download Results**: Get submission CSV file
### Technical Details
- **ESPnet Framework**: ESPnet + PyTorch for ESPnet OWSM models
- **Whisper Framework**: HuggingFace transformers for Whisper models
- **Model**: OWSM v3.1 E-Branchformer (ESPnet) or Whisper (HuggingFace)
- **Entity Extraction**: Frequency + capitalization analysis
- **Training**: Entity-weighted cross-entropy loss
### Documentation
See `ESPNET_OWSM_SETUP.md` and `IMPLEMENTATION_SUMMARY.md` for details.
""")
interface_time = time.time() - interface_start
timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3]
print(f"[{timestamp}] ⏱️ Total interface creation: {interface_time:.3f}s")
# Return demo and CSS path for Gradio 6.x (CSS goes in launch())
css_path = Path(__file__).parent / "styles.css"
return demo, css_path