atulyaatul commited on
Commit
e4ce772
·
verified ·
1 Parent(s): 8504e20

Fix PyTorch inference code: use libdf for correct feature extraction

Browse files
Files changed (1) hide show
  1. README.md +40 -17
README.md CHANGED
@@ -31,18 +31,6 @@ pipeline_tag: audio-to-audio
31
 
32
  ---
33
 
34
- ## Listen to the Model (Use headphones)
35
-
36
- **Raw Audio (Noisy Environment):**
37
-
38
- <audio controls src="https://huggingface.co/weya-ai/hush/resolve/main/assets/audio/sample_00006_raw.wav"></audio>
39
-
40
- **Denoised Audio (Hush Output):**
41
-
42
- <audio controls src="https://huggingface.co/weya-ai/hush/resolve/main/assets/audio/sample_00006_denoised.wav"></audio>
43
-
44
- ---
45
-
46
  ## Model Overview
47
 
48
  Hush is designed from the ground up for **Voice AI applications** — phone-based voice agents, call centre bots, voice assistants, real-time transcription pipelines, and conversational AI systems. It isolates exactly one speaker from a live audio stream, in real time, under production conditions.
@@ -146,25 +134,60 @@ ERB gain mask Complex filter
146
 
147
  ## Quick Start: PyTorch Inference
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  ```python
150
  import torch
 
151
  import soundfile as sf
152
- from model.dfnet_se import DfNetSE, get_config
 
153
 
 
154
  config = get_config()
155
  model = DfNetSE(config)
156
  checkpoint = torch.load("model_best.ckpt", map_location="cpu")
157
  model.model.load_state_dict(checkpoint)
158
  model.eval()
159
 
 
160
  audio, sr = sf.read("noisy_speech.wav")
161
  assert sr == 16000, "Input must be 16 kHz"
162
-
163
- wav = torch.tensor(audio).float().unsqueeze(0).unsqueeze(0) # [1, 1, T]
 
 
 
 
 
 
 
 
 
 
164
  with torch.no_grad():
165
- enhanced = model(wav) # [1, 1, T]
 
 
 
 
 
 
 
166
 
167
- sf.write("enhanced.wav", enhanced.squeeze().numpy(), 16000)
168
  ```
169
 
170
  ## Quick Start: Production (ONNX, No PyTorch)
 
31
 
32
  ---
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  ## Model Overview
35
 
36
  Hush is designed from the ground up for **Voice AI applications** — phone-based voice agents, call centre bots, voice assistants, real-time transcription pipelines, and conversational AI systems. It isolates exactly one speaker from a live audio stream, in real time, under production conditions.
 
134
 
135
  ## Quick Start: PyTorch Inference
136
 
137
+ > **Important:** PyTorch inference requires `DeepFilterLib` for correct feature extraction.
138
+ > Install it with `pip install DeepFilterLib`.
139
+
140
+ The simplest way is the CLI script from the [GitHub repo](https://github.com/pulp-vision/Hush):
141
+
142
+ ```bash
143
+ python scripts/infer_single.py \
144
+ --checkpoint model_best.ckpt \
145
+ --input noisy_speech.wav \
146
+ --output enhanced.wav
147
+ ```
148
+
149
+ Or use the Python API directly:
150
+
151
  ```python
152
  import torch
153
+ import numpy as np
154
  import soundfile as sf
155
+ from libdf import DF, erb, erb_norm, unit_norm
156
+ from model.dfnet_se import DfNetSE, as_complex, as_real, get_config, get_norm_alpha
157
 
158
+ # Load model
159
  config = get_config()
160
  model = DfNetSE(config)
161
  checkpoint = torch.load("model_best.ckpt", map_location="cpu")
162
  model.model.load_state_dict(checkpoint)
163
  model.eval()
164
 
165
+ # Load audio
166
  audio, sr = sf.read("noisy_speech.wav")
167
  assert sr == 16000, "Input must be 16 kHz"
168
+ wav = torch.tensor(audio, dtype=torch.float32).unsqueeze(0) # [1, T]
169
+
170
+ # Feature extraction via libdf (must match training pipeline)
171
+ df_state = DF(sr=16000, fft_size=320, hop_size=160, nb_bands=32, min_nb_erb_freqs=2)
172
+ alpha = get_norm_alpha(16000, 160, config.norm_tau)
173
+ wav_padded = torch.nn.functional.pad(wav, (0, 320))
174
+ spec_np = df_state.analysis(wav_padded.numpy(), reset=True)
175
+ erb_feat = torch.as_tensor(erb_norm(erb(spec_np, df_state.erb_widths()), alpha)).unsqueeze(1)
176
+ spec_feat = as_real(torch.as_tensor(unit_norm(spec_np[..., :64], alpha))).unsqueeze(1)
177
+ spec_t = as_real(torch.as_tensor(spec_np)).unsqueeze(1)
178
+
179
+ # Enhance
180
  with torch.no_grad():
181
+ spec_enh = model.model(spec_t.clone(), erb_feat, spec_feat)[0]
182
+ spec_enh_c = as_complex(spec_enh.squeeze(1))
183
+
184
+ # Synthesize and compensate delay
185
+ enh_np = df_state.synthesis(spec_enh_c.numpy(), reset=True)
186
+ enh = torch.from_numpy(np.asarray(enh_np, dtype=np.float32))
187
+ delay = 320 - 160 # fft_size - hop_size
188
+ enh = enh[:, delay : len(audio) + delay]
189
 
190
+ sf.write("enhanced.wav", enh.squeeze().numpy(), 16000)
191
  ```
192
 
193
  ## Quick Start: Production (ONNX, No PyTorch)