Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import json | |
import uuid | |
import threading | |
import time | |
import re | |
from openai import OpenAI | |
from realtime_transcriber import WebSocketClient, connections, WEBSOCKET_URI, WEBSOCKET_HEADERS | |
# ------------------ Load API Key ------------------ | |
from dotenv import load_dotenv | |
load_dotenv() | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
ASSISTANT_ID = os.getenv("ASSISTANT_ID") | |
if not OPENAI_API_KEY or not ASSISTANT_ID: | |
raise ValueError("Missing OPENAI_API_KEY or ASSISTANT_ID in environment variables") | |
client = OpenAI(api_key=OPENAI_API_KEY) | |
# ------------------ Chat Logic ------------------ | |
session_threads = {} | |
session_messages = {} | |
def reset_session(): | |
session_id = str(uuid.uuid4()) | |
thread = client.beta.threads.create() | |
session_threads[session_id] = thread.id | |
session_messages[session_id] = [] | |
return session_id | |
def process_chat(message, history, session_id): | |
thread_id = session_threads.get(session_id) | |
if not thread_id: | |
thread_id = client.beta.threads.create().id | |
session_threads[session_id] = thread_id | |
# Store user message | |
client.beta.threads.messages.create( | |
thread_id=thread_id, | |
role="user", | |
content=message | |
) | |
# Run assistant | |
run = client.beta.threads.runs.create( | |
thread_id=thread_id, | |
assistant_id=ASSISTANT_ID | |
) | |
while True: | |
run_status = client.beta.threads.runs.retrieve( | |
thread_id=thread_id, | |
run_id=run.id | |
) | |
if run_status.status == "completed": | |
break | |
time.sleep(1) | |
# Retrieve assistant message | |
messages = client.beta.threads.messages.list(thread_id=thread_id) | |
for msg in reversed(messages.data): | |
if msg.role == "assistant": | |
assistant_response = msg.content[0].text.value | |
break | |
else: | |
assistant_response = "β οΈ Assistant did not respond." | |
# Detect image if present | |
image_url = None | |
match = re.search(r'https://raw\.githubusercontent\.com/AndrewLORTech/surgical-pathology-manual/main/[\w\-/]*\.png', assistant_response) | |
if match: | |
image_url = match.group(0) | |
return assistant_response, image_url | |
# ------------------ Transcription Logic ------------------ | |
def create_websocket_client(): | |
client_id = str(uuid.uuid4()) | |
connections[client_id] = WebSocketClient(WEBSOCKET_URI, WEBSOCKET_HEADERS, client_id) | |
threading.Thread(target=connections[client_id].run, daemon=True).start() | |
return client_id | |
def clear_transcript(client_id): | |
if client_id in connections: | |
connections[client_id].transcript = "" | |
return "" | |
def send_audio_chunk(audio, client_id): | |
if client_id not in connections: | |
return "Initializing connection..." | |
sr, y = audio | |
connections[client_id].enqueue_audio_chunk(sr, y) | |
return connections[client_id].transcript | |
# ------------------ Gradio Interface ------------------ | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π§ Document AI + ποΈ Voice Assistant") | |
session_id = gr.State(value=reset_session()) | |
client_id = gr.State() | |
# ---------- Section 1: Document Image Display ---------- | |
with gr.Row(): | |
image_display = gr.Image(label="π Document Page (auto-extracted if available)", interactive=False, visible=False) | |
# ---------- Section 2: Chat Interface ---------- | |
with gr.Row(): | |
chatbot = gr.ChatInterface( | |
fn=lambda message, history, session_id: ( | |
process_chat(message, history, session_id)[0], | |
process_chat(message, history, session_id)[1], | |
), | |
additional_inputs=[session_id], | |
render_markdown=True, | |
examples=["What does clause 3.2 mean?", "Summarize the timeline from the image."], | |
title="π¬ Document Assistant", | |
retry_btn="π Retry", | |
undo_btn="β©οΈ Undo", | |
clear_btn="ποΈ Clear", | |
) | |
# Link image preview if extracted | |
def update_image_display(message, history, session_id): | |
_, image_url = process_chat(message, history, session_id) | |
return gr.update(value=image_url, visible=bool(image_url)) | |
chatbot.chatbot.change(fn=update_image_display, inputs=[chatbot.input, chatbot.chatbot, session_id], outputs=[image_display]) | |
# ---------- Section 3: Voice Transcription ---------- | |
gr.Markdown("## ποΈ Realtime Voice Transcription") | |
with gr.Row(): | |
transcript_box = gr.Textbox(label="Live Transcript", lines=7, interactive=False, autoscroll=True) | |
with gr.Row(): | |
mic_input = gr.Audio(source="microphone", streaming=True) | |
clear_button = gr.Button("Clear Transcript") | |
mic_input.stream(fn=send_audio_chunk, inputs=[mic_input, client_id], outputs=transcript_box) | |
clear_button.click(fn=clear_transcript, inputs=[client_id], outputs=transcript_box) | |
demo.load(fn=create_websocket_client, outputs=client_id) | |
demo.launch() | |