|
|
import enum |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torchaudio |
|
|
import transformers |
|
|
import wavlm_phoneme_fr_it |
|
|
|
|
|
SAMPLING_RATE = 16_000 |
|
|
|
|
|
class Languages(enum.Enum): |
|
|
FR = 0 |
|
|
IT = 1 |
|
|
|
|
|
|
|
|
class Scoring(enum.Enum): |
|
|
NUMBER_CORRECT = 0 |
|
|
PHONEME_DELETION = 1 |
|
|
|
|
|
|
|
|
def get_model(): |
|
|
checkpoint = "hugofara/wavlm-base-plus-phonemizer-fr-it" |
|
|
processor = transformers.AutoProcessor.from_pretrained( |
|
|
checkpoint, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|" |
|
|
) |
|
|
|
|
|
model = wavlm_phoneme_fr_it.WavLMPhonemeFrIt.from_pretrained( |
|
|
checkpoint |
|
|
) |
|
|
return model, processor |
|
|
|
|
|
|
|
|
def preprocess_audio(audio_data, target_sample_rate=SAMPLING_RATE): |
|
|
"""Convert audio to the correct format and sample rate""" |
|
|
if audio_data is None: |
|
|
return None |
|
|
|
|
|
sample_rate, audio = audio_data |
|
|
|
|
|
|
|
|
if len(audio.shape) > 1: |
|
|
audio = audio.mean(axis=1) |
|
|
|
|
|
|
|
|
if sample_rate != target_sample_rate: |
|
|
audio_tensor = torch.from_numpy(audio).float().unsqueeze(0) |
|
|
resampled = torchaudio.transforms.Resample(sample_rate, target_sample_rate)(audio_tensor) |
|
|
audio = resampled.squeeze(0).numpy() |
|
|
|
|
|
|
|
|
audio = audio.astype(np.float32) |
|
|
if np.max(np.abs(audio)) > 0: |
|
|
audio = audio / np.max(np.abs(audio)) |
|
|
|
|
|
return audio |
|
|
|
|
|
|
|
|
def prepare_model_inputs(audio, processor, sampling_rate=SAMPLING_RATE, language=Languages.FR): |
|
|
"""Prepare inputs for the model""" |
|
|
inputs = processor( |
|
|
audio, |
|
|
sampling_rate=sampling_rate, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
) |
|
|
|
|
|
|
|
|
language_code = 0. if language is Languages.FR else 1. |
|
|
inputs["language"] = torch.tensor([[language_code]], dtype=torch.float32) |
|
|
|
|
|
return inputs |
|
|
|
|
|
|
|
|
def run_inference(model, inputs): |
|
|
"""Run model inference and return predictions""" |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
|
|
|
|
return outputs, predicted_ids |
|
|
|
|
|
|
|
|
def decode_transcription(processor, predicted_ids): |
|
|
"""Decode predicted IDs to text""" |
|
|
return processor.batch_decode(predicted_ids)[0] |
|
|
|
|
|
|
|
|
def compare_with_target(transcription, target_word): |
|
|
"""Compare transcription with target word and return formatted result""" |
|
|
result = f"**Transcription:** {transcription}\n\n" |
|
|
|
|
|
if target_word and target_word.strip(): |
|
|
target_clean = target_word.strip().lower() |
|
|
transcription_clean = transcription.lower().replace("[pad]", "").strip() |
|
|
|
|
|
if target_clean in transcription_clean: |
|
|
result += f"β
**Match found!** The target word '{target_word}' appears in the transcription." |
|
|
else: |
|
|
result += f"β **No exact match.** The target word '{target_word}' was not found in the transcription." |
|
|
|
|
|
return result |
|
|
|