| import torch |
| import numpy as np |
| import torchaudio |
| import sentencepiece |
| import logging |
| from pathlib import Path |
| from moshi.models import loaders, LMGen |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class InferenceRecipe: |
| """Handles model inference for the Any-to-Any model.""" |
| |
| def __init__(self, model_path: str, device: str='cuda'): |
| """Initialize the model. |
| |
| Args: |
| model_path (str): Path to model directory with pre-downloaded files |
| device (str): Device to run on ('cuda' or 'cpu') |
| """ |
| self.device = torch.device(device) |
| self.model_path = Path(model_path) |
|
|
| |
| self.sample_rate = 24000 |
| self.frame_rate = 12.5 |
| |
| |
| logger.info(f"Initializing models from {model_path}") |
| self.mimi, self.text_tokenizer, self.lm_gen = self._initialize_models() |
| self.mimi = self.mimi.to(self.device) |
| self.lm_gen = self.lm_gen.to(self.device) |
| logger.info("Model initialization complete") |
| |
| def _initialize_models(self): |
| """Initialize all required model components.""" |
| print("Initializing models...") |
| |
| try: |
| |
| mimi_path = self.model_path / loaders.MIMI_NAME |
| if not mimi_path.exists(): |
| raise RuntimeError(f"MIMI model not found at {mimi_path}") |
| logger.info(f"Loading MIMI model from {mimi_path}") |
| mimi = loaders.get_mimi(str(mimi_path), device=self.device) |
| mimi.set_num_codebooks(8) |
| |
| |
| tokenizer_path = self.model_path / loaders.TEXT_TOKENIZER_NAME |
| if not tokenizer_path.exists(): |
| raise RuntimeError(f"Text tokenizer not found at {tokenizer_path}") |
| logger.info(f"Loading text tokenizer from {tokenizer_path}") |
| text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer_path)) |
| |
| |
| moshi_path = self.model_path / loaders.MOSHI_NAME |
| if not moshi_path.exists(): |
| raise RuntimeError(f"Language model not found at {moshi_path}") |
| logger.info(f"Loading language model from {moshi_path}") |
| moshi = loaders.get_moshi_lm(str(moshi_path), device=self.device) |
| lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) |
|
|
| return mimi, text_tokenizer, lm_gen |
|
|
| except Exception as e: |
| logger.error(f"Model initialization failed: {str(e)}") |
| raise |
| |
| def _load_audio(self, audio_array: np.ndarray, sample_rate: int): |
| """Load and preprocess audio.""" |
| try: |
| |
| wav = torch.from_numpy(audio_array).float().unsqueeze(0) |
| |
| |
| if sample_rate != self.sample_rate: |
| logger.info(f"Resampling from {sample_rate} to {self.sample_rate}") |
| |
| resampler = torchaudio.transforms.Resample( |
| orig_freq=sample_rate, |
| new_freq=self.sample_rate |
| ).to(self.device) |
| |
| wav = resampler(wav.to(self.device)) |
| else: |
| |
| wav = wav.to(self.device) |
| |
| |
| frame_size = int(self.sample_rate / self.frame_rate) |
| orig_length = wav.shape[-1] |
| wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size] |
| if wav.shape[-1] != orig_length: |
| logger.info(f"Trimmed audio from {orig_length} to {wav.shape[-1]} samples for frame alignment") |
| |
| return wav |
| |
| except Exception as e: |
| logger.error(f"Audio loading failed: {str(e)}") |
| raise |
| |
| def _pad_codes(self, all_codes, time_seconds=30): |
| try: |
| min_frames = int(time_seconds * self.frame_rate) |
| frame_size = int(self.sample_rate / self.frame_rate) |
| |
| if len(all_codes) < min_frames: |
| frames_to_add = min_frames - len(all_codes) |
| logger.info(f"Padding {frames_to_add} frames to reach minimum length") |
| with torch.no_grad(), self.mimi.streaming(batch_size=1): |
| |
| chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device) |
| for _ in range(frames_to_add): |
| additional_code = self.mimi.encode(chunk) |
| all_codes.append(additional_code) |
| |
| return all_codes |
|
|
| except Exception as e: |
| logger.error(f"Code padding failed: {str(e)}") |
| raise |
| |
| def _encode_audio(self, wav: torch.Tensor): |
| """Convert audio to codes.""" |
| try: |
| frame_size = int(self.sample_rate / self.frame_rate) |
| all_codes = [] |
| |
| with torch.no_grad(), self.mimi.streaming(batch_size=1): |
| for offset in range(0, wav.shape[-1], frame_size): |
| frame = wav[:, :, offset: offset + frame_size] |
| codes = self.mimi.encode(frame.to(self.device)) |
| assert codes.shape[-1] == 1, f"Expected code shape (*, *, 1), got {codes.shape}" |
| all_codes.append(codes) |
| |
| logger.info(f"Encoded {len(all_codes)} frames") |
| return all_codes |
|
|
| except Exception as e: |
| logger.error(f"Audio encoding failed: {str(e)}") |
| raise |
|
|
| def _warmup(self): |
| """Run a warmup pass.""" |
| try: |
| frame_size = int(self.sample_rate / self.frame_rate) |
| |
| chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device) |
| |
| with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1): |
| codes = self.mimi.encode(chunk) |
| tokens = self.lm_gen.step(codes[:, :, 0:1]) |
| if tokens is not None: |
| _ = self.mimi.decode(tokens[:, 1:]) |
| |
| if self.device.type == 'cuda': |
| torch.cuda.synchronize() |
| logger.info("Warmup pass completed") |
|
|
| except Exception as e: |
| logger.error(f"Warmup failed: {str(e)}") |
| raise |
|
|
| def _generate(self, all_codes): |
| """Generate audio and text from codes.""" |
| try: |
| out_wav_chunks = [] |
| text_output = [] |
|
|
| with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1): |
| for i, code in enumerate(all_codes): |
| assert code.shape == (1, 8, 1), f"Expected code shape (1, 8, 1), got {code.shape}" |
| tokens_out = self.lm_gen.step(code.to(self.device)) |
| |
| if tokens_out is not None: |
| |
| wav_chunk = self.mimi.decode(tokens_out[:, 1:]) |
| out_wav_chunks.append(wav_chunk) |
| |
| |
| text_token = tokens_out[0, 0, 0].item() |
| if text_token not in (0, 3): |
| _text = self.text_tokenizer.id_to_piece(text_token) |
| _text = _text.replace("▁", " ") |
| text_output.append(_text) |
|
|
| if (i + 1) % 100 == 0: |
| logger.info(f"Processed {i + 1}/{len(all_codes)} frames") |
|
|
| wav = torch.cat(out_wav_chunks, dim=-1) |
| text = ''.join(text_output) |
| |
| logger.info(f"Generated {wav.shape[-1]} samples of audio and {len(text)} characters of text") |
| return wav, text |
|
|
| except Exception as e: |
| logger.error(f"Generation failed: {str(e)}") |
| raise |
|
|
| def inference(self, audio_array: np.ndarray, sample_rate: int) -> dict: |
| """Run inference on input audio. |
| |
| Args: |
| audio_array (np.ndarray): Input audio as numpy array |
| sample_rate (int): Sample rate of input audio |
| |
| Returns: |
| dict: Contains generated audio array and optional transcribed text |
| """ |
| try: |
| logger.info(f"Starting inference on {len(audio_array)} samples at {sample_rate} Hz, self device: {self.device}") |
|
|
| |
| wav = self._load_audio(audio_array, sample_rate) |
| wav = wav.to(self.device) |
| |
| |
| all_codes = self._encode_audio(wav) |
| all_codes = self._pad_codes(all_codes) |
| |
| |
| self._warmup() |
| |
| |
| out_wav, text = self._generate(all_codes) |
|
|
| |
| output = out_wav.cpu().numpy().squeeze() |
| |
| logger.info("Inference completed successfully") |
| return { |
| "audio": output, |
| "text": text |
| } |
|
|
| except Exception as e: |
| logger.error(f"Inference failed: {str(e)}") |
| raise |
|
|
| if __name__ == "__main__": |
| |
| import librosa |
| |
| |
| model = InferenceRecipe("/path/to/models", device="cuda") |
| |
| |
| audio, sr = librosa.load("test.wav", sr=None) |
| |
| |
| result = model.inference(audio, sr) |
| print(f"Generated {len(result['audio'])} samples of audio") |
| print(f"Generated text: {result['text']}") |