Update handler.py
Browse files- handler.py +103 -59
handler.py
CHANGED
@@ -1,65 +1,109 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
import
|
4 |
import base64
|
|
|
5 |
import numpy as np
|
|
|
6 |
|
7 |
-
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
def __call__(self, data: Dict) -> Dict:
|
23 |
-
"""
|
24 |
-
Args:
|
25 |
-
data (Dict):
|
26 |
-
'inputs': Base64-encoded audio bytes
|
27 |
-
'parameters': Additional diarization parameters (currently unused)
|
28 |
-
Return:
|
29 |
-
Dict: Speaker diarization results
|
30 |
-
"""
|
31 |
-
inputs = data.get("inputs")
|
32 |
-
parameters = data.get("parameters", {}) # We are not using them now
|
33 |
-
|
34 |
-
# Decode the base64 audio data
|
35 |
-
audio_data = base64.b64decode(inputs)
|
36 |
-
audio_nparray = np.frombuffer(audio_data, dtype=np.int16)
|
37 |
-
|
38 |
-
# Handle multi-channel audio (convert to mono)
|
39 |
-
if audio_nparray.ndim > 1:
|
40 |
-
audio_nparray = audio_nparray.mean(axis=0) # Average channels to create mono
|
41 |
-
|
42 |
-
# Convert to PyTorch tensor
|
43 |
-
audio_tensor = torch.from_numpy(audio_nparray).float().unsqueeze(0)
|
44 |
-
if audio_tensor.dim() == 1:
|
45 |
-
audio_tensor = audio_tensor.unsqueeze(0)
|
46 |
-
|
47 |
-
pyannote_input = {"waveform": audio_tensor, "sample_rate": SAMPLE_RATE}
|
48 |
-
|
49 |
-
# Run diarization pipeline
|
50 |
-
try:
|
51 |
-
diarization = self.pipeline(pyannote_input) # No num_speakers parameter
|
52 |
-
except Exception as e:
|
53 |
-
print(f"An unexpected error occurred: {e}")
|
54 |
-
return {"error": "Diarization failed unexpectedly"}
|
55 |
-
|
56 |
-
# Build a friendly JSON response
|
57 |
-
processed_diarization = [
|
58 |
-
{
|
59 |
-
"label": str(label),
|
60 |
-
"start": str(segment.start),
|
61 |
-
"stop": str(segment.end),
|
62 |
-
}
|
63 |
-
for segment, _, label in diarization.itertracks(yield_label=True)
|
64 |
-
]
|
65 |
-
return {"diarization": processed_diarization}
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
import json
|
4 |
import base64
|
5 |
+
import soundfile as sf
|
6 |
import numpy as np
|
7 |
+
from scipy.signal import resample
|
8 |
|
9 |
+
# --- Configuration ---
|
10 |
+
# Replace with your actual API key/token
|
11 |
+
HF_TOKEN = os.environ.get("HF_API_TOKEN") # Get the token from environment variable
|
12 |
+
# Replace with your actual endpoint URL
|
13 |
+
STG_API_URL = "https://YOUR_ENDPOINT_URL"
|
14 |
|
15 |
+
# --- Functions ---
|
16 |
+
|
17 |
+
def query_to_hf(filename):
|
18 |
+
"""Sends audio file to Hugging Face API using requests."""
|
19 |
+
try:
|
20 |
+
data, sr = sf.read(filename)
|
21 |
+
except sf.LibsndfileError as e:
|
22 |
+
print(f"Error reading audio file: {e}")
|
23 |
+
return None
|
24 |
+
|
25 |
+
# Handle multi-channel audio (convert to mono)
|
26 |
+
if len(data.shape) > 1:
|
27 |
+
data = data.mean(axis=1) # Average channels to create mono
|
28 |
+
|
29 |
+
data = resample(data, num=int(len(data) * 16000 / sr))
|
30 |
+
data = (data * np.iinfo(np.int16).max).astype(np.int16)
|
31 |
+
|
32 |
+
# Prepare the data payload
|
33 |
+
data_payload = {
|
34 |
+
"inputs": base64.b64encode(data.tobytes()).decode("utf-8")
|
35 |
+
# No parameters needed
|
36 |
+
}
|
37 |
+
json_data = json.dumps(data_payload)
|
38 |
+
|
39 |
+
# Use requests to send the POST request
|
40 |
+
try:
|
41 |
+
response = requests.post(
|
42 |
+
url=STG_API_URL,
|
43 |
+
data=json_data,
|
44 |
+
headers={
|
45 |
+
"Content-Type": "application/json",
|
46 |
+
"Authorization": f"Bearer {HF_TOKEN}"
|
47 |
+
},
|
48 |
)
|
49 |
+
response.raise_for_status()
|
50 |
+
return response.json()
|
51 |
+
except requests.exceptions.RequestException as e:
|
52 |
+
print(f"Error during API request: {e}")
|
53 |
+
print(f"Response content: {response.content}")
|
54 |
+
return None
|
55 |
+
|
56 |
+
def format_timecode(seconds):
|
57 |
+
"""Formats seconds into HH:MM:SS:mmm format."""
|
58 |
+
m, s = divmod(seconds, 60)
|
59 |
+
h, m = divmod(m, 60)
|
60 |
+
return f"{int(h):02}:{int(m):02}:{int(s):02}:{int((s%1)*1000):03}"
|
61 |
+
|
62 |
+
def process_and_format_output(output, input_file):
|
63 |
+
"""Formats the API response (now a dict) and saves it to a file."""
|
64 |
+
if output is None:
|
65 |
+
print("No output received from API.")
|
66 |
+
return None
|
67 |
+
|
68 |
+
# Check if the output is a dictionary and has the expected key
|
69 |
+
if not isinstance(output, dict) or "diarization" not in output:
|
70 |
+
print(f"Unexpected output format: {output}")
|
71 |
+
return None
|
72 |
+
|
73 |
+
try:
|
74 |
+
formatted_output = []
|
75 |
+
for speaker in output["diarization"]:
|
76 |
+
start_time = format_timecode(float(speaker["start"]))
|
77 |
+
end_time = format_timecode(float(speaker["stop"]))
|
78 |
+
formatted_output.append(f"{speaker['label']} START: {start_time} END: {end_time}")
|
79 |
+
|
80 |
+
base_filename = os.path.splitext(os.path.basename(input_file))[0]
|
81 |
+
output_dir = "TMP_STG"
|
82 |
+
os.makedirs(output_dir, exist_ok=True)
|
83 |
+
output_filename = os.path.join(output_dir, base_filename + "_voicerec-output.txt")
|
84 |
+
|
85 |
+
with open(output_filename, "w", encoding="utf-8") as f:
|
86 |
+
for line in formatted_output:
|
87 |
+
f.write(line + "\n")
|
88 |
+
|
89 |
+
return output_filename
|
90 |
+
except (KeyError, ValueError) as e:
|
91 |
+
print(f"Error processing API output: {e}")
|
92 |
+
return None
|
93 |
+
|
94 |
+
# --- Main Script ---
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
# --- Configuration for Standalone Testing ---
|
98 |
+
SAMPLE_AUDIO_FILE = "sample.wav" # Put your sample audio file in the same directory
|
99 |
+
|
100 |
+
# --- Main Script Logic ---
|
101 |
+
print(f"Sending {SAMPLE_AUDIO_FILE} to Hugging Face API...")
|
102 |
+
api_output = query_to_hf(SAMPLE_AUDIO_FILE)
|
103 |
|
104 |
+
if api_output:
|
105 |
+
output_file = process_and_format_output(api_output, SAMPLE_AUDIO_FILE)
|
106 |
+
if output_file:
|
107 |
+
print(f"Output saved to: {output_file}")
|
108 |
+
else:
|
109 |
+
print("API request failed.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|