IAMTFRMZA commited on
Commit
0bb8b62
Β·
verified Β·
1 Parent(s): d6d49d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -84
app.py CHANGED
@@ -1,103 +1,105 @@
1
  import gradio as gr
2
  import os
 
3
  import uuid
 
 
 
 
4
  from openai import OpenAI
5
  from realtime_transcriber import WebSocketClient, connections, WEBSOCKET_URI, WEBSOCKET_HEADERS
6
 
7
- # Load OpenAI API key
8
- OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
9
- if not OPENAI_API_KEY:
10
- raise ValueError("OPENAI_API_KEY environment variable must be set")
 
 
 
 
11
  client = OpenAI(api_key=OPENAI_API_KEY)
 
12
 
13
- # Session state
14
- session_id = str(uuid.uuid4())
15
- if session_id not in connections:
16
- connections[session_id] = WebSocketClient(WEBSOCKET_URI, WEBSOCKET_HEADERS, session_id)
17
- import threading
18
- threading.Thread(target=connections[session_id].run, daemon=True).start()
19
-
20
- # Functions for Document Assistant
21
- def process_user_input(message, history):
22
- if not message:
23
- return "Please enter a message.", history
24
-
25
- try:
26
- thread = client.beta.threads.create()
27
- client.beta.threads.messages.create(
28
- thread_id=thread.id,
29
- role="user",
30
- content=message
31
- )
32
- run = client.beta.threads.runs.create(
33
- thread_id=thread.id,
34
- assistant_id=os.environ.get("ASSISTANT_ID")
35
- )
36
- while True:
37
- run_status = client.beta.threads.runs.retrieve(
38
- thread_id=thread.id,
39
- run_id=run.id
40
- )
41
- if run_status.status == "completed":
42
- break
43
- messages = client.beta.threads.messages.list(thread_id=thread.id)
44
- assistant_reply = next((m.content[0].text.value for m in reversed(messages.data) if m.role == "assistant"), "No response.")
45
- history.append((message, assistant_reply))
46
- return "", history
47
- except Exception as e:
48
- return f"❌ Error: {str(e)}", history
49
-
50
- # Functions for Realtime Voice Transcription
51
- def send_audio_chunk_realtime(mic_chunk):
52
- if session_id not in connections:
53
- return "Initializing voice session..."
54
- if mic_chunk is not None:
55
- sr, y = mic_chunk
56
- connections[session_id].enqueue_audio_chunk(sr, y)
57
- return connections[session_id].transcript
58
-
59
- def clear_transcript():
60
- if session_id in connections:
61
- connections[session_id].transcript = ""
62
- return ""
63
 
64
- # Gradio UI Components
65
- doc_image = gr.Image(label="πŸ“˜ Extracted Document Image", show_label=True, elem_id="docimg", height=500, width=360)
66
- chatbot = gr.Chatbot(label="🧠 Document Assistant", elem_id="chatbox", bubble_full_width=False)
67
- prompt = gr.Textbox(placeholder="Ask about the document...", label="Ask about the document")
68
- send_btn = gr.Button("Send")
69
 
70
- # Voice Section
71
- audio_in = gr.Audio(label="🎡 Audio", type="numpy", streaming=True)
72
- live_transcript = gr.Textbox(label="Live Transcript", lines=6)
73
- clear_btn = gr.Button("Clear Transcript")
 
74
 
75
- with gr.Blocks(theme=gr.themes.Base(), css="""
76
- #docimg img { object-fit: contain !important; }
77
- #chatbox { height: 500px; }
78
- .gr-box { border-radius: 12px; }
79
- """) as demo:
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  gr.Markdown("# 🧠 Document AI + πŸŽ™οΈ Voice Assistant")
82
- with gr.Row():
83
- with gr.Column(scale=1):
84
- doc_image.render()
85
- with gr.Column(scale=2):
86
- chatbot.render()
87
 
88
  with gr.Row():
89
- prompt.render()
90
- send_btn.render()
 
 
 
91
 
92
- send_btn.click(fn=process_user_input, inputs=[prompt, chatbot], outputs=[prompt, chatbot])
 
93
 
94
- with gr.Accordion("πŸŽ™οΈ Or Use Voice Instead", open=False):
95
- live_transcript.render()
 
 
96
  with gr.Row():
97
- audio_in.render()
98
- clear_btn.render()
99
- audio_in.stream(fn=send_audio_chunk_realtime, inputs=audio_in, outputs=live_transcript)
100
- clear_btn.click(fn=clear_transcript, outputs=live_transcript)
 
 
101
 
102
- # LAUNCH WITH SHARE ENABLED FOR PUBLIC URL
103
- demo.launch(share=True)
 
1
  import gradio as gr
2
  import os
3
+ import json
4
  import uuid
5
+ import threading
6
+ import time
7
+ import re
8
+ from dotenv import load_dotenv
9
  from openai import OpenAI
10
  from realtime_transcriber import WebSocketClient, connections, WEBSOCKET_URI, WEBSOCKET_HEADERS
11
 
12
+ # ------------------ Load Secrets ------------------
13
+ load_dotenv()
14
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
15
+ ASSISTANT_ID = os.getenv("ASSISTANT_ID")
16
+
17
+ if not OPENAI_API_KEY or not ASSISTANT_ID:
18
+ raise ValueError("Missing OPENAI_API_KEY or ASSISTANT_ID")
19
+
20
  client = OpenAI(api_key=OPENAI_API_KEY)
21
+ session_threads = {}
22
 
23
+ # ------------------ Chat Logic ------------------
24
+ def reset_session():
25
+ session_id = str(uuid.uuid4())
26
+ session_threads[session_id] = client.beta.threads.create().id
27
+ return session_id
28
+
29
+ def process_chat(message, history, session_id):
30
+ thread_id = session_threads.get(session_id)
31
+ if not thread_id:
32
+ thread_id = client.beta.threads.create().id
33
+ session_threads[session_id] = thread_id
34
+
35
+ client.beta.threads.messages.create(thread_id=thread_id, role="user", content=message)
36
+ run = client.beta.threads.runs.create(thread_id=thread_id, assistant_id=ASSISTANT_ID)
37
+
38
+ while client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id).status != "completed":
39
+ time.sleep(1)
40
+
41
+ messages = client.beta.threads.messages.list(thread_id=thread_id)
42
+ for msg in reversed(messages.data):
43
+ if msg.role == "assistant":
44
+ return msg.content[0].text.value
45
+ return "⚠️ Assistant did not respond."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ def extract_image_url(text):
48
+ match = re.search(r'https://raw\.githubusercontent\.com/[^\s"]+\.png', text)
49
+ return match.group(0) if match else None
 
 
50
 
51
+ def handle_chat(message, history, session_id):
52
+ response = process_chat(message, history, session_id)
53
+ history.append((message, response))
54
+ image = extract_image_url(response)
55
+ return history, image
56
 
57
+ # ------------------ Voice Logic ------------------
58
+ def create_websocket_client():
59
+ client_id = str(uuid.uuid4())
60
+ connections[client_id] = WebSocketClient(WEBSOCKET_URI, WEBSOCKET_HEADERS, client_id)
61
+ threading.Thread(target=connections[client_id].run, daemon=True).start()
62
+ return client_id
63
 
64
+ def clear_transcript(client_id):
65
+ if client_id in connections:
66
+ connections[client_id].transcript = ""
67
+ return ""
68
+
69
+ def send_audio_chunk(audio, client_id):
70
+ if client_id not in connections:
71
+ return "Initializing connection..."
72
+ sr, y = audio
73
+ connections[client_id].enqueue_audio_chunk(sr, y)
74
+ return connections[client_id].transcript
75
+
76
+ # ------------------ UI ------------------
77
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
78
  gr.Markdown("# 🧠 Document AI + πŸŽ™οΈ Voice Assistant")
79
+
80
+ session_id = gr.State(value=reset_session())
81
+ client_id = gr.State()
 
 
82
 
83
  with gr.Row():
84
+ image_display = gr.Image(label="πŸ“‘ Extracted Document Image", show_label=True, height=360)
85
+ with gr.Column():
86
+ chatbot = gr.Chatbot(label="πŸ’¬ Document Assistant", height=360)
87
+ text_input = gr.Textbox(label="Ask about the document", placeholder="e.g. What is clause 3.2?")
88
+ send_btn = gr.Button("Send")
89
 
90
+ send_btn.click(handle_chat, inputs=[text_input, chatbot, session_id], outputs=[chatbot, image_display])
91
+ text_input.submit(handle_chat, inputs=[text_input, chatbot, session_id], outputs=[chatbot, image_display])
92
 
93
+ # Toggle Section
94
+ with gr.Accordion("🎀 Or Use Voice Instead", open=False):
95
+ with gr.Row():
96
+ transcript_box = gr.Textbox(label="Live Transcript", lines=7, interactive=False, autoscroll=True)
97
  with gr.Row():
98
+ mic_input = gr.Audio(streaming=True)
99
+ clear_button = gr.Button("Clear Transcript")
100
+
101
+ mic_input.stream(fn=send_audio_chunk, inputs=[mic_input, client_id], outputs=transcript_box)
102
+ clear_button.click(fn=clear_transcript, inputs=[client_id], outputs=transcript_box)
103
+ demo.load(fn=create_websocket_client, outputs=client_id)
104
 
105
+ demo.launch()