import enum import numpy as np import torch import torchaudio import transformers import wavlm_phoneme_fr_it SAMPLING_RATE = 16_000 class Languages(enum.Enum): FR = 0 IT = 1 class Scoring(enum.Enum): NUMBER_CORRECT = 0 PHONEME_DELETION = 1 def get_model(): checkpoint = "hugofara/wavlm-base-plus-phonemizer-fr-it" processor = transformers.AutoProcessor.from_pretrained( checkpoint, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|" ) model = wavlm_phoneme_fr_it.WavLMPhonemeFrIt.from_pretrained( checkpoint ) return model, processor def preprocess_audio(audio_data, target_sample_rate=SAMPLING_RATE): """Convert audio to the correct format and sample rate""" if audio_data is None: return None sample_rate, audio = audio_data # Ensure audio is in the correct format (mono, float32) if len(audio.shape) > 1: audio = audio.mean(axis=1) # Convert to mono if stereo # Resample if necessary using torchaudio if sample_rate != target_sample_rate: audio_tensor = torch.from_numpy(audio).float().unsqueeze(0) resampled = torchaudio.transforms.Resample(sample_rate, target_sample_rate)(audio_tensor) audio = resampled.squeeze(0).numpy() # Normalize audio audio = audio.astype(np.float32) if np.max(np.abs(audio)) > 0: audio = audio / np.max(np.abs(audio)) return audio def prepare_model_inputs(audio, processor, sampling_rate=SAMPLING_RATE, language=Languages.FR): """Prepare inputs for the model""" inputs = processor( audio, sampling_rate=sampling_rate, return_tensors="pt", padding=True ) # Add language tensor (assuming French/Italian model) language_code = 0. if language is Languages.FR else 1. inputs["language"] = torch.tensor([[language_code]], dtype=torch.float32) return inputs def run_inference(model, inputs): """Run model inference and return predictions""" with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_ids = torch.argmax(logits, dim=-1) return outputs, predicted_ids def decode_transcription(processor, predicted_ids): """Decode predicted IDs to text""" return processor.batch_decode(predicted_ids)[0] def compare_with_target(transcription, target_word): """Compare transcription with target word and return formatted result""" result = f"**Transcription:** {transcription}\n\n" if target_word and target_word.strip(): target_clean = target_word.strip().lower() transcription_clean = transcription.lower().replace("[pad]", "").strip() if target_clean in transcription_clean: result += f"✅ **Match found!** The target word '{target_word}' appears in the transcription." else: result += f"❌ **No exact match.** The target word '{target_word}' was not found in the transcription." return result