File size: 6,374 Bytes
e760d91
 
 
 
e1336eb
 
9d8abce
8077eca
e760d91
 
 
 
 
 
 
 
 
 
 
 
9d8abce
e760d91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1336eb
 
 
 
 
 
 
9d8abce
e1336eb
e760d91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8077eca
e760d91
 
 
 
 
 
 
 
8077eca
 
e760d91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8077eca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import streamlit as st
import os
import time
import re
import requests
import tempfile
from openai import OpenAI
from streamlit_webrtc import webrtc_streamer, WebRtcMode
import av
import numpy as np
import wave

# ------------------ Configuration ------------------
st.set_page_config(page_title="Document AI Assistant", layout="wide")
st.title("πŸ“„ Document AI Assistant")
st.caption("Chat with an AI Assistant on your medical/pathology documents")

# ------------------ Secrets ------------------
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
ASSISTANT_ID = os.environ.get("ASSISTANT_ID")

if not OPENAI_API_KEY or not ASSISTANT_ID:
    st.error("❌ Missing secrets. Please set both OPENAI_API_KEY and ASSISTANT_ID in your Hugging Face Space settings.")
    st.stop()

client = OpenAI(api_key=OPENAI_API_KEY)

# ------------------ Session State ------------------
if "messages" not in st.session_state:
    st.session_state.messages = []
if "thread_id" not in st.session_state:
    st.session_state.thread_id = None
if "image_url" not in st.session_state:
    st.session_state.image_url = None
if "audio_buffer" not in st.session_state:
    st.session_state.audio_buffer = []

# ------------------ Whisper Transcription ------------------
def transcribe_audio(file_path, api_key):
    with open(file_path, "rb") as f:
        response = requests.post(
            "https://api.openai.com/v1/audio/transcriptions",
            headers={"Authorization": f"Bearer {api_key}"},
            files={"file": f},
            data={"model": "whisper-1"}
        )
    return response.json().get("text", None)

# ------------------ Audio Recorder ------------------
class AudioProcessor:
    def __init__(self):
        self.frames = []

    def recv(self, frame):
        audio = frame.to_ndarray()
        self.frames.append(audio)
        return av.AudioFrame.from_ndarray(audio, layout="mono")

def save_wav(frames, path, rate=48000):
    audio_data = np.concatenate(frames)
    with wave.open(path, 'wb') as wf:
        wf.setnchannels(1)
        wf.setsampwidth(2)
        wf.setframerate(rate)
        wf.writeframes(audio_data.tobytes())

# ------------------ Sidebar & Image Panel ------------------
st.sidebar.header("πŸ”§ Settings")
if st.sidebar.button("πŸ”„ Clear Chat"):
    st.session_state.messages = []
    st.session_state.thread_id = None
    st.session_state.image_url = None
    st.rerun()

show_image = st.sidebar.checkbox("πŸ“– Show Document Image", value=True)
col1, col2 = st.columns([1, 2])

with col1:
    if show_image and st.session_state.image_url:
        st.image(st.session_state.image_url, caption="πŸ“‘ Extracted Page", use_container_width=True)

# ------------------ Chat & Voice Panel ------------------
with col2:
    for message in st.session_state.messages:
        st.chat_message(message["role"]).write(message["content"])

    # 🎀 Real-time voice recorder
    st.subheader("πŸŽ™οΈ Ask with your voice")
    audio_ctx = webrtc_streamer(
        key="speech",
        mode=WebRtcMode.SENDONLY,
        in_audio_enabled=True,
        audio_receiver_size=256
    )

    if audio_ctx.audio_receiver:
        audio_processor = AudioProcessor()
        result = audio_ctx.audio_receiver.recv()
        audio_data = result.to_ndarray()
        st.session_state.audio_buffer.append(audio_data)

        # ⏱️ Auto stop after ~3 seconds
        if len(st.session_state.audio_buffer) > 30:
            tmp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
            save_wav(st.session_state.audio_buffer, tmp_path)
            st.session_state.audio_buffer = []

            with st.spinner("🧠 Transcribing..."):
                transcript = transcribe_audio(tmp_path, OPENAI_API_KEY)

            if transcript:
                st.success("πŸ“ " + transcript)
                st.session_state.messages.append({"role": "user", "content": transcript})
                st.chat_message("user").write(transcript)
                prompt = transcript

                try:
                    if st.session_state.thread_id is None:
                        thread = client.beta.threads.create()
                        st.session_state.thread_id = thread.id

                    thread_id = st.session_state.thread_id

                    client.beta.threads.messages.create(
                        thread_id=thread_id,
                        role="user",
                        content=prompt
                    )

                    run = client.beta.threads.runs.create(
                        thread_id=thread_id,
                        assistant_id=ASSISTANT_ID
                    )

                    with st.spinner("Assistant is thinking..."):
                        while True:
                            run_status = client.beta.threads.runs.retrieve(
                                thread_id=thread_id,
                                run_id=run.id
                            )
                            if run_status.status == "completed":
                                break
                            time.sleep(1)

                    messages = client.beta.threads.messages.list(thread_id=thread_id)
                    assistant_message = None
                    for message in reversed(messages.data):
                        if message.role == "assistant":
                            assistant_message = message.content[0].text.value
                            break

                    st.chat_message("assistant").write(assistant_message)
                    st.session_state.messages.append({"role": "assistant", "content": assistant_message})

                    image_match = re.search(
                        r'https://raw\.githubusercontent\.com/AndrewLORTech/surgical-pathology-manual/main/[\w\-/]*\.png',
                        assistant_message
                    )
                    if image_match:
                        st.session_state.image_url = image_match.group(0)

                except Exception as e:
                    st.error(f"❌ Error: {str(e)}")

    # Fallback text input
    if prompt := st.chat_input("πŸ’¬ Or type your question..."):
        st.session_state.messages.append({"role": "user", "content": prompt})
        st.chat_message("user").write(prompt)
        # You can add assistant logic here if you want it to run immediately