Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import json | |
import uuid | |
import threading | |
import time | |
import re | |
from openai import OpenAI | |
from dotenv import load_dotenv | |
from realtime_transcriber import WebSocketClient, connections, WEBSOCKET_URI, WEBSOCKET_HEADERS | |
# ------------------ Load API Key ------------------ | |
load_dotenv() | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
ASSISTANT_ID = os.getenv("ASSISTANT_ID") | |
if not OPENAI_API_KEY or not ASSISTANT_ID: | |
raise ValueError("Missing OPENAI_API_KEY or ASSISTANT_ID in environment variables") | |
client = OpenAI(api_key=OPENAI_API_KEY) | |
# ------------------ Chat Logic ------------------ | |
session_threads = {} | |
def reset_session(): | |
session_id = str(uuid.uuid4()) | |
thread = client.beta.threads.create() | |
session_threads[session_id] = thread.id | |
return session_id | |
def process_chat(message, history, session_id): | |
thread_id = session_threads.get(session_id) | |
if not thread_id: | |
thread_id = client.beta.threads.create().id | |
session_threads[session_id] = thread_id | |
client.beta.threads.messages.create( | |
thread_id=thread_id, | |
role="user", | |
content=message | |
) | |
run = client.beta.threads.runs.create( | |
thread_id=thread_id, | |
assistant_id=ASSISTANT_ID | |
) | |
while True: | |
run_status = client.beta.threads.runs.retrieve( | |
thread_id=thread_id, | |
run_id=run.id | |
) | |
if run_status.status == "completed": | |
break | |
time.sleep(1) | |
messages = client.beta.threads.messages.list(thread_id=thread_id) | |
assistant_response = "β οΈ Assistant did not respond." | |
for msg in reversed(messages.data): | |
if msg.role == "assistant": | |
assistant_response = msg.content[0].text.value | |
break | |
return assistant_response | |
def extract_image_url(text): | |
match = re.search( | |
r'https://raw\.githubusercontent\.com/AndrewLORTech/surgical-pathology-manual/main/[\w\-/]*\.png', | |
text | |
) | |
return match.group(0) if match else None | |
# ------------------ Transcription Logic ------------------ | |
def create_websocket_client(): | |
client_id = str(uuid.uuid4()) | |
connections[client_id] = WebSocketClient(WEBSOCKET_URI, WEBSOCKET_HEADERS, client_id) | |
threading.Thread(target=connections[client_id].run, daemon=True).start() | |
return client_id | |
def clear_transcript(client_id): | |
if client_id in connections: | |
connections[client_id].transcript = "" | |
return "" | |
def send_audio_chunk(audio, client_id): | |
if client_id not in connections: | |
return "Initializing connection..." | |
sr, y = audio | |
connections[client_id].enqueue_audio_chunk(sr, y) | |
return connections[client_id].transcript | |
# ------------------ Gradio App ------------------ | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π§ Document AI + ποΈ Voice Assistant") | |
session_id = gr.State(value=reset_session()) | |
client_id = gr.State() | |
image_url = gr.State(value=None) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_display = gr.Image(label="π Extracted Document Image", show_label=True, height=400) | |
with gr.Column(scale=2): | |
chatbot = gr.ChatInterface( | |
fn=lambda message, history, session_id: process_chat(message, history, session_id), | |
additional_inputs=[session_id], | |
examples=[ | |
["What does clause 3.2 mean?"], | |
["Summarize the timeline from the image."] | |
], | |
title="π¬ Document Assistant" | |
) | |
# Inject logic to extract image when assistant replies | |
def handle_reply_and_update_image(message, history, session_id): | |
response = process_chat(message, history, session_id) | |
url = extract_image_url(response) | |
return response, url | |
chatbot.fn = lambda message, history, session_id: handle_reply_and_update_image(message, history, session_id)[0] | |
chatbot.chatbot.change( | |
fn=lambda m, h, s: handle_reply_and_update_image(m, h, s)[1], | |
inputs=[chatbot.input, chatbot.chatbot, session_id], | |
outputs=image_display | |
) | |
# ------------------ Voice Transcription ------------------ | |
gr.Markdown("## ποΈ Realtime Voice Transcription") | |
with gr.Row(): | |
transcript_box = gr.Textbox(label="Live Transcript", lines=7, interactive=False, autoscroll=True) | |
with gr.Row(): | |
mic_input = gr.Audio(streaming=True) | |
clear_button = gr.Button("Clear Transcript") | |
mic_input.stream(fn=send_audio_chunk, inputs=[mic_input, client_id], outputs=transcript_box) | |
clear_button.click(fn=clear_transcript, inputs=[client_id], outputs=transcript_box) | |
demo.load(fn=create_websocket_client, outputs=client_id) | |
demo.launch() | |