IAMTFRMZA commited on
Commit
2e95bed
·
verified ·
1 Parent(s): f5736a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -157
app.py CHANGED
@@ -1,179 +1,161 @@
1
  import gradio as gr
2
- import asyncio
3
- from websockets import connect, Data, ClientConnection
4
- from dotenv import load_dotenv
5
- import json
6
- import os
7
- import threading
8
  import numpy as np
9
- import base64
10
  import soundfile as sf
11
- import io
12
  from pydub import AudioSegment
13
- import time
14
- import uuid
15
-
16
- # =========================
17
- # Setup & Configuration
18
- # =========================
19
-
20
- class LogColors:
21
- OK = '\033[94m'
22
- SUCCESS = '\033[92m'
23
- WARNING = '\033[93m'
24
- ERROR = '\033[91m'
25
- ENDC = '\033[0m'
26
-
27
- load_dotenv()
28
- OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
29
- if not OPENAI_API_KEY:
30
- raise ValueError("OPENAI_API_KEY environment variable must be set")
31
-
32
- WEBSOCKET_URI = "wss://api.openai.com/v1/realtime?intent=transcription"
33
- WEBSOCKET_HEADERS = {
34
- "Authorization": "Bearer " + OPENAI_API_KEY,
35
- "OpenAI-Beta": "realtime=v1"
36
- }
37
-
38
- css = ""
39
- connections = {}
40
 
41
- # =========================
42
- # WebSocket Client Class
43
- # =========================
 
 
44
 
 
 
 
 
 
45
  class WebSocketClient:
46
- def __init__(self, uri: str, headers: dict, client_id: str):
47
- self.uri = uri
48
- self.headers = headers
49
- self.websocket: ClientConnection = None
50
  self.queue = asyncio.Queue(maxsize=10)
51
- self.loop = None
52
- self.client_id = client_id
53
  self.transcript = ""
54
 
55
  async def connect(self):
56
- try:
57
- self.websocket = await connect(self.uri, additional_headers=self.headers)
58
- print(f"{LogColors.SUCCESS}Connected to OpenAI WebSocket{LogColors.ENDC}\n")
59
-
60
- with open("openai_transcription_settings.json", "r") as f:
61
- settings = f.read()
62
- await self.websocket.send(settings)
63
-
64
- await asyncio.gather(self.receive_messages(), self.send_audio_chunks())
65
- except Exception as e:
66
- print(f"{LogColors.ERROR}WebSocket Connection Error: {e}{LogColors.ENDC}")
67
 
68
  def run(self):
69
- self.loop = asyncio.new_event_loop()
70
- asyncio.set_event_loop(self.loop)
71
- self.loop.run_until_complete(self.connect())
72
-
73
- def process_websocket_message(self, message: Data):
74
- message_object = json.loads(message)
75
- if message_object["type"] != "error":
76
- print(f"{LogColors.OK}Received message: {LogColors.ENDC} {message}")
77
- if message_object["type"] == "conversation.item.input_audio_transcription.delta":
78
- delta = message_object["delta"]
79
- self.transcript += delta
80
- elif message_object["type"] == "conversation.item.input_audio_transcription.completed":
81
- self.transcript += ' ' if self.transcript and self.transcript[-1] != ' ' else ''
82
- else:
83
- print(f"{LogColors.ERROR}Error: {message}{LogColors.ENDC}")
84
 
85
  async def send_audio_chunks(self):
86
  while True:
87
- audio_data = await self.queue.get()
88
- sample_rate, audio_array = audio_data
89
- if self.websocket:
90
- if audio_array.ndim > 1:
91
- audio_array = audio_array.mean(axis=1)
92
- audio_array = audio_array.astype(np.float32)
93
- audio_array /= np.max(np.abs(audio_array)) if np.max(np.abs(audio_array)) > 0 else 1.0
94
- audio_array_int16 = (audio_array * 32767).astype(np.int16)
95
-
96
- audio_buffer = io.BytesIO()
97
- sf.write(audio_buffer, audio_array_int16, sample_rate, format='WAV', subtype='PCM_16')
98
- audio_buffer.seek(0)
99
- audio_segment = AudioSegment.from_file(audio_buffer, format="wav")
100
- resampled_audio = audio_segment.set_frame_rate(24000)
101
-
102
- output_buffer = io.BytesIO()
103
- resampled_audio.export(output_buffer, format="wav")
104
- output_buffer.seek(0)
105
- base64_audio = base64.b64encode(output_buffer.read()).decode("utf-8")
106
-
107
- await self.websocket.send(json.dumps({"type": "input_audio_buffer.append", "audio": base64_audio}))
108
- print(f"{LogColors.OK}Sent audio chunk{LogColors.ENDC}")
109
 
110
  async def receive_messages(self):
111
- async for message in self.websocket:
112
- self.process_websocket_message(message)
 
 
113
 
114
- def enqueue_audio_chunk(self, sample_rate: int, chunk_array: np.ndarray):
115
  if not self.queue.full():
116
- asyncio.run_coroutine_threadsafe(self.queue.put((sample_rate, chunk_array)), self.loop)
117
- else:
118
- print(f"{LogColors.WARNING}Queue is full, dropping audio chunk{LogColors.ENDC}")
119
-
120
- async def close(self):
121
- if self.websocket:
122
- await self.websocket.close()
123
- connections.pop(self.client_id)
124
- print(f"{LogColors.WARNING}WebSocket connection closed{LogColors.ENDC}")
125
-
126
-
127
- # =========================
128
- # Helper Functions
129
- # =========================
130
-
131
- def send_audio_chunk(new_chunk: gr.Audio, client_id: str):
132
- if client_id not in connections:
133
- return "Connection is being established, please try again in a few seconds."
134
- sr, y = new_chunk
135
- connections[client_id].enqueue_audio_chunk(sr, y)
136
- return connections[client_id].transcript
137
-
138
- def create_new_websocket_connection():
139
- client_id = str(uuid.uuid4())
140
- connections[client_id] = WebSocketClient(WEBSOCKET_URI, WEBSOCKET_HEADERS, client_id)
141
- threading.Thread(target=connections[client_id].run, daemon=True).start()
142
- return client_id
143
-
144
- def clear_transcript(client_id):
145
- if client_id in connections:
146
- connections[client_id].transcript = ""
147
  return ""
148
 
149
- # =========================
150
- # Gradio UI Sections
151
- # =========================
152
-
153
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
154
-
155
- with gr.Tab("💬 Chat Assistant"):
156
- gr.Markdown("### Chat Section (Coming Soon)")
157
- gr.Textbox(label="Your question")
158
- gr.Button("Send")
159
-
160
- with gr.Tab("📄 Document Viewer"):
161
- gr.Markdown("### Upload and View Documents")
162
- gr.File(label="Upload Document", file_types=[".pdf", ".txt", ".docx"])
163
- gr.Textbox(label="Document Preview", lines=10)
164
-
165
- with gr.Tab("🎤 Voice Transcription"):
166
- gr.Markdown("### Realtime Voice-to-Text Transcription")
167
- with gr.Row():
168
- output_textbox = gr.Textbox(label="Transcript", lines=7, interactive=False, autoscroll=True)
169
- with gr.Row():
170
- with gr.Column(scale=5):
171
- audio_input = gr.Audio(streaming=True, format="wav")
172
- with gr.Column():
173
- clear_button = gr.Button("Clear Transcript")
174
- client_id = gr.State()
175
- clear_button.click(clear_transcript, inputs=[client_id], outputs=[output_textbox])
176
- audio_input.stream(send_audio_chunk, [audio_input, client_id], [output_textbox], stream_every=0.5)
177
- demo.load(create_new_websocket_connection, outputs=[client_id])
178
-
179
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os, time, re, json, base64, asyncio, threading, uuid, io
 
 
 
 
 
3
  import numpy as np
 
4
  import soundfile as sf
 
5
  from pydub import AudioSegment
6
+ from openai import OpenAI
7
+ from websockets import connect
8
+ from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Load environment variables
11
+ load_dotenv()
12
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
13
+ ASSISTANT_ID = os.getenv("ASSISTANT_ID")
14
+ client = OpenAI(api_key=OPENAI_API_KEY)
15
 
16
+ HEADERS = {"Authorization": f"Bearer {OPENAI_API_KEY}", "OpenAI-Beta": "realtime=v1"}
17
+ WS_URI = "wss://api.openai.com/v1/realtime?intent=transcription"
18
+ connections = {}
19
+
20
+ # ---------------- WebSocket Client for Voice ----------------
21
  class WebSocketClient:
22
+ def __init__(self, uri, headers, client_id):
23
+ self.uri, self.headers, self.client_id = uri, headers, client_id
24
+ self.websocket = None
 
25
  self.queue = asyncio.Queue(maxsize=10)
 
 
26
  self.transcript = ""
27
 
28
  async def connect(self):
29
+ self.websocket = await connect(self.uri, additional_headers=self.headers)
30
+ with open("openai_transcription_settings.json", "r") as f:
31
+ await self.websocket.send(f.read())
32
+ await asyncio.gather(self.receive_messages(), self.send_audio_chunks())
 
 
 
 
 
 
 
33
 
34
  def run(self):
35
+ loop = asyncio.new_event_loop()
36
+ asyncio.set_event_loop(loop)
37
+ loop.run_until_complete(self.connect())
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  async def send_audio_chunks(self):
40
  while True:
41
+ sr, arr = await self.queue.get()
42
+ if arr.ndim > 1: arr = arr.mean(axis=1)
43
+ arr = (arr / np.max(np.abs(arr))) if np.max(np.abs(arr)) > 0 else arr
44
+ int16 = (arr * 32767).astype(np.int16)
45
+ buf = io.BytesIO(); sf.write(buf, int16, sr, format='WAV', subtype='PCM_16')
46
+ audio = AudioSegment.from_file(buf, format="wav").set_frame_rate(24000)
47
+ out = io.BytesIO(); audio.export(out, format="wav"); out.seek(0)
48
+ await self.websocket.send(json.dumps({
49
+ "type": "input_audio_buffer.append",
50
+ "audio": base64.b64encode(out.read()).decode()
51
+ }))
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  async def receive_messages(self):
54
+ async for msg in self.websocket:
55
+ data = json.loads(msg)
56
+ if data["type"] == "conversation.item.input_audio_transcription.delta":
57
+ self.transcript += data["delta"]
58
 
59
+ def enqueue_audio_chunk(self, sr, arr):
60
  if not self.queue.full():
61
+ asyncio.run_coroutine_threadsafe(self.queue.put((sr, arr)), asyncio.get_event_loop())
62
+
63
+ def create_ws():
64
+ cid = str(uuid.uuid4())
65
+ client = WebSocketClient(WS_URI, HEADERS, cid)
66
+ threading.Thread(target=client.run, daemon=True).start()
67
+ connections[cid] = client
68
+ return cid
69
+
70
+ def send_audio(chunk, cid):
71
+ if cid not in connections: return "Connecting..."
72
+ sr, arr = chunk
73
+ connections[cid].enqueue_audio_chunk(sr, arr)
74
+ return connections[cid].transcript
75
+
76
+ def clear_transcript(cid):
77
+ if cid in connections: connections[cid].transcript = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return ""
79
 
80
+ # ---------------- Chat Assistant Logic ----------------
81
+ def handle_chat(user_input, history, thread_id, image_url):
82
+ if not OPENAI_API_KEY or not ASSISTANT_ID:
83
+ return "❌ Missing API key or Assistant ID.", history, thread_id, image_url
84
+
85
+ try:
86
+ if thread_id is None:
87
+ thread = client.beta.threads.create()
88
+ thread_id = thread.id
89
+
90
+ client.beta.threads.messages.create(thread_id=thread_id, role="user", content=user_input)
91
+ run = client.beta.threads.runs.create(thread_id=thread_id, assistant_id=ASSISTANT_ID)
92
+
93
+ while True:
94
+ status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id)
95
+ if status.status == "completed": break
96
+ time.sleep(1)
97
+
98
+ msgs = client.beta.threads.messages.list(thread_id=thread_id)
99
+ for msg in reversed(msgs.data):
100
+ if msg.role == "assistant":
101
+ content = msg.content[0].text.value
102
+ history.append((user_input, content))
103
+ match = re.search(
104
+ r'https://raw\.githubusercontent\.com/AndrewLORTech/surgical-pathology-manual/main/[\w\-/]*\.png',
105
+ content
106
+ )
107
+ if match:
108
+ image_url = match.group(0)
109
+ break
110
+
111
+ return "", history, thread_id, image_url
112
+
113
+ except Exception as e:
114
+ return f"❌ {e}", history, thread_id, image_url
115
+
116
+ # ---------------- UI ----------------
117
+ with gr.Blocks(theme="lone17/kotaemon") as app:
118
+ gr.Markdown("# 📄 Document AI Assistant")
119
+
120
+ # States
121
+ chat_state = gr.State([])
122
+ thread_state = gr.State()
123
+ image_state = gr.State()
124
+ client_id = gr.State()
125
+ mic_shown = gr.State(False)
126
+
127
+ with gr.Row(equal_height=True):
128
+ # Left: Document Viewer
129
+ with gr.Column(scale=1):
130
+ image_display = gr.Image(label="🖼️ Document Preview", type="filepath", show_download_button=False)
131
+
132
+ # Right: Chat + Mic
133
+ with gr.Column(scale=1.4):
134
+ chat = gr.Chatbot(label="💬 Chat", height=450)
135
+
136
+ with gr.Row():
137
+ user_input = gr.Textbox(placeholder="Ask your question...", show_label=False, scale=6)
138
+ mic_btn = gr.Button("🎙️", scale=1)
139
+ send_btn = gr.Button("Send", scale=2)
140
+
141
+ # Hidden Voice Section
142
+ with gr.Row(visible=False) as mic_row:
143
+ with gr.Column(scale=4):
144
+ audio = gr.Audio(label="🎤 Speak", streaming=True)
145
+ with gr.Column(scale=5):
146
+ transcript = gr.Textbox(label="Transcript", lines=2, interactive=False)
147
+ with gr.Column(scale=2):
148
+ clear_btn = gr.Button("🧹 Clear")
149
+
150
+ # Logic Wiring
151
+ def toggle_mic(state): return not state, gr.update(visible=not state)
152
+ mic_btn.click(toggle_mic, inputs=mic_shown, outputs=[mic_shown, mic_row])
153
+ send_btn.click(handle_chat,
154
+ inputs=[user_input, chat_state, thread_state, image_state],
155
+ outputs=[user_input, chat, thread_state, image_state])
156
+ image_state.change(fn=lambda x: x, inputs=image_state, outputs=image_display)
157
+ audio.stream(fn=send_audio, inputs=[audio, client_id], outputs=transcript, stream_every=0.5)
158
+ clear_btn.click(fn=clear_transcript, inputs=[client_id], outputs=transcript)
159
+ app.load(fn=create_ws, outputs=[client_id])
160
+
161
+ app.launch()