File size: 3,006 Bytes
898ed95
 
0e90d9f
 
 
a9d4833
 
 
0e90d9f
 
898ed95
 
 
a9d4833
9bc684b
 
 
 
 
 
a9d4833
 
 
 
 
 
 
 
 
 
0e90d9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
898ed95
0e90d9f
 
 
 
 
 
 
 
 
898ed95
 
0e90d9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import enum

import numpy as np
import torch
import torchaudio
import transformers
import wavlm_phoneme_fr_it

SAMPLING_RATE = 16_000

class Languages(enum.Enum):
    FR = 0
    IT = 1


class Scoring(enum.Enum):
    NUMBER_CORRECT = 0
    PHONEME_DELETION = 1


def get_model():
    checkpoint = "hugofara/wavlm-base-plus-phonemizer-fr-it"
    processor = transformers.AutoProcessor.from_pretrained(
        checkpoint, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|"
    )

    model = wavlm_phoneme_fr_it.WavLMPhonemeFrIt.from_pretrained(
        checkpoint
    )
    return model, processor


def preprocess_audio(audio_data, target_sample_rate=SAMPLING_RATE):
    """Convert audio to the correct format and sample rate"""
    if audio_data is None:
        return None

    sample_rate, audio = audio_data

    # Ensure audio is in the correct format (mono, float32)
    if len(audio.shape) > 1:
        audio = audio.mean(axis=1)  # Convert to mono if stereo

    # Resample if necessary using torchaudio
    if sample_rate != target_sample_rate:
        audio_tensor = torch.from_numpy(audio).float().unsqueeze(0)
        resampled = torchaudio.transforms.Resample(sample_rate, target_sample_rate)(audio_tensor)
        audio = resampled.squeeze(0).numpy()

    # Normalize audio
    audio = audio.astype(np.float32)
    if np.max(np.abs(audio)) > 0:
        audio = audio / np.max(np.abs(audio))

    return audio


def prepare_model_inputs(audio, processor, sampling_rate=SAMPLING_RATE, language=Languages.FR):
    """Prepare inputs for the model"""
    inputs = processor(
        audio,
        sampling_rate=sampling_rate,
        return_tensors="pt",
        padding=True
    )

    # Add language tensor (assuming French/Italian model)
    language_code = 0. if language is Languages.FR else 1.
    inputs["language"] = torch.tensor([[language_code]], dtype=torch.float32)

    return inputs


def run_inference(model, inputs):
    """Run model inference and return predictions"""
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_ids = torch.argmax(logits, dim=-1)

    return outputs, predicted_ids


def decode_transcription(processor, predicted_ids):
    """Decode predicted IDs to text"""
    return processor.batch_decode(predicted_ids)[0]


def compare_with_target(transcription, target_word):
    """Compare transcription with target word and return formatted result"""
    result = f"**Transcription:** {transcription}\n\n"

    if target_word and target_word.strip():
        target_clean = target_word.strip().lower()
        transcription_clean = transcription.lower().replace("[pad]", "").strip()

        if target_clean in transcription_clean:
            result += f"βœ… **Match found!** The target word '{target_word}' appears in the transcription."
        else:
            result += f"❌ **No exact match.** The target word '{target_word}' was not found in the transcription."

    return result