Spaces:
Sleeping
Sleeping
Create diarization.py
Browse files- services/diarization.py +63 -0
services/diarization.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from pyannote.audio import Pipeline
|
4 |
+
|
5 |
+
def extract_files(files):
|
6 |
+
filepaths = [file.name for file in files]
|
7 |
+
return filepaths
|
8 |
+
|
9 |
+
class Diarizer:
|
10 |
+
def __init__(self, conf):
|
11 |
+
self.conf = conf
|
12 |
+
self.pipeline = self.pyannote_pipeline()
|
13 |
+
|
14 |
+
def pyannote_pipeline(self):
|
15 |
+
pipeline = Pipeline.from_pretrained(
|
16 |
+
self.conf["model"]["diarizer"],
|
17 |
+
use_auth_token=os.environ["HUGGINGFACE_TOKEN"]
|
18 |
+
)
|
19 |
+
return pipeline
|
20 |
+
|
21 |
+
def add_device(self, pipeline):
|
22 |
+
"""Offloaded to allow for best timing when working with GPUs"""
|
23 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
+
pipeline.to(device)
|
25 |
+
return pipeline
|
26 |
+
|
27 |
+
def diarize_audio(self, temp_file, num_speakers):
|
28 |
+
pipeline = self.add_device(self.pipeline)
|
29 |
+
diarization = pipeline(temp_file, num_speakers=num_speakers)
|
30 |
+
# os.remove(temp_file) # Uncomment if you want to remove the temp file after processing
|
31 |
+
return str(diarization)
|
32 |
+
|
33 |
+
def extract_seconds(self, timestamp):
|
34 |
+
h, m, s = map(float, timestamp.split(':'))
|
35 |
+
return 3600 * h + 60 * m + s
|
36 |
+
|
37 |
+
def generate_labels_from_diarization(self, diarized_output):
|
38 |
+
labels_path = 'labels.txt'
|
39 |
+
lines = diarized_output.strip().split('\n')
|
40 |
+
plaintext = ""
|
41 |
+
for line in lines:
|
42 |
+
try:
|
43 |
+
parts = line.strip()[1:-1].split(' --> ')
|
44 |
+
if len(parts) == 2:
|
45 |
+
label = line.split()[-1].strip()
|
46 |
+
|
47 |
+
start_seconds = self.extract_seconds(parts[0].strip())
|
48 |
+
end_seconds = self.extract_seconds(parts[1].split(']')[0].strip())
|
49 |
+
plaintext += f"{start_seconds}\t{end_seconds}\t{label}\n"
|
50 |
+
else:
|
51 |
+
raise ValueError("Unexpected format in diarized output")
|
52 |
+
except Exception as e:
|
53 |
+
print(f"Error processing line: '{line.strip()}'. Error: {e}")
|
54 |
+
|
55 |
+
with open(labels_path, "w") as file:
|
56 |
+
file.write(plaintext)
|
57 |
+
|
58 |
+
return labels_path
|
59 |
+
|
60 |
+
def run(self, temp_file, num_speakers):
|
61 |
+
diarization_result = self.diarize_audio(temp_file, num_speakers)
|
62 |
+
label_file = self.generate_labels_from_diarization(diarization_result)
|
63 |
+
return diarization_result, label_file
|