Spaces:
Running
on
Zero
Running
on
Zero
| # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Pytriton server for token2wav conversion and ASR""" | |
| from datasets import load_dataset | |
| from cosyvoice.cli.cosyvoice import CosyVoice2 | |
| from omnisense.models import OmniSenseVoiceSmall | |
| from pytriton.proxy.types import Request | |
| from pytriton.triton import Triton, TritonConfig | |
| from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor | |
| from pytriton.decorators import batch | |
| import argparse | |
| import io | |
| import logging | |
| from typing import Any, List | |
| import numpy as np | |
| import torch | |
| from scipy.signal import resample | |
| import sys | |
| import random | |
| import re | |
| from jiwer import wer | |
| from pypinyin import lazy_pinyin, Style | |
| from tn.chinese.normalizer import Normalizer as ZhNormalizer | |
| # Chinese text normalizer (cached globally) | |
| zh_tn_model = ZhNormalizer( | |
| cache_dir="./cache", | |
| remove_erhua=False, | |
| remove_interjections=False, | |
| remove_puncts=True, | |
| overwrite_cache=True, | |
| ) | |
| sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") | |
| logger = logging.getLogger("token2wav_asr_server") | |
| class _ASR_Server: | |
| """Wraps a single OmniSenseVoiceSmall model instance for Triton.""" | |
| def __init__(self, device_id: int): | |
| self._model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id) | |
| def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np.ndarray, TEXT_NORM: np.ndarray): | |
| """ | |
| WAV: np.ndarray, WAV_LENS: np.ndarray | |
| LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used | |
| See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu | |
| """ | |
| logger.debug("WAV: %s, WAV_LENS: %s, shapes: %s %s", type(WAV), type(WAV_LENS), WAV.shape, WAV_LENS.shape) | |
| wavs = [WAV[i, :WAV_LENS[i, 0]] for i in range(len(WAV))] | |
| results = self._model.transcribe_single_batch( | |
| wavs, | |
| language="zh", | |
| textnorm="woitn", | |
| ) | |
| texts = [result.text for result in results] | |
| transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8") | |
| return {"TRANSCRIPTS": transcripts} | |
| def audio_decode_cosyvoice2( | |
| audio_tokens, prompt_text, prompt_speech_16k, codec_decoder | |
| ): | |
| """ | |
| Generate audio from tokens with optional tone and prompt embedding. | |
| """ | |
| model_inputs_dict = codec_decoder.frontend.frontend_zero_shot( | |
| "empty", prompt_text, prompt_speech_16k, 24000 | |
| ) | |
| tts_mel, _ = codec_decoder.model.flow.inference( | |
| token=audio_tokens.to(codec_decoder.model.device), | |
| token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to( | |
| codec_decoder.model.device | |
| ), | |
| prompt_token=model_inputs_dict["flow_prompt_speech_token"].to( | |
| codec_decoder.model.device | |
| ), | |
| prompt_token_len=torch.tensor( | |
| [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32 | |
| ).to(codec_decoder.model.device), | |
| prompt_feat=model_inputs_dict["prompt_speech_feat"].to( | |
| codec_decoder.model.device | |
| ), | |
| prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to( | |
| codec_decoder.model.device | |
| ), | |
| embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device), | |
| finalize=True, | |
| ) | |
| audio_hat, _ = codec_decoder.model.hift.inference( | |
| speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) | |
| ) | |
| return audio_hat | |
| def get_random_prompt_from_dataset(dataset): | |
| """ | |
| Get random prompt text and speech from the pre-loaded dataset. | |
| Returns (prompt_text, prompt_speech_16k) | |
| """ | |
| random_idx = random.randint(0, len(dataset) - 1) | |
| sample = dataset[random_idx] | |
| # Extract audio data | |
| audio_data = sample["audio"] | |
| audio_array = audio_data["array"] | |
| sample_rate = audio_data["sampling_rate"] | |
| # Convert audio to 16kHz if needed | |
| if sample_rate != 16000: | |
| num_samples = int(len(audio_array) * (16000 / sample_rate)) | |
| audio_array = resample(audio_array, num_samples) | |
| # Convert to torch tensor | |
| prompt_speech_16k = torch.from_numpy(audio_array).float().unsqueeze(0) | |
| prompt_text = sample["text"] | |
| # remove space in prompt_text | |
| prompt_text = prompt_text.replace(" ", "") | |
| return prompt_text, prompt_speech_16k | |
| class _Token2Wav_ASR: | |
| """Wraps a single OmniSenseVoiceSmall model instance for Triton.""" | |
| def __init__(self, device_id: int): | |
| self.asr_model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id) | |
| self.dataset = load_dataset("yuekai/aishell", "test", trust_remote_code=True)["test"] | |
| # Make sure the CosyVoice2 decoder lives on the same GPU as the ASR model | |
| # CosyVoice2 internally uses generic "cuda" device, so we first switch the | |
| # current CUDA context to the desired card before the object is created. | |
| # Afterwards, all parameters loaded with the generic "cuda" device will | |
| # reside on this GPU. We keep the selected id in `self.device_id` and | |
| # will set the context again for every forward call to avoid race | |
| # conditions when several instances are used in the same process. | |
| self.device_id = device_id | |
| # Construct the TTS codec decoder under the correct CUDA device context | |
| with torch.cuda.device(self.device_id): | |
| self.codec_decoder = CosyVoice2( | |
| "/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True | |
| ) | |
| def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray): | |
| """ | |
| WAV: np.ndarray, WAV_LENS: np.ndarray | |
| LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used | |
| See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu | |
| """ | |
| # Ensure the default CUDA device is set correctly for this invocation | |
| torch.cuda.set_device(self.device_id) | |
| if self.device_id == 0: | |
| print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}") | |
| tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))] | |
| # Decode ground-truth text strings (BYTES → str) | |
| if GT_TEXT.ndim == 2: | |
| gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))] | |
| else: | |
| gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))] | |
| wavs = [] | |
| for tokens in tokens_list: | |
| prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset) | |
| audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0) | |
| audio_hat = audio_decode_cosyvoice2( | |
| audio_tokens, | |
| prompt_text, | |
| prompt_speech_16k, | |
| self.codec_decoder, | |
| ) | |
| # resample to 16000 using soundfile | |
| audio_hat = audio_hat.squeeze(0).float().cpu() | |
| audio_hat = audio_hat.numpy() | |
| num_samples = int(len(audio_hat) * (16000 / 24000)) | |
| audio_hat = resample(audio_hat, num_samples) | |
| wavs.append(audio_hat) | |
| results = self.asr_model.transcribe_single_batch( | |
| wavs, | |
| language="zh", | |
| textnorm="woitn", | |
| ) | |
| texts = [result.text for result in results] | |
| # ---------------- Reward computation ---------------- | |
| rewards = [] | |
| for gt_text, hyp_text in zip(gt_texts, texts): | |
| gt_norm = zh_tn_model.normalize(gt_text).lower() | |
| hyp_norm = zh_tn_model.normalize(hyp_text).lower() | |
| gt_pinyin = lazy_pinyin( | |
| gt_norm, | |
| style=Style.TONE3, | |
| tone_sandhi=True, | |
| neutral_tone_with_five=True, | |
| ) | |
| hyp_pinyin = lazy_pinyin( | |
| hyp_norm, | |
| style=Style.TONE3, | |
| tone_sandhi=True, | |
| neutral_tone_with_five=True, | |
| ) | |
| c = float(wer(" ".join(gt_pinyin), " ".join(hyp_pinyin))) | |
| reward_val = 1.0 - np.tanh(3.0 * c) | |
| reward_val = max(0.0, min(1.0, reward_val)) | |
| rewards.append(reward_val) | |
| print(f"gt_text: {gt_text}, hyp_text: {hyp_text}, reward_val: {reward_val}") | |
| transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8") | |
| rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1) | |
| return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts} | |
| def _infer_function_factory(device_ids: List[int], model_name: str): | |
| """Creates a list of inference functions, one for each requested device ID.""" | |
| infer_funcs = [] | |
| for device_id in device_ids: | |
| if model_name == "sensevoice": | |
| infer_funcs.append(_ASR_Server(device_id=device_id)) | |
| else: | |
| infer_funcs.append(_Token2Wav_ASR(device_id=device_id)) | |
| return infer_funcs | |
| def main(): | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument( | |
| "--max-batch-size", | |
| type=int, | |
| default=32, | |
| help="Batch size of request.", | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--verbose", | |
| action="store_true", | |
| default=False, | |
| ) | |
| parser.add_argument( | |
| "--number-of-instances-per-device", | |
| type=int, | |
| default=1, | |
| help="Number of model instances to load.", | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--number-of-devices", | |
| type=int, | |
| default=8, | |
| help="Number of devices to use.", | |
| ) | |
| parser.add_argument( | |
| "--model-name", | |
| type=str, | |
| default="token2wav_asr", | |
| choices=["token2wav_asr", "sensevoice"], | |
| help="Model name.", | |
| ) | |
| args = parser.parse_args() | |
| log_level = logging.DEBUG if args.verbose else logging.INFO | |
| logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s") | |
| triton_config = TritonConfig( | |
| http_port=8000, | |
| grpc_port=8001, | |
| metrics_port=8002, | |
| ) | |
| device_ids = list(range(args.number_of_devices)) | |
| device_ids = device_ids * args.number_of_instances_per_device | |
| with Triton(config=triton_config) as triton: | |
| logger.info("Loading SenseVoice model on device ids: %s", device_ids) | |
| if args.model_name == "sensevoice": | |
| triton.bind( | |
| model_name="sensevoice", | |
| infer_func=_infer_function_factory(device_ids, args.model_name), | |
| inputs=[ | |
| Tensor(name="WAV", dtype=np.float32, shape=(-1,)), | |
| Tensor(name="WAV_LENS", dtype=np.int32, shape=(-1,)), | |
| Tensor(name="LANGUAGE", dtype=np.int32, shape=(-1,)), | |
| Tensor(name="TEXT_NORM", dtype=np.int32, shape=(-1,)), | |
| ], | |
| outputs=[ | |
| Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)), | |
| ], | |
| config=ModelConfig( | |
| max_batch_size=args.max_batch_size, | |
| batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms | |
| ), | |
| strict=True, | |
| ) | |
| else: | |
| triton.bind( | |
| model_name="token2wav_asr", | |
| infer_func=_infer_function_factory(device_ids, args.model_name), | |
| inputs=[ | |
| Tensor(name="TOKENS", dtype=np.int32, shape=(-1,)), | |
| Tensor(name="TOKEN_LENS", dtype=np.int32, shape=(-1,)), | |
| Tensor(name="GT_TEXT", dtype=bytes, shape=(-1,)), | |
| ], | |
| outputs=[ | |
| Tensor(name="REWARDS", dtype=np.float32, shape=(-1,)), | |
| Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)), | |
| ], | |
| config=ModelConfig( | |
| max_batch_size=args.max_batch_size, | |
| batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms | |
| ), | |
| strict=True, | |
| ) | |
| logger.info("Serving inference") | |
| triton.serve() | |
| if __name__ == "__main__": | |
| main() | |