Y-Mangoes commited on
Commit
84035c8
·
verified ·
1 Parent(s): 4873c49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -115
app.py CHANGED
@@ -1,133 +1,101 @@
 
1
  import gradio as gr
2
  import torch
 
 
3
  from pyannote.audio import Pipeline
4
- from pyannote.core import Segment, Annotation
5
- import os
6
  from huggingface_hub import login
7
- import tempfile
8
- import librosa
9
- import soundfile as sf
10
  import numpy as np
11
- import warnings
12
-
13
- # Suppress torchaudio backend warning
14
- warnings.filterwarnings("ignore", category=UserWarning, module="pyannote.audio.core.io")
15
 
16
- # Authenticate with Hugging Face
17
- os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN") # Set in Hugging Face Space secrets
18
- login(token=os.environ["HF_TOKEN"])
 
 
 
19
 
20
- # Initialize the pyannote pipeline with pre-trained model
21
- pipeline = Pipeline.from_pretrained(
22
- "pyannote/speaker-diarization-3.1",
23
- use_auth_token=True
24
- )
25
-
26
- # Optimize for GPU if available
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- pipeline.to(device)
29
 
30
- def process_audio(audio_file):
31
- """
32
- Process the input audio file and return diarization results.
33
-
34
- Args:
35
- audio_file: Path to the input audio file
36
-
37
- Returns:
38
- Tuple containing:
39
- - Diarization text output
40
- - Path to visualization plot
41
- - Number of speakers detected
42
- """
43
  try:
44
- # Load and preprocess audio
45
- audio, sr = librosa.load(audio_file, sr=16000, mono=True)
46
-
47
- # Save temporary audio file in WAV format (pyannote requirement)
48
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
49
- sf.write(temp_file.name, audio, sr)
50
- temp_file_path = temp_file.name
51
-
52
- # Perform speaker diarization
53
- diarization = pipeline({"uri": "audio", "audio": temp_file_path})
54
-
55
- # Clean up temporary file
56
- os.unlink(temp_file_path)
57
-
58
- # Process diarization results
59
- output_text = []
60
- speakers = set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  for turn, _, speaker in diarization.itertracks(yield_label=True):
62
- start = turn.start
63
- end = turn.end
64
- output_text.append(
65
- f"Speaker {speaker}: {start:.2f}s - {end:.2f}s"
66
- )
67
- speakers.add(speaker)
68
-
69
- # Generate visualization
70
- plot_path = visualize_diarization(diarization, audio, sr)
71
-
72
- return (
73
- "\n".join(output_text),
74
- plot_path,
75
- len(speakers)
76
- )
77
-
78
  except Exception as e:
79
- return f"Error processing audio: {str(e)}", None, 0
80
 
81
- def visualize_diarization(diarization, audio, sr):
82
- """
83
- Create a visualization of the diarization results.
84
-
85
- Args:
86
- diarization: Pyannote diarization object
87
- audio: Audio waveform
88
- sr: Sample rate
89
-
90
- Returns:
91
- Path to saved visualization plot
92
- """
93
- import matplotlib.pyplot as plt
94
 
95
- plt.figure(figsize=(12, 4))
 
 
96
 
97
- # Plot waveform
98
- time = np.linspace(0, len(audio)/sr, num=len(audio))
99
- plt.plot(time, audio, alpha=0.3, color='gray')
100
 
101
- # Plot diarization segments
102
- for turn, _, speaker in diarization.itertracks(yield_label=True):
103
- plt.axvspan(turn.start, turn.end, alpha=0.2, label=f'Speaker {speaker}')
104
 
105
- plt.xlabel('Time (s)')
106
- plt.ylabel('Amplitude')
107
- plt.title('Speaker Diarization')
108
- plt.legend()
109
-
110
- # Save plot
111
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_plot:
112
- plt.savefig(temp_plot.name)
113
- plot_path = temp_plot.name
114
-
115
- plt.close()
116
- return plot_path
117
-
118
- # Create Gradio interface
119
- iface = gr.Interface(
120
- fn=process_audio,
121
- inputs=gr.Audio(type="filepath", label="Upload Audio File"),
122
- outputs=[
123
- gr.Textbox(label="Diarization Results"),
124
- gr.Image(label="Visualization"),
125
- gr.Number(label="Number of Speakers")
126
- ],
127
- title="Speaker Diarization with Pyannote 3.1",
128
- description="Upload an audio file to perform speaker diarization. Results show speaker segments and a visualization."
129
- )
130
 
131
- # Launch the interface
132
- if __name__ == "__main__":
133
- iface.launch()
 
1
+ import os
2
  import gradio as gr
3
  import torch
4
+ import torchaudio
5
+ from pydub import AudioSegment
6
  from pyannote.audio import Pipeline
 
 
7
  from huggingface_hub import login
 
 
 
8
  import numpy as np
9
+ import json
 
 
 
10
 
11
+ # Authenticate with Huggingface
12
+ HF_TOKEN = os.getenv("HF_TOKEN")
13
+ if HF_TOKEN:
14
+ login(HF_TOKEN)
15
+ else:
16
+ raise ValueError("Huggingface token not found. Set HF_TOKEN environment variable.")
17
 
18
+ # Load the diarization pipeline
 
 
 
 
 
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.0").to(device)
21
 
22
+ def preprocess_audio(audio_path):
23
+ """Convert audio to mono, 16kHz WAV format suitable for pyannote."""
 
 
 
 
 
 
 
 
 
 
 
24
  try:
25
+ # Load audio with pydub
26
+ audio = AudioSegment.from_file(audio_path)
27
+ # Convert to mono and set sample rate to 16kHz
28
+ audio = audio.set_channels(1).set_frame_rate(16000)
29
+ # Export to temporary WAV file
30
+ temp_wav = "temp_audio.wav"
31
+ audio.export(temp_wav, format="wav")
32
+ return temp_wav
33
+ except Exception as e:
34
+ raise ValueError(f"Error preprocessing audio: {str(e)}")
35
+
36
+ def diarize_audio(audio_path, num_speakers):
37
+ """Perform speaker diarization and return formatted results."""
38
+ try:
39
+ # Validate inputs
40
+ if not os.path.exists(audio_path):
41
+ raise ValueError("Audio file not found.")
42
+ if not isinstance(num_speakers, int) or num_speakers < 1:
43
+ raise ValueError("Number of speakers must be a positive integer.")
44
+
45
+ # Preprocess audio
46
+ wav_path = preprocess_audio(audio_path)
47
+
48
+ # Load audio for pyannote
49
+ waveform, sample_rate = torchaudio.load(wav_path)
50
+ audio_dict = {"waveform": waveform.to(device), "sample_rate": sample_rate}
51
+
52
+ # Configure pipeline with number of speakers
53
+ pipeline_params = {"num_speakers": num_speakers}
54
+ diarization = pipeline(audio_dict, **pipeline_params)
55
+
56
+ # Format results
57
+ results = []
58
+ text_output = ""
59
  for turn, _, speaker in diarization.itertracks(yield_label=True):
60
+ result = {
61
+ "start": round(turn.start, 3),
62
+ "end": round(turn.end, 3),
63
+ "speaker_id": speaker
64
+ }
65
+ results.append(result)
66
+ text_output += f"Speaker {speaker}: {result['start']}s - {result['end']}s\n"
67
+
68
+ # Clean up temporary file
69
+ if os.path.exists(wav_path):
70
+ os.remove(wav_path)
71
+
72
+ # Return text and JSON results
73
+ json_output = json.dumps(results, indent=2)
74
+ return text_output, json_output
75
+
76
  except Exception as e:
77
+ return f"Error: {str(e)}", ""
78
 
79
+ # Gradio interface
80
+ with gr.Blocks() as demo:
81
+ gr.Markdown("# Speaker Diarization with Pyannote 3.0")
82
+ gr.Markdown("Upload an audio file and specify the number of speakers to diarize the audio.")
 
 
 
 
 
 
 
 
 
83
 
84
+ with gr.Row():
85
+ audio_input = gr.Audio(label="Upload Audio File", type="filepath")
86
+ num_speakers = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Speakers", value=2)
87
 
88
+ submit_btn = gr.Button("Diarize")
 
 
89
 
90
+ with gr.Row():
91
+ text_output = gr.Textbox(label="Diarization Results (Text)")
92
+ json_output = gr.Textbox(label="Diarization Results (JSON)")
93
 
94
+ submit_btn.click(
95
+ fn=diarize_audio,
96
+ inputs=[audio_input, num_speakers],
97
+ outputs=[text_output, json_output]
98
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ # Launch the Gradio app
101
+ demo.launch()