shaun3141 commited on
Commit
0ac13f6
·
1 Parent(s): ef24863

Fix training pipeline: use composition for model wrapper and HF Datasets for audio loading

Browse files
Files changed (2) hide show
  1. owsm_model.py +103 -39
  2. training/trainer.py +71 -208
owsm_model.py CHANGED
@@ -4,31 +4,30 @@ This implements loss re-weighting for proper nouns without external data.
4
  """
5
  import torch
6
  import torch.nn as nn
7
- from transformers import AutoModelForSpeechSeq2Seq
8
- from typing import Set, Optional, Dict
 
9
 
10
-
11
- class OWSMWithEntityLoss(AutoModelForSpeechSeq2Seq):
12
  """
13
- OWSM model with weighted cross-entropy loss that up-weights errors on entity tokens.
14
-
15
- This is fully compliant with competition rules:
16
- - No external data (entities come from training transcripts only)
17
- - Single model (no reranker, no second model)
18
- - Reproducible (deterministic loss computation)
19
 
20
- FIXED: Now weights ALL tokens that make up an entity, not just the first token.
 
21
  """
22
 
23
- def __init__(self, config, tokenizer, high_value_tokens: Set[str], entity_weight: float = 3.0):
24
  """
25
  Args:
26
  config: Model configuration
 
27
  tokenizer: Tokenizer for converting entity words to token IDs
28
  high_value_tokens: Set of entity words (lowercase) to up-weight
29
  entity_weight: Multiplier for entity token errors (default: 3.0)
30
  """
31
  super().__init__(config)
 
32
  self.tokenizer = tokenizer
33
  self.entity_weight = entity_weight
34
 
@@ -47,11 +46,13 @@ class OWSMWithEntityLoss(AutoModelForSpeechSeq2Seq):
47
  all_entity_token_ids.update(token_id_set)
48
 
49
  print(f" → Mapped to {len(all_entity_token_ids)} unique token IDs")
50
- print(f" → Average tokens per entity: {sum(len(ids) for ids in self.entity_word_to_token_ids.values()) / max(len(self.entity_word_to_token_ids), 1):.2f}")
 
 
51
 
52
  # Pre-compute vocab_weights tensor for O(1) lookup during training
53
  vocab_size = config.vocab_size if hasattr(config, 'vocab_size') else len(tokenizer)
54
- self.vocab_weights = torch.ones(vocab_size, dtype=torch.float32)
55
 
56
  # Set entity token weights
57
  for token_id in all_entity_token_ids:
@@ -61,59 +62,122 @@ class OWSMWithEntityLoss(AutoModelForSpeechSeq2Seq):
61
  # Store for debugging
62
  self.entity_token_ids = all_entity_token_ids
63
  self.high_value_tokens = high_value_tokens
64
-
65
- def compute_loss(self, model_outputs, labels, attention_mask=None):
66
- """Compute weighted cross-entropy loss with higher weight for entity tokens."""
67
- logits = model_outputs.logits # [B, T, V]
 
 
 
 
68
 
69
- # Shift for teacher forcing
70
- shift_logits = logits[..., :-1, :].contiguous()
71
- shift_labels = labels[..., 1:].contiguous()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # Flatten
74
- flat_logits = shift_logits.view(-1, shift_logits.size(-1))
75
- flat_labels = shift_labels.view(-1)
76
 
77
- # Compute per-token loss
78
- loss_fct = nn.CrossEntropyLoss(reduction="none")
79
- loss = loss_fct(flat_logits, flat_labels)
80
 
81
- # Get weights using pre-computed vocab_weights tensor (O(1) lookup)
82
- # Move vocab_weights to same device as labels
83
- if self.vocab_weights.device != flat_labels.device:
84
- self.vocab_weights = self.vocab_weights.to(flat_labels.device)
85
 
86
- # Lookup weights for each token in the batch (O(1) operation)
87
- weights = self.vocab_weights[flat_labels]
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  # Apply weights
90
  weighted_loss = loss * weights
91
 
92
- # Ignore padding tokens
93
- padding_mask = (flat_labels != -100)
94
- weighted_loss = weighted_loss[padding_mask]
 
 
95
 
96
  if weighted_loss.numel() == 0:
97
- return loss[padding_mask].mean() if padding_mask.any() else loss.mean()
98
-
99
- return weighted_loss.mean()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  @classmethod
102
  def from_pretrained(cls, pretrained_model_name_or_path: str,
103
  tokenizer, high_value_tokens: Set[str],
104
  entity_weight: float = 3.0, **kwargs):
105
  """Load pretrained OWSM model and wrap with entity-weighted loss."""
 
106
  base_model = AutoModelForSpeechSeq2Seq.from_pretrained(
107
  pretrained_model_name_or_path, **kwargs
108
  )
109
 
 
110
  model = cls(
111
  config=base_model.config,
 
112
  tokenizer=tokenizer,
113
  high_value_tokens=high_value_tokens,
114
  entity_weight=entity_weight
115
  )
116
 
117
- model.load_state_dict(base_model.state_dict(), strict=True)
118
  return model
119
 
 
 
 
 
 
 
 
 
 
4
  """
5
  import torch
6
  import torch.nn as nn
7
+ from transformers import AutoModelForSpeechSeq2Seq, PreTrainedModel
8
+ from transformers.modeling_outputs import Seq2SeqLMOutput
9
+ from typing import Set, Optional, Dict, Any
10
 
11
+ class OWSMWithEntityLoss(PreTrainedModel):
 
12
  """
13
+ Wrapper around OWSM model that implements weighted cross-entropy loss
14
+ to up-weight errors on entity tokens.
 
 
 
 
15
 
16
+ This model wraps the base model using composition rather than inheritance
17
+ to avoid issues with the AutoModel factory pattern.
18
  """
19
 
20
+ def __init__(self, config, base_model, tokenizer, high_value_tokens: Set[str], entity_weight: float = 3.0):
21
  """
22
  Args:
23
  config: Model configuration
24
+ base_model: The instantiated base model (SpeechEncoderDecoderModel)
25
  tokenizer: Tokenizer for converting entity words to token IDs
26
  high_value_tokens: Set of entity words (lowercase) to up-weight
27
  entity_weight: Multiplier for entity token errors (default: 3.0)
28
  """
29
  super().__init__(config)
30
+ self.model = base_model
31
  self.tokenizer = tokenizer
32
  self.entity_weight = entity_weight
33
 
 
46
  all_entity_token_ids.update(token_id_set)
47
 
48
  print(f" → Mapped to {len(all_entity_token_ids)} unique token IDs")
49
+ if self.entity_word_to_token_ids:
50
+ avg_tokens = sum(len(ids) for ids in self.entity_word_to_token_ids.values()) / len(self.entity_word_to_token_ids)
51
+ print(f" → Average tokens per entity: {avg_tokens:.2f}")
52
 
53
  # Pre-compute vocab_weights tensor for O(1) lookup during training
54
  vocab_size = config.vocab_size if hasattr(config, 'vocab_size') else len(tokenizer)
55
+ self.register_buffer('vocab_weights', torch.ones(vocab_size, dtype=torch.float32))
56
 
57
  # Set entity token weights
58
  for token_id in all_entity_token_ids:
 
62
  # Store for debugging
63
  self.entity_token_ids = all_entity_token_ids
64
  self.high_value_tokens = high_value_tokens
65
+
66
+ def get_encoder(self):
67
+ """Delegate to sub-model's encoder."""
68
+ return self.model.get_encoder()
69
+
70
+ def get_decoder(self):
71
+ """Delegate to sub-model's decoder."""
72
+ return self.model.get_decoder()
73
 
74
+ def forward(self, input_features=None, attention_mask=None, decoder_input_ids=None, labels=None, **kwargs):
75
+ """
76
+ Forward pass that computes weighted loss if labels are provided.
77
+ Delegates to underlying model.
78
+ """
79
+ outputs = self.model(
80
+ input_features=input_features,
81
+ attention_mask=attention_mask,
82
+ decoder_input_ids=decoder_input_ids,
83
+ labels=labels,
84
+ return_dict=True,
85
+ **kwargs
86
+ )
87
+
88
+ # If we are not training or have no labels, return standard outputs
89
+ if labels is None:
90
+ return outputs
91
+
92
+ # Custom Loss Computation
93
+ logits = outputs.logits # [B, T, V]
94
 
95
  # Flatten
96
+ # Standard CrossEntropyLoss expects [N, C] logits and [N] labels
97
+ # where N is batch_size * sequence_length
98
 
99
+ flat_logits = logits.view(-1, logits.size(-1))
100
+ flat_labels = labels.view(-1)
 
101
 
102
+ # Create per-token weights
103
+ # Use pre-computed weights: O(1) lookup
104
+ # labels can be -100 (ignore), we need to handle that for lookup
 
105
 
106
+ # Create a mask for valid labels (not -100)
107
+ valid_mask = (flat_labels != -100)
108
+
109
+ # Use padding token ID (usually 0 or 1) for lookup where label is -100
110
+ # This avoids index out of bounds. We'll mask the loss anyway.
111
+ safe_labels = flat_labels.clone()
112
+ safe_labels[~valid_mask] = 0
113
+
114
+ # Get weights
115
+ weights = self.vocab_weights[safe_labels]
116
+
117
+ # Compute unreduced loss
118
+ loss_fct = nn.CrossEntropyLoss(reduction="none")
119
+ loss = loss_fct(flat_logits, flat_labels)
120
 
121
  # Apply weights
122
  weighted_loss = loss * weights
123
 
124
+ # Apply masking (CrossEntropyLoss usually handles -100 by ignoring,
125
+ # but since we used reduction='none', we have to double check)
126
+ # The loss for -100 labels should be 0 from CrossEntropyLoss if used correctly,
127
+ # but explicit masking is safer with custom weighting.
128
+ weighted_loss = weighted_loss[valid_mask]
129
 
130
  if weighted_loss.numel() == 0:
131
+ final_loss = torch.tensor(0.0, device=logits.device, requires_grad=True)
132
+ else:
133
+ final_loss = weighted_loss.mean()
134
+
135
+ return Seq2SeqLMOutput(
136
+ loss=final_loss,
137
+ logits=logits,
138
+ past_key_values=outputs.past_key_values,
139
+ decoder_hidden_states=outputs.decoder_hidden_states,
140
+ decoder_attentions=outputs.decoder_attentions,
141
+ cross_attentions=outputs.cross_attentions,
142
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
143
+ encoder_hidden_states=outputs.encoder_hidden_states,
144
+ encoder_attentions=outputs.encoder_attentions,
145
+ )
146
+
147
+ def generate(self, *args, **kwargs):
148
+ """Delegate generation to the underlying model."""
149
+ return self.model.generate(*args, **kwargs)
150
+
151
+ def prepare_inputs_for_generation(self, *args, **kwargs):
152
+ """Delegate to underlying model."""
153
+ return self.model.prepare_inputs_for_generation(*args, **kwargs)
154
 
155
  @classmethod
156
  def from_pretrained(cls, pretrained_model_name_or_path: str,
157
  tokenizer, high_value_tokens: Set[str],
158
  entity_weight: float = 3.0, **kwargs):
159
  """Load pretrained OWSM model and wrap with entity-weighted loss."""
160
+ # Load the base model using the Auto class
161
  base_model = AutoModelForSpeechSeq2Seq.from_pretrained(
162
  pretrained_model_name_or_path, **kwargs
163
  )
164
 
165
+ # Initialize wrapper
166
  model = cls(
167
  config=base_model.config,
168
+ base_model=base_model,
169
  tokenizer=tokenizer,
170
  high_value_tokens=high_value_tokens,
171
  entity_weight=entity_weight
172
  )
173
 
 
174
  return model
175
 
176
+ def save_pretrained(self, save_directory, **kwargs):
177
+ """
178
+ Save the underlying model to the directory.
179
+ This ensures that the saved model is a standard OWSM model
180
+ that can be loaded with AutoModelForSpeechSeq2Seq for inference.
181
+ """
182
+ print(f"Saving underlying model to {save_directory}...")
183
+ self.model.save_pretrained(save_directory, **kwargs)
training/trainer.py CHANGED
@@ -4,24 +4,17 @@ import json
4
  import torch
5
  import numpy as np
6
  import random
7
- import pandas as pd
8
- from typing import Tuple, Optional
9
- from pathlib import Path
10
- from datasets import Dataset, Audio
11
  from transformers import (
12
  AutoProcessor,
13
- AutoModelForSpeechSeq2Seq,
14
  Seq2SeqTrainingArguments,
15
  Seq2SeqTrainer,
16
  DataCollatorForSeq2Seq,
17
  EarlyStoppingCallback,
18
  )
19
- from sklearn.model_selection import train_test_split
20
- import torchaudio
21
-
22
- from data.manager import ENTITIES_PATH, AUDIO_DIR, MODEL_OUTPUT_DIR
23
- from data.loader import get_train_dataframe
24
  from owsm_model import OWSMWithEntityLoss
 
25
 
26
  # Set seeds for reproducibility
27
  SEED = 42
@@ -36,7 +29,7 @@ torch.use_deterministic_algorithms(True, warn_only=True)
36
  MODEL_NAME = "espnet/owsm_v3.1_ebf_small"
37
  TARGET_SR = 16000
38
  MAX_AUDIO_LENGTH = 30 # seconds
39
-
40
 
41
  def compute_wer_metric(predictions, labels, tokenizer):
42
  """Compute Word Error Rate metric."""
@@ -51,13 +44,11 @@ def compute_wer_metric(predictions, labels, tokenizer):
51
  return 1.0 if len(hyp_words) > 0 else 0.0
52
 
53
  # Simple Levenshtein-like WER
54
- # For simplicity, use character-level edit distance approximation
55
  ref_str = ' '.join(ref_words)
56
  hyp_str = ' '.join(hyp_words)
57
  if ref_str == hyp_str:
58
  return 0.0
59
 
60
- # Count word-level differences
61
  ref_set = set(ref_words)
62
  hyp_set = set(hyp_words)
63
  common = len(ref_set & hyp_set)
@@ -66,7 +57,6 @@ def compute_wer_metric(predictions, labels, tokenizer):
66
 
67
  # Decode predictions and labels
68
  decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
69
- decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
70
 
71
  # Replace -100 with pad token for decoding
72
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
@@ -87,76 +77,22 @@ def compute_wer_metric(predictions, labels, tokenizer):
87
  return {"wer": wer}
88
 
89
 
90
- def load_audio(path: str, target_sr: int = TARGET_SR) -> np.ndarray:
91
  """
92
- Load and resample audio to target sample rate, convert to mono.
93
-
94
- FAILS LOUDLY if file doesn't exist - no silent fallbacks.
95
- """
96
- if not os.path.exists(path):
97
- raise FileNotFoundError(
98
- f"Audio file not found: {path}\n"
99
- f"Expected audio file at: {os.path.abspath(path)}\n"
100
- f"Please ensure all audio files are available before training."
101
- )
102
-
103
- try:
104
- wav, sr = torchaudio.load(path)
105
- except Exception as e:
106
- raise RuntimeError(
107
- f"Failed to load audio file: {path}\n"
108
- f"Error: {str(e)}\n"
109
- f"Please check that the file is a valid audio file."
110
- ) from e
111
-
112
- if sr != target_sr:
113
- wav = torchaudio.functional.resample(wav, sr, target_sr)
114
-
115
- # Convert to mono if stereo
116
- if wav.shape[0] > 1:
117
- wav = wav.mean(0, keepdim=True)
118
-
119
- return wav.squeeze(0).numpy() # Return as numpy array for processor
120
-
121
-
122
- def prepare_dataset(df: pd.DataFrame, audio_dir: str, processor, is_training: bool = True):
123
- """
124
- Prepare dataset for training/inference.
125
-
126
- FAILS LOUDLY if audio files are missing - no silent fallbacks.
127
  """
128
 
129
  def prepare_batch(batch):
130
  """Process a batch of examples."""
131
- audio_paths = batch["audio_path"]
132
- transcriptions = batch["Transcription"]
133
 
134
- # Load and process audio - FAIL LOUDLY if missing
135
- audio_arrays = []
136
- for audio_path in audio_paths:
137
- full_path = os.path.join(audio_dir, audio_path)
138
-
139
- # Validate file exists BEFORE processing
140
- if not os.path.exists(full_path):
141
- raise FileNotFoundError(
142
- f"Audio file not found during dataset preparation: {audio_path}\n"
143
- f"Full path: {os.path.abspath(full_path)}\n"
144
- f"Audio directory: {os.path.abspath(audio_dir)}\n"
145
- f"Please ensure all audio files are available before training."
146
- )
147
-
148
- audio = load_audio(full_path)
149
-
150
- # Truncate if too long
151
- max_samples = TARGET_SR * MAX_AUDIO_LENGTH
152
- if len(audio) > max_samples:
153
- audio = audio[:max_samples]
154
-
155
- audio_arrays.append(audio)
156
 
157
  # Process audio with processor
158
  inputs = processor(
159
- audio_arrays,
160
  sampling_rate=TARGET_SR,
161
  return_tensors="pt",
162
  padding=True,
@@ -178,85 +114,26 @@ def prepare_dataset(df: pd.DataFrame, audio_dir: str, processor, is_training: bo
178
 
179
  return batch
180
 
181
- # Create dataset
182
- dataset_dict = {
183
- "audio_path": df["ID"].apply(lambda x: f"{x}.wav").tolist(),
184
- "Transcription": df["Transcription"].astype(str).tolist(),
185
- }
186
-
187
- # Validate all audio files exist BEFORE creating dataset
188
- if is_training:
189
- print("Validating audio files exist...")
190
- missing_files = []
191
- for audio_path in dataset_dict["audio_path"]:
192
- full_path = os.path.join(audio_dir, audio_path)
193
- if not os.path.exists(full_path):
194
- missing_files.append(full_path)
195
-
196
- if missing_files:
197
- error_msg = (
198
- f"Found {len(missing_files)} missing audio files:\n"
199
- f"First 10 missing files:\n"
200
- )
201
- for f in missing_files[:10]:
202
- error_msg += f" - {f}\n"
203
- if len(missing_files) > 10:
204
- error_msg += f" ... and {len(missing_files) - 10} more\n"
205
- error_msg += f"\nPlease ensure all audio files are available before training."
206
- raise FileNotFoundError(error_msg)
207
-
208
- print(f"✓ All {len(dataset_dict['audio_path'])} audio files validated")
209
-
210
- dataset = Dataset.from_dict(dataset_dict)
211
 
212
  # Process in batches
213
  dataset = dataset.map(
214
  prepare_batch,
215
  batched=True,
216
  batch_size=16,
217
- remove_columns=["audio_path"], # Keep Transcription for reference
 
218
  )
219
 
220
  return dataset
221
 
222
 
223
- def create_stratified_split(df: pd.DataFrame, test_size: float = 0.1):
224
- """
225
- Create stratified train/val split based on:
226
- - Utterance length (short vs long)
227
- - Presence of Caribbean keywords
228
- """
229
- # Create bins for stratification
230
- df['word_count'] = df['Transcription'].str.split().str.len()
231
- df['has_caribbean'] = df['Transcription'].str.lower().str.contains(
232
- 'caribbean|bbc|trinidad|jamaica|guyana|haiti|barbados',
233
- regex=True,
234
- na=False
235
- )
236
-
237
- # Create stratification key
238
- df['length_bin'] = pd.cut(df['word_count'], bins=5, labels=False)
239
- df['stratify_key'] = df['length_bin'].astype(str) + '_' + df['has_caribbean'].astype(str)
240
-
241
- train_df, val_df = train_test_split(
242
- df,
243
- test_size=test_size,
244
- stratify=df['stratify_key'],
245
- random_state=SEED
246
- )
247
-
248
- print(f"Train: {len(train_df):,} samples")
249
- print(f"Val: {len(val_df):,} samples")
250
-
251
- return train_df, val_df
252
-
253
-
254
- def run_training_progress(epochs: int, batch_size: int, learning_rate: float, progress=None) -> Tuple[str, str]:
255
  """
256
- Run OWSM training with progress tracking.
257
-
258
- Uses espnet/owsm_v3.1_ebf_small with NO FALLBACKS.
259
- If model loading fails, raises exception with clear error message.
260
  """
261
  try:
262
  if progress:
@@ -278,72 +155,63 @@ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, pr
278
  print(f"Loaded {len(high_value_entities)} high-value entities")
279
 
280
  if progress:
281
- progress(0.1, desc="Loading training data from dataset...")
282
- try:
283
- train_df = get_train_dataframe()
284
- except ValueError as e:
285
- raise FileNotFoundError(
286
- f"Training data not available: {str(e)}. "
287
- f"Please ensure the dataset is loaded."
288
- )
289
- print(f"Loaded {len(train_df):,} training samples")
290
 
 
 
 
 
 
 
 
 
 
291
  # Create train/val split
292
  if progress:
293
  progress(0.15, desc="Creating train/val split...")
294
- train_df_split, val_df_split = create_stratified_split(train_df, test_size=0.1)
295
 
296
- # Load processor - NO FALLBACKS
 
 
 
 
 
 
 
 
 
 
297
  if progress:
298
  progress(0.2, desc=f"Loading processor: {MODEL_NAME}...")
299
  print(f"\nLoading processor: {MODEL_NAME}")
300
- print("NOTE: Using espnet/owsm_v3.1_ebf_small with NO FALLBACKS")
301
- print("If this fails, check that the model is available on HuggingFace.")
302
 
303
  try:
304
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
305
  except Exception as e:
306
- error_msg = (
307
- f"FAILED to load processor from {MODEL_NAME}\n\n"
308
- f"Error: {str(e)}\n\n"
309
- f"This training pipeline requires espnet/owsm_v3.1_ebf_small.\n"
310
- f"No fallbacks are configured. Please ensure:\n"
311
- f"1. The model is available on HuggingFace\n"
312
- f"2. You have internet access to download the model\n"
313
- f"3. You have sufficient disk space\n"
314
- f"4. The transformers library supports this model\n\n"
315
- f"If the model is not available, you may need to use ESPnet's native training framework."
316
- )
317
- raise RuntimeError(error_msg) from e
318
 
319
  print(f"✓ Processor loaded successfully")
320
 
321
- # Load model - NO FALLBACKS
322
  if progress:
323
  progress(0.25, desc=f"Loading model: {MODEL_NAME}...")
324
  print(f"\nLoading model: {MODEL_NAME}")
325
 
326
  try:
 
327
  model = OWSMWithEntityLoss.from_pretrained(
328
  MODEL_NAME,
329
  tokenizer=processor.tokenizer,
330
  high_value_tokens=high_value_entities,
331
- entity_weight=3.0, # 3x weight for entity errors
332
  )
333
  except Exception as e:
334
- error_msg = (
335
- f"FAILED to load model from {MODEL_NAME}\n\n"
336
- f"Error: {str(e)}\n\n"
337
- f"This training pipeline requires espnet/owsm_v3.1_ebf_small.\n"
338
- f"No fallbacks are configured. Please ensure:\n"
339
- f"1. The model is available on HuggingFace\n"
340
- f"2. You have internet access to download the model\n"
341
- f"3. You have sufficient disk space\n"
342
- f"4. The transformers library supports this model\n"
343
- f"5. AutoModelForSpeechSeq2Seq can load this model\n\n"
344
- f"If the model is not available, you may need to use ESPnet's native training framework."
345
- )
346
- raise RuntimeError(error_msg) from e
347
 
348
  print(f"✓ Model loaded successfully")
349
 
@@ -353,14 +221,14 @@ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, pr
353
 
354
  # Prepare datasets
355
  if progress:
356
- progress(0.3, desc="Preparing training dataset...")
357
- print("\nPreparing training dataset...")
358
- train_dataset = prepare_dataset(train_df_split, AUDIO_DIR, processor, is_training=True)
359
 
360
  if progress:
361
- progress(0.4, desc="Preparing validation dataset...")
362
- print("Preparing validation dataset...")
363
- val_dataset = prepare_dataset(val_df_split, AUDIO_DIR, processor, is_training=False)
364
 
365
  # Training arguments
366
  if progress:
@@ -370,7 +238,7 @@ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, pr
370
  output_dir=MODEL_OUTPUT_DIR,
371
  per_device_train_batch_size=batch_size,
372
  per_device_eval_batch_size=batch_size,
373
- gradient_accumulation_steps=4, # Effective batch size = batch_size * 4
374
  learning_rate=learning_rate,
375
  warmup_steps=500,
376
  num_train_epochs=epochs,
@@ -380,20 +248,23 @@ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, pr
380
  save_steps=1000,
381
  logging_steps=100,
382
  load_best_model_at_end=True,
383
- metric_for_best_model="wer", # Use WER instead of loss
384
- greater_is_better=False, # Lower WER is better
385
  save_total_limit=3,
386
- fp16=torch.cuda.is_available(), # Use mixed precision if GPU available
387
  dataloader_num_workers=4,
388
- report_to="none", # Disable wandb/tensorboard
389
  seed=SEED,
390
- predict_with_generate=True, # Need to generate for WER calculation
 
391
  )
392
 
393
  # Data collator
394
  data_collator = DataCollatorForSeq2Seq(
395
  processor=processor,
396
- model=model,
 
 
397
  padding=True,
398
  )
399
 
@@ -435,6 +306,7 @@ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, pr
435
  progress(0.95, desc="Saving model...")
436
 
437
  print(f"\nSaving model to {MODEL_OUTPUT_DIR}...")
 
438
  model.save_pretrained(MODEL_OUTPUT_DIR)
439
  processor.save_pretrained(MODEL_OUTPUT_DIR)
440
 
@@ -471,21 +343,12 @@ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, pr
471
  The model is now ready for inference!
472
  """
473
 
474
- return success_msg, json.dumps(final_metrics, indent=2)
475
 
476
- except FileNotFoundError as e:
477
- error_msg = f"❌ File Not Found Error:\n\n{str(e)}"
478
- if progress:
479
- progress(1.0, desc="Error!")
480
- return error_msg, ""
481
- except RuntimeError as e:
482
- error_msg = f"❌ Runtime Error:\n\n{str(e)}"
483
- if progress:
484
- progress(1.0, desc="Error!")
485
- return error_msg, ""
486
  except Exception as e:
487
  import traceback
488
- error_msg = f"❌ Unexpected Error: {str(e)}\n\n{traceback.format_exc()}"
 
489
  if progress:
490
  progress(1.0, desc="Error!")
491
- return error_msg, ""
 
4
  import torch
5
  import numpy as np
6
  import random
7
+ from typing import Tuple, Optional, Dict, Any
8
+ from datasets import load_dataset, Audio, DatasetDict
 
 
9
  from transformers import (
10
  AutoProcessor,
 
11
  Seq2SeqTrainingArguments,
12
  Seq2SeqTrainer,
13
  DataCollatorForSeq2Seq,
14
  EarlyStoppingCallback,
15
  )
 
 
 
 
 
16
  from owsm_model import OWSMWithEntityLoss
17
+ from data.manager import ENTITIES_PATH, MODEL_OUTPUT_DIR, BASE_DIR
18
 
19
  # Set seeds for reproducibility
20
  SEED = 42
 
29
  MODEL_NAME = "espnet/owsm_v3.1_ebf_small"
30
  TARGET_SR = 16000
31
  MAX_AUDIO_LENGTH = 30 # seconds
32
+ HF_DATASET_NAME = "shaun3141/caribbean-voices-hackathon"
33
 
34
  def compute_wer_metric(predictions, labels, tokenizer):
35
  """Compute Word Error Rate metric."""
 
44
  return 1.0 if len(hyp_words) > 0 else 0.0
45
 
46
  # Simple Levenshtein-like WER
 
47
  ref_str = ' '.join(ref_words)
48
  hyp_str = ' '.join(hyp_words)
49
  if ref_str == hyp_str:
50
  return 0.0
51
 
 
52
  ref_set = set(ref_words)
53
  hyp_set = set(hyp_words)
54
  common = len(ref_set & hyp_set)
 
57
 
58
  # Decode predictions and labels
59
  decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
 
60
 
61
  # Replace -100 with pad token for decoding
62
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
 
77
  return {"wer": wer}
78
 
79
 
80
+ def prepare_dataset_hf(dataset, processor):
81
  """
82
+ Prepare dataset using Hugging Face Datasets built-in audio handling.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  """
84
 
85
  def prepare_batch(batch):
86
  """Process a batch of examples."""
87
+ audio = batch["audio"]
88
+ transcriptions = batch["transcription"] # Note: check lowercase 't' in dataset
89
 
90
+ # Audio is already a dictionary with 'array' and 'sampling_rate'
91
+ # because we cast it to Audio() in the loading step
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # Process audio with processor
94
  inputs = processor(
95
+ [x["array"] for x in audio],
96
  sampling_rate=TARGET_SR,
97
  return_tensors="pt",
98
  padding=True,
 
114
 
115
  return batch
116
 
117
+ # Remove columns that are not needed
118
+ # Note: We keep 'transcription' maybe? No, remove it to save memory, we have labels.
119
+ # But check what columns exist first.
120
+ column_names = dataset.column_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  # Process in batches
123
  dataset = dataset.map(
124
  prepare_batch,
125
  batched=True,
126
  batch_size=16,
127
+ remove_columns=column_names,
128
+ desc="Preprocessing dataset",
129
  )
130
 
131
  return dataset
132
 
133
 
134
+ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, progress=None) -> Tuple[str, Optional[Dict[str, Any]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  """
136
+ Run OWSM training with progress tracking using HF Datasets.
 
 
 
137
  """
138
  try:
139
  if progress:
 
155
  print(f"Loaded {len(high_value_entities)} high-value entities")
156
 
157
  if progress:
158
+ progress(0.1, desc="Loading dataset from Hugging Face...")
159
+
160
+ # Load dataset from HF
161
+ hf_token = os.getenv("HF_TOKEN")
162
+ print(f"Loading dataset: {HF_DATASET_NAME}")
163
+ dataset = load_dataset(HF_DATASET_NAME, token=hf_token)
 
 
 
164
 
165
+ if 'train' not in dataset:
166
+ raise ValueError(f"Dataset {HF_DATASET_NAME} does not contain a 'train' split.")
167
+
168
+ train_full = dataset['train']
169
+ print(f"Loaded {len(train_full):,} total training samples")
170
+
171
+ # Cast to Audio to ensure correct sampling rate
172
+ train_full = train_full.cast_column("audio", Audio(sampling_rate=TARGET_SR))
173
+
174
  # Create train/val split
175
  if progress:
176
  progress(0.15, desc="Creating train/val split...")
 
177
 
178
+ # Simple random split since we don't want to download all audio to stratify by length/content
179
+ # unless we want to iterate the whole dataset which might be slow.
180
+ # We'll use a random split for speed and simplicity with the streamed/remote dataset.
181
+ split_dataset = train_full.train_test_split(test_size=0.1, seed=SEED)
182
+ train_dataset_raw = split_dataset['train']
183
+ val_dataset_raw = split_dataset['test']
184
+
185
+ print(f"Train: {len(train_dataset_raw):,} samples")
186
+ print(f"Val: {len(val_dataset_raw):,} samples")
187
+
188
+ # Load processor
189
  if progress:
190
  progress(0.2, desc=f"Loading processor: {MODEL_NAME}...")
191
  print(f"\nLoading processor: {MODEL_NAME}")
 
 
192
 
193
  try:
194
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
195
  except Exception as e:
196
+ raise RuntimeError(f"FAILED to load processor from {MODEL_NAME}: {e}")
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  print(f"✓ Processor loaded successfully")
199
 
200
+ # Load model
201
  if progress:
202
  progress(0.25, desc=f"Loading model: {MODEL_NAME}...")
203
  print(f"\nLoading model: {MODEL_NAME}")
204
 
205
  try:
206
+ # Use our new wrapper class
207
  model = OWSMWithEntityLoss.from_pretrained(
208
  MODEL_NAME,
209
  tokenizer=processor.tokenizer,
210
  high_value_tokens=high_value_entities,
211
+ entity_weight=3.0,
212
  )
213
  except Exception as e:
214
+ raise RuntimeError(f"FAILED to load model from {MODEL_NAME}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  print(f"✓ Model loaded successfully")
217
 
 
221
 
222
  # Prepare datasets
223
  if progress:
224
+ progress(0.3, desc="Preprocessing training dataset...")
225
+ print("\nPreprocessing training dataset...")
226
+ train_dataset = prepare_dataset_hf(train_dataset_raw, processor)
227
 
228
  if progress:
229
+ progress(0.4, desc="Preprocessing validation dataset...")
230
+ print("Preprocessing validation dataset...")
231
+ val_dataset = prepare_dataset_hf(val_dataset_raw, processor)
232
 
233
  # Training arguments
234
  if progress:
 
238
  output_dir=MODEL_OUTPUT_DIR,
239
  per_device_train_batch_size=batch_size,
240
  per_device_eval_batch_size=batch_size,
241
+ gradient_accumulation_steps=4,
242
  learning_rate=learning_rate,
243
  warmup_steps=500,
244
  num_train_epochs=epochs,
 
248
  save_steps=1000,
249
  logging_steps=100,
250
  load_best_model_at_end=True,
251
+ metric_for_best_model="wer",
252
+ greater_is_better=False,
253
  save_total_limit=3,
254
+ fp16=torch.cuda.is_available(),
255
  dataloader_num_workers=4,
256
+ report_to="none",
257
  seed=SEED,
258
+ predict_with_generate=True,
259
+ generation_max_length=200, # Prevent infinite generation
260
  )
261
 
262
  # Data collator
263
  data_collator = DataCollatorForSeq2Seq(
264
  processor=processor,
265
+ model=model, # The trainer needs the model for the collator sometimes if it uses it for padding?
266
+ # Actually DataCollatorForSeq2Seq uses tokenizer usually.
267
+ # But passing model is fine.
268
  padding=True,
269
  )
270
 
 
306
  progress(0.95, desc="Saving model...")
307
 
308
  print(f"\nSaving model to {MODEL_OUTPUT_DIR}...")
309
+ # This calls our custom save_pretrained which saves the inner model
310
  model.save_pretrained(MODEL_OUTPUT_DIR)
311
  processor.save_pretrained(MODEL_OUTPUT_DIR)
312
 
 
343
  The model is now ready for inference!
344
  """
345
 
346
+ return success_msg, final_metrics
347
 
 
 
 
 
 
 
 
 
 
 
348
  except Exception as e:
349
  import traceback
350
+ error_msg = f"❌ Error during training: {str(e)}\n\n{traceback.format_exc()}"
351
+ print(error_msg)
352
  if progress:
353
  progress(1.0, desc="Error!")
354
+ return error_msg, None