Jarbas's picture
Upload folder using huggingface_hub
57aa504 verified
#!/usr/bin/env python3
"""CTC inference with CoreML parakeet-ctc-0.6b-vi.
Usage: python infer.py audio.wav
"""
import json, sys
from pathlib import Path
import coremltools as ct
import numpy as np
import soundfile as sf
REPO_DIR = Path(__file__).parent
SAMPLE_RATE = 16_000
def load_models(compute_units="ALL"):
cu_map = {"ALL": ct.ComputeUnit.ALL, "CPU_ONLY": ct.ComputeUnit.CPU_ONLY,
"CPU_AND_NE": ct.ComputeUnit.CPU_AND_NE}
cu = cu_map.get(compute_units.upper(), ct.ComputeUnit.ALL)
mel_enc = ct.models.MLModel(str(REPO_DIR / "parakeet_mel_encoder.mlpackage"), compute_units=cu)
ctc_dec = ct.models.MLModel(str(REPO_DIR / "parakeet_ctc_decoder.mlpackage"), compute_units=cu)
return mel_enc, ctc_dec
def load_audio(path, max_samples):
data, sr = sf.read(path, dtype="float32", always_2d=False)
if sr != SAMPLE_RATE:
raise ValueError(f"Expected {SAMPLE_RATE} Hz, got {sr} Hz.")
if data.ndim > 1: data = data[:, 0]
actual = min(len(data), max_samples)
data = np.pad(data, (0, max(0, max_samples - len(data))))[:max_samples]
return data.reshape(1, -1).astype(np.float32), actual
def decode_ctc(log_probs, vocab, blank_id):
ids = np.argmax(log_probs[0], axis=-1)
out, prev = [], None
for t in ids:
if t != blank_id and t != prev: out.append(int(t))
prev = t
return "".join(vocab[i] for i in out).replace("▁", " ").strip()
def transcribe(audio_path, compute_units="ALL"):
meta = json.loads((REPO_DIR / "metadata.json").read_text())
vocab = json.loads((REPO_DIR / "vocab.json").read_text())
blank = meta["blank_id"]
n = meta["max_audio_samples"]
mel_enc, ctc_dec = load_models(compute_units)
audio, actual = load_audio(audio_path, n)
length = np.array([actual], dtype=np.int32)
enc_out = mel_enc.predict({"audio_signal": audio, "audio_length": length})
encoder = enc_out["encoder"]
enc_len = int(enc_out["encoder_length"][0])
ctc_out = ctc_dec.predict({"encoder": encoder[:, :, :enc_len]})
return decode_ctc(ctc_out["log_probs"], vocab, blank)
if __name__ == "__main__":
args = sys.argv[1:]
if not args:
print("Usage: python infer.py <audio.wav> [--compute-units ALL|CPU_ONLY|CPU_AND_NE]")
sys.exit(1)
cu = "ALL"
if "--compute-units" in args:
cu = args[args.index("--compute-units") + 1]
print(transcribe(args[0], cu))