import gradio as gr import asyncio from websockets import connect, Data, ClientConnection from dotenv import load_dotenv import json import os import threading import numpy as np import base64 import soundfile as sf import io from pydub import AudioSegment import time import uuid class LogColors: OK = '\033[94m' SUCCESS = '\033[92m' WARNING = '\033[93m' ERROR = '\033[91m' ENDC = '\033[0m' load_dotenv() OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") if not OPENAI_API_KEY: raise ValueError("OPENAI_API_KEY environment variable must be set") WEBSOCKET_URI = "wss://api.openai.com/v1/realtime?intent=transcription" WEBSOCKET_HEADERS = { "Authorization": "Bearer " + OPENAI_API_KEY, "OpenAI-Beta": "realtime=v1" } transcription = "" css = """ """ connections = {} class WebSocketClient: def __init__(self, uri: str, headers: dict, client_id: str): self.uri = uri self.headers = headers self.websocket: ClientConnection = None self.queue = asyncio.Queue(maxsize=10) self.loop = None self.client_id = client_id async def connect(self): try: self.websocket = await connect(self.uri, additional_headers=self.headers) print(f"{LogColors.SUCCESS}Connected to OpenAI WebSocket{LogColors.ENDC}\n") # Send session settings to OpenAI with open("openai_transcription_settings.json", "r") as f: settings = f.read() await self.websocket.send(settings) await asyncio.gather(self.receive_messages(), self.send_audio_chunks()) except Exception as e: print(f"{LogColors.ERROR}WebSocket Connection Error: {e}{LogColors.ENDC}") def run(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) self.loop.run_until_complete(self.connect()) def process_websocket_message(self, message: Data): global transcription message_object = json.loads(message) if message_object["type"] != "error": print(f"{LogColors.OK}Received message: {LogColors.ENDC} {message}") if message_object["type"] == "conversation.item.input_audio_transcription.delta": delta = message_object["delta"] transcription += delta elif message_object["type"] == "conversation.item.input_audio_transcription.completed": transcription += ' ' if len(transcription) and transcription[-1] != ' ' else '' else: print(f"{LogColors.ERROR}Error: {message}{LogColors.ENDC}") async def send_audio_chunks(self): while True: audio_data = await self.queue.get() sample_rate, audio_array = audio_data if self.websocket: # Convert to mono if stereo if audio_array.ndim > 1: audio_array = audio_array.mean(axis=1) # Convert to float32 and normalize audio_array = audio_array.astype(np.float32) audio_array /= np.max(np.abs(audio_array)) if np.max(np.abs(audio_array)) > 0 else 1.0 # Convert to 16-bit PCM audio_array_int16 = (audio_array * 32767).astype(np.int16) audio_buffer = io.BytesIO() sf.write(audio_buffer, audio_array_int16, sample_rate, format='WAV', subtype='PCM_16') audio_buffer.seek(0) audio_segment = AudioSegment.from_file(audio_buffer, format="wav") resampled_audio = audio_segment.set_frame_rate(24000) output_buffer = io.BytesIO() resampled_audio.export(output_buffer, format="wav") output_buffer.seek(0) base64_audio = base64.b64encode(output_buffer.read()).decode("utf-8") await self.websocket.send(json.dumps({"type": "input_audio_buffer.append", "audio": base64_audio})) print(f"{LogColors.OK}Sent audio chunk{LogColors.ENDC}") async def receive_messages(self): async for message in self.websocket: self.process_websocket_message(message) def enqueue_audio_chunk(self, sample_rate: int, chunk_array: np.ndarray): if not self.queue.full(): asyncio.run_coroutine_threadsafe(self.queue.put((sample_rate, chunk_array)), self.loop) else: print(f"{LogColors.WARNING}Queue is full, dropping audio chunk{LogColors.ENDC}") async def close(self): if self.websocket: await self.websocket.close() connections.pop(self.client_id) print(f"{LogColors.WARNING}WebSocket connection closed{LogColors.ENDC}") def send_audio_chunk(new_chunk: gr.Audio, client_id: str): if client_id not in connections: return "Connection is being established, please try again in a few seconds." sr, y = new_chunk connections[client_id].enqueue_audio_chunk(sr, y) return transcription def create_new_websocket_connection(): client_id = str(uuid.uuid4()) connections[client_id] = WebSocketClient(WEBSOCKET_URI, WEBSOCKET_HEADERS, client_id) threading.Thread(target=connections[client_id].run, daemon=True).start() return client_id if __name__ == "__main__": with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: gr.Markdown(f"# Realtime transcription demo") with gr.Row(): with gr.Column(): output_textbox = gr.Textbox(label="Transcription", value="", lines=7, interactive=False, autoscroll=True) with gr.Row(): with gr.Column(scale=5): audio_input = gr.Audio(streaming=True, format="wav") with gr.Column(): clear_button = gr.Button("Clear") client_id = gr.State() state = gr.State() clear_button.click(lambda: None, outputs=[state]).then(lambda: "", outputs=[output_textbox]) audio_input.stream(send_audio_chunk, [audio_input, client_id], [output_textbox], stream_every=0.5, concurrency_limit=None) demo.load(create_new_websocket_connection, outputs=[client_id]) demo.launch()