import gradio as gr import common from grapheme_to_phoneme import Grapheme2Phoneme import aligner model, processor = common.get_model() # Initialize phonemizers for both languages phonemizer_fr = Grapheme2Phoneme(language="fr", cuda=False) phonemizer_it = Grapheme2Phoneme(language="it", cuda=False) def phonemize_text(text, language): """Convert text to phonemes using the appropriate phonemizer""" if not text or not text.strip(): return "" phonemizer = phonemizer_fr if language == "French" else phonemizer_it phonemes = phonemizer.phonemize([text.strip()]) return " ".join([word.replace(" ", "") for word in phonemes]) if phonemes and phonemes[0] else "" def process_audio_advanced(audio_data, target_word, language, advanced_mode, insertion_cost, deletion_cost, threshold, temperature, scoring_method): """Process recorded audio with advanced alignment if enabled""" if audio_data is None: return "Please record some audio first.", "", "", None # Convert target word to phonemes if provided phonemized_target = "" if target_word and target_word.strip(): phonemized_target = phonemize_text(target_word, language) # Preprocess audio audio = common.preprocess_audio(audio_data) if audio is None: return "Failed to process audio.", "", "", None # Prepare model inputs with correct language lang_enum = common.Languages.FR if language == "French" else common.Languages.IT inputs = common.prepare_model_inputs(audio, processor, language=lang_enum) # Run inference outputs, predicted_ids = common.run_inference(model, inputs) # Decode transcription transcription = common.decode_transcription(processor, predicted_ids) # Create basic result result = f"**Language:** {language}\n\n" result += f"**Transcription:** {transcription}\n\n" alignment_result = "" alignment_plot_fig = None if target_word and target_word.strip(): result += f"**Target Word:** {target_word}\n" result += f"**Target Phonemes:** {phonemized_target}\n\n" if advanced_mode and phonemized_target: # Advanced mode: Use alignment try: # Encode target phonemes target_encoded = aligner.encode_phonemes( phonemized_target, processor.tokenizer ) # Get model logits (raw outputs before softmax) prediction_logits = outputs.logits # Perform alignment using user-defined weights matching, alignment_score = aligner.bellman_matching( prediction_logits, target_encoded, insertion_cost=insertion_cost, deletion_cost=deletion_cost, metric=aligner.l2_logit_norm ) # Calculate alignment score using user-defined weights and scoring method weights = [insertion_cost, deletion_cost, threshold, temperature] scoring_enum = common.Scoring.NUMBER_CORRECT if scoring_method == "NUMBER_CORRECT" else common.Scoring.PHONEME_DELETION score = aligner.get_alignment_score( prediction_logits, target_encoded, weights, processor.tokenizer.pad_token_id, scoring=scoring_enum ) # Use reduced prediction tensor for alignment plot (remove temporal effects) reduced_prediction = aligner.remove_pad_tokens( prediction_logits, processor.tokenizer.pad_token_id, temperature ) # Generate alignment plot with reduced prediction path_matrix = aligner.compute_path_matrix( reduced_prediction, target_encoded, aligner.l2_logit_norm, insertion_cost, deletion_cost ) # Re-compute matching with reduced prediction for visualization matching_for_plot, _ = aligner.bellman_matching( reduced_prediction, target_encoded, insertion_cost=insertion_cost, deletion_cost=deletion_cost, metric=aligner.l2_logit_norm ) alignment_plot_fig = aligner.display_matrix_result( path_matrix, matching_for_plot, reduced_prediction, target_encoded, processor ) alignment_result = f"**🔬 Advanced Alignment Analysis:**\n\n" alignment_result += f"**Scoring Method:** {scoring_method}\n" alignment_result += f"**Settings:** Insertion={insertion_cost}, Deletion={deletion_cost}, Threshold={threshold}, Temperature={temperature}\n\n" alignment_result += f"**Alignment Score:** {alignment_score:.3f}\n" alignment_result += f"**Matching Points:** {len(matching)}\n" if scoring_method == "NUMBER_CORRECT": alignment_result += f"**Correct Phonemes:** {score}/{target_encoded.shape[1]}\n\n" accuracy = score / target_encoded.shape[1] if target_encoded.shape[1] > 0 else 0 if accuracy >= 0.9: alignment_result += "✅ **Excellent Match!** Most target phonemes are correctly aligned." elif accuracy >= 0.7: alignment_result += "⚠️ **Good Match!** Most target phonemes align well." else: alignment_result += "❌ **Poor Match.** Many target phonemes don't align correctly." else: # PHONEME_DELETION alignment_result += f"**Classification Score:** {score}/2\n\n" if score == 2: alignment_result += "✅ **Perfect Match!** Target phonemes align perfectly with transcription." elif score == 1: alignment_result += "⚠️ **Close Match!** Target phonemes align with 1 minor error." else: alignment_result += "❌ **Poor Match.** Target phonemes don't align well with transcription." except Exception as e: alignment_result = f"**⚠️ Alignment Error:** {str(e)}" else: # Simple mode: String matching transcription_clean = transcription.lower().replace("[pad]", "").strip() phonemized_target_clean = phonemized_target.lower().strip() if phonemized_target_clean in transcription_clean: result += f"✅ **Phoneme Match!** The phonemized target appears in the transcription." else: result += f"❌ **No phoneme match.** The phonemized target was not found in the transcription." return result, phonemized_target, alignment_result, alignment_plot_fig # Keep the simple function for backward compatibility def process_audio(audio_data, target_word, language): """Simple audio processing without advanced features""" result, phonemes, _, _ = process_audio_advanced(audio_data, target_word, language, False, 1.3, 3.0, 0.7, 1.0, "NUMBER_CORRECT") return result, phonemes def create_interface(): """Create and return the Gradio interface""" with gr.Blocks(title="WavLM ASR Demo") as demo: gr.Markdown("# WavLM ASR Capabilities Demo") gr.Markdown("Record audio and optionally specify a target word to test the ASR model's accuracy.") with gr.Row(): with gr.Column(): language_radio = gr.Radio( choices=["French", "Italian"], value="French", label="Model Language", info="Select the language for ASR recognition" ) advanced_mode = gr.Checkbox( label="🔬 Advanced Mode", value=False, info="Use advanced alignment analysis for more accurate matching" ) # Advanced mode weight sliders (initially hidden) with gr.Group(visible=False) as weight_controls: gr.Markdown("### ⚙️ Alignment Parameters") insertion_cost = gr.Slider( minimum=0.5, maximum=2.0, value=1.1, step=0.1, label="Insertion Cost", info="Penalty for extra phonemes in prediction (lower = more lenient)" ) deletion_cost = gr.Slider( minimum=0.5, maximum=3.0, value=0.7, step=0.1, label="Deletion Cost", info="Penalty for missing phonemes in prediction (higher = stricter)" ) threshold = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Match Threshold", info="Minimum similarity for phoneme match (higher = stricter)" ) temperature = gr.Slider( minimum=0.5, maximum=10, value=1.0, step=0.1, label="Temperature", info="Softmax temperature for prediction confidence (1.0 = normal)" ) scoring_method = gr.Radio( choices=["NUMBER_CORRECT", "PHONEME_DELETION"], value="NUMBER_CORRECT", label="Scoring Method", info="Method for calculating alignment scores" ) target_word_input = gr.Textbox( label="Target Word (optional)", placeholder="Enter a word you expect to say...", info="Will be converted to phonemes for comparison" ) phonemes_display = gr.Textbox( label="Target Phonemes", interactive=False, placeholder="Phonemes will appear here...", info="Automatic phoneme conversion of your target word" ) audio_input = gr.Audio( label="Record Audio", sources=["microphone", "upload"], type="numpy" ) process_btn = gr.Button("Process Audio", variant="primary") with gr.Column(): output_text = gr.Markdown( value="Results will appear here after processing..." ) alignment_output = gr.Markdown( value="", visible=False, label="Alignment Analysis" ) alignment_plot = gr.Plot( label="Alignment Matrix", visible=False ) # Update phonemes when target word or language changes def update_phonemes(text, language): if text and text.strip(): return phonemize_text(text, language) return "" # Toggle alignment output and weight controls visibility based on advanced mode def toggle_advanced_features(advanced): return ( gr.update(visible=advanced), # alignment_output gr.update(visible=advanced), # weight_controls gr.update(visible=advanced) # alignment_plot ) target_word_input.change( fn=update_phonemes, inputs=[target_word_input, language_radio], outputs=phonemes_display ) language_radio.change( fn=update_phonemes, inputs=[target_word_input, language_radio], outputs=phonemes_display ) advanced_mode.change( fn=toggle_advanced_features, inputs=advanced_mode, outputs=[alignment_output, weight_controls, alignment_plot] ) # Main processing function def process_with_mode(audio_data, target_word, language, advanced, ins_cost, del_cost, thresh, temp, score_method): result, phonemes, alignment, plot_fig = process_audio_advanced( audio_data, target_word, language, advanced, ins_cost, del_cost, thresh, temp, score_method ) return result, phonemes, alignment, plot_fig process_btn.click( fn=process_with_mode, inputs=[audio_input, target_word_input, language_radio, advanced_mode, insertion_cost, deletion_cost, threshold, temperature, scoring_method], outputs=[output_text, phonemes_display, alignment_output, alignment_plot] ) # Auto-process when audio is recorded audio_input.change( fn=process_with_mode, inputs=[audio_input, target_word_input, language_radio, advanced_mode, insertion_cost, deletion_cost, threshold, temperature, scoring_method], outputs=[output_text, phonemes_display, alignment_output, alignment_plot] ) return demo if __name__ == "__main__": my_demo = create_interface() my_demo.launch()