shaun3141 commited on
Commit
c0cd25b
·
1 Parent(s): 5c2e86a

Fix: Ensure generation_config and pad_token handling for Whisper training

Browse files

- Copy generation_config from base model to prevent NoneType errors in Seq2SeqTrainer
- Set task='transcribe' and clear deprecated forced_decoder_ids
- Fix pad_token == eos_token issue by setting pad_token to unk_token
- Ensure pad_token_id is set in generation_config and model config
- Copy additional attributes (main_input_name, forced_decoder_ids, suppress_tokens) for compatibility

Files changed (3) hide show
  1. owsm_model.py +56 -0
  2. training/trainer.py +64 -5
  3. training/whisper_trainer.py +85 -10
owsm_model.py CHANGED
@@ -175,6 +175,62 @@ class OWSMWithEntityLoss(PreTrainedModel):
175
  entity_weight=entity_weight
176
  )
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  return model
179
 
180
  def save_pretrained(self, save_directory, **kwargs):
 
175
  entity_weight=entity_weight
176
  )
177
 
178
+ # Copy important attributes from base model to ensure full compatibility
179
+ # with transformers components like Seq2SeqTrainer, data collators, etc.
180
+
181
+ # 1. generation_config - Required for Seq2SeqTrainer evaluation
182
+ # Seq2SeqTrainer accesses model.generation_config._from_model_config in prediction_step
183
+ if hasattr(base_model, 'generation_config') and base_model.generation_config is not None:
184
+ # Copy generation_config from base model (preferred method)
185
+ model.generation_config = base_model.generation_config
186
+ else:
187
+ # Fallback: create generation_config from model config
188
+ # This handles cases where base model doesn't have generation_config set
189
+ try:
190
+ from transformers import GenerationConfig
191
+ model.generation_config = GenerationConfig.from_model_config(model.config)
192
+ except Exception:
193
+ # If GenerationConfig.from_model_config fails, create a minimal config
194
+ # This ensures generation_config is never None, preventing AttributeError
195
+ from transformers import GenerationConfig
196
+ model.generation_config = GenerationConfig()
197
+
198
+ # 1b. Ensure generation_config uses modern task/language flags instead of deprecated forced_decoder_ids
199
+ # For Whisper models, prefer task="transcribe" and language settings over forced_decoder_ids
200
+ # Setting task/language will cause forced_decoder_ids to be ignored (as per transformers deprecation)
201
+ if hasattr(model.generation_config, 'task'):
202
+ if model.generation_config.task is None:
203
+ # Set default task for Whisper models (transcribe, not translate)
204
+ model.generation_config.task = "transcribe"
205
+ # If task is set, forced_decoder_ids will be ignored, so we can clear it to avoid warnings
206
+ if hasattr(model.generation_config, 'forced_decoder_ids') and model.generation_config.forced_decoder_ids is not None:
207
+ # Clear forced_decoder_ids when task is set to avoid deprecation warnings
208
+ model.generation_config.forced_decoder_ids = None
209
+
210
+ # 1c. Ensure pad_token_id is set in generation_config to avoid attention mask warnings
211
+ # This is important when pad_token_id == eos_token_id
212
+ if hasattr(tokenizer, 'pad_token_id') and tokenizer.pad_token_id is not None:
213
+ if hasattr(model.generation_config, 'pad_token_id'):
214
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
215
+
216
+ # If base model has language set, preserve it; otherwise default to None (auto-detect)
217
+ # Note: For Caribbean Voices, we want transcription, not translation to English
218
+ # So we don't force language='en' - let the model auto-detect or use what's in config
219
+
220
+ # 2. main_input_name - Important for data collators and input handling
221
+ # e.g., "input_features" for Whisper, "input_values" for Wav2Vec2
222
+ if hasattr(base_model, 'main_input_name'):
223
+ model.main_input_name = base_model.main_input_name
224
+
225
+ # 3. Model-specific config attributes that might be set on the instance
226
+ # Note: forced_decoder_ids is deprecated in favor of task/language flags in generation_config
227
+ # We still copy it for backward compatibility, but the modern approach is preferred
228
+ for attr_name in ['forced_decoder_ids', 'suppress_tokens']:
229
+ if hasattr(base_model, attr_name):
230
+ attr_value = getattr(base_model, attr_name)
231
+ if attr_value is not None:
232
+ setattr(model, attr_name, attr_value)
233
+
234
  return model
235
 
236
  def save_pretrained(self, save_directory, **kwargs):
training/trainer.py CHANGED
@@ -1,9 +1,12 @@
1
  """Training logic for OWSM fine-tuning."""
2
  import os
 
 
3
  import json
4
  import torch
5
  import numpy as np
6
  import random
 
7
  from typing import Tuple, Optional, Dict, Any
8
  from datasets import load_dataset, Audio, DatasetDict, disable_caching
9
  from transformers import (
@@ -16,7 +19,7 @@ from transformers import (
16
  WhisperProcessor,
17
  )
18
  from owsm_model import OWSMWithEntityLoss
19
- from data.manager import ENTITIES_PATH, MODEL_OUTPUT_DIR, BASE_DIR
20
 
21
  # Disable dataset caching to save disk space
22
  disable_caching()
@@ -82,10 +85,39 @@ def compute_wer_metric(predictions, labels, tokenizer):
82
  return {"wer": wer}
83
 
84
 
85
- def prepare_dataset_hf(dataset, processor):
 
 
 
 
 
 
86
  """
87
  Prepare dataset using Hugging Face Datasets built-in audio handling.
 
 
 
 
 
 
 
 
 
88
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def prepare_batch(batch):
91
  """Process a batch of examples."""
@@ -131,11 +163,22 @@ def prepare_dataset_hf(dataset, processor):
131
  batched=True,
132
  batch_size=16,
133
  remove_columns=column_names,
134
- desc="Preprocessing dataset",
135
  load_from_cache_file=False, # Don't load from cache
136
  keep_in_memory=True, # Keep in memory to avoid disk writes
137
  )
138
 
 
 
 
 
 
 
 
 
 
 
 
139
  return dataset
140
 
141
 
@@ -448,12 +491,26 @@ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, pr
448
  if progress:
449
  progress(0.3, desc="Preprocessing training dataset...")
450
  print("\nPreprocessing training dataset...")
451
- train_dataset = prepare_dataset_hf(train_dataset_raw, processor)
 
 
 
 
 
 
 
452
 
453
  if progress:
454
  progress(0.4, desc="Preprocessing validation dataset...")
455
  print("Preprocessing validation dataset...")
456
- val_dataset = prepare_dataset_hf(val_dataset_raw, processor)
 
 
 
 
 
 
 
457
 
458
  # Training arguments
459
  if progress:
@@ -479,10 +536,12 @@ def run_training_progress(epochs: int, batch_size: int, learning_rate: float, pr
479
  save_total_limit=3,
480
  fp16=torch.cuda.is_available(),
481
  dataloader_num_workers=4,
 
482
  report_to="none",
483
  seed=SEED,
484
  predict_with_generate=True, # Still used for seq2seq generation during eval
485
  generation_max_length=200, # Prevent infinite generation
 
486
  )
487
 
488
  # Data collator
 
1
  """Training logic for OWSM fine-tuning."""
2
  import os
3
+ # Disable tokenizers parallelism to avoid fork warning with DataLoader workers
4
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
5
  import json
6
  import torch
7
  import numpy as np
8
  import random
9
+ import hashlib
10
  from typing import Tuple, Optional, Dict, Any
11
  from datasets import load_dataset, Audio, DatasetDict, disable_caching
12
  from transformers import (
 
19
  WhisperProcessor,
20
  )
21
  from owsm_model import OWSMWithEntityLoss
22
+ from data.manager import ENTITIES_PATH, MODEL_OUTPUT_DIR, BASE_DIR, CACHE_DIR
23
 
24
  # Disable dataset caching to save disk space
25
  disable_caching()
 
85
  return {"wer": wer}
86
 
87
 
88
+ def get_cache_key(dataset_name: str, model_name: str, split: str, seed: int) -> str:
89
+ """Generate a cache key based on dataset, model, split, and seed."""
90
+ cache_string = f"{dataset_name}_{model_name}_{split}_{seed}"
91
+ return hashlib.md5(cache_string.encode()).hexdigest()
92
+
93
+
94
+ def prepare_dataset_hf(dataset, processor, dataset_name: str = None, model_name: str = None, split: str = None, use_cache: bool = True):
95
  """
96
  Prepare dataset using Hugging Face Datasets built-in audio handling.
97
+ Supports caching to avoid reprocessing.
98
+
99
+ Args:
100
+ dataset: The dataset to process
101
+ processor: The processor to use for preprocessing
102
+ dataset_name: Name of the dataset (for cache key)
103
+ model_name: Name of the model (for cache key)
104
+ split: Split name ('train' or 'val') (for cache key)
105
+ use_cache: Whether to use cache if available
106
  """
107
+ # Try to load from cache if enabled and cache key components provided
108
+ if use_cache and dataset_name and model_name and split:
109
+ cache_key = get_cache_key(dataset_name, model_name, split, SEED)
110
+ cache_path = os.path.join(CACHE_DIR, cache_key)
111
+
112
+ if os.path.exists(cache_path):
113
+ print(f"Loading preprocessed {split} dataset from cache: {cache_path}")
114
+ try:
115
+ from datasets import load_from_disk
116
+ cached_dataset = load_from_disk(cache_path)
117
+ print(f"✓ Successfully loaded cached {split} dataset ({len(cached_dataset):,} samples)")
118
+ return cached_dataset
119
+ except Exception as e:
120
+ print(f"⚠ Failed to load from cache: {e}. Reprocessing...")
121
 
122
  def prepare_batch(batch):
123
  """Process a batch of examples."""
 
163
  batched=True,
164
  batch_size=16,
165
  remove_columns=column_names,
166
+ desc=None, # Disable progress bar for dataset preprocessing
167
  load_from_cache_file=False, # Don't load from cache
168
  keep_in_memory=True, # Keep in memory to avoid disk writes
169
  )
170
 
171
+ # Save to cache if enabled and cache key components provided
172
+ if use_cache and dataset_name and model_name and split:
173
+ cache_key = get_cache_key(dataset_name, model_name, split, SEED)
174
+ cache_path = os.path.join(CACHE_DIR, cache_key)
175
+ print(f"Saving preprocessed {split} dataset to cache: {cache_path}")
176
+ try:
177
+ dataset.save_to_disk(cache_path)
178
+ print(f"✓ Successfully cached {split} dataset ({len(dataset):,} samples)")
179
+ except Exception as e:
180
+ print(f"⚠ Failed to save to cache: {e}. Continuing without cache...")
181
+
182
  return dataset
183
 
184
 
 
491
  if progress:
492
  progress(0.3, desc="Preprocessing training dataset...")
493
  print("\nPreprocessing training dataset...")
494
+ train_dataset = prepare_dataset_hf(
495
+ train_dataset_raw,
496
+ processor,
497
+ dataset_name=HF_DATASET_NAME,
498
+ model_name=MODEL_NAME,
499
+ split="train",
500
+ use_cache=True
501
+ )
502
 
503
  if progress:
504
  progress(0.4, desc="Preprocessing validation dataset...")
505
  print("Preprocessing validation dataset...")
506
+ val_dataset = prepare_dataset_hf(
507
+ val_dataset_raw,
508
+ processor,
509
+ dataset_name=HF_DATASET_NAME,
510
+ model_name=MODEL_NAME,
511
+ split="val",
512
+ use_cache=True
513
+ )
514
 
515
  # Training arguments
516
  if progress:
 
536
  save_total_limit=3,
537
  fp16=torch.cuda.is_available(),
538
  dataloader_num_workers=4,
539
+ dataloader_pin_memory=True, # Faster CPU→GPU transfers for GPU training
540
  report_to="none",
541
  seed=SEED,
542
  predict_with_generate=True, # Still used for seq2seq generation during eval
543
  generation_max_length=200, # Prevent infinite generation
544
+ disable_tqdm=True, # Disable progress bars during training
545
  )
546
 
547
  # Data collator
training/whisper_trainer.py CHANGED
@@ -3,10 +3,13 @@ Whisper training using HuggingFace transformers.
3
  Full integration with HuggingFace training features.
4
  """
5
  import os
 
 
6
  import json
7
  import torch
8
  import numpy as np
9
  import random
 
10
  from typing import Tuple, Optional, Dict, Any, List, Union
11
  from dataclasses import dataclass
12
  from datasets import load_dataset, Audio, disable_caching
@@ -17,7 +20,7 @@ from transformers import (
17
  EarlyStoppingCallback,
18
  )
19
  from owsm_model import OWSMWithEntityLoss
20
- from data.manager import ENTITIES_PATH, MODEL_OUTPUT_DIR, BASE_DIR
21
 
22
  # Disable dataset caching to save disk space
23
  disable_caching()
@@ -93,12 +96,13 @@ class DataCollatorSpeechSeq2SeqWithPadding:
93
  batch = {"input_features": input_features_batch}
94
 
95
  # Pad labels (text tokens) using the processor's tokenizer
96
- # Pass as list of dicts with "input_ids" key
97
- label_features_dicts = [{"input_ids": label.tolist() if isinstance(label, np.ndarray) else label} for label in labels_list]
98
- labels_batch = self.processor.tokenizer.pad(
99
- label_features_dicts,
100
  return_tensors="pt",
101
- padding=True
 
102
  )
103
 
104
  # Replace padding token id's of the labels by -100 so they are ignored by the loss function
@@ -161,10 +165,39 @@ def compute_wer_metric(predictions, labels, tokenizer):
161
  return {"wer": np.mean(wer_scores)}
162
 
163
 
164
- def prepare_whisper_dataset(dataset, processor):
 
 
 
 
 
 
165
  """
166
  Prepare dataset for Whisper training using Hugging Face Datasets.
 
 
 
 
 
 
 
 
 
167
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  def prepare_batch(batch):
170
  """Process a batch of examples."""
@@ -231,11 +264,22 @@ def prepare_whisper_dataset(dataset, processor):
231
  batched=True,
232
  batch_size=16,
233
  remove_columns=column_names,
234
- desc="Preprocessing dataset",
235
  load_from_cache_file=False, # Don't load from cache
236
  keep_in_memory=True, # Keep in memory to avoid disk writes
237
  )
238
 
 
 
 
 
 
 
 
 
 
 
 
239
  return dataset
240
 
241
 
@@ -299,6 +343,17 @@ def run_whisper_training_progress(epochs: int, batch_size: int, learning_rate: f
299
  processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME)
300
  print(f"✓ Whisper processor loaded successfully")
301
 
 
 
 
 
 
 
 
 
 
 
 
302
  # Load Whisper model
303
  if progress:
304
  progress(0.25, desc=f"Loading Whisper model: {WHISPER_MODEL_NAME}...")
@@ -314,6 +369,10 @@ def run_whisper_training_progress(epochs: int, batch_size: int, learning_rate: f
314
  attn_implementation="eager",
315
  )
316
 
 
 
 
 
317
  print(f"✓ Whisper model loaded successfully")
318
 
319
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -324,12 +383,26 @@ def run_whisper_training_progress(epochs: int, batch_size: int, learning_rate: f
324
  if progress:
325
  progress(0.3, desc="Preprocessing training dataset...")
326
  print("\nPreprocessing training dataset...")
327
- train_dataset = prepare_whisper_dataset(train_dataset_raw, processor)
 
 
 
 
 
 
 
328
 
329
  if progress:
330
  progress(0.4, desc="Preprocessing validation dataset...")
331
  print("Preprocessing validation dataset...")
332
- val_dataset = prepare_whisper_dataset(val_dataset_raw, processor)
 
 
 
 
 
 
 
333
 
334
  # Training arguments
335
  if progress:
@@ -355,10 +428,12 @@ def run_whisper_training_progress(epochs: int, batch_size: int, learning_rate: f
355
  save_total_limit=3,
356
  fp16=torch.cuda.is_available(),
357
  dataloader_num_workers=4,
 
358
  report_to="none",
359
  seed=SEED,
360
  predict_with_generate=True, # Still used for seq2seq generation during eval
361
  generation_max_length=200,
 
362
  )
363
 
364
  # Data collator
 
3
  Full integration with HuggingFace training features.
4
  """
5
  import os
6
+ # Disable tokenizers parallelism to avoid fork warning with DataLoader workers
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
  import json
9
  import torch
10
  import numpy as np
11
  import random
12
+ import hashlib
13
  from typing import Tuple, Optional, Dict, Any, List, Union
14
  from dataclasses import dataclass
15
  from datasets import load_dataset, Audio, disable_caching
 
20
  EarlyStoppingCallback,
21
  )
22
  from owsm_model import OWSMWithEntityLoss
23
+ from data.manager import ENTITIES_PATH, MODEL_OUTPUT_DIR, BASE_DIR, CACHE_DIR
24
 
25
  # Disable dataset caching to save disk space
26
  disable_caching()
 
96
  batch = {"input_features": input_features_batch}
97
 
98
  # Pad labels (text tokens) using the processor's tokenizer
99
+ # Use tokenizer.__call__() for better performance with fast tokenizers
100
+ label_ids_list = [label.tolist() if isinstance(label, np.ndarray) else label for label in labels_list]
101
+ labels_batch = self.processor.tokenizer(
102
+ label_ids_list,
103
  return_tensors="pt",
104
+ padding=True,
105
+ truncation=False # Already handled in preprocessing
106
  )
107
 
108
  # Replace padding token id's of the labels by -100 so they are ignored by the loss function
 
165
  return {"wer": np.mean(wer_scores)}
166
 
167
 
168
+ def get_cache_key(dataset_name: str, model_name: str, split: str, seed: int) -> str:
169
+ """Generate a cache key based on dataset, model, split, and seed."""
170
+ cache_string = f"{dataset_name}_{model_name}_{split}_{seed}"
171
+ return hashlib.md5(cache_string.encode()).hexdigest()
172
+
173
+
174
+ def prepare_whisper_dataset(dataset, processor, dataset_name: str = None, model_name: str = None, split: str = None, use_cache: bool = True):
175
  """
176
  Prepare dataset for Whisper training using Hugging Face Datasets.
177
+ Supports caching to avoid reprocessing.
178
+
179
+ Args:
180
+ dataset: The dataset to process
181
+ processor: The processor to use for preprocessing
182
+ dataset_name: Name of the dataset (for cache key)
183
+ model_name: Name of the model (for cache key)
184
+ split: Split name ('train' or 'val') (for cache key)
185
+ use_cache: Whether to use cache if available
186
  """
187
+ # Try to load from cache if enabled and cache key components provided
188
+ if use_cache and dataset_name and model_name and split:
189
+ cache_key = get_cache_key(dataset_name, model_name, split, SEED)
190
+ cache_path = os.path.join(CACHE_DIR, cache_key)
191
+
192
+ if os.path.exists(cache_path):
193
+ print(f"Loading preprocessed {split} dataset from cache: {cache_path}")
194
+ try:
195
+ from datasets import load_from_disk
196
+ cached_dataset = load_from_disk(cache_path)
197
+ print(f"✓ Successfully loaded cached {split} dataset ({len(cached_dataset):,} samples)")
198
+ return cached_dataset
199
+ except Exception as e:
200
+ print(f"⚠ Failed to load from cache: {e}. Reprocessing...")
201
 
202
  def prepare_batch(batch):
203
  """Process a batch of examples."""
 
264
  batched=True,
265
  batch_size=16,
266
  remove_columns=column_names,
267
+ desc=None, # Disable progress bar for dataset preprocessing
268
  load_from_cache_file=False, # Don't load from cache
269
  keep_in_memory=True, # Keep in memory to avoid disk writes
270
  )
271
 
272
+ # Save to cache if enabled and cache key components provided
273
+ if use_cache and dataset_name and model_name and split:
274
+ cache_key = get_cache_key(dataset_name, model_name, split, SEED)
275
+ cache_path = os.path.join(CACHE_DIR, cache_key)
276
+ print(f"Saving preprocessed {split} dataset to cache: {cache_path}")
277
+ try:
278
+ dataset.save_to_disk(cache_path)
279
+ print(f"✓ Successfully cached {split} dataset ({len(dataset):,} samples)")
280
+ except Exception as e:
281
+ print(f"⚠ Failed to save to cache: {e}. Continuing without cache...")
282
+
283
  return dataset
284
 
285
 
 
343
  processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME)
344
  print(f"✓ Whisper processor loaded successfully")
345
 
346
+ # Fix pad_token issue: Whisper tokenizers often have pad_token_id == eos_token_id
347
+ # This causes warnings about attention masks. Set pad_token to unk_token if needed.
348
+ if processor.tokenizer.pad_token_id == processor.tokenizer.eos_token_id:
349
+ if processor.tokenizer.unk_token_id is not None:
350
+ processor.tokenizer.pad_token_id = processor.tokenizer.unk_token_id
351
+ processor.tokenizer.pad_token = processor.tokenizer.unk_token
352
+ print(f"✓ Set pad_token to unk_token ({processor.tokenizer.unk_token_id}) to avoid attention mask warnings")
353
+ else:
354
+ # If no unk_token, use eos_token but ensure attention masks are always passed
355
+ print(f"⚠ pad_token == eos_token ({processor.tokenizer.eos_token_id}). Ensure attention masks are passed during generation.")
356
+
357
  # Load Whisper model
358
  if progress:
359
  progress(0.25, desc=f"Loading Whisper model: {WHISPER_MODEL_NAME}...")
 
369
  attn_implementation="eager",
370
  )
371
 
372
+ # Update model config to match tokenizer pad_token_id
373
+ if hasattr(model.config, 'pad_token_id'):
374
+ model.config.pad_token_id = processor.tokenizer.pad_token_id
375
+
376
  print(f"✓ Whisper model loaded successfully")
377
 
378
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
383
  if progress:
384
  progress(0.3, desc="Preprocessing training dataset...")
385
  print("\nPreprocessing training dataset...")
386
+ train_dataset = prepare_whisper_dataset(
387
+ train_dataset_raw,
388
+ processor,
389
+ dataset_name=HF_DATASET_NAME,
390
+ model_name=WHISPER_MODEL_NAME,
391
+ split="train",
392
+ use_cache=True
393
+ )
394
 
395
  if progress:
396
  progress(0.4, desc="Preprocessing validation dataset...")
397
  print("Preprocessing validation dataset...")
398
+ val_dataset = prepare_whisper_dataset(
399
+ val_dataset_raw,
400
+ processor,
401
+ dataset_name=HF_DATASET_NAME,
402
+ model_name=WHISPER_MODEL_NAME,
403
+ split="val",
404
+ use_cache=True
405
+ )
406
 
407
  # Training arguments
408
  if progress:
 
428
  save_total_limit=3,
429
  fp16=torch.cuda.is_available(),
430
  dataloader_num_workers=4,
431
+ dataloader_pin_memory=True, # Faster CPU→GPU transfers for GPU training
432
  report_to="none",
433
  seed=SEED,
434
  predict_with_generate=True, # Still used for seq2seq generation during eval
435
  generation_max_length=200,
436
+ disable_tqdm=True, # Disable progress bars during training
437
  )
438
 
439
  # Data collator