IAMTFRMZA commited on
Commit
fa361c9
·
verified ·
1 Parent(s): 126ed75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -138
app.py CHANGED
@@ -1,164 +1,161 @@
1
  import gradio as gr
2
- import asyncio
3
- from websockets import connect, Data, ClientConnection
4
- from dotenv import load_dotenv
5
- import json
6
- import os
7
- import threading
8
  import numpy as np
9
- import base64
10
  import soundfile as sf
11
- import io
12
  from pydub import AudioSegment
13
- import time
14
- import uuid
15
-
16
- class LogColors:
17
- OK = '\033[94m'
18
- SUCCESS = '\033[92m'
19
- WARNING = '\033[93m'
20
- ERROR = '\033[91m'
21
- ENDC = '\033[0m'
22
-
23
- load_dotenv()
24
- OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
25
- if not OPENAI_API_KEY:
26
- raise ValueError("OPENAI_API_KEY environment variable must be set")
27
-
28
- WEBSOCKET_URI = "wss://api.openai.com/v1/realtime?intent=transcription"
29
- WEBSOCKET_HEADERS = {
30
- "Authorization": "Bearer " + OPENAI_API_KEY,
31
- "OpenAI-Beta": "realtime=v1"
32
- }
33
-
34
- css = """
35
- """
36
 
 
 
 
 
 
 
 
 
37
  connections = {}
38
 
 
39
  class WebSocketClient:
40
- def __init__(self, uri: str, headers: dict, client_id: str):
41
- self.uri = uri
42
- self.headers = headers
43
- self.websocket: ClientConnection = None
44
  self.queue = asyncio.Queue(maxsize=10)
45
- self.loop = None
46
- self.client_id = client_id
47
  self.transcript = ""
48
 
49
  async def connect(self):
50
- try:
51
- self.websocket = await connect(self.uri, additional_headers=self.headers)
52
- print(f"{LogColors.SUCCESS}Connected to OpenAI WebSocket{LogColors.ENDC}\n")
53
-
54
- # Send session settings to OpenAI
55
- with open("openai_transcription_settings.json", "r") as f:
56
- settings = f.read()
57
- await self.websocket.send(settings)
58
-
59
- await asyncio.gather(self.receive_messages(), self.send_audio_chunks())
60
- except Exception as e:
61
- print(f"{LogColors.ERROR}WebSocket Connection Error: {e}{LogColors.ENDC}")
62
 
63
  def run(self):
64
- self.loop = asyncio.new_event_loop()
65
- asyncio.set_event_loop(self.loop)
66
- self.loop.run_until_complete(self.connect())
67
-
68
- def process_websocket_message(self, message: Data):
69
- message_object = json.loads(message)
70
- if message_object["type"] != "error":
71
- print(f"{LogColors.OK}Received message: {LogColors.ENDC} {message}")
72
- if message_object["type"] == "conversation.item.input_audio_transcription.delta":
73
- delta = message_object["delta"]
74
- self.transcript += delta
75
- elif message_object["type"] == "conversation.item.input_audio_transcription.completed":
76
- self.transcript += ' ' if len(self.transcript) and self.transcript[-1] != ' ' else ''
77
- else:
78
- print(f"{LogColors.ERROR}Error: {message}{LogColors.ENDC}")
79
 
80
  async def send_audio_chunks(self):
81
  while True:
82
- audio_data = await self.queue.get()
83
- sample_rate, audio_array = audio_data
84
- if self.websocket:
85
- # Convert to mono if stereo
86
- if audio_array.ndim > 1:
87
- audio_array = audio_array.mean(axis=1)
88
-
89
- # Convert to float32 and normalize
90
- audio_array = audio_array.astype(np.float32)
91
- audio_array /= np.max(np.abs(audio_array)) if np.max(np.abs(audio_array)) > 0 else 1.0
92
-
93
- # Convert to 16-bit PCM
94
- audio_array_int16 = (audio_array * 32767).astype(np.int16)
95
-
96
- audio_buffer = io.BytesIO()
97
- sf.write(audio_buffer, audio_array_int16, sample_rate, format='WAV', subtype='PCM_16')
98
- audio_buffer.seek(0)
99
- audio_segment = AudioSegment.from_file(audio_buffer, format="wav")
100
- resampled_audio = audio_segment.set_frame_rate(24000)
101
-
102
- output_buffer = io.BytesIO()
103
- resampled_audio.export(output_buffer, format="wav")
104
- output_buffer.seek(0)
105
- base64_audio = base64.b64encode(output_buffer.read()).decode("utf-8")
106
-
107
- await self.websocket.send(json.dumps({"type": "input_audio_buffer.append", "audio": base64_audio}))
108
- print(f"{LogColors.OK}Sent audio chunk{LogColors.ENDC}")
109
 
110
  async def receive_messages(self):
111
- async for message in self.websocket:
112
- self.process_websocket_message(message)
 
 
113
 
114
- def enqueue_audio_chunk(self, sample_rate: int, chunk_array: np.ndarray):
115
  if not self.queue.full():
116
- asyncio.run_coroutine_threadsafe(self.queue.put((sample_rate, chunk_array)), self.loop)
117
- else:
118
- print(f"{LogColors.WARNING}Queue is full, dropping audio chunk{LogColors.ENDC}")
119
-
120
- async def close(self):
121
- if self.websocket:
122
- await self.websocket.close()
123
- connections.pop(self.client_id)
124
- print(f"{LogColors.WARNING}WebSocket connection closed{LogColors.ENDC}")
125
-
126
-
127
- def send_audio_chunk(new_chunk: gr.Audio, client_id: str):
128
- if client_id not in connections:
129
- return "Connection is being established, please try again in a few seconds."
130
- sr, y = new_chunk
131
- connections[client_id].enqueue_audio_chunk(sr, y)
132
- return connections[client_id].transcript
133
-
134
- def create_new_websocket_connection():
135
- client_id = str(uuid.uuid4())
136
- connections[client_id] = WebSocketClient(WEBSOCKET_URI, WEBSOCKET_HEADERS, client_id)
137
- threading.Thread(target=connections[client_id].run, daemon=True).start()
138
- return client_id
139
-
140
- def clear_transcript(client_id):
141
- if client_id in connections:
142
- connections[client_id].transcript = ""
143
  return ""
144
 
145
- if __name__ == "__main__":
146
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
147
- gr.Markdown(f"# Realtime transcription demo")
148
- with gr.Row():
149
- with gr.Column():
150
- output_textbox = gr.Textbox(label="Transcript", value="", lines=7, interactive=False, autoscroll=True)
151
- with gr.Row():
152
- with gr.Column(scale=5):
153
- audio_input = gr.Audio(streaming=True, format="wav")
154
- with gr.Column():
155
- clear_button = gr.Button("Clear")
156
-
157
- client_id = gr.State()
158
- clear_button.click(clear_transcript, inputs=[client_id], outputs=[output_textbox])
159
- audio_input.stream(send_audio_chunk, [audio_input, client_id], [output_textbox], stream_every=0.5, concurrency_limit=None)
160
- demo.load(create_new_websocket_connection, outputs=[client_id])
161
 
162
- demo.launch()
 
 
 
163
 
 
 
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os, time, re, json, base64, asyncio, threading, uuid, io
 
 
 
 
 
3
  import numpy as np
 
4
  import soundfile as sf
 
5
  from pydub import AudioSegment
6
+ from openai import OpenAI
7
+ from websockets import connect
8
+ from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Load environment variables
11
+ load_dotenv()
12
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
13
+ ASSISTANT_ID = os.getenv("ASSISTANT_ID")
14
+ client = OpenAI(api_key=OPENAI_API_KEY)
15
+
16
+ HEADERS = {"Authorization": f"Bearer {OPENAI_API_KEY}", "OpenAI-Beta": "realtime=v1"}
17
+ WS_URI = "wss://api.openai.com/v1/realtime?intent=transcription"
18
  connections = {}
19
 
20
+ # ---------------- WebSocket Client for Voice ----------------
21
  class WebSocketClient:
22
+ def __init__(self, uri, headers, client_id):
23
+ self.uri, self.headers, self.client_id = uri, headers, client_id
24
+ self.websocket = None
 
25
  self.queue = asyncio.Queue(maxsize=10)
 
 
26
  self.transcript = ""
27
 
28
  async def connect(self):
29
+ self.websocket = await connect(self.uri, additional_headers=self.headers)
30
+ with open("openai_transcription_settings.json", "r") as f:
31
+ await self.websocket.send(f.read())
32
+ await asyncio.gather(self.receive_messages(), self.send_audio_chunks())
 
 
 
 
 
 
 
 
33
 
34
  def run(self):
35
+ loop = asyncio.new_event_loop()
36
+ asyncio.set_event_loop(loop)
37
+ loop.run_until_complete(self.connect())
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  async def send_audio_chunks(self):
40
  while True:
41
+ sr, arr = await self.queue.get()
42
+ if arr.ndim > 1: arr = arr.mean(axis=1)
43
+ arr = (arr / np.max(np.abs(arr))) if np.max(np.abs(arr)) > 0 else arr
44
+ int16 = (arr * 32767).astype(np.int16)
45
+ buf = io.BytesIO(); sf.write(buf, int16, sr, format='WAV', subtype='PCM_16')
46
+ audio = AudioSegment.from_file(buf, format="wav").set_frame_rate(24000)
47
+ out = io.BytesIO(); audio.export(out, format="wav"); out.seek(0)
48
+ await self.websocket.send(json.dumps({
49
+ "type": "input_audio_buffer.append",
50
+ "audio": base64.b64encode(out.read()).decode()
51
+ }))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  async def receive_messages(self):
54
+ async for msg in self.websocket:
55
+ data = json.loads(msg)
56
+ if data["type"] == "conversation.item.input_audio_transcription.delta":
57
+ self.transcript += data["delta"]
58
 
59
+ def enqueue_audio_chunk(self, sr, arr):
60
  if not self.queue.full():
61
+ asyncio.run_coroutine_threadsafe(self.queue.put((sr, arr)), asyncio.get_event_loop())
62
+
63
+ def create_ws():
64
+ cid = str(uuid.uuid4())
65
+ client = WebSocketClient(WS_URI, HEADERS, cid)
66
+ threading.Thread(target=client.run, daemon=True).start()
67
+ connections[cid] = client
68
+ return cid
69
+
70
+ def send_audio(chunk, cid):
71
+ if cid not in connections: return "Connecting..."
72
+ sr, arr = chunk
73
+ connections[cid].enqueue_audio_chunk(sr, arr)
74
+ return connections[cid].transcript
75
+
76
+ def clear_transcript(cid):
77
+ if cid in connections: connections[cid].transcript = ""
 
 
 
 
 
 
 
 
 
 
78
  return ""
79
 
80
+ # ---------------- Chat Assistant Logic ----------------
81
+ def handle_chat(user_input, history, thread_id, image_url):
82
+ if not OPENAI_API_KEY or not ASSISTANT_ID:
83
+ return "❌ Missing API key or Assistant ID.", history, thread_id, image_url
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ try:
86
+ if thread_id is None:
87
+ thread = client.beta.threads.create()
88
+ thread_id = thread.id
89
 
90
+ client.beta.threads.messages.create(thread_id=thread_id, role="user", content=user_input)
91
+ run = client.beta.threads.runs.create(thread_id=thread_id, assistant_id=ASSISTANT_ID)
92
 
93
+ while True:
94
+ status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id)
95
+ if status.status == "completed": break
96
+ time.sleep(1)
97
+
98
+ msgs = client.beta.threads.messages.list(thread_id=thread_id)
99
+ for msg in reversed(msgs.data):
100
+ if msg.role == "assistant":
101
+ content = msg.content[0].text.value
102
+ history.append((user_input, content))
103
+ match = re.search(
104
+ r'https://raw\.githubusercontent\.com/AndrewLORTech/surgical-pathology-manual/main/[\w\-/]*\.png',
105
+ content
106
+ )
107
+ if match:
108
+ image_url = match.group(0)
109
+ break
110
+
111
+ return "", history, thread_id, image_url
112
+
113
+ except Exception as e:
114
+ return f"❌ {e}", history, thread_id, image_url
115
+
116
+ # ---------------- UI ----------------
117
+ with gr.Blocks(theme="lone17/kotaemon") as app:
118
+ gr.Markdown("# 📄 Document AI Assistant")
119
+
120
+ # States
121
+ chat_state = gr.State([])
122
+ thread_state = gr.State()
123
+ image_state = gr.State()
124
+ client_id = gr.State()
125
+ mic_shown = gr.State(False)
126
+
127
+ with gr.Row(equal_height=True):
128
+ # Left: Document Viewer
129
+ with gr.Column(scale=1):
130
+ image_display = gr.Image(label="🖼️ Document Preview", type="filepath", show_download_button=False)
131
+
132
+ # Right: Chat + Mic
133
+ with gr.Column(scale=1.4):
134
+ chat = gr.Chatbot(label="💬 Chat", height=450)
135
+
136
+ with gr.Row():
137
+ user_input = gr.Textbox(placeholder="Ask your question...", show_label=False, scale=6)
138
+ mic_btn = gr.Button("🎙️", scale=1)
139
+ send_btn = gr.Button("Send", scale=2)
140
+
141
+ # Hidden Voice Section
142
+ with gr.Row(visible=False) as mic_row:
143
+ with gr.Column(scale=4):
144
+ audio = gr.Audio(label="🎤 Speak", streaming=True)
145
+ with gr.Column(scale=5):
146
+ transcript = gr.Textbox(label="Transcript", lines=2, interactive=False)
147
+ with gr.Column(scale=2):
148
+ clear_btn = gr.Button("🧹 Clear")
149
+
150
+ # Logic Wiring
151
+ def toggle_mic(state): return not state, gr.update(visible=not state)
152
+ mic_btn.click(toggle_mic, inputs=mic_shown, outputs=[mic_shown, mic_row])
153
+ send_btn.click(handle_chat,
154
+ inputs=[user_input, chat_state, thread_state, image_state],
155
+ outputs=[user_input, chat, thread_state, image_state])
156
+ image_state.change(fn=lambda x: x, inputs=image_state, outputs=image_display)
157
+ audio.stream(fn=send_audio, inputs=[audio, client_id], outputs=transcript, stream_every=0.5)
158
+ clear_btn.click(fn=clear_transcript, inputs=[client_id], outputs=transcript)
159
+ app.load(fn=create_ws, outputs=[client_id])
160
+
161
+ app.launch()