Hugo Farajallah
fix(files): the main file should be named app.py.
560a824
import gradio as gr
import common
from grapheme_to_phoneme import Grapheme2Phoneme
import aligner
model, processor = common.get_model()
# Initialize phonemizers for both languages
phonemizer_fr = Grapheme2Phoneme(language="fr", cuda=False)
phonemizer_it = Grapheme2Phoneme(language="it", cuda=False)
def phonemize_text(text, language):
"""Convert text to phonemes using the appropriate phonemizer"""
if not text or not text.strip():
return ""
phonemizer = phonemizer_fr if language == "French" else phonemizer_it
phonemes = phonemizer.phonemize([text.strip()])
return " ".join([word.replace(" ", "") for word in phonemes]) if phonemes and phonemes[0] else ""
def process_audio_advanced(audio_data, target_word, language, advanced_mode, insertion_cost, deletion_cost, threshold, temperature, scoring_method):
"""Process recorded audio with advanced alignment if enabled"""
if audio_data is None:
return "Please record some audio first.", "", "", None
# Convert target word to phonemes if provided
phonemized_target = ""
if target_word and target_word.strip():
phonemized_target = phonemize_text(target_word, language)
# Preprocess audio
audio = common.preprocess_audio(audio_data)
if audio is None:
return "Failed to process audio.", "", "", None
# Prepare model inputs with correct language
lang_enum = common.Languages.FR if language == "French" else common.Languages.IT
inputs = common.prepare_model_inputs(audio, processor, language=lang_enum)
# Run inference
outputs, predicted_ids = common.run_inference(model, inputs)
# Decode transcription
transcription = common.decode_transcription(processor, predicted_ids)
# Create basic result
result = f"**Language:** {language}\n\n"
result += f"**Transcription:** {transcription}\n\n"
alignment_result = ""
alignment_plot_fig = None
if target_word and target_word.strip():
result += f"**Target Word:** {target_word}\n"
result += f"**Target Phonemes:** {phonemized_target}\n\n"
if advanced_mode and phonemized_target:
# Advanced mode: Use alignment
try:
# Encode target phonemes
target_encoded = aligner.encode_phonemes(
phonemized_target, processor.tokenizer
)
# Get model logits (raw outputs before softmax)
prediction_logits = outputs.logits
# Perform alignment using user-defined weights
matching, alignment_score = aligner.bellman_matching(
prediction_logits,
target_encoded,
insertion_cost=insertion_cost,
deletion_cost=deletion_cost,
metric=aligner.l2_logit_norm
)
# Calculate alignment score using user-defined weights and scoring method
weights = [insertion_cost, deletion_cost, threshold, temperature]
scoring_enum = common.Scoring.NUMBER_CORRECT if scoring_method == "NUMBER_CORRECT" else common.Scoring.PHONEME_DELETION
score = aligner.get_alignment_score(
prediction_logits,
target_encoded,
weights,
processor.tokenizer.pad_token_id,
scoring=scoring_enum
)
# Use reduced prediction tensor for alignment plot (remove temporal effects)
reduced_prediction = aligner.remove_pad_tokens(
prediction_logits, processor.tokenizer.pad_token_id, temperature
)
# Generate alignment plot with reduced prediction
path_matrix = aligner.compute_path_matrix(
reduced_prediction,
target_encoded,
aligner.l2_logit_norm,
insertion_cost,
deletion_cost
)
# Re-compute matching with reduced prediction for visualization
matching_for_plot, _ = aligner.bellman_matching(
reduced_prediction,
target_encoded,
insertion_cost=insertion_cost,
deletion_cost=deletion_cost,
metric=aligner.l2_logit_norm
)
alignment_plot_fig = aligner.display_matrix_result(
path_matrix, matching_for_plot, reduced_prediction, target_encoded, processor
)
alignment_result = f"**πŸ”¬ Advanced Alignment Analysis:**\n\n"
alignment_result += f"**Scoring Method:** {scoring_method}\n"
alignment_result += f"**Settings:** Insertion={insertion_cost}, Deletion={deletion_cost}, Threshold={threshold}, Temperature={temperature}\n\n"
alignment_result += f"**Alignment Score:** {alignment_score:.3f}\n"
alignment_result += f"**Matching Points:** {len(matching)}\n"
if scoring_method == "NUMBER_CORRECT":
alignment_result += f"**Correct Phonemes:** {score}/{target_encoded.shape[1]}\n\n"
accuracy = score / target_encoded.shape[1] if target_encoded.shape[1] > 0 else 0
if accuracy >= 0.9:
alignment_result += "βœ… **Excellent Match!** Most target phonemes are correctly aligned."
elif accuracy >= 0.7:
alignment_result += "⚠️ **Good Match!** Most target phonemes align well."
else:
alignment_result += "❌ **Poor Match.** Many target phonemes don't align correctly."
else: # PHONEME_DELETION
alignment_result += f"**Classification Score:** {score}/2\n\n"
if score == 2:
alignment_result += "βœ… **Perfect Match!** Target phonemes align perfectly with transcription."
elif score == 1:
alignment_result += "⚠️ **Close Match!** Target phonemes align with 1 minor error."
else:
alignment_result += "❌ **Poor Match.** Target phonemes don't align well with transcription."
except Exception as e:
alignment_result = f"**⚠️ Alignment Error:** {str(e)}"
else:
# Simple mode: String matching
transcription_clean = transcription.lower().replace("[pad]", "").strip()
phonemized_target_clean = phonemized_target.lower().strip()
if phonemized_target_clean in transcription_clean:
result += f"βœ… **Phoneme Match!** The phonemized target appears in the transcription."
else:
result += f"❌ **No phoneme match.** The phonemized target was not found in the transcription."
return result, phonemized_target, alignment_result, alignment_plot_fig
# Keep the simple function for backward compatibility
def process_audio(audio_data, target_word, language):
"""Simple audio processing without advanced features"""
result, phonemes, _, _ = process_audio_advanced(audio_data, target_word, language, False, 1.3, 3.0, 0.7, 1.0, "NUMBER_CORRECT")
return result, phonemes
def create_interface():
"""Create and return the Gradio interface"""
with gr.Blocks(title="WavLM ASR Demo") as demo:
gr.Markdown("# WavLM ASR Capabilities Demo")
gr.Markdown("Record audio and optionally specify a target word to test the ASR model's accuracy.")
with gr.Row():
with gr.Column():
language_radio = gr.Radio(
choices=["French", "Italian"],
value="French",
label="Model Language",
info="Select the language for ASR recognition"
)
advanced_mode = gr.Checkbox(
label="πŸ”¬ Advanced Mode",
value=False,
info="Use advanced alignment analysis for more accurate matching"
)
# Advanced mode weight sliders (initially hidden)
with gr.Group(visible=False) as weight_controls:
gr.Markdown("### βš™οΈ Alignment Parameters")
insertion_cost = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.1,
step=0.1,
label="Insertion Cost",
info="Penalty for extra phonemes in prediction (lower = more lenient)"
)
deletion_cost = gr.Slider(
minimum=0.5,
maximum=3.0,
value=0.7,
step=0.1,
label="Deletion Cost",
info="Penalty for missing phonemes in prediction (higher = stricter)"
)
threshold = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.05,
label="Match Threshold",
info="Minimum similarity for phoneme match (higher = stricter)"
)
temperature = gr.Slider(
minimum=0.5,
maximum=10,
value=1.0,
step=0.1,
label="Temperature",
info="Softmax temperature for prediction confidence (1.0 = normal)"
)
scoring_method = gr.Radio(
choices=["NUMBER_CORRECT", "PHONEME_DELETION"],
value="NUMBER_CORRECT",
label="Scoring Method",
info="Method for calculating alignment scores"
)
target_word_input = gr.Textbox(
label="Target Word (optional)",
placeholder="Enter a word you expect to say...",
info="Will be converted to phonemes for comparison"
)
phonemes_display = gr.Textbox(
label="Target Phonemes",
interactive=False,
placeholder="Phonemes will appear here...",
info="Automatic phoneme conversion of your target word"
)
audio_input = gr.Audio(
label="Record Audio",
sources=["microphone", "upload"],
type="numpy"
)
process_btn = gr.Button("Process Audio", variant="primary")
with gr.Column():
output_text = gr.Markdown(
value="Results will appear here after processing..."
)
alignment_output = gr.Markdown(
value="",
visible=False,
label="Alignment Analysis"
)
alignment_plot = gr.Plot(
label="Alignment Matrix",
visible=False
)
# Update phonemes when target word or language changes
def update_phonemes(text, language):
if text and text.strip():
return phonemize_text(text, language)
return ""
# Toggle alignment output and weight controls visibility based on advanced mode
def toggle_advanced_features(advanced):
return (
gr.update(visible=advanced), # alignment_output
gr.update(visible=advanced), # weight_controls
gr.update(visible=advanced) # alignment_plot
)
target_word_input.change(
fn=update_phonemes,
inputs=[target_word_input, language_radio],
outputs=phonemes_display
)
language_radio.change(
fn=update_phonemes,
inputs=[target_word_input, language_radio],
outputs=phonemes_display
)
advanced_mode.change(
fn=toggle_advanced_features,
inputs=advanced_mode,
outputs=[alignment_output, weight_controls, alignment_plot]
)
# Main processing function
def process_with_mode(audio_data, target_word, language, advanced, ins_cost, del_cost, thresh, temp, score_method):
result, phonemes, alignment, plot_fig = process_audio_advanced(
audio_data, target_word, language, advanced, ins_cost, del_cost, thresh, temp, score_method
)
return result, phonemes, alignment, plot_fig
process_btn.click(
fn=process_with_mode,
inputs=[audio_input, target_word_input, language_radio, advanced_mode,
insertion_cost, deletion_cost, threshold, temperature, scoring_method],
outputs=[output_text, phonemes_display, alignment_output, alignment_plot]
)
# Auto-process when audio is recorded
audio_input.change(
fn=process_with_mode,
inputs=[audio_input, target_word_input, language_radio, advanced_mode,
insertion_cost, deletion_cost, threshold, temperature, scoring_method],
outputs=[output_text, phonemes_display, alignment_output, alignment_plot]
)
return demo
if __name__ == "__main__":
my_demo = create_interface()
my_demo.launch()