Fix training pipeline: use composition for model wrapper and HF Datasets for audio loading
Browse files- owsm_model.py +103 -39
- training/trainer.py +71 -208
owsm_model.py
CHANGED
|
@@ -4,31 +4,30 @@ This implements loss re-weighting for proper nouns without external data.
|
|
| 4 |
"""
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
-
from transformers import AutoModelForSpeechSeq2Seq
|
| 8 |
-
from
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
class OWSMWithEntityLoss(AutoModelForSpeechSeq2Seq):
|
| 12 |
"""
|
| 13 |
-
OWSM model
|
| 14 |
-
|
| 15 |
-
This is fully compliant with competition rules:
|
| 16 |
-
- No external data (entities come from training transcripts only)
|
| 17 |
-
- Single model (no reranker, no second model)
|
| 18 |
-
- Reproducible (deterministic loss computation)
|
| 19 |
|
| 20 |
-
|
|
|
|
| 21 |
"""
|
| 22 |
|
| 23 |
-
def __init__(self, config, tokenizer, high_value_tokens: Set[str], entity_weight: float = 3.0):
|
| 24 |
"""
|
| 25 |
Args:
|
| 26 |
config: Model configuration
|
|
|
|
| 27 |
tokenizer: Tokenizer for converting entity words to token IDs
|
| 28 |
high_value_tokens: Set of entity words (lowercase) to up-weight
|
| 29 |
entity_weight: Multiplier for entity token errors (default: 3.0)
|
| 30 |
"""
|
| 31 |
super().__init__(config)
|
|
|
|
| 32 |
self.tokenizer = tokenizer
|
| 33 |
self.entity_weight = entity_weight
|
| 34 |
|
|
@@ -47,11 +46,13 @@ class OWSMWithEntityLoss(AutoModelForSpeechSeq2Seq):
|
|
| 47 |
all_entity_token_ids.update(token_id_set)
|
| 48 |
|
| 49 |
print(f" → Mapped to {len(all_entity_token_ids)} unique token IDs")
|
| 50 |
-
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# Pre-compute vocab_weights tensor for O(1) lookup during training
|
| 53 |
vocab_size = config.vocab_size if hasattr(config, 'vocab_size') else len(tokenizer)
|
| 54 |
-
self.vocab_weights
|
| 55 |
|
| 56 |
# Set entity token weights
|
| 57 |
for token_id in all_entity_token_ids:
|
|
@@ -61,59 +62,122 @@ class OWSMWithEntityLoss(AutoModelForSpeechSeq2Seq):
|
|
| 61 |
# Store for debugging
|
| 62 |
self.entity_token_ids = all_entity_token_ids
|
| 63 |
self.high_value_tokens = high_value_tokens
|
| 64 |
-
|
| 65 |
-
def
|
| 66 |
-
"""
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
# Flatten
|
| 74 |
-
|
| 75 |
-
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
loss = loss_fct(flat_logits, flat_labels)
|
| 80 |
|
| 81 |
-
#
|
| 82 |
-
#
|
| 83 |
-
|
| 84 |
-
self.vocab_weights = self.vocab_weights.to(flat_labels.device)
|
| 85 |
|
| 86 |
-
#
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
# Apply weights
|
| 90 |
weighted_loss = loss * weights
|
| 91 |
|
| 92 |
-
#
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
| 95 |
|
| 96 |
if weighted_loss.numel() == 0:
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
@classmethod
|
| 102 |
def from_pretrained(cls, pretrained_model_name_or_path: str,
|
| 103 |
tokenizer, high_value_tokens: Set[str],
|
| 104 |
entity_weight: float = 3.0, **kwargs):
|
| 105 |
"""Load pretrained OWSM model and wrap with entity-weighted loss."""
|
|
|
|
| 106 |
base_model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
| 107 |
pretrained_model_name_or_path, **kwargs
|
| 108 |
)
|
| 109 |
|
|
|
|
| 110 |
model = cls(
|
| 111 |
config=base_model.config,
|
|
|
|
| 112 |
tokenizer=tokenizer,
|
| 113 |
high_value_tokens=high_value_tokens,
|
| 114 |
entity_weight=entity_weight
|
| 115 |
)
|
| 116 |
|
| 117 |
-
model.load_state_dict(base_model.state_dict(), strict=True)
|
| 118 |
return model
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
+
from transformers import AutoModelForSpeechSeq2Seq, PreTrainedModel
|
| 8 |
+
from transformers.modeling_outputs import Seq2SeqLMOutput
|
| 9 |
+
from typing import Set, Optional, Dict, Any
|
| 10 |
|
| 11 |
+
class OWSMWithEntityLoss(PreTrainedModel):
|
|
|
|
| 12 |
"""
|
| 13 |
+
Wrapper around OWSM model that implements weighted cross-entropy loss
|
| 14 |
+
to up-weight errors on entity tokens.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
This model wraps the base model using composition rather than inheritance
|
| 17 |
+
to avoid issues with the AutoModel factory pattern.
|
| 18 |
"""
|
| 19 |
|
| 20 |
+
def __init__(self, config, base_model, tokenizer, high_value_tokens: Set[str], entity_weight: float = 3.0):
|
| 21 |
"""
|
| 22 |
Args:
|
| 23 |
config: Model configuration
|
| 24 |
+
base_model: The instantiated base model (SpeechEncoderDecoderModel)
|
| 25 |
tokenizer: Tokenizer for converting entity words to token IDs
|
| 26 |
high_value_tokens: Set of entity words (lowercase) to up-weight
|
| 27 |
entity_weight: Multiplier for entity token errors (default: 3.0)
|
| 28 |
"""
|
| 29 |
super().__init__(config)
|
| 30 |
+
self.model = base_model
|
| 31 |
self.tokenizer = tokenizer
|
| 32 |
self.entity_weight = entity_weight
|
| 33 |
|
|
|
|
| 46 |
all_entity_token_ids.update(token_id_set)
|
| 47 |
|
| 48 |
print(f" → Mapped to {len(all_entity_token_ids)} unique token IDs")
|
| 49 |
+
if self.entity_word_to_token_ids:
|
| 50 |
+
avg_tokens = sum(len(ids) for ids in self.entity_word_to_token_ids.values()) / len(self.entity_word_to_token_ids)
|
| 51 |
+
print(f" → Average tokens per entity: {avg_tokens:.2f}")
|
| 52 |
|
| 53 |
# Pre-compute vocab_weights tensor for O(1) lookup during training
|
| 54 |
vocab_size = config.vocab_size if hasattr(config, 'vocab_size') else len(tokenizer)
|
| 55 |
+
self.register_buffer('vocab_weights', torch.ones(vocab_size, dtype=torch.float32))
|
| 56 |
|
| 57 |
# Set entity token weights
|
| 58 |
for token_id in all_entity_token_ids:
|
|
|
|
| 62 |
# Store for debugging
|
| 63 |
self.entity_token_ids = all_entity_token_ids
|
| 64 |
self.high_value_tokens = high_value_tokens
|
| 65 |
+
|
| 66 |
+
def get_encoder(self):
|
| 67 |
+
"""Delegate to sub-model's encoder."""
|
| 68 |
+
return self.model.get_encoder()
|
| 69 |
+
|
| 70 |
+
def get_decoder(self):
|
| 71 |
+
"""Delegate to sub-model's decoder."""
|
| 72 |
+
return self.model.get_decoder()
|
| 73 |
|
| 74 |
+
def forward(self, input_features=None, attention_mask=None, decoder_input_ids=None, labels=None, **kwargs):
|
| 75 |
+
"""
|
| 76 |
+
Forward pass that computes weighted loss if labels are provided.
|
| 77 |
+
Delegates to underlying model.
|
| 78 |
+
"""
|
| 79 |
+
outputs = self.model(
|
| 80 |
+
input_features=input_features,
|
| 81 |
+
attention_mask=attention_mask,
|
| 82 |
+
decoder_input_ids=decoder_input_ids,
|
| 83 |
+
labels=labels,
|
| 84 |
+
return_dict=True,
|
| 85 |
+
**kwargs
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# If we are not training or have no labels, return standard outputs
|
| 89 |
+
if labels is None:
|
| 90 |
+
return outputs
|
| 91 |
+
|
| 92 |
+
# Custom Loss Computation
|
| 93 |
+
logits = outputs.logits # [B, T, V]
|
| 94 |
|
| 95 |
# Flatten
|
| 96 |
+
# Standard CrossEntropyLoss expects [N, C] logits and [N] labels
|
| 97 |
+
# where N is batch_size * sequence_length
|
| 98 |
|
| 99 |
+
flat_logits = logits.view(-1, logits.size(-1))
|
| 100 |
+
flat_labels = labels.view(-1)
|
|
|
|
| 101 |
|
| 102 |
+
# Create per-token weights
|
| 103 |
+
# Use pre-computed weights: O(1) lookup
|
| 104 |
+
# labels can be -100 (ignore), we need to handle that for lookup
|
|
|
|
| 105 |
|
| 106 |
+
# Create a mask for valid labels (not -100)
|
| 107 |
+
valid_mask = (flat_labels != -100)
|
| 108 |
+
|
| 109 |
+
# Use padding token ID (usually 0 or 1) for lookup where label is -100
|
| 110 |
+
# This avoids index out of bounds. We'll mask the loss anyway.
|
| 111 |
+
safe_labels = flat_labels.clone()
|
| 112 |
+
safe_labels[~valid_mask] = 0
|
| 113 |
+
|
| 114 |
+
# Get weights
|
| 115 |
+
weights = self.vocab_weights[safe_labels]
|
| 116 |
+
|
| 117 |
+
# Compute unreduced loss
|
| 118 |
+
loss_fct = nn.CrossEntropyLoss(reduction="none")
|
| 119 |
+
loss = loss_fct(flat_logits, flat_labels)
|
| 120 |
|
| 121 |
# Apply weights
|
| 122 |
weighted_loss = loss * weights
|
| 123 |
|
| 124 |
+
# Apply masking (CrossEntropyLoss usually handles -100 by ignoring,
|
| 125 |
+
# but since we used reduction='none', we have to double check)
|
| 126 |
+
# The loss for -100 labels should be 0 from CrossEntropyLoss if used correctly,
|
| 127 |
+
# but explicit masking is safer with custom weighting.
|
| 128 |
+
weighted_loss = weighted_loss[valid_mask]
|
| 129 |
|
| 130 |
if weighted_loss.numel() == 0:
|
| 131 |
+
final_loss = torch.tensor(0.0, device=logits.device, requires_grad=True)
|
| 132 |
+
else:
|
| 133 |
+
final_loss = weighted_loss.mean()
|
| 134 |
+
|
| 135 |
+
return Seq2SeqLMOutput(
|
| 136 |
+
loss=final_loss,
|
| 137 |
+
logits=logits,
|
| 138 |
+
past_key_values=outputs.past_key_values,
|
| 139 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
| 140 |
+
decoder_attentions=outputs.decoder_attentions,
|
| 141 |
+
cross_attentions=outputs.cross_attentions,
|
| 142 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
| 143 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
| 144 |
+
encoder_attentions=outputs.encoder_attentions,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def generate(self, *args, **kwargs):
|
| 148 |
+
"""Delegate generation to the underlying model."""
|
| 149 |
+
return self.model.generate(*args, **kwargs)
|
| 150 |
+
|
| 151 |
+
def prepare_inputs_for_generation(self, *args, **kwargs):
|
| 152 |
+
"""Delegate to underlying model."""
|
| 153 |
+
return self.model.prepare_inputs_for_generation(*args, **kwargs)
|
| 154 |
|
| 155 |
@classmethod
|
| 156 |
def from_pretrained(cls, pretrained_model_name_or_path: str,
|
| 157 |
tokenizer, high_value_tokens: Set[str],
|
| 158 |
entity_weight: float = 3.0, **kwargs):
|
| 159 |
"""Load pretrained OWSM model and wrap with entity-weighted loss."""
|
| 160 |
+
# Load the base model using the Auto class
|
| 161 |
base_model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
| 162 |
pretrained_model_name_or_path, **kwargs
|
| 163 |
)
|
| 164 |
|
| 165 |
+
# Initialize wrapper
|
| 166 |
model = cls(
|
| 167 |
config=base_model.config,
|
| 168 |
+
base_model=base_model,
|
| 169 |
tokenizer=tokenizer,
|
| 170 |
high_value_tokens=high_value_tokens,
|
| 171 |
entity_weight=entity_weight
|
| 172 |
)
|
| 173 |
|
|
|
|
| 174 |
return model
|
| 175 |
|
| 176 |
+
def save_pretrained(self, save_directory, **kwargs):
|
| 177 |
+
"""
|
| 178 |
+
Save the underlying model to the directory.
|
| 179 |
+
This ensures that the saved model is a standard OWSM model
|
| 180 |
+
that can be loaded with AutoModelForSpeechSeq2Seq for inference.
|
| 181 |
+
"""
|
| 182 |
+
print(f"Saving underlying model to {save_directory}...")
|
| 183 |
+
self.model.save_pretrained(save_directory, **kwargs)
|
training/trainer.py
CHANGED
|
@@ -4,24 +4,17 @@ import json
|
|
| 4 |
import torch
|
| 5 |
import numpy as np
|
| 6 |
import random
|
| 7 |
-
import
|
| 8 |
-
from
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
from datasets import Dataset, Audio
|
| 11 |
from transformers import (
|
| 12 |
AutoProcessor,
|
| 13 |
-
AutoModelForSpeechSeq2Seq,
|
| 14 |
Seq2SeqTrainingArguments,
|
| 15 |
Seq2SeqTrainer,
|
| 16 |
DataCollatorForSeq2Seq,
|
| 17 |
EarlyStoppingCallback,
|
| 18 |
)
|
| 19 |
-
from sklearn.model_selection import train_test_split
|
| 20 |
-
import torchaudio
|
| 21 |
-
|
| 22 |
-
from data.manager import ENTITIES_PATH, AUDIO_DIR, MODEL_OUTPUT_DIR
|
| 23 |
-
from data.loader import get_train_dataframe
|
| 24 |
from owsm_model import OWSMWithEntityLoss
|
|
|
|
| 25 |
|
| 26 |
# Set seeds for reproducibility
|
| 27 |
SEED = 42
|
|
@@ -36,7 +29,7 @@ torch.use_deterministic_algorithms(True, warn_only=True)
|
|
| 36 |
MODEL_NAME = "espnet/owsm_v3.1_ebf_small"
|
| 37 |
TARGET_SR = 16000
|
| 38 |
MAX_AUDIO_LENGTH = 30 # seconds
|
| 39 |
-
|
| 40 |
|
| 41 |
def compute_wer_metric(predictions, labels, tokenizer):
|
| 42 |
"""Compute Word Error Rate metric."""
|
|
@@ -51,13 +44,11 @@ def compute_wer_metric(predictions, labels, tokenizer):
|
|
| 51 |
return 1.0 if len(hyp_words) > 0 else 0.0
|
| 52 |
|
| 53 |
# Simple Levenshtein-like WER
|
| 54 |
-
# For simplicity, use character-level edit distance approximation
|
| 55 |
ref_str = ' '.join(ref_words)
|
| 56 |
hyp_str = ' '.join(hyp_words)
|
| 57 |
if ref_str == hyp_str:
|
| 58 |
return 0.0
|
| 59 |
|
| 60 |
-
# Count word-level differences
|
| 61 |
ref_set = set(ref_words)
|
| 62 |
hyp_set = set(hyp_words)
|
| 63 |
common = len(ref_set & hyp_set)
|
|
@@ -66,7 +57,6 @@ def compute_wer_metric(predictions, labels, tokenizer):
|
|
| 66 |
|
| 67 |
# Decode predictions and labels
|
| 68 |
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
| 69 |
-
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
| 70 |
|
| 71 |
# Replace -100 with pad token for decoding
|
| 72 |
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
|
@@ -87,76 +77,22 @@ def compute_wer_metric(predictions, labels, tokenizer):
|
|
| 87 |
return {"wer": wer}
|
| 88 |
|
| 89 |
|
| 90 |
-
def
|
| 91 |
"""
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
FAILS LOUDLY if file doesn't exist - no silent fallbacks.
|
| 95 |
-
"""
|
| 96 |
-
if not os.path.exists(path):
|
| 97 |
-
raise FileNotFoundError(
|
| 98 |
-
f"Audio file not found: {path}\n"
|
| 99 |
-
f"Expected audio file at: {os.path.abspath(path)}\n"
|
| 100 |
-
f"Please ensure all audio files are available before training."
|
| 101 |
-
)
|
| 102 |
-
|
| 103 |
-
try:
|
| 104 |
-
wav, sr = torchaudio.load(path)
|
| 105 |
-
except Exception as e:
|
| 106 |
-
raise RuntimeError(
|
| 107 |
-
f"Failed to load audio file: {path}\n"
|
| 108 |
-
f"Error: {str(e)}\n"
|
| 109 |
-
f"Please check that the file is a valid audio file."
|
| 110 |
-
) from e
|
| 111 |
-
|
| 112 |
-
if sr != target_sr:
|
| 113 |
-
wav = torchaudio.functional.resample(wav, sr, target_sr)
|
| 114 |
-
|
| 115 |
-
# Convert to mono if stereo
|
| 116 |
-
if wav.shape[0] > 1:
|
| 117 |
-
wav = wav.mean(0, keepdim=True)
|
| 118 |
-
|
| 119 |
-
return wav.squeeze(0).numpy() # Return as numpy array for processor
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
def prepare_dataset(df: pd.DataFrame, audio_dir: str, processor, is_training: bool = True):
|
| 123 |
-
"""
|
| 124 |
-
Prepare dataset for training/inference.
|
| 125 |
-
|
| 126 |
-
FAILS LOUDLY if audio files are missing - no silent fallbacks.
|
| 127 |
"""
|
| 128 |
|
| 129 |
def prepare_batch(batch):
|
| 130 |
"""Process a batch of examples."""
|
| 131 |
-
|
| 132 |
-
transcriptions = batch["
|
| 133 |
|
| 134 |
-
#
|
| 135 |
-
|
| 136 |
-
for audio_path in audio_paths:
|
| 137 |
-
full_path = os.path.join(audio_dir, audio_path)
|
| 138 |
-
|
| 139 |
-
# Validate file exists BEFORE processing
|
| 140 |
-
if not os.path.exists(full_path):
|
| 141 |
-
raise FileNotFoundError(
|
| 142 |
-
f"Audio file not found during dataset preparation: {audio_path}\n"
|
| 143 |
-
f"Full path: {os.path.abspath(full_path)}\n"
|
| 144 |
-
f"Audio directory: {os.path.abspath(audio_dir)}\n"
|
| 145 |
-
f"Please ensure all audio files are available before training."
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
audio = load_audio(full_path)
|
| 149 |
-
|
| 150 |
-
# Truncate if too long
|
| 151 |
-
max_samples = TARGET_SR * MAX_AUDIO_LENGTH
|
| 152 |
-
if len(audio) > max_samples:
|
| 153 |
-
audio = audio[:max_samples]
|
| 154 |
-
|
| 155 |
-
audio_arrays.append(audio)
|
| 156 |
|
| 157 |
# Process audio with processor
|
| 158 |
inputs = processor(
|
| 159 |
-
|
| 160 |
sampling_rate=TARGET_SR,
|
| 161 |
return_tensors="pt",
|
| 162 |
padding=True,
|
|
@@ -178,85 +114,26 @@ def prepare_dataset(df: pd.DataFrame, audio_dir: str, processor, is_training: bo
|
|
| 178 |
|
| 179 |
return batch
|
| 180 |
|
| 181 |
-
#
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
}
|
| 186 |
-
|
| 187 |
-
# Validate all audio files exist BEFORE creating dataset
|
| 188 |
-
if is_training:
|
| 189 |
-
print("Validating audio files exist...")
|
| 190 |
-
missing_files = []
|
| 191 |
-
for audio_path in dataset_dict["audio_path"]:
|
| 192 |
-
full_path = os.path.join(audio_dir, audio_path)
|
| 193 |
-
if not os.path.exists(full_path):
|
| 194 |
-
missing_files.append(full_path)
|
| 195 |
-
|
| 196 |
-
if missing_files:
|
| 197 |
-
error_msg = (
|
| 198 |
-
f"Found {len(missing_files)} missing audio files:\n"
|
| 199 |
-
f"First 10 missing files:\n"
|
| 200 |
-
)
|
| 201 |
-
for f in missing_files[:10]:
|
| 202 |
-
error_msg += f" - {f}\n"
|
| 203 |
-
if len(missing_files) > 10:
|
| 204 |
-
error_msg += f" ... and {len(missing_files) - 10} more\n"
|
| 205 |
-
error_msg += f"\nPlease ensure all audio files are available before training."
|
| 206 |
-
raise FileNotFoundError(error_msg)
|
| 207 |
-
|
| 208 |
-
print(f"✓ All {len(dataset_dict['audio_path'])} audio files validated")
|
| 209 |
-
|
| 210 |
-
dataset = Dataset.from_dict(dataset_dict)
|
| 211 |
|
| 212 |
# Process in batches
|
| 213 |
dataset = dataset.map(
|
| 214 |
prepare_batch,
|
| 215 |
batched=True,
|
| 216 |
batch_size=16,
|
| 217 |
-
remove_columns=
|
|
|
|
| 218 |
)
|
| 219 |
|
| 220 |
return dataset
|
| 221 |
|
| 222 |
|
| 223 |
-
def
|
| 224 |
-
"""
|
| 225 |
-
Create stratified train/val split based on:
|
| 226 |
-
- Utterance length (short vs long)
|
| 227 |
-
- Presence of Caribbean keywords
|
| 228 |
-
"""
|
| 229 |
-
# Create bins for stratification
|
| 230 |
-
df['word_count'] = df['Transcription'].str.split().str.len()
|
| 231 |
-
df['has_caribbean'] = df['Transcription'].str.lower().str.contains(
|
| 232 |
-
'caribbean|bbc|trinidad|jamaica|guyana|haiti|barbados',
|
| 233 |
-
regex=True,
|
| 234 |
-
na=False
|
| 235 |
-
)
|
| 236 |
-
|
| 237 |
-
# Create stratification key
|
| 238 |
-
df['length_bin'] = pd.cut(df['word_count'], bins=5, labels=False)
|
| 239 |
-
df['stratify_key'] = df['length_bin'].astype(str) + '_' + df['has_caribbean'].astype(str)
|
| 240 |
-
|
| 241 |
-
train_df, val_df = train_test_split(
|
| 242 |
-
df,
|
| 243 |
-
test_size=test_size,
|
| 244 |
-
stratify=df['stratify_key'],
|
| 245 |
-
random_state=SEED
|
| 246 |
-
)
|
| 247 |
-
|
| 248 |
-
print(f"Train: {len(train_df):,} samples")
|
| 249 |
-
print(f"Val: {len(val_df):,} samples")
|
| 250 |
-
|
| 251 |
-
return train_df, val_df
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
def run_training_progress(epochs: int, batch_size: int, learning_rate: float, progress=None) -> Tuple[str, str]:
|
| 255 |
"""
|
| 256 |
-
Run OWSM training with progress tracking.
|
| 257 |
-
|
| 258 |
-
Uses espnet/owsm_v3.1_ebf_small with NO FALLBACKS.
|
| 259 |
-
If model loading fails, raises exception with clear error message.
|
| 260 |
"""
|
| 261 |
try:
|
| 262 |
if progress:
|
|
@@ -278,72 +155,63 @@ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, pr
|
|
| 278 |
print(f"Loaded {len(high_value_entities)} high-value entities")
|
| 279 |
|
| 280 |
if progress:
|
| 281 |
-
progress(0.1, desc="Loading
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
f"Please ensure the dataset is loaded."
|
| 288 |
-
)
|
| 289 |
-
print(f"Loaded {len(train_df):,} training samples")
|
| 290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
# Create train/val split
|
| 292 |
if progress:
|
| 293 |
progress(0.15, desc="Creating train/val split...")
|
| 294 |
-
train_df_split, val_df_split = create_stratified_split(train_df, test_size=0.1)
|
| 295 |
|
| 296 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
if progress:
|
| 298 |
progress(0.2, desc=f"Loading processor: {MODEL_NAME}...")
|
| 299 |
print(f"\nLoading processor: {MODEL_NAME}")
|
| 300 |
-
print("NOTE: Using espnet/owsm_v3.1_ebf_small with NO FALLBACKS")
|
| 301 |
-
print("If this fails, check that the model is available on HuggingFace.")
|
| 302 |
|
| 303 |
try:
|
| 304 |
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
| 305 |
except Exception as e:
|
| 306 |
-
|
| 307 |
-
f"FAILED to load processor from {MODEL_NAME}\n\n"
|
| 308 |
-
f"Error: {str(e)}\n\n"
|
| 309 |
-
f"This training pipeline requires espnet/owsm_v3.1_ebf_small.\n"
|
| 310 |
-
f"No fallbacks are configured. Please ensure:\n"
|
| 311 |
-
f"1. The model is available on HuggingFace\n"
|
| 312 |
-
f"2. You have internet access to download the model\n"
|
| 313 |
-
f"3. You have sufficient disk space\n"
|
| 314 |
-
f"4. The transformers library supports this model\n\n"
|
| 315 |
-
f"If the model is not available, you may need to use ESPnet's native training framework."
|
| 316 |
-
)
|
| 317 |
-
raise RuntimeError(error_msg) from e
|
| 318 |
|
| 319 |
print(f"✓ Processor loaded successfully")
|
| 320 |
|
| 321 |
-
# Load model
|
| 322 |
if progress:
|
| 323 |
progress(0.25, desc=f"Loading model: {MODEL_NAME}...")
|
| 324 |
print(f"\nLoading model: {MODEL_NAME}")
|
| 325 |
|
| 326 |
try:
|
|
|
|
| 327 |
model = OWSMWithEntityLoss.from_pretrained(
|
| 328 |
MODEL_NAME,
|
| 329 |
tokenizer=processor.tokenizer,
|
| 330 |
high_value_tokens=high_value_entities,
|
| 331 |
-
entity_weight=3.0,
|
| 332 |
)
|
| 333 |
except Exception as e:
|
| 334 |
-
|
| 335 |
-
f"FAILED to load model from {MODEL_NAME}\n\n"
|
| 336 |
-
f"Error: {str(e)}\n\n"
|
| 337 |
-
f"This training pipeline requires espnet/owsm_v3.1_ebf_small.\n"
|
| 338 |
-
f"No fallbacks are configured. Please ensure:\n"
|
| 339 |
-
f"1. The model is available on HuggingFace\n"
|
| 340 |
-
f"2. You have internet access to download the model\n"
|
| 341 |
-
f"3. You have sufficient disk space\n"
|
| 342 |
-
f"4. The transformers library supports this model\n"
|
| 343 |
-
f"5. AutoModelForSpeechSeq2Seq can load this model\n\n"
|
| 344 |
-
f"If the model is not available, you may need to use ESPnet's native training framework."
|
| 345 |
-
)
|
| 346 |
-
raise RuntimeError(error_msg) from e
|
| 347 |
|
| 348 |
print(f"✓ Model loaded successfully")
|
| 349 |
|
|
@@ -353,14 +221,14 @@ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, pr
|
|
| 353 |
|
| 354 |
# Prepare datasets
|
| 355 |
if progress:
|
| 356 |
-
progress(0.3, desc="
|
| 357 |
-
print("\
|
| 358 |
-
train_dataset =
|
| 359 |
|
| 360 |
if progress:
|
| 361 |
-
progress(0.4, desc="
|
| 362 |
-
print("
|
| 363 |
-
val_dataset =
|
| 364 |
|
| 365 |
# Training arguments
|
| 366 |
if progress:
|
|
@@ -370,7 +238,7 @@ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, pr
|
|
| 370 |
output_dir=MODEL_OUTPUT_DIR,
|
| 371 |
per_device_train_batch_size=batch_size,
|
| 372 |
per_device_eval_batch_size=batch_size,
|
| 373 |
-
gradient_accumulation_steps=4,
|
| 374 |
learning_rate=learning_rate,
|
| 375 |
warmup_steps=500,
|
| 376 |
num_train_epochs=epochs,
|
|
@@ -380,20 +248,23 @@ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, pr
|
|
| 380 |
save_steps=1000,
|
| 381 |
logging_steps=100,
|
| 382 |
load_best_model_at_end=True,
|
| 383 |
-
metric_for_best_model="wer",
|
| 384 |
-
greater_is_better=False,
|
| 385 |
save_total_limit=3,
|
| 386 |
-
fp16=torch.cuda.is_available(),
|
| 387 |
dataloader_num_workers=4,
|
| 388 |
-
report_to="none",
|
| 389 |
seed=SEED,
|
| 390 |
-
predict_with_generate=True,
|
|
|
|
| 391 |
)
|
| 392 |
|
| 393 |
# Data collator
|
| 394 |
data_collator = DataCollatorForSeq2Seq(
|
| 395 |
processor=processor,
|
| 396 |
-
model=model,
|
|
|
|
|
|
|
| 397 |
padding=True,
|
| 398 |
)
|
| 399 |
|
|
@@ -435,6 +306,7 @@ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, pr
|
|
| 435 |
progress(0.95, desc="Saving model...")
|
| 436 |
|
| 437 |
print(f"\nSaving model to {MODEL_OUTPUT_DIR}...")
|
|
|
|
| 438 |
model.save_pretrained(MODEL_OUTPUT_DIR)
|
| 439 |
processor.save_pretrained(MODEL_OUTPUT_DIR)
|
| 440 |
|
|
@@ -471,21 +343,12 @@ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, pr
|
|
| 471 |
The model is now ready for inference!
|
| 472 |
"""
|
| 473 |
|
| 474 |
-
return success_msg,
|
| 475 |
|
| 476 |
-
except FileNotFoundError as e:
|
| 477 |
-
error_msg = f"❌ File Not Found Error:\n\n{str(e)}"
|
| 478 |
-
if progress:
|
| 479 |
-
progress(1.0, desc="Error!")
|
| 480 |
-
return error_msg, ""
|
| 481 |
-
except RuntimeError as e:
|
| 482 |
-
error_msg = f"❌ Runtime Error:\n\n{str(e)}"
|
| 483 |
-
if progress:
|
| 484 |
-
progress(1.0, desc="Error!")
|
| 485 |
-
return error_msg, ""
|
| 486 |
except Exception as e:
|
| 487 |
import traceback
|
| 488 |
-
error_msg = f"❌
|
|
|
|
| 489 |
if progress:
|
| 490 |
progress(1.0, desc="Error!")
|
| 491 |
-
return error_msg,
|
|
|
|
| 4 |
import torch
|
| 5 |
import numpy as np
|
| 6 |
import random
|
| 7 |
+
from typing import Tuple, Optional, Dict, Any
|
| 8 |
+
from datasets import load_dataset, Audio, DatasetDict
|
|
|
|
|
|
|
| 9 |
from transformers import (
|
| 10 |
AutoProcessor,
|
|
|
|
| 11 |
Seq2SeqTrainingArguments,
|
| 12 |
Seq2SeqTrainer,
|
| 13 |
DataCollatorForSeq2Seq,
|
| 14 |
EarlyStoppingCallback,
|
| 15 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
from owsm_model import OWSMWithEntityLoss
|
| 17 |
+
from data.manager import ENTITIES_PATH, MODEL_OUTPUT_DIR, BASE_DIR
|
| 18 |
|
| 19 |
# Set seeds for reproducibility
|
| 20 |
SEED = 42
|
|
|
|
| 29 |
MODEL_NAME = "espnet/owsm_v3.1_ebf_small"
|
| 30 |
TARGET_SR = 16000
|
| 31 |
MAX_AUDIO_LENGTH = 30 # seconds
|
| 32 |
+
HF_DATASET_NAME = "shaun3141/caribbean-voices-hackathon"
|
| 33 |
|
| 34 |
def compute_wer_metric(predictions, labels, tokenizer):
|
| 35 |
"""Compute Word Error Rate metric."""
|
|
|
|
| 44 |
return 1.0 if len(hyp_words) > 0 else 0.0
|
| 45 |
|
| 46 |
# Simple Levenshtein-like WER
|
|
|
|
| 47 |
ref_str = ' '.join(ref_words)
|
| 48 |
hyp_str = ' '.join(hyp_words)
|
| 49 |
if ref_str == hyp_str:
|
| 50 |
return 0.0
|
| 51 |
|
|
|
|
| 52 |
ref_set = set(ref_words)
|
| 53 |
hyp_set = set(hyp_words)
|
| 54 |
common = len(ref_set & hyp_set)
|
|
|
|
| 57 |
|
| 58 |
# Decode predictions and labels
|
| 59 |
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
|
|
|
| 60 |
|
| 61 |
# Replace -100 with pad token for decoding
|
| 62 |
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
|
|
|
| 77 |
return {"wer": wer}
|
| 78 |
|
| 79 |
|
| 80 |
+
def prepare_dataset_hf(dataset, processor):
|
| 81 |
"""
|
| 82 |
+
Prepare dataset using Hugging Face Datasets built-in audio handling.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
"""
|
| 84 |
|
| 85 |
def prepare_batch(batch):
|
| 86 |
"""Process a batch of examples."""
|
| 87 |
+
audio = batch["audio"]
|
| 88 |
+
transcriptions = batch["transcription"] # Note: check lowercase 't' in dataset
|
| 89 |
|
| 90 |
+
# Audio is already a dictionary with 'array' and 'sampling_rate'
|
| 91 |
+
# because we cast it to Audio() in the loading step
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
# Process audio with processor
|
| 94 |
inputs = processor(
|
| 95 |
+
[x["array"] for x in audio],
|
| 96 |
sampling_rate=TARGET_SR,
|
| 97 |
return_tensors="pt",
|
| 98 |
padding=True,
|
|
|
|
| 114 |
|
| 115 |
return batch
|
| 116 |
|
| 117 |
+
# Remove columns that are not needed
|
| 118 |
+
# Note: We keep 'transcription' maybe? No, remove it to save memory, we have labels.
|
| 119 |
+
# But check what columns exist first.
|
| 120 |
+
column_names = dataset.column_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
# Process in batches
|
| 123 |
dataset = dataset.map(
|
| 124 |
prepare_batch,
|
| 125 |
batched=True,
|
| 126 |
batch_size=16,
|
| 127 |
+
remove_columns=column_names,
|
| 128 |
+
desc="Preprocessing dataset",
|
| 129 |
)
|
| 130 |
|
| 131 |
return dataset
|
| 132 |
|
| 133 |
|
| 134 |
+
def run_training_progress(epochs: int, batch_size: int, learning_rate: float, progress=None) -> Tuple[str, Optional[Dict[str, Any]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
"""
|
| 136 |
+
Run OWSM training with progress tracking using HF Datasets.
|
|
|
|
|
|
|
|
|
|
| 137 |
"""
|
| 138 |
try:
|
| 139 |
if progress:
|
|
|
|
| 155 |
print(f"Loaded {len(high_value_entities)} high-value entities")
|
| 156 |
|
| 157 |
if progress:
|
| 158 |
+
progress(0.1, desc="Loading dataset from Hugging Face...")
|
| 159 |
+
|
| 160 |
+
# Load dataset from HF
|
| 161 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 162 |
+
print(f"Loading dataset: {HF_DATASET_NAME}")
|
| 163 |
+
dataset = load_dataset(HF_DATASET_NAME, token=hf_token)
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
if 'train' not in dataset:
|
| 166 |
+
raise ValueError(f"Dataset {HF_DATASET_NAME} does not contain a 'train' split.")
|
| 167 |
+
|
| 168 |
+
train_full = dataset['train']
|
| 169 |
+
print(f"Loaded {len(train_full):,} total training samples")
|
| 170 |
+
|
| 171 |
+
# Cast to Audio to ensure correct sampling rate
|
| 172 |
+
train_full = train_full.cast_column("audio", Audio(sampling_rate=TARGET_SR))
|
| 173 |
+
|
| 174 |
# Create train/val split
|
| 175 |
if progress:
|
| 176 |
progress(0.15, desc="Creating train/val split...")
|
|
|
|
| 177 |
|
| 178 |
+
# Simple random split since we don't want to download all audio to stratify by length/content
|
| 179 |
+
# unless we want to iterate the whole dataset which might be slow.
|
| 180 |
+
# We'll use a random split for speed and simplicity with the streamed/remote dataset.
|
| 181 |
+
split_dataset = train_full.train_test_split(test_size=0.1, seed=SEED)
|
| 182 |
+
train_dataset_raw = split_dataset['train']
|
| 183 |
+
val_dataset_raw = split_dataset['test']
|
| 184 |
+
|
| 185 |
+
print(f"Train: {len(train_dataset_raw):,} samples")
|
| 186 |
+
print(f"Val: {len(val_dataset_raw):,} samples")
|
| 187 |
+
|
| 188 |
+
# Load processor
|
| 189 |
if progress:
|
| 190 |
progress(0.2, desc=f"Loading processor: {MODEL_NAME}...")
|
| 191 |
print(f"\nLoading processor: {MODEL_NAME}")
|
|
|
|
|
|
|
| 192 |
|
| 193 |
try:
|
| 194 |
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
| 195 |
except Exception as e:
|
| 196 |
+
raise RuntimeError(f"FAILED to load processor from {MODEL_NAME}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
print(f"✓ Processor loaded successfully")
|
| 199 |
|
| 200 |
+
# Load model
|
| 201 |
if progress:
|
| 202 |
progress(0.25, desc=f"Loading model: {MODEL_NAME}...")
|
| 203 |
print(f"\nLoading model: {MODEL_NAME}")
|
| 204 |
|
| 205 |
try:
|
| 206 |
+
# Use our new wrapper class
|
| 207 |
model = OWSMWithEntityLoss.from_pretrained(
|
| 208 |
MODEL_NAME,
|
| 209 |
tokenizer=processor.tokenizer,
|
| 210 |
high_value_tokens=high_value_entities,
|
| 211 |
+
entity_weight=3.0,
|
| 212 |
)
|
| 213 |
except Exception as e:
|
| 214 |
+
raise RuntimeError(f"FAILED to load model from {MODEL_NAME}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
print(f"✓ Model loaded successfully")
|
| 217 |
|
|
|
|
| 221 |
|
| 222 |
# Prepare datasets
|
| 223 |
if progress:
|
| 224 |
+
progress(0.3, desc="Preprocessing training dataset...")
|
| 225 |
+
print("\nPreprocessing training dataset...")
|
| 226 |
+
train_dataset = prepare_dataset_hf(train_dataset_raw, processor)
|
| 227 |
|
| 228 |
if progress:
|
| 229 |
+
progress(0.4, desc="Preprocessing validation dataset...")
|
| 230 |
+
print("Preprocessing validation dataset...")
|
| 231 |
+
val_dataset = prepare_dataset_hf(val_dataset_raw, processor)
|
| 232 |
|
| 233 |
# Training arguments
|
| 234 |
if progress:
|
|
|
|
| 238 |
output_dir=MODEL_OUTPUT_DIR,
|
| 239 |
per_device_train_batch_size=batch_size,
|
| 240 |
per_device_eval_batch_size=batch_size,
|
| 241 |
+
gradient_accumulation_steps=4,
|
| 242 |
learning_rate=learning_rate,
|
| 243 |
warmup_steps=500,
|
| 244 |
num_train_epochs=epochs,
|
|
|
|
| 248 |
save_steps=1000,
|
| 249 |
logging_steps=100,
|
| 250 |
load_best_model_at_end=True,
|
| 251 |
+
metric_for_best_model="wer",
|
| 252 |
+
greater_is_better=False,
|
| 253 |
save_total_limit=3,
|
| 254 |
+
fp16=torch.cuda.is_available(),
|
| 255 |
dataloader_num_workers=4,
|
| 256 |
+
report_to="none",
|
| 257 |
seed=SEED,
|
| 258 |
+
predict_with_generate=True,
|
| 259 |
+
generation_max_length=200, # Prevent infinite generation
|
| 260 |
)
|
| 261 |
|
| 262 |
# Data collator
|
| 263 |
data_collator = DataCollatorForSeq2Seq(
|
| 264 |
processor=processor,
|
| 265 |
+
model=model, # The trainer needs the model for the collator sometimes if it uses it for padding?
|
| 266 |
+
# Actually DataCollatorForSeq2Seq uses tokenizer usually.
|
| 267 |
+
# But passing model is fine.
|
| 268 |
padding=True,
|
| 269 |
)
|
| 270 |
|
|
|
|
| 306 |
progress(0.95, desc="Saving model...")
|
| 307 |
|
| 308 |
print(f"\nSaving model to {MODEL_OUTPUT_DIR}...")
|
| 309 |
+
# This calls our custom save_pretrained which saves the inner model
|
| 310 |
model.save_pretrained(MODEL_OUTPUT_DIR)
|
| 311 |
processor.save_pretrained(MODEL_OUTPUT_DIR)
|
| 312 |
|
|
|
|
| 343 |
The model is now ready for inference!
|
| 344 |
"""
|
| 345 |
|
| 346 |
+
return success_msg, final_metrics
|
| 347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
except Exception as e:
|
| 349 |
import traceback
|
| 350 |
+
error_msg = f"❌ Error during training: {str(e)}\n\n{traceback.format_exc()}"
|
| 351 |
+
print(error_msg)
|
| 352 |
if progress:
|
| 353 |
progress(1.0, desc="Error!")
|
| 354 |
+
return error_msg, None
|