KIFF commited on
Commit
a4109dd
·
verified ·
1 Parent(s): c06fa1f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +20 -6
handler.py CHANGED
@@ -1,28 +1,41 @@
 
1
  from pyannote.audio import Pipeline, Audio
2
  import torch
3
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
- # initialize pretrained pipeline
8
- self._pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
 
 
 
 
 
 
 
9
 
10
- # send pipeline to GPU if available
11
  if torch.cuda.is_available():
12
  self._pipeline.to(torch.device("cuda"))
13
 
14
- # initialize audio reader
15
  self._io = Audio()
16
 
17
  def __call__(self, data):
 
18
  inputs = data.pop("inputs", data)
19
  waveform, sample_rate = self._io(inputs)
20
 
 
21
  parameters = data.pop("parameters", dict())
22
- diarization = self.pipeline(
 
 
23
  {"waveform": waveform, "sample_rate": sample_rate}, **parameters
24
  )
25
 
 
26
  processed_diarization = [
27
  {
28
  "speaker": speaker,
@@ -32,4 +45,5 @@ class EndpointHandler:
32
  for turn, _, speaker in diarization.itertracks(yield_label=True)
33
  ]
34
 
35
- return {"diarization": processed_diarization}
 
 
1
+ import os
2
  from pyannote.audio import Pipeline, Audio
3
  import torch
4
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
+ # Get the Hugging Face authentication token from the environment variable
9
+ auth_token = os.getenv("MY_KEY")
10
+ if not auth_token:
11
+ raise ValueError("Hugging Face authentication token (MY_KEY) is missing.")
12
+
13
+ # Initialize pretrained pipeline with the token
14
+ self._pipeline = Pipeline.from_pretrained(
15
+ "pyannote/speaker-diarization-3.1", use_auth_token=auth_token
16
+ )
17
 
18
+ # Send pipeline to GPU if available
19
  if torch.cuda.is_available():
20
  self._pipeline.to(torch.device("cuda"))
21
 
22
+ # Initialize audio reader
23
  self._io = Audio()
24
 
25
  def __call__(self, data):
26
+ # Extract inputs from request data
27
  inputs = data.pop("inputs", data)
28
  waveform, sample_rate = self._io(inputs)
29
 
30
+ # Extract pipeline parameters if provided
31
  parameters = data.pop("parameters", dict())
32
+
33
+ # Run speaker diarization
34
+ diarization = self._pipeline(
35
  {"waveform": waveform, "sample_rate": sample_rate}, **parameters
36
  )
37
 
38
+ # Process diarization results
39
  processed_diarization = [
40
  {
41
  "speaker": speaker,
 
45
  for turn, _, speaker in diarization.itertracks(yield_label=True)
46
  ]
47
 
48
+ # Return results as JSON
49
+ return {"diarization": processed_diarization}