IAMTFRMZA commited on
Commit
f5736a5
·
verified ·
1 Parent(s): 9850ad3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -135
app.py CHANGED
@@ -1,157 +1,179 @@
1
  import gradio as gr
2
- import os, time, re, json, base64, asyncio, threading, uuid, io
 
 
 
 
 
3
  import numpy as np
 
4
  import soundfile as sf
 
5
  from pydub import AudioSegment
6
- from openai import OpenAI
7
- from websockets import connect, Data, ClientConnection
8
- from dotenv import load_dotenv
9
-
10
- # ============ Load Secrets ============
11
- load_dotenv()
12
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
13
- ASSISTANT_ID = os.getenv("ASSISTANT_ID")
14
- client = OpenAI(api_key=OPENAI_API_KEY)
15
-
16
- HEADERS = {"Authorization": f"Bearer {OPENAI_API_KEY}", "OpenAI-Beta": "realtime=v1"}
17
- WS_URI = "wss://api.openai.com/v1/realtime?intent=transcription"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  connections = {}
19
 
20
- # ============ WebSocket Client ============
 
 
 
21
  class WebSocketClient:
22
- def __init__(self, uri, headers, client_id):
23
- self.uri, self.headers, self.client_id = uri, headers, client_id
24
- self.websocket = None
 
25
  self.queue = asyncio.Queue(maxsize=10)
 
 
26
  self.transcript = ""
27
 
28
  async def connect(self):
29
- self.websocket = await connect(self.uri, additional_headers=self.headers)
30
- with open("openai_transcription_settings.json", "r") as f:
31
- await self.websocket.send(f.read())
32
- await asyncio.gather(self.receive_messages(), self.send_audio_chunks())
 
 
 
 
 
 
 
33
 
34
  def run(self):
35
- loop = asyncio.new_event_loop()
36
- asyncio.set_event_loop(loop)
37
- loop.run_until_complete(self.connect())
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  async def send_audio_chunks(self):
40
  while True:
41
- sr, arr = await self.queue.get()
42
- if arr.ndim > 1: arr = arr.mean(axis=1)
43
- arr = (arr / np.max(np.abs(arr))) if np.max(np.abs(arr)) > 0 else arr
44
- int16 = (arr * 32767).astype(np.int16)
45
- buf = io.BytesIO(); sf.write(buf, int16, sr, format='WAV', subtype='PCM_16')
46
- audio = AudioSegment.from_file(buf, format="wav").set_frame_rate(24000)
47
- out = io.BytesIO(); audio.export(out, format="wav"); out.seek(0)
48
- await self.websocket.send(json.dumps({
49
- "type": "input_audio_buffer.append",
50
- "audio": base64.b64encode(out.read()).decode()
51
- }))
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  async def receive_messages(self):
54
- async for msg in self.websocket:
55
- data = json.loads(msg)
56
- if data["type"] == "conversation.item.input_audio_transcription.delta":
57
- self.transcript += data["delta"]
58
 
59
- def enqueue_audio_chunk(self, sr, arr):
60
  if not self.queue.full():
61
- asyncio.run_coroutine_threadsafe(self.queue.put((sr, arr)), asyncio.get_event_loop())
62
-
63
- def create_ws():
64
- cid = str(uuid.uuid4())
65
- client = WebSocketClient(WS_URI, HEADERS, cid)
66
- threading.Thread(target=client.run, daemon=True).start()
67
- connections[cid] = client
68
- return cid
69
-
70
- def send_audio(chunk, cid):
71
- if cid not in connections: return "Connecting..."
72
- sr, arr = chunk
73
- connections[cid].enqueue_audio_chunk(sr, arr)
74
- return connections[cid].transcript
75
-
76
- def clear_transcript(cid):
77
- if cid in connections: connections[cid].transcript = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return ""
79
 
80
- # ============ Chat Assistant ============
81
- def handle_chat(user_input, history, thread_id, image_url):
82
- if not OPENAI_API_KEY or not ASSISTANT_ID:
83
- return "❌ Missing secrets!", history, thread_id, image_url
84
-
85
- try:
86
- if thread_id is None:
87
- thread = client.beta.threads.create()
88
- thread_id = thread.id
89
-
90
- client.beta.threads.messages.create(thread_id=thread_id, role="user", content=user_input)
91
- run = client.beta.threads.runs.create(thread_id=thread_id, assistant_id=ASSISTANT_ID)
92
-
93
- while True:
94
- status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id)
95
- if status.status == "completed": break
96
- time.sleep(1)
97
-
98
- msgs = client.beta.threads.messages.list(thread_id=thread_id)
99
- for msg in reversed(msgs.data):
100
- if msg.role == "assistant":
101
- content = msg.content[0].text.value
102
- history.append((user_input, content))
103
- match = re.search(
104
- r'https://raw\.githubusercontent\.com/AndrewLORTech/surgical-pathology-manual/main/[\w\-/]*\.png',
105
- content
106
- )
107
- if match: image_url = match.group(0)
108
- break
109
-
110
- return "", history, thread_id, image_url
111
-
112
- except Exception as e:
113
- return f"❌ {e}", history, thread_id, image_url
114
-
115
- # ============ Gradio UI ============
116
- with gr.Blocks(theme=gr.themes.Soft()) as app:
117
- gr.Markdown("# 📄 Document AI Assistant")
118
-
119
- # STATES
120
- chat_state = gr.State([])
121
- thread_state = gr.State()
122
- image_state = gr.State()
123
- client_id = gr.State()
124
- voice_enabled = gr.State(False)
125
-
126
- with gr.Row(equal_height=True):
127
- with gr.Column(scale=1):
128
- image_display = gr.Image(label="🖼️ Document", type="filepath", show_download_button=False)
129
-
130
- with gr.Column(scale=1.4):
131
- chat = gr.Chatbot(label="💬 Chat", height=460)
132
-
133
- with gr.Row():
134
- user_prompt = gr.Textbox(placeholder="Ask your question...", show_label=False, scale=6)
135
- mic_toggle_btn = gr.Button("🎙️", scale=1)
136
- send_btn = gr.Button("Send", variant="primary", scale=2)
137
-
138
- with gr.Accordion("🎤 Voice Transcription", open=False) as voice_section:
139
- with gr.Row():
140
- voice_input = gr.Audio(label="Mic", streaming=True)
141
- voice_transcript = gr.Textbox(label="Transcript", lines=2, interactive=False)
142
- clear_btn = gr.Button("🧹 Clear Transcript")
143
-
144
- # FUNCTIONAL CONNECTIONS
145
- def toggle_voice(curr):
146
- return not curr, gr.update(visible=not curr)
147
-
148
- mic_toggle_btn.click(fn=toggle_voice, inputs=voice_enabled, outputs=[voice_enabled, voice_section])
149
- send_btn.click(fn=handle_chat,
150
- inputs=[user_prompt, chat_state, thread_state, image_state],
151
- outputs=[user_prompt, chat, thread_state, image_state])
152
- image_state.change(fn=lambda x: x, inputs=image_state, outputs=image_display)
153
- voice_input.stream(fn=send_audio, inputs=[voice_input, client_id], outputs=voice_transcript, stream_every=0.5)
154
- clear_btn.click(fn=clear_transcript, inputs=[client_id], outputs=voice_transcript)
155
- app.load(fn=create_ws, outputs=[client_id])
156
-
157
- app.launch()
 
1
  import gradio as gr
2
+ import asyncio
3
+ from websockets import connect, Data, ClientConnection
4
+ from dotenv import load_dotenv
5
+ import json
6
+ import os
7
+ import threading
8
  import numpy as np
9
+ import base64
10
  import soundfile as sf
11
+ import io
12
  from pydub import AudioSegment
13
+ import time
14
+ import uuid
15
+
16
+ # =========================
17
+ # Setup & Configuration
18
+ # =========================
19
+
20
+ class LogColors:
21
+ OK = '\033[94m'
22
+ SUCCESS = '\033[92m'
23
+ WARNING = '\033[93m'
24
+ ERROR = '\033[91m'
25
+ ENDC = '\033[0m'
26
+
27
+ load_dotenv()
28
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
29
+ if not OPENAI_API_KEY:
30
+ raise ValueError("OPENAI_API_KEY environment variable must be set")
31
+
32
+ WEBSOCKET_URI = "wss://api.openai.com/v1/realtime?intent=transcription"
33
+ WEBSOCKET_HEADERS = {
34
+ "Authorization": "Bearer " + OPENAI_API_KEY,
35
+ "OpenAI-Beta": "realtime=v1"
36
+ }
37
+
38
+ css = ""
39
  connections = {}
40
 
41
+ # =========================
42
+ # WebSocket Client Class
43
+ # =========================
44
+
45
  class WebSocketClient:
46
+ def __init__(self, uri: str, headers: dict, client_id: str):
47
+ self.uri = uri
48
+ self.headers = headers
49
+ self.websocket: ClientConnection = None
50
  self.queue = asyncio.Queue(maxsize=10)
51
+ self.loop = None
52
+ self.client_id = client_id
53
  self.transcript = ""
54
 
55
  async def connect(self):
56
+ try:
57
+ self.websocket = await connect(self.uri, additional_headers=self.headers)
58
+ print(f"{LogColors.SUCCESS}Connected to OpenAI WebSocket{LogColors.ENDC}\n")
59
+
60
+ with open("openai_transcription_settings.json", "r") as f:
61
+ settings = f.read()
62
+ await self.websocket.send(settings)
63
+
64
+ await asyncio.gather(self.receive_messages(), self.send_audio_chunks())
65
+ except Exception as e:
66
+ print(f"{LogColors.ERROR}WebSocket Connection Error: {e}{LogColors.ENDC}")
67
 
68
  def run(self):
69
+ self.loop = asyncio.new_event_loop()
70
+ asyncio.set_event_loop(self.loop)
71
+ self.loop.run_until_complete(self.connect())
72
+
73
+ def process_websocket_message(self, message: Data):
74
+ message_object = json.loads(message)
75
+ if message_object["type"] != "error":
76
+ print(f"{LogColors.OK}Received message: {LogColors.ENDC} {message}")
77
+ if message_object["type"] == "conversation.item.input_audio_transcription.delta":
78
+ delta = message_object["delta"]
79
+ self.transcript += delta
80
+ elif message_object["type"] == "conversation.item.input_audio_transcription.completed":
81
+ self.transcript += ' ' if self.transcript and self.transcript[-1] != ' ' else ''
82
+ else:
83
+ print(f"{LogColors.ERROR}Error: {message}{LogColors.ENDC}")
84
 
85
  async def send_audio_chunks(self):
86
  while True:
87
+ audio_data = await self.queue.get()
88
+ sample_rate, audio_array = audio_data
89
+ if self.websocket:
90
+ if audio_array.ndim > 1:
91
+ audio_array = audio_array.mean(axis=1)
92
+ audio_array = audio_array.astype(np.float32)
93
+ audio_array /= np.max(np.abs(audio_array)) if np.max(np.abs(audio_array)) > 0 else 1.0
94
+ audio_array_int16 = (audio_array * 32767).astype(np.int16)
95
+
96
+ audio_buffer = io.BytesIO()
97
+ sf.write(audio_buffer, audio_array_int16, sample_rate, format='WAV', subtype='PCM_16')
98
+ audio_buffer.seek(0)
99
+ audio_segment = AudioSegment.from_file(audio_buffer, format="wav")
100
+ resampled_audio = audio_segment.set_frame_rate(24000)
101
+
102
+ output_buffer = io.BytesIO()
103
+ resampled_audio.export(output_buffer, format="wav")
104
+ output_buffer.seek(0)
105
+ base64_audio = base64.b64encode(output_buffer.read()).decode("utf-8")
106
+
107
+ await self.websocket.send(json.dumps({"type": "input_audio_buffer.append", "audio": base64_audio}))
108
+ print(f"{LogColors.OK}Sent audio chunk{LogColors.ENDC}")
109
 
110
  async def receive_messages(self):
111
+ async for message in self.websocket:
112
+ self.process_websocket_message(message)
 
 
113
 
114
+ def enqueue_audio_chunk(self, sample_rate: int, chunk_array: np.ndarray):
115
  if not self.queue.full():
116
+ asyncio.run_coroutine_threadsafe(self.queue.put((sample_rate, chunk_array)), self.loop)
117
+ else:
118
+ print(f"{LogColors.WARNING}Queue is full, dropping audio chunk{LogColors.ENDC}")
119
+
120
+ async def close(self):
121
+ if self.websocket:
122
+ await self.websocket.close()
123
+ connections.pop(self.client_id)
124
+ print(f"{LogColors.WARNING}WebSocket connection closed{LogColors.ENDC}")
125
+
126
+
127
+ # =========================
128
+ # Helper Functions
129
+ # =========================
130
+
131
+ def send_audio_chunk(new_chunk: gr.Audio, client_id: str):
132
+ if client_id not in connections:
133
+ return "Connection is being established, please try again in a few seconds."
134
+ sr, y = new_chunk
135
+ connections[client_id].enqueue_audio_chunk(sr, y)
136
+ return connections[client_id].transcript
137
+
138
+ def create_new_websocket_connection():
139
+ client_id = str(uuid.uuid4())
140
+ connections[client_id] = WebSocketClient(WEBSOCKET_URI, WEBSOCKET_HEADERS, client_id)
141
+ threading.Thread(target=connections[client_id].run, daemon=True).start()
142
+ return client_id
143
+
144
+ def clear_transcript(client_id):
145
+ if client_id in connections:
146
+ connections[client_id].transcript = ""
147
  return ""
148
 
149
+ # =========================
150
+ # Gradio UI Sections
151
+ # =========================
152
+
153
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
154
+
155
+ with gr.Tab("💬 Chat Assistant"):
156
+ gr.Markdown("### Chat Section (Coming Soon)")
157
+ gr.Textbox(label="Your question")
158
+ gr.Button("Send")
159
+
160
+ with gr.Tab("📄 Document Viewer"):
161
+ gr.Markdown("### Upload and View Documents")
162
+ gr.File(label="Upload Document", file_types=[".pdf", ".txt", ".docx"])
163
+ gr.Textbox(label="Document Preview", lines=10)
164
+
165
+ with gr.Tab("🎤 Voice Transcription"):
166
+ gr.Markdown("### Realtime Voice-to-Text Transcription")
167
+ with gr.Row():
168
+ output_textbox = gr.Textbox(label="Transcript", lines=7, interactive=False, autoscroll=True)
169
+ with gr.Row():
170
+ with gr.Column(scale=5):
171
+ audio_input = gr.Audio(streaming=True, format="wav")
172
+ with gr.Column():
173
+ clear_button = gr.Button("Clear Transcript")
174
+ client_id = gr.State()
175
+ clear_button.click(clear_transcript, inputs=[client_id], outputs=[output_textbox])
176
+ audio_input.stream(send_audio_chunk, [audio_input, client_id], [output_textbox], stream_every=0.5)
177
+ demo.load(create_new_websocket_connection, outputs=[client_id])
178
+
179
+ demo.launch()