Hugo Farajallah commited on
Commit
0e90d9f
·
1 Parent(s): a9d4833

refactor(code): apply DRY and SRP principles to the code.

Browse files
Files changed (3) hide show
  1. common.py +76 -0
  2. hf_space.py +13 -53
  3. main.py +10 -9
common.py CHANGED
@@ -1,6 +1,11 @@
 
 
 
1
  import transformers
2
  import wavlm_phoneme_fr_it
3
 
 
 
4
 
5
  def get_model():
6
  checkpoint = "hugofara/wavlm-base-plus-phonemizer-fr-it"
@@ -12,3 +17,74 @@ def get_model():
12
  checkpoint
13
  )
14
  return model, processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchaudio
4
  import transformers
5
  import wavlm_phoneme_fr_it
6
 
7
+ SAMPLING_RATE = 16_000
8
+
9
 
10
  def get_model():
11
  checkpoint = "hugofara/wavlm-base-plus-phonemizer-fr-it"
 
17
  checkpoint
18
  )
19
  return model, processor
20
+
21
+
22
+ def preprocess_audio(audio_data, target_sample_rate=SAMPLING_RATE):
23
+ """Convert audio to the correct format and sample rate"""
24
+ if audio_data is None:
25
+ return None
26
+
27
+ sample_rate, audio = audio_data
28
+
29
+ # Ensure audio is in the correct format (mono, float32)
30
+ if len(audio.shape) > 1:
31
+ audio = audio.mean(axis=1) # Convert to mono if stereo
32
+
33
+ # Resample if necessary using torchaudio
34
+ if sample_rate != target_sample_rate:
35
+ audio_tensor = torch.from_numpy(audio).float().unsqueeze(0)
36
+ resampled = torchaudio.transforms.Resample(sample_rate, target_sample_rate)(audio_tensor)
37
+ audio = resampled.squeeze(0).numpy()
38
+
39
+ # Normalize audio
40
+ audio = audio.astype(np.float32)
41
+ if np.max(np.abs(audio)) > 0:
42
+ audio = audio / np.max(np.abs(audio))
43
+
44
+ return audio
45
+
46
+
47
+ def prepare_model_inputs(audio, processor, sampling_rate=SAMPLING_RATE):
48
+ """Prepare inputs for the model"""
49
+ inputs = processor(
50
+ audio,
51
+ sampling_rate=sampling_rate,
52
+ return_tensors="pt",
53
+ padding=True
54
+ )
55
+
56
+ # Add language tensor (assuming French/Italian model)
57
+ inputs["language"] = torch.tensor([[0]])
58
+
59
+ return inputs
60
+
61
+
62
+ def run_inference(model, inputs):
63
+ """Run model inference and return predictions"""
64
+ with torch.no_grad():
65
+ outputs = model(**inputs)
66
+ logits = outputs.logits
67
+ predicted_ids = torch.argmax(logits, dim=-1)
68
+
69
+ return outputs, predicted_ids
70
+
71
+
72
+ def decode_transcription(processor, predicted_ids):
73
+ """Decode predicted IDs to text"""
74
+ return processor.batch_decode(predicted_ids)[0]
75
+
76
+
77
+ def compare_with_target(transcription, target_word):
78
+ """Compare transcription with target word and return formatted result"""
79
+ result = f"**Transcription:** {transcription}\n\n"
80
+
81
+ if target_word and target_word.strip():
82
+ target_clean = target_word.strip().lower()
83
+ transcription_clean = transcription.lower().replace("[pad]", "").strip()
84
+
85
+ if target_clean in transcription_clean:
86
+ result += f"✅ **Match found!** The target word '{target_word}' appears in the transcription."
87
+ else:
88
+ result += f"❌ **No exact match.** The target word '{target_word}' was not found in the transcription."
89
+
90
+ return result
hf_space.py CHANGED
@@ -1,71 +1,31 @@
1
  import gradio as gr
2
- import numpy as np
3
- import torch
4
- import torchaudio
5
 
6
  import common
7
 
8
  model, processor = common.get_model()
9
 
10
- SAMPLING_RATE = 16_000
11
-
12
 
13
  def process_audio(audio_data, target_word):
14
  """Process recorded audio and return ASR output with target word comparison"""
15
  if audio_data is None:
16
  return "Please record some audio first."
17
 
18
- # Extract audio data and sample rate
19
- sample_rate, audio = audio_data
20
-
21
- # Ensure audio is in the correct format (mono, float32)
22
- if len(audio.shape) > 1:
23
- audio = audio.mean(axis=1) # Convert to mono if stereo
24
-
25
- # Resample if necessary using torchaudio
26
- if sample_rate != SAMPLING_RATE:
27
- audio_tensor = torch.from_numpy(audio).float().unsqueeze(0)
28
- resampled = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)(audio_tensor)
29
- audio = resampled.squeeze(0).numpy()
30
-
31
- # Normalize audio
32
- audio = audio.astype(np.float32)
33
- if np.max(np.abs(audio)) > 0:
34
- audio = audio / np.max(np.abs(audio))
35
 
36
- # Process with the model
37
- inputs = processor(
38
- audio,
39
- sampling_rate=SAMPLING_RATE,
40
- return_tensors="pt",
41
- padding=True
42
- )
43
-
44
- # Add language tensor (assuming French/Italian model)
45
- inputs["language"] = torch.tensor([[0]])
46
 
47
  # Run inference
48
- with torch.no_grad():
49
- outputs = model(**inputs)
50
- logits = outputs.logits
51
- predicted_ids = torch.argmax(logits, dim=-1)
52
-
53
- # Decode the prediction
54
- transcription = processor.batch_decode(predicted_ids)[0]
55
-
56
- # Compare with target word if provided
57
- result = f"**Transcription:** {transcription}\n\n"
58
-
59
- if target_word and target_word.strip():
60
- target_clean = target_word.strip().lower()
61
- transcription_clean = transcription.lower().replace("[pad]", "").strip()
62
 
63
- if target_clean in transcription_clean:
64
- result += f"✅ **Match found!** The target word '{target_word}' appears in the transcription."
65
- else:
66
- result += f"❌ **No exact match.** The target word '{target_word}' was not found in the transcription."
67
 
68
- return result
 
69
 
70
 
71
  def create_interface():
@@ -114,5 +74,5 @@ def create_interface():
114
 
115
 
116
  if __name__ == "__main__":
117
- demo = create_interface()
118
- demo.launch()
 
1
  import gradio as gr
 
 
 
2
 
3
  import common
4
 
5
  model, processor = common.get_model()
6
 
 
 
7
 
8
  def process_audio(audio_data, target_word):
9
  """Process recorded audio and return ASR output with target word comparison"""
10
  if audio_data is None:
11
  return "Please record some audio first."
12
 
13
+ # Preprocess audio
14
+ audio = common.preprocess_audio(audio_data)
15
+ if audio is None:
16
+ return "Failed to process audio."
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Prepare model inputs
19
+ inputs = common.prepare_model_inputs(audio, processor)
 
 
 
 
 
 
 
 
20
 
21
  # Run inference
22
+ outputs, predicted_ids = common.run_inference(model, inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Decode transcription
25
+ transcription = common.decode_transcription(processor, predicted_ids)
 
 
26
 
27
+ # Compare with target word
28
+ return common.compare_with_target(transcription, target_word)
29
 
30
 
31
  def create_interface():
 
74
 
75
 
76
  if __name__ == "__main__":
77
+ my_demo = create_interface()
78
+ my_demo.launch()
main.py CHANGED
@@ -7,16 +7,17 @@ import sounddevice as sd
7
  import torch
8
  from torchcodec.decoders import AudioDecoder
9
  import wavlm_phoneme_fr_it
 
10
 
11
  import common
12
 
13
- SAMPLING_RATE = 16_000
14
- VOCAB_SIZE = 97
15
-
16
 
17
  def fake_model(chunk):
18
  output_length = int(chunk.shape[0] * 0.02)
19
- return np.random.rand(output_length, VOCAB_SIZE)
 
 
 
20
 
21
 
22
  def update_frame(frames, ax, matrix_plot, tokenizer=None):
@@ -48,8 +49,8 @@ def main(record_mic=False):
48
  if record_mic:
49
  print("Recording the microphone...")
50
  waveform = sd.rec(
51
- int(audio_duration * SAMPLING_RATE),
52
- samplerate=SAMPLING_RATE,
53
  channels=1
54
  ).T
55
  sd.wait() # Wait until recording is finished
@@ -58,7 +59,7 @@ def main(record_mic=False):
58
  audio_file = "ceci est un test.wav"
59
  decoded = AudioDecoder(audio_file).get_all_samples()
60
  waveform = decoded.data.numpy()
61
- assert decoded.sample_rate == SAMPLING_RATE, f"Bad audio frequency {decoded.sample_rate}"
62
 
63
  # Split audio
64
  chunks = []
@@ -71,7 +72,7 @@ def main(record_mic=False):
71
  inputs = processor(
72
  chunks,
73
  return_attention_mask=True,
74
- sampling_rate=SAMPLING_RATE,
75
  padding=True
76
  )
77
  inputs.update({
@@ -101,7 +102,7 @@ def main(record_mic=False):
101
  ax.set_title("Animation Preview")
102
  matrix_plot = ax.matshow(logit_groups[0][0], animated=True, vmin=0, vmax=1)
103
  logits_list = []
104
- masks = inputs["attention_mask"].sum(dim=1) / SAMPLING_RATE
105
  for i, chunk in enumerate(chunks):
106
  # logits = fake_model(chunk) # for testing purposes only
107
  logits_list.append(logits)
 
7
  import torch
8
  from torchcodec.decoders import AudioDecoder
9
  import wavlm_phoneme_fr_it
10
+ import json
11
 
12
  import common
13
 
 
 
 
14
 
15
  def fake_model(chunk):
16
  output_length = int(chunk.shape[0] * 0.02)
17
+ with open("vocab.json", "r") as vocab_file:
18
+ vocab = json.loads(vocab_file.read())
19
+ vocab_size = len(vocab) + 3
20
+ return np.random.rand(output_length, vocab_size)
21
 
22
 
23
  def update_frame(frames, ax, matrix_plot, tokenizer=None):
 
49
  if record_mic:
50
  print("Recording the microphone...")
51
  waveform = sd.rec(
52
+ int(audio_duration * common.SAMPLING_RATE),
53
+ samplerate=common.SAMPLING_RATE,
54
  channels=1
55
  ).T
56
  sd.wait() # Wait until recording is finished
 
59
  audio_file = "ceci est un test.wav"
60
  decoded = AudioDecoder(audio_file).get_all_samples()
61
  waveform = decoded.data.numpy()
62
+ assert decoded.sample_rate == common.SAMPLING_RATE, f"Bad audio frequency {decoded.sample_rate}"
63
 
64
  # Split audio
65
  chunks = []
 
72
  inputs = processor(
73
  chunks,
74
  return_attention_mask=True,
75
+ sampling_rate=common.SAMPLING_RATE,
76
  padding=True
77
  )
78
  inputs.update({
 
102
  ax.set_title("Animation Preview")
103
  matrix_plot = ax.matshow(logit_groups[0][0], animated=True, vmin=0, vmax=1)
104
  logits_list = []
105
+ masks = inputs["attention_mask"].sum(dim=1) / common.SAMPLING_RATE
106
  for i, chunk in enumerate(chunks):
107
  # logits = fake_model(chunk) # for testing purposes only
108
  logits_list.append(logits)