Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForCTC, Wav2Vec2PhonemeCTCTokenizer | |
| import librosa | |
| from itertools import groupby | |
| from datasets import load_dataset | |
| # Load the model and processor | |
| # checkpoint = "bookbot/wav2vec2-ljspeech-gruut" | |
| checkpoint = "facebook/wav2vec2-lv-60-espeak-cv-ft" | |
| model = AutoModelForCTC.from_pretrained(checkpoint) | |
| processor = AutoProcessor.from_pretrained(checkpoint) | |
| tokenizer = Wav2Vec2PhonemeCTCTokenizer.from_pretrained(checkpoint) | |
| sr = processor.feature_extractor.sampling_rate | |
| def decode_phonemes( | |
| ids: torch.Tensor, processor: AutoProcessor, ignore_stress: bool = False | |
| ) -> str: | |
| """CTC-like decoding. First removes consecutive duplicates, then removes special tokens.""" | |
| # Remove consecutive duplicates | |
| ids = [id_ for id_, _ in groupby(ids)] | |
| special_token_ids = processor.tokenizer.all_special_ids + [ | |
| processor.tokenizer.word_delimiter_token_id | |
| ] | |
| # Convert id to token, skipping special tokens | |
| phonemes = [processor.decode(id_) for id_ in ids if id_ not in special_token_ids] | |
| # Join phonemes | |
| prediction = " ".join(phonemes) | |
| # Ignore IPA stress marks if specified | |
| if ignore_stress: | |
| prediction = prediction.replace("ˈ", "").replace("ˌ", "") | |
| return prediction | |
| def text_to_phonemes(text: str) -> str: | |
| s_time = time.time() | |
| """Convert text to phonemes using phonemizer.""" | |
| # phonemes = phonemize(text, language="en-us", backend="espeak", strip=True) | |
| phonemes = tokenizer.phonemize(text, phonemizer_lang="en-us") | |
| e_time = time.time() | |
| print(f"Execution time of text_to_phonemes: {e_time - s_time:.6f} seconds") | |
| return phonemes | |
| def separate_characters(input_string): | |
| no_spaces = input_string.replace(" ", "") | |
| spaced_string = " ".join(no_spaces) | |
| return spaced_string | |
| def predict_phonemes(audio_array): | |
| # Load audio file and preprocess | |
| # audio_array, _ = librosa.load(audio_path, sr=sr) | |
| inputs = processor(audio_array, return_tensors="pt", padding=True) | |
| # Perform inference | |
| with torch.no_grad(): | |
| logits = model(inputs["input_values"]).logits | |
| # Decode the predicted phonemes | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| predicted_phonemes = decode_phonemes( | |
| predicted_ids[0], processor, ignore_stress=True | |
| ) | |
| return predicted_phonemes # Return the predicted phonemes | |
| def adjust_phonemes(predicted: str) -> str: | |
| # Replace specific phonemes or patterns as needed | |
| # adjusted = predicted.replace(" ə ", " ") # Remove schwa if it appears alone | |
| adjusted = predicted.replace(" ", " ") # Remove double spaces | |
| adjusted = adjusted.strip() # Trim leading/trailing spaces | |
| return adjusted | |
| def calculate_score(expected: str, predicted: str) -> float: | |
| expected_list = expected.split() | |
| predicted_list = predicted.split() | |
| # Calculate the number of correct matches | |
| correct_matches = sum(1 for e, p in zip(expected_list, predicted_list) if e == p) | |
| # Calculate the score as the ratio of correct matches to expected phonemes | |
| score = correct_matches / len(expected_list) if expected_list else 0 | |
| return score | |
| def test_sound(): | |
| start_time = time.time() | |
| ds = load_dataset( | |
| "patrickvonplaten/librispeech_asr_dummy", | |
| "clean", | |
| split="validation", | |
| trust_remote_code=True, | |
| ) | |
| audio_array = ds[0]["audio"]["array"] | |
| text = ds[0]["text"] | |
| # audio_path = "hello.wav" | |
| # text = "Hello" | |
| expected_transcript = text # Expected transcript | |
| expected_phonemes = text_to_phonemes(text) # Expected phonemes for "Hello" | |
| expected_phonemes = separate_characters(expected_phonemes) | |
| # Call the phoneme prediction function | |
| predicted_phonemes = predict_phonemes(audio_array) | |
| adjusted_phonemes = adjust_phonemes(predicted_phonemes) | |
| print(f"Expected Phonemes: {expected_phonemes}") | |
| print(f"Predicted Phonemes: {predicted_phonemes}") | |
| print(f"Adjusted Phonemes: {adjusted_phonemes}") | |
| # Calculate score based on expected and predicted phonemes | |
| score = calculate_score(expected_phonemes, adjusted_phonemes) | |
| # Prepare the output | |
| text = f"Transcript: {expected_transcript}\nExpected Phonemes: {expected_phonemes}\nPredicted Phonemes: {predicted_phonemes}\nAdjusted Phonemes: {adjusted_phonemes}\nScore: {score:.2f}" | |
| end_time = time.time() | |
| execution_time = end_time - start_time | |
| print(f"Execution time: {execution_time:.6f} seconds") | |
| return {"text": text} | |