KIFF commited on
Commit
f6dd816
·
verified ·
1 Parent(s): 7797ec6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +103 -59
handler.py CHANGED
@@ -1,65 +1,109 @@
1
- from typing import Dict
2
- from pyannote.audio import Pipeline
3
- import torch
4
  import base64
 
5
  import numpy as np
 
6
 
7
- SAMPLE_RATE = 16000
 
 
 
 
8
 
9
- class EndpointHandler():
10
- def __init__(self, path=""):
11
- # Initialize the pipeline (no authentication needed for public models)
12
- self.pipeline = Pipeline.from_pretrained(
13
- "pyannote/speaker-diarization-3.1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Move the pipeline to the appropriate device (CPU or GPU)
17
- self.pipeline.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
18
-
19
- # Instantiate the pipeline with its parameters
20
- self.pipeline = self.pipeline.instantiate(self.pipeline.parameters)
21
-
22
- def __call__(self, data: Dict) -> Dict:
23
- """
24
- Args:
25
- data (Dict):
26
- 'inputs': Base64-encoded audio bytes
27
- 'parameters': Additional diarization parameters (currently unused)
28
- Return:
29
- Dict: Speaker diarization results
30
- """
31
- inputs = data.get("inputs")
32
- parameters = data.get("parameters", {}) # We are not using them now
33
-
34
- # Decode the base64 audio data
35
- audio_data = base64.b64decode(inputs)
36
- audio_nparray = np.frombuffer(audio_data, dtype=np.int16)
37
-
38
- # Handle multi-channel audio (convert to mono)
39
- if audio_nparray.ndim > 1:
40
- audio_nparray = audio_nparray.mean(axis=0) # Average channels to create mono
41
-
42
- # Convert to PyTorch tensor
43
- audio_tensor = torch.from_numpy(audio_nparray).float().unsqueeze(0)
44
- if audio_tensor.dim() == 1:
45
- audio_tensor = audio_tensor.unsqueeze(0)
46
-
47
- pyannote_input = {"waveform": audio_tensor, "sample_rate": SAMPLE_RATE}
48
-
49
- # Run diarization pipeline
50
- try:
51
- diarization = self.pipeline(pyannote_input) # No num_speakers parameter
52
- except Exception as e:
53
- print(f"An unexpected error occurred: {e}")
54
- return {"error": "Diarization failed unexpectedly"}
55
-
56
- # Build a friendly JSON response
57
- processed_diarization = [
58
- {
59
- "label": str(label),
60
- "start": str(segment.start),
61
- "stop": str(segment.end),
62
- }
63
- for segment, _, label in diarization.itertracks(yield_label=True)
64
- ]
65
- return {"diarization": processed_diarization}
 
1
+ import os
2
+ import requests
3
+ import json
4
  import base64
5
+ import soundfile as sf
6
  import numpy as np
7
+ from scipy.signal import resample
8
 
9
+ # --- Configuration ---
10
+ # Replace with your actual API key/token
11
+ HF_TOKEN = os.environ.get("HF_API_TOKEN") # Get the token from environment variable
12
+ # Replace with your actual endpoint URL
13
+ STG_API_URL = "https://YOUR_ENDPOINT_URL"
14
 
15
+ # --- Functions ---
16
+
17
+ def query_to_hf(filename):
18
+ """Sends audio file to Hugging Face API using requests."""
19
+ try:
20
+ data, sr = sf.read(filename)
21
+ except sf.LibsndfileError as e:
22
+ print(f"Error reading audio file: {e}")
23
+ return None
24
+
25
+ # Handle multi-channel audio (convert to mono)
26
+ if len(data.shape) > 1:
27
+ data = data.mean(axis=1) # Average channels to create mono
28
+
29
+ data = resample(data, num=int(len(data) * 16000 / sr))
30
+ data = (data * np.iinfo(np.int16).max).astype(np.int16)
31
+
32
+ # Prepare the data payload
33
+ data_payload = {
34
+ "inputs": base64.b64encode(data.tobytes()).decode("utf-8")
35
+ # No parameters needed
36
+ }
37
+ json_data = json.dumps(data_payload)
38
+
39
+ # Use requests to send the POST request
40
+ try:
41
+ response = requests.post(
42
+ url=STG_API_URL,
43
+ data=json_data,
44
+ headers={
45
+ "Content-Type": "application/json",
46
+ "Authorization": f"Bearer {HF_TOKEN}"
47
+ },
48
  )
49
+ response.raise_for_status()
50
+ return response.json()
51
+ except requests.exceptions.RequestException as e:
52
+ print(f"Error during API request: {e}")
53
+ print(f"Response content: {response.content}")
54
+ return None
55
+
56
+ def format_timecode(seconds):
57
+ """Formats seconds into HH:MM:SS:mmm format."""
58
+ m, s = divmod(seconds, 60)
59
+ h, m = divmod(m, 60)
60
+ return f"{int(h):02}:{int(m):02}:{int(s):02}:{int((s%1)*1000):03}"
61
+
62
+ def process_and_format_output(output, input_file):
63
+ """Formats the API response (now a dict) and saves it to a file."""
64
+ if output is None:
65
+ print("No output received from API.")
66
+ return None
67
+
68
+ # Check if the output is a dictionary and has the expected key
69
+ if not isinstance(output, dict) or "diarization" not in output:
70
+ print(f"Unexpected output format: {output}")
71
+ return None
72
+
73
+ try:
74
+ formatted_output = []
75
+ for speaker in output["diarization"]:
76
+ start_time = format_timecode(float(speaker["start"]))
77
+ end_time = format_timecode(float(speaker["stop"]))
78
+ formatted_output.append(f"{speaker['label']} START: {start_time} END: {end_time}")
79
+
80
+ base_filename = os.path.splitext(os.path.basename(input_file))[0]
81
+ output_dir = "TMP_STG"
82
+ os.makedirs(output_dir, exist_ok=True)
83
+ output_filename = os.path.join(output_dir, base_filename + "_voicerec-output.txt")
84
+
85
+ with open(output_filename, "w", encoding="utf-8") as f:
86
+ for line in formatted_output:
87
+ f.write(line + "\n")
88
+
89
+ return output_filename
90
+ except (KeyError, ValueError) as e:
91
+ print(f"Error processing API output: {e}")
92
+ return None
93
+
94
+ # --- Main Script ---
95
+
96
+ if __name__ == "__main__":
97
+ # --- Configuration for Standalone Testing ---
98
+ SAMPLE_AUDIO_FILE = "sample.wav" # Put your sample audio file in the same directory
99
+
100
+ # --- Main Script Logic ---
101
+ print(f"Sending {SAMPLE_AUDIO_FILE} to Hugging Face API...")
102
+ api_output = query_to_hf(SAMPLE_AUDIO_FILE)
103
 
104
+ if api_output:
105
+ output_file = process_and_format_output(api_output, SAMPLE_AUDIO_FILE)
106
+ if output_file:
107
+ print(f"Output saved to: {output_file}")
108
+ else:
109
+ print("API request failed.")