import functools import os import matplotlib.animation import matplotlib.pyplot as plt import numpy as np import sounddevice as sd import torch from torchcodec.decoders import AudioDecoder import wavlm_phoneme_fr_it import json import common def fake_model(chunk): output_length = int(chunk.shape[0] * 0.02) with open("vocab.json", "r") as vocab_file: vocab = json.loads(vocab_file.read()) vocab_size = len(vocab) + 3 return np.random.rand(output_length, vocab_size) def update_frame(frames, ax, matrix_plot, tokenizer=None, colorbar=None): ax.clear() ax.set_title( "Activation levels for WavLM Base +'s hidden layers\n" f"Layer = {frames[0] + 1}, T = {frames[1]}s" ) ax.set_xlabel("Phoneme Vocabulary") ax.set_ylabel("Time Steps, and Selected Phoneme") data = frames[2].detach().clone() matrix_plot = ax.matshow(data, vmin=0, vmax=1, cmap='Blues') if tokenizer is not None: label_ids = torch.argmax(data, -1) labels = tokenizer.batch_decode(label_ids) ax.set_xticks([i for v, i in tokenizer.vocab.items() if v in labels]) ax.set_xticklabels([v for v, i in tokenizer.vocab.items() if v in labels], rotation=45, ha='right') ax.set_yticks([i for i, v in enumerate(labels) if v not in ("", "[PAD]")]) ax.set_yticklabels([v for i, v in enumerate(labels) if v not in ("", "[PAD]")]) # Position the decoded text below the plot with proper spacing decoded_text = tokenizer.decode(label_ids) if len(decoded_text) > 50: decoded_text = decoded_text[:50] + "..." ax.text( 0.5, -0.15, f"Decoded: {decoded_text}", transform=ax.transAxes, ha='center', va='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8) ) plt.tight_layout() return ax, matrix_plot def main(record_mic=False): """ Record an inference run of the model. :param bool record_mic: True to record from the microphone, False to use dummy file. :return str: Path of the output file. """ audio_duration = 5 split_length = 0.1 if record_mic: print("Recording the microphone...") waveform = sd.rec( int(audio_duration * common.SAMPLING_RATE), samplerate=common.SAMPLING_RATE, channels=1 ).T sd.wait() # Wait until recording is finished print("Recording finished.") else: audio_file = "ceci est un test.wav" decoded = AudioDecoder(audio_file).get_all_samples() waveform = decoded.data.numpy() assert decoded.sample_rate == common.SAMPLING_RATE, f"Bad audio frequency {decoded.sample_rate}" # Split audio chunks = [] for i in np.linspace(0, waveform.shape[1], int(audio_duration / split_length), dtype=np.uint64): if i == 0: continue chunks.append(waveform[0, :i]) model, processor = common.get_model() inputs = processor( chunks, return_attention_mask=True, sampling_rate=common.SAMPLING_RATE, padding=True ) inputs.update({ "input_values": torch.tensor(np.array(inputs["input_values"])), "attention_mask": torch.tensor(np.array(inputs["attention_mask"])), "language": torch.tensor([[0] for _ in enumerate(chunks)]) }) # Inference time hidden_outputs = [] with torch.no_grad(): output = model(**inputs, output_hidden_states=True) for hidden in output.hidden_states: hidden_balanced = wavlm_phoneme_fr_it.add_language_to_hidden(hidden, inputs["language"]) hidden_outputs.append(torch.softmax(model.lm_head(hidden_balanced), dim=-1)) logits = torch.softmax(output.logits, dim=-1) logit_groups = [ [ torch.zeros((logits.shape[1], logits.shape[2])) for __ in enumerate(chunks) ] for _ in range(len(hidden_outputs) + 1) ] fig, ax = plt.subplots(animated=True) ax.set_title("Animation Preview") matrix_plot = ax.matshow(logit_groups[0][0], animated=True, vmin=0, vmax=1, cmap='Blues') # Add colorbar once for the entire animation colorbar = plt.colorbar(matrix_plot, ax=ax, label='Activation Level') logits_list = [] masks = inputs["attention_mask"].sum(dim=1) / common.SAMPLING_RATE for i, chunk in enumerate(chunks): # logits = fake_model(chunk) # for testing purposes only logits_list.append(logits) time_indices = int(logits.shape[1] * masks[i]) for j, layer in enumerate(hidden_outputs + [logits]): logit_groups[j][i][:time_indices] = layer[i, :time_indices] # Flatten frames flattened = [] for layer_index, layer_logits_list in enumerate(logit_groups): for time_stamp_index, logits in enumerate(layer_logits_list): flattened.append( (layer_index, time_stamp_index * audio_duration / int(audio_duration / split_length), logits) ) # Animate global animation animation = matplotlib.animation.FuncAnimation( fig, functools.partial( update_frame, ax=ax, matrix_plot=matrix_plot, tokenizer=processor.tokenizer ), flattened, interval=100, repeat=False, # blit=True ) plt.show() # Save to file dir_path = "outputs" if not os.path.exists(dir_path) or not os.path.isdir(dir_path): os.makedirs(dir_path) if os.path.exists(f"{dir_path}/animated.webm"): i = 1 while os.path.exists(f"{dir_path}/animated_({i}).webm"): i += 1 file_name = f"{dir_path}/animated_({i}).webm" else: file_name = f"{dir_path}/animated.webm" animation.save(file_name) return file_name if __name__ == "__main__": animation = None main(record_mic=False)