Hugo Farajallah
feat(scoring): adds new scorig method.
9bc684b
import enum
import numpy as np
import torch
import torchaudio
import transformers
import wavlm_phoneme_fr_it
SAMPLING_RATE = 16_000
class Languages(enum.Enum):
FR = 0
IT = 1
class Scoring(enum.Enum):
NUMBER_CORRECT = 0
PHONEME_DELETION = 1
def get_model():
checkpoint = "hugofara/wavlm-base-plus-phonemizer-fr-it"
processor = transformers.AutoProcessor.from_pretrained(
checkpoint, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|"
)
model = wavlm_phoneme_fr_it.WavLMPhonemeFrIt.from_pretrained(
checkpoint
)
return model, processor
def preprocess_audio(audio_data, target_sample_rate=SAMPLING_RATE):
"""Convert audio to the correct format and sample rate"""
if audio_data is None:
return None
sample_rate, audio = audio_data
# Ensure audio is in the correct format (mono, float32)
if len(audio.shape) > 1:
audio = audio.mean(axis=1) # Convert to mono if stereo
# Resample if necessary using torchaudio
if sample_rate != target_sample_rate:
audio_tensor = torch.from_numpy(audio).float().unsqueeze(0)
resampled = torchaudio.transforms.Resample(sample_rate, target_sample_rate)(audio_tensor)
audio = resampled.squeeze(0).numpy()
# Normalize audio
audio = audio.astype(np.float32)
if np.max(np.abs(audio)) > 0:
audio = audio / np.max(np.abs(audio))
return audio
def prepare_model_inputs(audio, processor, sampling_rate=SAMPLING_RATE, language=Languages.FR):
"""Prepare inputs for the model"""
inputs = processor(
audio,
sampling_rate=sampling_rate,
return_tensors="pt",
padding=True
)
# Add language tensor (assuming French/Italian model)
language_code = 0. if language is Languages.FR else 1.
inputs["language"] = torch.tensor([[language_code]], dtype=torch.float32)
return inputs
def run_inference(model, inputs):
"""Run model inference and return predictions"""
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_ids = torch.argmax(logits, dim=-1)
return outputs, predicted_ids
def decode_transcription(processor, predicted_ids):
"""Decode predicted IDs to text"""
return processor.batch_decode(predicted_ids)[0]
def compare_with_target(transcription, target_word):
"""Compare transcription with target word and return formatted result"""
result = f"**Transcription:** {transcription}\n\n"
if target_word and target_word.strip():
target_clean = target_word.strip().lower()
transcription_clean = transcription.lower().replace("[pad]", "").strip()
if target_clean in transcription_clean:
result += f"βœ… **Match found!** The target word '{target_word}' appears in the transcription."
else:
result += f"❌ **No exact match.** The target word '{target_word}' was not found in the transcription."
return result