tykiww commited on
Commit
33216f6
·
verified ·
1 Parent(s): 6aee0cc

Create diarization.py

Browse files
Files changed (1) hide show
  1. services/diarization.py +63 -0
services/diarization.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from pyannote.audio import Pipeline
4
+
5
+ def extract_files(files):
6
+ filepaths = [file.name for file in files]
7
+ return filepaths
8
+
9
+ class Diarizer:
10
+ def __init__(self, conf):
11
+ self.conf = conf
12
+ self.pipeline = self.pyannote_pipeline()
13
+
14
+ def pyannote_pipeline(self):
15
+ pipeline = Pipeline.from_pretrained(
16
+ self.conf["model"]["diarizer"],
17
+ use_auth_token=os.environ["HUGGINGFACE_TOKEN"]
18
+ )
19
+ return pipeline
20
+
21
+ def add_device(self, pipeline):
22
+ """Offloaded to allow for best timing when working with GPUs"""
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ pipeline.to(device)
25
+ return pipeline
26
+
27
+ def diarize_audio(self, temp_file, num_speakers):
28
+ pipeline = self.add_device(self.pipeline)
29
+ diarization = pipeline(temp_file, num_speakers=num_speakers)
30
+ # os.remove(temp_file) # Uncomment if you want to remove the temp file after processing
31
+ return str(diarization)
32
+
33
+ def extract_seconds(self, timestamp):
34
+ h, m, s = map(float, timestamp.split(':'))
35
+ return 3600 * h + 60 * m + s
36
+
37
+ def generate_labels_from_diarization(self, diarized_output):
38
+ labels_path = 'labels.txt'
39
+ lines = diarized_output.strip().split('\n')
40
+ plaintext = ""
41
+ for line in lines:
42
+ try:
43
+ parts = line.strip()[1:-1].split(' --> ')
44
+ if len(parts) == 2:
45
+ label = line.split()[-1].strip()
46
+
47
+ start_seconds = self.extract_seconds(parts[0].strip())
48
+ end_seconds = self.extract_seconds(parts[1].split(']')[0].strip())
49
+ plaintext += f"{start_seconds}\t{end_seconds}\t{label}\n"
50
+ else:
51
+ raise ValueError("Unexpected format in diarized output")
52
+ except Exception as e:
53
+ print(f"Error processing line: '{line.strip()}'. Error: {e}")
54
+
55
+ with open(labels_path, "w") as file:
56
+ file.write(plaintext)
57
+
58
+ return labels_path
59
+
60
+ def run(self, temp_file, num_speakers):
61
+ diarization_result = self.diarize_audio(temp_file, num_speakers)
62
+ label_file = self.generate_labels_from_diarization(diarization_result)
63
+ return diarization_result, label_file