File size: 3,888 Bytes
e1b59aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import asyncio
from websockets import connect, Data, ClientConnection
import json
import numpy as np
import base64
import soundfile as sf
import io
from pydub import AudioSegment

# Load OpenAI API key from environment (dotenv is optional)
import os
from dotenv import load_dotenv
load_dotenv()

OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    raise ValueError("OPENAI_API_KEY environment variable must be set")

WEBSOCKET_URI = "wss://api.openai.com/v1/realtime?intent=transcription"
WEBSOCKET_HEADERS = {
    "Authorization": "Bearer " + OPENAI_API_KEY,
    "OpenAI-Beta": "realtime=v1"
}

connections = {}

class WebSocketClient:
    def __init__(self, uri: str, headers: dict, client_id: str):
        self.uri = uri
        self.headers = headers
        self.websocket: ClientConnection = None
        self.queue = asyncio.Queue(maxsize=10)
        self.loop = None
        self.client_id = client_id
        self.transcript = ""

    async def connect(self):
        try:
            self.websocket = await connect(self.uri, additional_headers=self.headers)
            print("✅ Connected to OpenAI WebSocket")

            with open("openai_transcription_settings.json", "r") as f:
                settings = f.read()
                await self.websocket.send(settings)

            await asyncio.gather(self.receive_messages(), self.send_audio_chunks())
        except Exception as e:
            print(f"❌ WebSocket error: {e}")

    def run(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)
        self.loop.run_until_complete(self.connect())

    def process_websocket_message(self, message: Data):
        message_object = json.loads(message)
        if message_object["type"] != "error":
            if message_object["type"] == "conversation.item.input_audio_transcription.delta":
                delta = message_object["delta"]
                self.transcript += delta
            elif message_object["type"] == "conversation.item.input_audio_transcription.completed":
                self.transcript += ' ' if len(self.transcript) and self.transcript[-1] != ' ' else ''
        else:
            print(f"⚠️ Error received: {message}")

    async def send_audio_chunks(self):
        while True:
            sample_rate, audio_array = await self.queue.get()
            if self.websocket:
                if audio_array.ndim > 1:
                    audio_array = audio_array.mean(axis=1)
                audio_array = audio_array.astype(np.float32)
                audio_array /= np.max(np.abs(audio_array)) if np.max(np.abs(audio_array)) > 0 else 1.0
                audio_array_int16 = (audio_array * 32767).astype(np.int16)

                buffer = io.BytesIO()
                sf.write(buffer, audio_array_int16, sample_rate, format='WAV', subtype='PCM_16')
                buffer.seek(0)
                segment = AudioSegment.from_file(buffer, format="wav")
                resampled = segment.set_frame_rate(24000)

                out_buf = io.BytesIO()
                resampled.export(out_buf, format="wav")
                out_buf.seek(0)

                b64_audio = base64.b64encode(out_buf.read()).decode("utf-8")
                await self.websocket.send(json.dumps({
                    "type": "input_audio_buffer.append",
                    "audio": b64_audio
                }))

    async def receive_messages(self):
        async for message in self.websocket:
            self.process_websocket_message(message)

    def enqueue_audio_chunk(self, sample_rate: int, chunk_array: np.ndarray):
        if not self.queue.full():
            asyncio.run_coroutine_threadsafe(self.queue.put((sample_rate, chunk_array)), self.loop)

    async def close(self):
        if self.websocket:
            await self.websocket.close()
            connections.pop(self.client_id)