Update handler.py
Browse files- handler.py +20 -6
handler.py
CHANGED
@@ -1,28 +1,41 @@
|
|
|
|
1 |
from pyannote.audio import Pipeline, Audio
|
2 |
import torch
|
3 |
|
4 |
|
5 |
class EndpointHandler:
|
6 |
def __init__(self, path=""):
|
7 |
-
#
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
#
|
11 |
if torch.cuda.is_available():
|
12 |
self._pipeline.to(torch.device("cuda"))
|
13 |
|
14 |
-
#
|
15 |
self._io = Audio()
|
16 |
|
17 |
def __call__(self, data):
|
|
|
18 |
inputs = data.pop("inputs", data)
|
19 |
waveform, sample_rate = self._io(inputs)
|
20 |
|
|
|
21 |
parameters = data.pop("parameters", dict())
|
22 |
-
|
|
|
|
|
23 |
{"waveform": waveform, "sample_rate": sample_rate}, **parameters
|
24 |
)
|
25 |
|
|
|
26 |
processed_diarization = [
|
27 |
{
|
28 |
"speaker": speaker,
|
@@ -32,4 +45,5 @@ class EndpointHandler:
|
|
32 |
for turn, _, speaker in diarization.itertracks(yield_label=True)
|
33 |
]
|
34 |
|
35 |
-
|
|
|
|
1 |
+
import os
|
2 |
from pyannote.audio import Pipeline, Audio
|
3 |
import torch
|
4 |
|
5 |
|
6 |
class EndpointHandler:
|
7 |
def __init__(self, path=""):
|
8 |
+
# Get the Hugging Face authentication token from the environment variable
|
9 |
+
auth_token = os.getenv("MY_KEY")
|
10 |
+
if not auth_token:
|
11 |
+
raise ValueError("Hugging Face authentication token (MY_KEY) is missing.")
|
12 |
+
|
13 |
+
# Initialize pretrained pipeline with the token
|
14 |
+
self._pipeline = Pipeline.from_pretrained(
|
15 |
+
"pyannote/speaker-diarization-3.1", use_auth_token=auth_token
|
16 |
+
)
|
17 |
|
18 |
+
# Send pipeline to GPU if available
|
19 |
if torch.cuda.is_available():
|
20 |
self._pipeline.to(torch.device("cuda"))
|
21 |
|
22 |
+
# Initialize audio reader
|
23 |
self._io = Audio()
|
24 |
|
25 |
def __call__(self, data):
|
26 |
+
# Extract inputs from request data
|
27 |
inputs = data.pop("inputs", data)
|
28 |
waveform, sample_rate = self._io(inputs)
|
29 |
|
30 |
+
# Extract pipeline parameters if provided
|
31 |
parameters = data.pop("parameters", dict())
|
32 |
+
|
33 |
+
# Run speaker diarization
|
34 |
+
diarization = self._pipeline(
|
35 |
{"waveform": waveform, "sample_rate": sample_rate}, **parameters
|
36 |
)
|
37 |
|
38 |
+
# Process diarization results
|
39 |
processed_diarization = [
|
40 |
{
|
41 |
"speaker": speaker,
|
|
|
45 |
for turn, _, speaker in diarization.itertracks(yield_label=True)
|
46 |
]
|
47 |
|
48 |
+
# Return results as JSON
|
49 |
+
return {"diarization": processed_diarization}
|