from io import BytesIO import os import gradio as gr import spaces import torch from pyannote.audio import Pipeline import torchaudio from pydub import AudioSegment from pyannote.audio import Pipeline import json import requests # Authenticate with Huggingface AUTH_TOKEN = os.getenv("HF_TOKEN") # Load the diarization pipeline device = torch.device("cuda") pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-community-1", token=AUTH_TOKEN).to(device) def preprocess_audio(audio_path): """Convert audio to mono, 16kHz WAV format suitable for pyannote.""" try: if isinstance(audio_path, str): bytes = False else: bytes = True # Load audio with pydub audio = AudioSegment.from_file(BytesIO(audio_path) if bytes else audio_path) # Convert to mono and set sample rate to 16kHz audio = audio.set_channels(1).set_frame_rate(16000) # Export to temporary WAV file temp_wav = "temp_audio.wav" audio.export(temp_wav, format="wav") return temp_wav except Exception as e: raise ValueError(f"Error preprocessing audio: {str(e)}") def handle_audio(url, audio_path, num_speakers): """Handle audio processing and diarization.""" if url: response = requests.get(url, timeout=60) audio_path = response.content audio_path = preprocess_audio(audio_path) res = diarize_audio(audio_path, num_speakers) # Clean up temporary file if os.path.exists(audio_path): os.remove(audio_path) return res @spaces.GPU(duration=180) def diarize_audio(audio_path, num_speakers): """Perform speaker diarization and return formatted results.""" try: # Load audio for pyannote waveform, sample_rate = torchaudio.load(audio_path) audio_dict = {"waveform": waveform, "sample_rate": sample_rate} # Configure pipeline with number of speakers pipeline_params = {"num_speakers": num_speakers} if num_speakers > 0 else { "min_speakers": 2, "max_speakers": 6 } diarization = pipeline(audio_dict, **pipeline_params) # Format results results = [] for turn, speaker in diarization.exclusive_speaker_diarization: result = { "start": round(turn.start, 3), "end": round(turn.end, 3), "speaker_id": speaker } results.append(result) return json.dumps(results, indent=2) except Exception as e: return f"Error: {str(e)}", "" # Gradio interface with gr.Blocks() as demo: gr.Markdown("# Speaker Diarization with speaker-diarization-community-1") gr.Markdown("Upload an audio file and specify the number of speakers to diarize the audio.") with gr.Row(): url_input = gr.Textbox(label="URL") audio_input = gr.Audio(label="Upload Audio File", type="filepath") num_speakers = gr.Slider(minimum=0, maximum=10, step=1, label="Number of Speakers", value=2) submit_btn = gr.Button("Diarize") with gr.Row(): json_output = gr.Textbox(label="Diarization Results (JSON)") submit_btn.click( fn=handle_audio, inputs=[url_input, audio_input, num_speakers], outputs=[json_output], concurrency_limit=2, ) # Launch the Gradio app demo.launch()