Spaces:
Sleeping
Sleeping
import gradio as gr | |
import asyncio | |
from websockets import connect, Data, ClientConnection | |
from dotenv import load_dotenv | |
import json | |
import os | |
import threading | |
import numpy as np | |
import base64 | |
import soundfile as sf | |
import io | |
from pydub import AudioSegment | |
import time | |
import uuid | |
class LogColors: | |
OK = '\033[94m' | |
SUCCESS = '\033[92m' | |
WARNING = '\033[93m' | |
ERROR = '\033[91m' | |
ENDC = '\033[0m' | |
load_dotenv() | |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
if not OPENAI_API_KEY: | |
raise ValueError("OPENAI_API_KEY environment variable must be set") | |
WEBSOCKET_URI = "wss://api.openai.com/v1/realtime?intent=transcription" | |
WEBSOCKET_HEADERS = { | |
"Authorization": "Bearer " + OPENAI_API_KEY, | |
"OpenAI-Beta": "realtime=v1" | |
} | |
transcription = "" | |
css = """ | |
""" | |
connections = {} | |
class WebSocketClient: | |
def __init__(self, uri: str, headers: dict, client_id: str): | |
self.uri = uri | |
self.headers = headers | |
self.websocket: ClientConnection = None | |
self.queue = asyncio.Queue(maxsize=10) | |
self.loop = None | |
self.client_id = client_id | |
async def connect(self): | |
try: | |
self.websocket = await connect(self.uri, additional_headers=self.headers) | |
print(f"{LogColors.SUCCESS}Connected to OpenAI WebSocket{LogColors.ENDC}\n") | |
# Send session settings to OpenAI | |
with open("openai_transcription_settings.json", "r") as f: | |
settings = f.read() | |
await self.websocket.send(settings) | |
await asyncio.gather(self.receive_messages(), self.send_audio_chunks()) | |
except Exception as e: | |
print(f"{LogColors.ERROR}WebSocket Connection Error: {e}{LogColors.ENDC}") | |
def run(self): | |
self.loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(self.loop) | |
self.loop.run_until_complete(self.connect()) | |
def process_websocket_message(self, message: Data): | |
global transcription | |
message_object = json.loads(message) | |
if message_object["type"] != "error": | |
print(f"{LogColors.OK}Received message: {LogColors.ENDC} {message}") | |
if message_object["type"] == "conversation.item.input_audio_transcription.delta": | |
delta = message_object["delta"] | |
transcription += delta | |
elif message_object["type"] == "conversation.item.input_audio_transcription.completed": | |
transcription += ' ' if len(transcription) and transcription[-1] != ' ' else '' | |
else: | |
print(f"{LogColors.ERROR}Error: {message}{LogColors.ENDC}") | |
async def send_audio_chunks(self): | |
while True: | |
audio_data = await self.queue.get() | |
sample_rate, audio_array = audio_data | |
if self.websocket: | |
# Convert to mono if stereo | |
if audio_array.ndim > 1: | |
audio_array = audio_array.mean(axis=1) | |
# Convert to float32 and normalize | |
audio_array = audio_array.astype(np.float32) | |
audio_array /= np.max(np.abs(audio_array)) if np.max(np.abs(audio_array)) > 0 else 1.0 | |
# Convert to 16-bit PCM | |
audio_array_int16 = (audio_array * 32767).astype(np.int16) | |
audio_buffer = io.BytesIO() | |
sf.write(audio_buffer, audio_array_int16, sample_rate, format='WAV', subtype='PCM_16') | |
audio_buffer.seek(0) | |
audio_segment = AudioSegment.from_file(audio_buffer, format="wav") | |
resampled_audio = audio_segment.set_frame_rate(24000) | |
output_buffer = io.BytesIO() | |
resampled_audio.export(output_buffer, format="wav") | |
output_buffer.seek(0) | |
base64_audio = base64.b64encode(output_buffer.read()).decode("utf-8") | |
await self.websocket.send(json.dumps({"type": "input_audio_buffer.append", "audio": base64_audio})) | |
print(f"{LogColors.OK}Sent audio chunk{LogColors.ENDC}") | |
async def receive_messages(self): | |
async for message in self.websocket: | |
self.process_websocket_message(message) | |
def enqueue_audio_chunk(self, sample_rate: int, chunk_array: np.ndarray): | |
if not self.queue.full(): | |
asyncio.run_coroutine_threadsafe(self.queue.put((sample_rate, chunk_array)), self.loop) | |
else: | |
print(f"{LogColors.WARNING}Queue is full, dropping audio chunk{LogColors.ENDC}") | |
async def close(self): | |
if self.websocket: | |
await self.websocket.close() | |
connections.pop(self.client_id) | |
print(f"{LogColors.WARNING}WebSocket connection closed{LogColors.ENDC}") | |
def send_audio_chunk(new_chunk: gr.Audio, client_id: str): | |
if client_id not in connections: | |
return "Connection is being established, please try again in a few seconds." | |
sr, y = new_chunk | |
connections[client_id].enqueue_audio_chunk(sr, y) | |
return transcription | |
def create_new_websocket_connection(): | |
client_id = str(uuid.uuid4()) | |
connections[client_id] = WebSocketClient(WEBSOCKET_URI, WEBSOCKET_HEADERS, client_id) | |
threading.Thread(target=connections[client_id].run, daemon=True).start() | |
return client_id | |
if __name__ == "__main__": | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
gr.Markdown(f"# Realtime transcription demo") | |
with gr.Row(): | |
with gr.Column(): | |
output_textbox = gr.Textbox(label="Transcription", value="", lines=7, interactive=False, autoscroll=True) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
audio_input = gr.Audio(streaming=True, format="wav") | |
with gr.Column(): | |
clear_button = gr.Button("Clear") | |
client_id = gr.State() | |
state = gr.State() | |
clear_button.click(lambda: None, outputs=[state]).then(lambda: "", outputs=[output_textbox]) | |
audio_input.stream(send_audio_chunk, [audio_input, client_id], [output_textbox], stream_every=0.5, concurrency_limit=None) | |
demo.load(create_new_websocket_connection, outputs=[client_id]) | |
demo.launch() |