|
|
import gradio as gr |
|
|
|
|
|
import common |
|
|
from grapheme_to_phoneme import Grapheme2Phoneme |
|
|
import aligner |
|
|
|
|
|
model, processor = common.get_model() |
|
|
|
|
|
|
|
|
phonemizer_fr = Grapheme2Phoneme(language="fr", cuda=False) |
|
|
phonemizer_it = Grapheme2Phoneme(language="it", cuda=False) |
|
|
|
|
|
|
|
|
def phonemize_text(text, language): |
|
|
"""Convert text to phonemes using the appropriate phonemizer""" |
|
|
if not text or not text.strip(): |
|
|
return "" |
|
|
|
|
|
phonemizer = phonemizer_fr if language == "French" else phonemizer_it |
|
|
phonemes = phonemizer.phonemize([text.strip()]) |
|
|
return " ".join([word.replace(" ", "") for word in phonemes]) if phonemes and phonemes[0] else "" |
|
|
|
|
|
|
|
|
def process_audio_advanced(audio_data, target_word, language, advanced_mode, insertion_cost, deletion_cost, threshold, temperature, scoring_method): |
|
|
"""Process recorded audio with advanced alignment if enabled""" |
|
|
if audio_data is None: |
|
|
return "Please record some audio first.", "", "", None |
|
|
|
|
|
|
|
|
phonemized_target = "" |
|
|
if target_word and target_word.strip(): |
|
|
phonemized_target = phonemize_text(target_word, language) |
|
|
|
|
|
|
|
|
audio = common.preprocess_audio(audio_data) |
|
|
if audio is None: |
|
|
return "Failed to process audio.", "", "", None |
|
|
|
|
|
|
|
|
lang_enum = common.Languages.FR if language == "French" else common.Languages.IT |
|
|
inputs = common.prepare_model_inputs(audio, processor, language=lang_enum) |
|
|
|
|
|
|
|
|
outputs, predicted_ids = common.run_inference(model, inputs) |
|
|
|
|
|
|
|
|
transcription = common.decode_transcription(processor, predicted_ids) |
|
|
|
|
|
|
|
|
result = f"**Language:** {language}\n\n" |
|
|
result += f"**Transcription:** {transcription}\n\n" |
|
|
|
|
|
alignment_result = "" |
|
|
alignment_plot_fig = None |
|
|
|
|
|
if target_word and target_word.strip(): |
|
|
result += f"**Target Word:** {target_word}\n" |
|
|
result += f"**Target Phonemes:** {phonemized_target}\n\n" |
|
|
|
|
|
if advanced_mode and phonemized_target: |
|
|
|
|
|
try: |
|
|
|
|
|
target_encoded = aligner.encode_phonemes( |
|
|
phonemized_target, processor.tokenizer |
|
|
) |
|
|
|
|
|
|
|
|
prediction_logits = outputs.logits |
|
|
|
|
|
|
|
|
matching, alignment_score = aligner.bellman_matching( |
|
|
prediction_logits, |
|
|
target_encoded, |
|
|
insertion_cost=insertion_cost, |
|
|
deletion_cost=deletion_cost, |
|
|
metric=aligner.l2_logit_norm |
|
|
) |
|
|
|
|
|
|
|
|
weights = [insertion_cost, deletion_cost, threshold, temperature] |
|
|
scoring_enum = common.Scoring.NUMBER_CORRECT if scoring_method == "NUMBER_CORRECT" else common.Scoring.PHONEME_DELETION |
|
|
score = aligner.get_alignment_score( |
|
|
prediction_logits, |
|
|
target_encoded, |
|
|
weights, |
|
|
processor.tokenizer.pad_token_id, |
|
|
scoring=scoring_enum |
|
|
) |
|
|
|
|
|
|
|
|
reduced_prediction = aligner.remove_pad_tokens( |
|
|
prediction_logits, processor.tokenizer.pad_token_id, temperature |
|
|
) |
|
|
|
|
|
|
|
|
path_matrix = aligner.compute_path_matrix( |
|
|
reduced_prediction, |
|
|
target_encoded, |
|
|
aligner.l2_logit_norm, |
|
|
insertion_cost, |
|
|
deletion_cost |
|
|
) |
|
|
|
|
|
|
|
|
matching_for_plot, _ = aligner.bellman_matching( |
|
|
reduced_prediction, |
|
|
target_encoded, |
|
|
insertion_cost=insertion_cost, |
|
|
deletion_cost=deletion_cost, |
|
|
metric=aligner.l2_logit_norm |
|
|
) |
|
|
|
|
|
alignment_plot_fig = aligner.display_matrix_result( |
|
|
path_matrix, matching_for_plot, reduced_prediction, target_encoded, processor |
|
|
) |
|
|
|
|
|
alignment_result = f"**π¬ Advanced Alignment Analysis:**\n\n" |
|
|
alignment_result += f"**Scoring Method:** {scoring_method}\n" |
|
|
alignment_result += f"**Settings:** Insertion={insertion_cost}, Deletion={deletion_cost}, Threshold={threshold}, Temperature={temperature}\n\n" |
|
|
alignment_result += f"**Alignment Score:** {alignment_score:.3f}\n" |
|
|
alignment_result += f"**Matching Points:** {len(matching)}\n" |
|
|
|
|
|
if scoring_method == "NUMBER_CORRECT": |
|
|
alignment_result += f"**Correct Phonemes:** {score}/{target_encoded.shape[1]}\n\n" |
|
|
accuracy = score / target_encoded.shape[1] if target_encoded.shape[1] > 0 else 0 |
|
|
if accuracy >= 0.9: |
|
|
alignment_result += "β
**Excellent Match!** Most target phonemes are correctly aligned." |
|
|
elif accuracy >= 0.7: |
|
|
alignment_result += "β οΈ **Good Match!** Most target phonemes align well." |
|
|
else: |
|
|
alignment_result += "β **Poor Match.** Many target phonemes don't align correctly." |
|
|
else: |
|
|
alignment_result += f"**Classification Score:** {score}/2\n\n" |
|
|
if score == 2: |
|
|
alignment_result += "β
**Perfect Match!** Target phonemes align perfectly with transcription." |
|
|
elif score == 1: |
|
|
alignment_result += "β οΈ **Close Match!** Target phonemes align with 1 minor error." |
|
|
else: |
|
|
alignment_result += "β **Poor Match.** Target phonemes don't align well with transcription." |
|
|
|
|
|
except Exception as e: |
|
|
alignment_result = f"**β οΈ Alignment Error:** {str(e)}" |
|
|
else: |
|
|
|
|
|
transcription_clean = transcription.lower().replace("[pad]", "").strip() |
|
|
phonemized_target_clean = phonemized_target.lower().strip() |
|
|
|
|
|
if phonemized_target_clean in transcription_clean: |
|
|
result += f"β
**Phoneme Match!** The phonemized target appears in the transcription." |
|
|
else: |
|
|
result += f"β **No phoneme match.** The phonemized target was not found in the transcription." |
|
|
|
|
|
return result, phonemized_target, alignment_result, alignment_plot_fig |
|
|
|
|
|
|
|
|
|
|
|
def process_audio(audio_data, target_word, language): |
|
|
"""Simple audio processing without advanced features""" |
|
|
result, phonemes, _, _ = process_audio_advanced(audio_data, target_word, language, False, 1.3, 3.0, 0.7, 1.0, "NUMBER_CORRECT") |
|
|
return result, phonemes |
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
"""Create and return the Gradio interface""" |
|
|
|
|
|
with gr.Blocks(title="WavLM ASR Demo") as demo: |
|
|
gr.Markdown("# WavLM ASR Capabilities Demo") |
|
|
gr.Markdown("Record audio and optionally specify a target word to test the ASR model's accuracy.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
language_radio = gr.Radio( |
|
|
choices=["French", "Italian"], |
|
|
value="French", |
|
|
label="Model Language", |
|
|
info="Select the language for ASR recognition" |
|
|
) |
|
|
|
|
|
advanced_mode = gr.Checkbox( |
|
|
label="π¬ Advanced Mode", |
|
|
value=False, |
|
|
info="Use advanced alignment analysis for more accurate matching" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(visible=False) as weight_controls: |
|
|
gr.Markdown("### βοΈ Alignment Parameters") |
|
|
|
|
|
insertion_cost = gr.Slider( |
|
|
minimum=0.5, |
|
|
maximum=2.0, |
|
|
value=1.1, |
|
|
step=0.1, |
|
|
label="Insertion Cost", |
|
|
info="Penalty for extra phonemes in prediction (lower = more lenient)" |
|
|
) |
|
|
|
|
|
deletion_cost = gr.Slider( |
|
|
minimum=0.5, |
|
|
maximum=3.0, |
|
|
value=0.7, |
|
|
step=0.1, |
|
|
label="Deletion Cost", |
|
|
info="Penalty for missing phonemes in prediction (higher = stricter)" |
|
|
) |
|
|
|
|
|
threshold = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.7, |
|
|
step=0.05, |
|
|
label="Match Threshold", |
|
|
info="Minimum similarity for phoneme match (higher = stricter)" |
|
|
) |
|
|
|
|
|
temperature = gr.Slider( |
|
|
minimum=0.5, |
|
|
maximum=10, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
label="Temperature", |
|
|
info="Softmax temperature for prediction confidence (1.0 = normal)" |
|
|
) |
|
|
|
|
|
scoring_method = gr.Radio( |
|
|
choices=["NUMBER_CORRECT", "PHONEME_DELETION"], |
|
|
value="NUMBER_CORRECT", |
|
|
label="Scoring Method", |
|
|
info="Method for calculating alignment scores" |
|
|
) |
|
|
|
|
|
target_word_input = gr.Textbox( |
|
|
label="Target Word (optional)", |
|
|
placeholder="Enter a word you expect to say...", |
|
|
info="Will be converted to phonemes for comparison" |
|
|
) |
|
|
|
|
|
phonemes_display = gr.Textbox( |
|
|
label="Target Phonemes", |
|
|
interactive=False, |
|
|
placeholder="Phonemes will appear here...", |
|
|
info="Automatic phoneme conversion of your target word" |
|
|
) |
|
|
|
|
|
audio_input = gr.Audio( |
|
|
label="Record Audio", |
|
|
sources=["microphone", "upload"], |
|
|
type="numpy" |
|
|
) |
|
|
|
|
|
process_btn = gr.Button("Process Audio", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_text = gr.Markdown( |
|
|
value="Results will appear here after processing..." |
|
|
) |
|
|
|
|
|
alignment_output = gr.Markdown( |
|
|
value="", |
|
|
visible=False, |
|
|
label="Alignment Analysis" |
|
|
) |
|
|
|
|
|
alignment_plot = gr.Plot( |
|
|
label="Alignment Matrix", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
def update_phonemes(text, language): |
|
|
if text and text.strip(): |
|
|
return phonemize_text(text, language) |
|
|
return "" |
|
|
|
|
|
|
|
|
def toggle_advanced_features(advanced): |
|
|
return ( |
|
|
gr.update(visible=advanced), |
|
|
gr.update(visible=advanced), |
|
|
gr.update(visible=advanced) |
|
|
) |
|
|
|
|
|
target_word_input.change( |
|
|
fn=update_phonemes, |
|
|
inputs=[target_word_input, language_radio], |
|
|
outputs=phonemes_display |
|
|
) |
|
|
|
|
|
language_radio.change( |
|
|
fn=update_phonemes, |
|
|
inputs=[target_word_input, language_radio], |
|
|
outputs=phonemes_display |
|
|
) |
|
|
|
|
|
advanced_mode.change( |
|
|
fn=toggle_advanced_features, |
|
|
inputs=advanced_mode, |
|
|
outputs=[alignment_output, weight_controls, alignment_plot] |
|
|
) |
|
|
|
|
|
|
|
|
def process_with_mode(audio_data, target_word, language, advanced, ins_cost, del_cost, thresh, temp, score_method): |
|
|
result, phonemes, alignment, plot_fig = process_audio_advanced( |
|
|
audio_data, target_word, language, advanced, ins_cost, del_cost, thresh, temp, score_method |
|
|
) |
|
|
return result, phonemes, alignment, plot_fig |
|
|
|
|
|
process_btn.click( |
|
|
fn=process_with_mode, |
|
|
inputs=[audio_input, target_word_input, language_radio, advanced_mode, |
|
|
insertion_cost, deletion_cost, threshold, temperature, scoring_method], |
|
|
outputs=[output_text, phonemes_display, alignment_output, alignment_plot] |
|
|
) |
|
|
|
|
|
|
|
|
audio_input.change( |
|
|
fn=process_with_mode, |
|
|
inputs=[audio_input, target_word_input, language_radio, advanced_mode, |
|
|
insertion_cost, deletion_cost, threshold, temperature, scoring_method], |
|
|
outputs=[output_text, phonemes_display, alignment_output, alignment_plot] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
my_demo = create_interface() |
|
|
my_demo.launch() |
|
|
|