Hugo Farajallah commited on
Commit
b6bd379
·
1 Parent(s): 9bc684b

feat(HF): display alignment matrix as well as two scoring systems.

Browse files
Files changed (3) hide show
  1. dataset_process.py +29 -12
  2. hf_space.py +65 -25
  3. vocab.json +96 -95
dataset_process.py CHANGED
@@ -121,16 +121,24 @@ def solve_path(prediction, target, path_matrix):
121
  return matching
122
 
123
 
124
- def display_matrix_result(path_matrix, matching, prediction, target):
125
- """Display all the information resulting from a Bellman matching of matrices."""
126
- fig, axis = plt.subplots()
127
- _model, processor = common.get_model()
 
 
 
 
 
128
 
129
  # Display the matrix
130
- axis.matshow(path_matrix.T, aspect="auto")
 
131
 
132
  # Set the labels for the axes
133
- axis.set_xlabel('Predicted String')
 
 
134
  # String for the x-axis
135
  predicted_labels = tuple(map(processor.decode, torch.argmax(prediction, -1)[0]))
136
  axis.set_xticks(
@@ -143,7 +151,7 @@ def display_matrix_result(path_matrix, matching, prediction, target):
143
  minor=True
144
  )
145
 
146
- axis.set_ylabel('Target String')
147
  target_labels = tuple(map(processor.decode, torch.argmax(target, -1)[0]))
148
  axis.set_yticks(
149
  [i for i, label in enumerate(target_labels) if label == ""],
@@ -154,16 +162,25 @@ def display_matrix_result(path_matrix, matching, prediction, target):
154
  labels=[label for label in target_labels if label != ""],
155
  minor=True
156
  )
157
- # axis.yaxis.grid(which="major", color='k', linestyle='--')
158
 
159
- axis.grid(which="major", color="black")
160
- axis.grid(which="minor", linestyle="--")
 
 
161
  axis.plot(
162
  [val[0] for val in matching],
163
  [val[1] for val in matching],
164
- color="red"
 
 
 
 
165
  )
166
- plt.show()
 
 
 
 
167
 
168
 
169
  def bellman_matching(prediction, target, insertion_cost=1.3, deletion_cost=3, metric=l2_logit_norm):
 
121
  return matching
122
 
123
 
124
+ def display_matrix_result(path_matrix, matching, prediction, target, processor=None):
125
+ """Display all the information resulting from a Bellman matching of matrices.
126
+
127
+ Returns the figure instead of showing it directly for use in Gradio.
128
+ """
129
+ fig, axis = plt.subplots(figsize=(10, 6))
130
+
131
+ if processor is None:
132
+ _model, processor = common.get_model()
133
 
134
  # Display the matrix
135
+ im = axis.matshow(path_matrix.T, aspect="auto", cmap='Blues')
136
+ plt.colorbar(im, ax=axis)
137
 
138
  # Set the labels for the axes
139
+ axis.set_xlabel('Predicted String', fontsize=12)
140
+ axis.set_title('Alignment Matrix: Predicted vs Target Phonemes', fontsize=14, pad=20)
141
+
142
  # String for the x-axis
143
  predicted_labels = tuple(map(processor.decode, torch.argmax(prediction, -1)[0]))
144
  axis.set_xticks(
 
151
  minor=True
152
  )
153
 
154
+ axis.set_ylabel('Target String', fontsize=12)
155
  target_labels = tuple(map(processor.decode, torch.argmax(target, -1)[0]))
156
  axis.set_yticks(
157
  [i for i, label in enumerate(target_labels) if label == ""],
 
162
  labels=[label for label in target_labels if label != ""],
163
  minor=True
164
  )
 
165
 
166
+ axis.grid(which="major", color="black", alpha=0.3)
167
+ axis.grid(which="minor", linestyle="--", alpha=0.2)
168
+
169
+ # Plot the optimal path in red
170
  axis.plot(
171
  [val[0] for val in matching],
172
  [val[1] for val in matching],
173
+ color="red",
174
+ linewidth=2,
175
+ marker='o',
176
+ markersize=3,
177
+ label="Optimal Alignment Path"
178
  )
179
+
180
+ axis.legend()
181
+ plt.tight_layout()
182
+
183
+ return fig
184
 
185
 
186
  def bellman_matching(prediction, target, insertion_cost=1.3, deletion_cost=3, metric=l2_logit_norm):
hf_space.py CHANGED
@@ -21,10 +21,10 @@ def phonemize_text(text, language):
21
  return " ".join([word.replace(" ", "") for word in phonemes]) if phonemes and phonemes[0] else ""
22
 
23
 
24
- def process_audio_advanced(audio_data, target_word, language, advanced_mode, insertion_cost, deletion_cost, threshold, temperature):
25
  """Process recorded audio with advanced alignment if enabled"""
26
  if audio_data is None:
27
- return "Please record some audio first.", "", ""
28
 
29
  # Convert target word to phonemes if provided
30
  phonemized_target = ""
@@ -34,7 +34,7 @@ def process_audio_advanced(audio_data, target_word, language, advanced_mode, ins
34
  # Preprocess audio
35
  audio = common.preprocess_audio(audio_data)
36
  if audio is None:
37
- return "Failed to process audio.", "", ""
38
 
39
  # Prepare model inputs with correct language
40
  lang_enum = common.Languages.FR if language == "French" else common.Languages.IT
@@ -51,6 +51,7 @@ def process_audio_advanced(audio_data, target_word, language, advanced_mode, ins
51
  result += f"**Transcription:** {transcription}\n\n"
52
 
53
  alignment_result = ""
 
54
 
55
  if target_word and target_word.strip():
56
  result += f"**Target Word:** {target_word}\n"
@@ -61,7 +62,7 @@ def process_audio_advanced(audio_data, target_word, language, advanced_mode, ins
61
  try:
62
  # Encode target phonemes
63
  target_encoded = dataset_process.encode_phonemes(
64
- phonemized_target.split(), processor.tokenizer
65
  )
66
 
67
  # Get model logits (raw outputs before softmax)
@@ -76,26 +77,52 @@ def process_audio_advanced(audio_data, target_word, language, advanced_mode, ins
76
  metric=dataset_process.l2_logit_norm
77
  )
78
 
79
- # Calculate alignment score using user-defined weights
80
  weights = [insertion_cost, deletion_cost, threshold, temperature]
 
81
  score = dataset_process.get_alignment_score(
82
  prediction_logits,
83
  target_encoded,
84
- weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  )
86
 
87
  alignment_result = f"**🔬 Advanced Alignment Analysis:**\n\n"
 
88
  alignment_result += f"**Settings:** Insertion={insertion_cost}, Deletion={deletion_cost}, Threshold={threshold}, Temperature={temperature}\n\n"
89
  alignment_result += f"**Alignment Score:** {alignment_score:.3f}\n"
90
  alignment_result += f"**Matching Points:** {len(matching)}\n"
91
- alignment_result += f"**Classification Score:** {score}/2\n\n"
92
 
93
- if score == 2:
94
- alignment_result += "**Perfect Match!** Target phonemes align perfectly with transcription."
95
- elif score == 1:
96
- alignment_result += "⚠️ **Close Match!** Target phonemes align with 1 minor error."
97
- else:
98
- alignment_result += "❌ **Poor Match.** Target phonemes don't align well with transcription."
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  except Exception as e:
101
  alignment_result = f"**⚠️ Alignment Error:** {str(e)}"
@@ -109,13 +136,13 @@ def process_audio_advanced(audio_data, target_word, language, advanced_mode, ins
109
  else:
110
  result += f"❌ **No phoneme match.** The phonemized target was not found in the transcription."
111
 
112
- return result, phonemized_target, alignment_result
113
 
114
 
115
  # Keep the simple function for backward compatibility
116
  def process_audio(audio_data, target_word, language):
117
  """Simple audio processing without advanced features"""
118
- result, phonemes, _ = process_audio_advanced(audio_data, target_word, language, False, 1.3, 3.0, 0.7, 1.0)
119
  return result, phonemes
120
 
121
 
@@ -181,6 +208,13 @@ def create_interface():
181
  info="Softmax temperature for prediction confidence (1.0 = normal)"
182
  )
183
 
 
 
 
 
 
 
 
184
  target_word_input = gr.Textbox(
185
  label="Target Word (optional)",
186
  placeholder="Enter a word you expect to say...",
@@ -213,6 +247,11 @@ def create_interface():
213
  label="Alignment Analysis"
214
  )
215
 
 
 
 
 
 
216
  # Update phonemes when target word or language changes
217
  def update_phonemes(text, language):
218
  if text and text.strip():
@@ -223,7 +262,8 @@ def create_interface():
223
  def toggle_advanced_features(advanced):
224
  return (
225
  gr.update(visible=advanced), # alignment_output
226
- gr.update(visible=advanced) # weight_controls
 
227
  )
228
 
229
  target_word_input.change(
@@ -241,29 +281,29 @@ def create_interface():
241
  advanced_mode.change(
242
  fn=toggle_advanced_features,
243
  inputs=advanced_mode,
244
- outputs=[alignment_output, weight_controls]
245
  )
246
 
247
  # Main processing function
248
- def process_with_mode(audio_data, target_word, language, advanced, ins_cost, del_cost, thresh, temp):
249
- result, phonemes, alignment = process_audio_advanced(
250
- audio_data, target_word, language, advanced, ins_cost, del_cost, thresh, temp
251
  )
252
- return result, phonemes, alignment
253
 
254
  process_btn.click(
255
  fn=process_with_mode,
256
  inputs=[audio_input, target_word_input, language_radio, advanced_mode,
257
- insertion_cost, deletion_cost, threshold, temperature],
258
- outputs=[output_text, phonemes_display, alignment_output]
259
  )
260
 
261
  # Auto-process when audio is recorded
262
  audio_input.change(
263
  fn=process_with_mode,
264
  inputs=[audio_input, target_word_input, language_radio, advanced_mode,
265
- insertion_cost, deletion_cost, threshold, temperature],
266
- outputs=[output_text, phonemes_display, alignment_output]
267
  )
268
 
269
  return demo
 
21
  return " ".join([word.replace(" ", "") for word in phonemes]) if phonemes and phonemes[0] else ""
22
 
23
 
24
+ def process_audio_advanced(audio_data, target_word, language, advanced_mode, insertion_cost, deletion_cost, threshold, temperature, scoring_method):
25
  """Process recorded audio with advanced alignment if enabled"""
26
  if audio_data is None:
27
+ return "Please record some audio first.", "", "", None
28
 
29
  # Convert target word to phonemes if provided
30
  phonemized_target = ""
 
34
  # Preprocess audio
35
  audio = common.preprocess_audio(audio_data)
36
  if audio is None:
37
+ return "Failed to process audio.", "", "", None
38
 
39
  # Prepare model inputs with correct language
40
  lang_enum = common.Languages.FR if language == "French" else common.Languages.IT
 
51
  result += f"**Transcription:** {transcription}\n\n"
52
 
53
  alignment_result = ""
54
+ alignment_plot_fig = None
55
 
56
  if target_word and target_word.strip():
57
  result += f"**Target Word:** {target_word}\n"
 
62
  try:
63
  # Encode target phonemes
64
  target_encoded = dataset_process.encode_phonemes(
65
+ phonemized_target, processor.tokenizer
66
  )
67
 
68
  # Get model logits (raw outputs before softmax)
 
77
  metric=dataset_process.l2_logit_norm
78
  )
79
 
80
+ # Calculate alignment score using user-defined weights and scoring method
81
  weights = [insertion_cost, deletion_cost, threshold, temperature]
82
+ scoring_enum = common.Scoring.NUMBER_CORRECT if scoring_method == "NUMBER_CORRECT" else common.Scoring.PHONEME_DELETION
83
  score = dataset_process.get_alignment_score(
84
  prediction_logits,
85
  target_encoded,
86
+ weights,
87
+ 94,
88
+ scoring=scoring_enum
89
+ )
90
+
91
+ # Generate alignment plot
92
+ path_matrix = dataset_process.compute_path_matrix(
93
+ prediction_logits,
94
+ target_encoded,
95
+ dataset_process.l2_logit_norm,
96
+ insertion_cost,
97
+ deletion_cost
98
+ )
99
+ alignment_plot_fig = dataset_process.display_matrix_result(
100
+ path_matrix, matching, prediction_logits, target_encoded, processor
101
  )
102
 
103
  alignment_result = f"**🔬 Advanced Alignment Analysis:**\n\n"
104
+ alignment_result += f"**Scoring Method:** {scoring_method}\n"
105
  alignment_result += f"**Settings:** Insertion={insertion_cost}, Deletion={deletion_cost}, Threshold={threshold}, Temperature={temperature}\n\n"
106
  alignment_result += f"**Alignment Score:** {alignment_score:.3f}\n"
107
  alignment_result += f"**Matching Points:** {len(matching)}\n"
 
108
 
109
+ if scoring_method == "NUMBER_CORRECT":
110
+ alignment_result += f"**Correct Phonemes:** {score}/{target_encoded.shape[1]}\n\n"
111
+ accuracy = score / target_encoded.shape[1] if target_encoded.shape[1] > 0 else 0
112
+ if accuracy >= 0.9:
113
+ alignment_result += "✅ **Excellent Match!** Most target phonemes are correctly aligned."
114
+ elif accuracy >= 0.7:
115
+ alignment_result += "⚠️ **Good Match!** Most target phonemes align well."
116
+ else:
117
+ alignment_result += "❌ **Poor Match.** Many target phonemes don't align correctly."
118
+ else: # PHONEME_DELETION
119
+ alignment_result += f"**Classification Score:** {score}/2\n\n"
120
+ if score == 2:
121
+ alignment_result += "✅ **Perfect Match!** Target phonemes align perfectly with transcription."
122
+ elif score == 1:
123
+ alignment_result += "⚠️ **Close Match!** Target phonemes align with 1 minor error."
124
+ else:
125
+ alignment_result += "❌ **Poor Match.** Target phonemes don't align well with transcription."
126
 
127
  except Exception as e:
128
  alignment_result = f"**⚠️ Alignment Error:** {str(e)}"
 
136
  else:
137
  result += f"❌ **No phoneme match.** The phonemized target was not found in the transcription."
138
 
139
+ return result, phonemized_target, alignment_result, alignment_plot_fig
140
 
141
 
142
  # Keep the simple function for backward compatibility
143
  def process_audio(audio_data, target_word, language):
144
  """Simple audio processing without advanced features"""
145
+ result, phonemes, _, _ = process_audio_advanced(audio_data, target_word, language, False, 1.3, 3.0, 0.7, 1.0, "NUMBER_CORRECT")
146
  return result, phonemes
147
 
148
 
 
208
  info="Softmax temperature for prediction confidence (1.0 = normal)"
209
  )
210
 
211
+ scoring_method = gr.Radio(
212
+ choices=["NUMBER_CORRECT", "PHONEME_DELETION"],
213
+ value="NUMBER_CORRECT",
214
+ label="Scoring Method",
215
+ info="Method for calculating alignment scores"
216
+ )
217
+
218
  target_word_input = gr.Textbox(
219
  label="Target Word (optional)",
220
  placeholder="Enter a word you expect to say...",
 
247
  label="Alignment Analysis"
248
  )
249
 
250
+ alignment_plot = gr.Plot(
251
+ label="Alignment Matrix",
252
+ visible=False
253
+ )
254
+
255
  # Update phonemes when target word or language changes
256
  def update_phonemes(text, language):
257
  if text and text.strip():
 
262
  def toggle_advanced_features(advanced):
263
  return (
264
  gr.update(visible=advanced), # alignment_output
265
+ gr.update(visible=advanced), # weight_controls
266
+ gr.update(visible=advanced) # alignment_plot
267
  )
268
 
269
  target_word_input.change(
 
281
  advanced_mode.change(
282
  fn=toggle_advanced_features,
283
  inputs=advanced_mode,
284
+ outputs=[alignment_output, weight_controls, alignment_plot]
285
  )
286
 
287
  # Main processing function
288
+ def process_with_mode(audio_data, target_word, language, advanced, ins_cost, del_cost, thresh, temp, score_method):
289
+ result, phonemes, alignment, plot_fig = process_audio_advanced(
290
+ audio_data, target_word, language, advanced, ins_cost, del_cost, thresh, temp, score_method
291
  )
292
+ return result, phonemes, alignment, plot_fig
293
 
294
  process_btn.click(
295
  fn=process_with_mode,
296
  inputs=[audio_input, target_word_input, language_radio, advanced_mode,
297
+ insertion_cost, deletion_cost, threshold, temperature, scoring_method],
298
+ outputs=[output_text, phonemes_display, alignment_output, alignment_plot]
299
  )
300
 
301
  # Auto-process when audio is recorded
302
  audio_input.change(
303
  fn=process_with_mode,
304
  inputs=[audio_input, target_word_input, language_radio, advanced_mode,
305
+ insertion_cost, deletion_cost, threshold, temperature, scoring_method],
306
+ outputs=[output_text, phonemes_display, alignment_output, alignment_plot]
307
  )
308
 
309
  return demo
vocab.json CHANGED
@@ -1,96 +1,97 @@
1
  {
2
- "[PAD]": 0,
3
- "[UNK]": 1,
4
- "a": 2,
5
- "": 3,
6
- "b": 4,
7
- "c": 5,
8
- "d": 6,
9
- "d͡z": 7,
10
- "d͡ʒ": 8,
11
- "e": 9,
12
- "": 10,
13
- "f": 11,
14
- "g": 12,
15
- "h": 13,
16
- "i": 14,
17
- "j": 15,
18
- "k": 16,
19
- "l": 17,
20
- "m": 18,
21
- "": 19,
22
- "n": 20,
23
- "": 21,
24
- "o": 22,
25
- "": 23,
26
- "p": 24,
27
- "": 25,
28
- "r": 26,
29
- "s": 27,
30
- "": 28,
31
- "t": 29,
32
- "": 30,
33
- "t͡s": 31,
34
- "t͡ʃ": 32,
35
- "u": 33,
36
- "": 34,
37
- "v": 35,
38
- "": 36,
39
- "w": 37,
40
- "y": 38,
41
- "": 39,
42
- "z": 40,
43
- "": 41,
44
- "ø": 42,
45
- "øʼ": 43,
46
- "ŋ": 44,
47
- "ŋʼ": 45,
48
- "ɲ": 46,
49
- "œ": 47,
50
- "œ̃": 48,
51
- "ɑ̃": 49,
52
- "ɑ̃ʼ": 50,
53
- "ɔ": 51,
54
- "ɔʼ": 52,
55
- "ɔ̃": 53,
56
- "ɔ̃ʼ": 54,
57
- "ə": 55,
58
- "əʼ": 56,
59
- "ɛ": 57,
60
- "ɛʼ": 58,
61
- "ɛː": 59,
62
- "ɛ̃": 60,
63
- "ɛ̃ʼ": 61,
64
- "ɥ": 62,
65
- "ʁ": 63,
66
- "ʁʼ": 64,
67
- "ʃ": 65,
68
- "ʈ": 66,
69
- "ʒ": 67,
70
- "ʒʼ": 68,
71
- "ʼ": 69,
72
- "ʼa": 70,
73
- "ʼe": 71,
74
- "ʼi": 72,
75
- "ʼj": 73,
76
- "ʼo": 74,
77
- "ʼu": 75,
78
- "ʼy": 76,
79
- "ʼœ": 77,
80
- "ʼœ̃": 78,
81
- "ʼɑ̃": 79,
82
- "ʼɔ": 80,
83
- "ʼɔ̃": 81,
84
- "ʼə": 82,
85
- "ʼɛ": 83,
86
- "ʼɛ̃": 84,
87
- "ʼɥ": 85,
88
- "ˈe": 86,
89
- "ˈh": 87,
90
- "ˈk": 88,
91
- "ˈp": 89,
92
- "ˈs": 90,
93
- "ˈu": 91,
94
- "ˈæ": 92,
95
- "ˈð": 93
96
- }
 
 
1
  {
2
+ "|": 0,
3
+ "a": 1,
4
+ "ã": 2,
5
+ "b": 3,
6
+ "c": 4,
7
+ "d": 5,
8
+ "d͡z": 6,
9
+ "d͡ʒ": 7,
10
+ "e": 8,
11
+ "ẽ": 9,
12
+ "f": 10,
13
+ "g": 11,
14
+ "h": 12,
15
+ "i": 13,
16
+ "j": 14,
17
+ "k": 15,
18
+ "l": 16,
19
+ "m": 17,
20
+ "": 18,
21
+ "n": 19,
22
+ "": 20,
23
+ "o": 21,
24
+ "": 22,
25
+ "p": 23,
26
+ "": 24,
27
+ "r": 25,
28
+ "s": 26,
29
+ "": 27,
30
+ "t": 28,
31
+ "": 29,
32
+ "t͡s": 30,
33
+ "t͡ʃ": 31,
34
+ "u": 32,
35
+ "": 33,
36
+ "v": 34,
37
+ "": 35,
38
+ "w": 36,
39
+ "y": 37,
40
+ "": 38,
41
+ "z": 39,
42
+ "": 40,
43
+ "ø": 41,
44
+ "øʼ": 42,
45
+ "ŋ": 43,
46
+ "ŋʼ": 44,
47
+ "ɲ": 45,
48
+ "œ": 46,
49
+ "œ̃": 47,
50
+ "ɑ̃": 48,
51
+ "ɑ̃ʼ": 49,
52
+ "ɔ": 50,
53
+ "ɔʼ": 51,
54
+ "ɔ̃": 52,
55
+ "ɔ̃ʼ": 53,
56
+ "ə": 54,
57
+ "əʼ": 55,
58
+ "ɛ": 56,
59
+ "ɛʼ": 57,
60
+ "ɛː": 58,
61
+ "ɛ̃": 59,
62
+ "ɛ̃ʼ": 60,
63
+ "ɥ": 61,
64
+ "ʁ": 62,
65
+ "ʁʼ": 63,
66
+ "ʃ": 64,
67
+ "ʈ": 65,
68
+ "ʒ": 66,
69
+ "ʒʼ": 67,
70
+ "ʼ": 68,
71
+ "ʼa": 69,
72
+ "ʼe": 70,
73
+ "ʼi": 71,
74
+ "ʼj": 72,
75
+ "ʼo": 73,
76
+ "ʼu": 74,
77
+ "ʼy": 75,
78
+ "ʼœ": 76,
79
+ "ʼœ̃": 77,
80
+ "ʼɑ̃": 78,
81
+ "ʼɔ": 79,
82
+ "ʼɔ̃": 80,
83
+ "ʼə": 81,
84
+ "ʼɛ": 82,
85
+ "ʼɛ̃": 83,
86
+ "ʼɥ": 84,
87
+ "ˈe": 85,
88
+ "ˈh": 86,
89
+ "ˈk": 87,
90
+ "ˈp": 88,
91
+ "ˈs": 89,
92
+ "ˈu": 90,
93
+ "ˈæ": 91,
94
+ "ˈð": 92,
95
+ "[UNK]": 93,
96
+ "[PAD]": 94
97
+ }