File size: 13,780 Bytes
a9d4833 898ed95 b57fa93 a9d4833 898ed95 a9d4833 898ed95 b6bd379 031164f a9d4833 b6bd379 898ed95 a9d4833 0e90d9f b6bd379 a9d4833 898ed95 a9d4833 0e90d9f a9d4833 0e90d9f a9d4833 031164f 898ed95 031164f b6bd379 031164f 898ed95 031164f b57fa93 b6bd379 031164f 25551f0 b57fa93 031164f 25551f0 b57fa93 031164f 898ed95 b6bd379 25551f0 b6bd379 b57fa93 031164f b6bd379 54520eb b6bd379 54520eb b57fa93 54520eb b57fa93 54520eb b6bd379 b57fa93 b6bd379 54520eb b57fa93 54520eb b57fa93 54520eb b57fa93 54520eb 031164f b6bd379 25551f0 031164f b6bd379 031164f 898ed95 031164f 898ed95 b6bd379 031164f b6bd379 031164f a9d4833 898ed95 031164f 25551f0 d19ec84 25551f0 d19ec84 25551f0 d19ec84 25551f0 d19ec84 25551f0 b6bd379 a9d4833 898ed95 a9d4833 031164f b6bd379 898ed95 25551f0 b6bd379 25551f0 031164f 898ed95 031164f 25551f0 031164f b6bd379 031164f b6bd379 031164f b6bd379 031164f a9d4833 031164f 25551f0 b6bd379 a9d4833 031164f 25551f0 b6bd379 a9d4833 0e90d9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
import gradio as gr
import common
from grapheme_to_phoneme import Grapheme2Phoneme
import aligner
model, processor = common.get_model()
# Initialize phonemizers for both languages
phonemizer_fr = Grapheme2Phoneme(language="fr", cuda=False)
phonemizer_it = Grapheme2Phoneme(language="it", cuda=False)
def phonemize_text(text, language):
"""Convert text to phonemes using the appropriate phonemizer"""
if not text or not text.strip():
return ""
phonemizer = phonemizer_fr if language == "French" else phonemizer_it
phonemes = phonemizer.phonemize([text.strip()])
return " ".join([word.replace(" ", "") for word in phonemes]) if phonemes and phonemes[0] else ""
def process_audio_advanced(audio_data, target_word, language, advanced_mode, insertion_cost, deletion_cost, threshold, temperature, scoring_method):
"""Process recorded audio with advanced alignment if enabled"""
if audio_data is None:
return "Please record some audio first.", "", "", None
# Convert target word to phonemes if provided
phonemized_target = ""
if target_word and target_word.strip():
phonemized_target = phonemize_text(target_word, language)
# Preprocess audio
audio = common.preprocess_audio(audio_data)
if audio is None:
return "Failed to process audio.", "", "", None
# Prepare model inputs with correct language
lang_enum = common.Languages.FR if language == "French" else common.Languages.IT
inputs = common.prepare_model_inputs(audio, processor, language=lang_enum)
# Run inference
outputs, predicted_ids = common.run_inference(model, inputs)
# Decode transcription
transcription = common.decode_transcription(processor, predicted_ids)
# Create basic result
result = f"**Language:** {language}\n\n"
result += f"**Transcription:** {transcription}\n\n"
alignment_result = ""
alignment_plot_fig = None
if target_word and target_word.strip():
result += f"**Target Word:** {target_word}\n"
result += f"**Target Phonemes:** {phonemized_target}\n\n"
if advanced_mode and phonemized_target:
# Advanced mode: Use alignment
try:
# Encode target phonemes
target_encoded = aligner.encode_phonemes(
phonemized_target, processor.tokenizer
)
# Get model logits (raw outputs before softmax)
prediction_logits = outputs.logits
# Perform alignment using user-defined weights
matching, alignment_score = aligner.bellman_matching(
prediction_logits,
target_encoded,
insertion_cost=insertion_cost,
deletion_cost=deletion_cost,
metric=aligner.l2_logit_norm
)
# Calculate alignment score using user-defined weights and scoring method
weights = [insertion_cost, deletion_cost, threshold, temperature]
scoring_enum = common.Scoring.NUMBER_CORRECT if scoring_method == "NUMBER_CORRECT" else common.Scoring.PHONEME_DELETION
score = aligner.get_alignment_score(
prediction_logits,
target_encoded,
weights,
processor.tokenizer.pad_token_id,
scoring=scoring_enum
)
# Use reduced prediction tensor for alignment plot (remove temporal effects)
reduced_prediction = aligner.remove_pad_tokens(
prediction_logits, processor.tokenizer.pad_token_id, temperature
)
# Generate alignment plot with reduced prediction
path_matrix = aligner.compute_path_matrix(
reduced_prediction,
target_encoded,
aligner.l2_logit_norm,
insertion_cost,
deletion_cost
)
# Re-compute matching with reduced prediction for visualization
matching_for_plot, _ = aligner.bellman_matching(
reduced_prediction,
target_encoded,
insertion_cost=insertion_cost,
deletion_cost=deletion_cost,
metric=aligner.l2_logit_norm
)
alignment_plot_fig = aligner.display_matrix_result(
path_matrix, matching_for_plot, reduced_prediction, target_encoded, processor
)
alignment_result = f"**π¬ Advanced Alignment Analysis:**\n\n"
alignment_result += f"**Scoring Method:** {scoring_method}\n"
alignment_result += f"**Settings:** Insertion={insertion_cost}, Deletion={deletion_cost}, Threshold={threshold}, Temperature={temperature}\n\n"
alignment_result += f"**Alignment Score:** {alignment_score:.3f}\n"
alignment_result += f"**Matching Points:** {len(matching)}\n"
if scoring_method == "NUMBER_CORRECT":
alignment_result += f"**Correct Phonemes:** {score}/{target_encoded.shape[1]}\n\n"
accuracy = score / target_encoded.shape[1] if target_encoded.shape[1] > 0 else 0
if accuracy >= 0.9:
alignment_result += "β
**Excellent Match!** Most target phonemes are correctly aligned."
elif accuracy >= 0.7:
alignment_result += "β οΈ **Good Match!** Most target phonemes align well."
else:
alignment_result += "β **Poor Match.** Many target phonemes don't align correctly."
else: # PHONEME_DELETION
alignment_result += f"**Classification Score:** {score}/2\n\n"
if score == 2:
alignment_result += "β
**Perfect Match!** Target phonemes align perfectly with transcription."
elif score == 1:
alignment_result += "β οΈ **Close Match!** Target phonemes align with 1 minor error."
else:
alignment_result += "β **Poor Match.** Target phonemes don't align well with transcription."
except Exception as e:
alignment_result = f"**β οΈ Alignment Error:** {str(e)}"
else:
# Simple mode: String matching
transcription_clean = transcription.lower().replace("[pad]", "").strip()
phonemized_target_clean = phonemized_target.lower().strip()
if phonemized_target_clean in transcription_clean:
result += f"β
**Phoneme Match!** The phonemized target appears in the transcription."
else:
result += f"β **No phoneme match.** The phonemized target was not found in the transcription."
return result, phonemized_target, alignment_result, alignment_plot_fig
# Keep the simple function for backward compatibility
def process_audio(audio_data, target_word, language):
"""Simple audio processing without advanced features"""
result, phonemes, _, _ = process_audio_advanced(audio_data, target_word, language, False, 1.3, 3.0, 0.7, 1.0, "NUMBER_CORRECT")
return result, phonemes
def create_interface():
"""Create and return the Gradio interface"""
with gr.Blocks(title="WavLM ASR Demo") as demo:
gr.Markdown("# WavLM ASR Capabilities Demo")
gr.Markdown("Record audio and optionally specify a target word to test the ASR model's accuracy.")
with gr.Row():
with gr.Column():
language_radio = gr.Radio(
choices=["French", "Italian"],
value="French",
label="Model Language",
info="Select the language for ASR recognition"
)
advanced_mode = gr.Checkbox(
label="π¬ Advanced Mode",
value=False,
info="Use advanced alignment analysis for more accurate matching"
)
# Advanced mode weight sliders (initially hidden)
with gr.Group(visible=False) as weight_controls:
gr.Markdown("### βοΈ Alignment Parameters")
insertion_cost = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.1,
step=0.1,
label="Insertion Cost",
info="Penalty for extra phonemes in prediction (lower = more lenient)"
)
deletion_cost = gr.Slider(
minimum=0.5,
maximum=3.0,
value=0.7,
step=0.1,
label="Deletion Cost",
info="Penalty for missing phonemes in prediction (higher = stricter)"
)
threshold = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.05,
label="Match Threshold",
info="Minimum similarity for phoneme match (higher = stricter)"
)
temperature = gr.Slider(
minimum=0.5,
maximum=10,
value=1.0,
step=0.1,
label="Temperature",
info="Softmax temperature for prediction confidence (1.0 = normal)"
)
scoring_method = gr.Radio(
choices=["NUMBER_CORRECT", "PHONEME_DELETION"],
value="NUMBER_CORRECT",
label="Scoring Method",
info="Method for calculating alignment scores"
)
target_word_input = gr.Textbox(
label="Target Word (optional)",
placeholder="Enter a word you expect to say...",
info="Will be converted to phonemes for comparison"
)
phonemes_display = gr.Textbox(
label="Target Phonemes",
interactive=False,
placeholder="Phonemes will appear here...",
info="Automatic phoneme conversion of your target word"
)
audio_input = gr.Audio(
label="Record Audio",
sources=["microphone", "upload"],
type="numpy"
)
process_btn = gr.Button("Process Audio", variant="primary")
with gr.Column():
output_text = gr.Markdown(
value="Results will appear here after processing..."
)
alignment_output = gr.Markdown(
value="",
visible=False,
label="Alignment Analysis"
)
alignment_plot = gr.Plot(
label="Alignment Matrix",
visible=False
)
# Update phonemes when target word or language changes
def update_phonemes(text, language):
if text and text.strip():
return phonemize_text(text, language)
return ""
# Toggle alignment output and weight controls visibility based on advanced mode
def toggle_advanced_features(advanced):
return (
gr.update(visible=advanced), # alignment_output
gr.update(visible=advanced), # weight_controls
gr.update(visible=advanced) # alignment_plot
)
target_word_input.change(
fn=update_phonemes,
inputs=[target_word_input, language_radio],
outputs=phonemes_display
)
language_radio.change(
fn=update_phonemes,
inputs=[target_word_input, language_radio],
outputs=phonemes_display
)
advanced_mode.change(
fn=toggle_advanced_features,
inputs=advanced_mode,
outputs=[alignment_output, weight_controls, alignment_plot]
)
# Main processing function
def process_with_mode(audio_data, target_word, language, advanced, ins_cost, del_cost, thresh, temp, score_method):
result, phonemes, alignment, plot_fig = process_audio_advanced(
audio_data, target_word, language, advanced, ins_cost, del_cost, thresh, temp, score_method
)
return result, phonemes, alignment, plot_fig
process_btn.click(
fn=process_with_mode,
inputs=[audio_input, target_word_input, language_radio, advanced_mode,
insertion_cost, deletion_cost, threshold, temperature, scoring_method],
outputs=[output_text, phonemes_display, alignment_output, alignment_plot]
)
# Auto-process when audio is recorded
audio_input.change(
fn=process_with_mode,
inputs=[audio_input, target_word_input, language_radio, advanced_mode,
insertion_cost, deletion_cost, threshold, temperature, scoring_method],
outputs=[output_text, phonemes_display, alignment_output, alignment_plot]
)
return demo
if __name__ == "__main__":
my_demo = create_interface()
my_demo.launch()
|