File size: 5,269 Bytes
e760d91
 
 
 
e1336eb
 
1c296f6
e760d91
1c296f6
81240ab
e760d91
1c296f6
e760d91
 
 
 
1c296f6
e760d91
 
9d8abce
e760d91
1c296f6
e760d91
 
 
 
1c296f6
81240ab
1c296f6
81240ab
e760d91
 
e1336eb
 
 
 
 
 
 
9d8abce
e1336eb
e760d91
81240ab
e760d91
 
 
 
 
1c296f6
e760d91
 
 
 
 
1c296f6
e760d91
 
 
 
81240ab
e760d91
 
 
 
81240ab
1c296f6
81240ab
 
 
 
 
 
 
1c296f6
81240ab
1c296f6
81240ab
 
1c296f6
81240ab
 
 
 
 
1c296f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e760d91
1c296f6
 
 
 
 
 
e760d91
1c296f6
 
e760d91
81240ab
1c296f6
 
 
 
 
 
e760d91
1c296f6
 
e760d91
81240ab
e760d91
 
 
81240ab
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import streamlit as st
import os
import time
import re
import requests
import tempfile
import wave
import numpy as np
from openai import OpenAI
from streamlit_audio_recorder import audio_recorder

# ------------------ Page Config ------------------
st.set_page_config(page_title="Document AI Assistant", layout="wide")
st.title("πŸ“„ Document AI Assistant")
st.caption("Chat with an AI Assistant on your medical/pathology documents")

# ------------------ Load Secrets ------------------
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
ASSISTANT_ID = os.environ.get("ASSISTANT_ID")

if not OPENAI_API_KEY or not ASSISTANT_ID:
    st.error("❌ Missing secrets. Please set both OPENAI_API_KEY and ASSISTANT_ID in Hugging Face Space settings.")
    st.stop()

client = OpenAI(api_key=OPENAI_API_KEY)

# ------------------ Session State Init ------------------
for key in ["messages", "thread_id", "image_url", "transcript"]:
    if key not in st.session_state:
        st.session_state[key] = [] if key == "messages" else None

# ------------------ Whisper Transcription ------------------
def transcribe_audio(file_path, api_key):
    with open(file_path, "rb") as f:
        response = requests.post(
            "https://api.openai.com/v1/audio/transcriptions",
            headers={"Authorization": f"Bearer {api_key}"},
            files={"file": f},
            data={"model": "whisper-1"}
        )
    return response.json().get("text", None)

# ------------------ Sidebar & Layout ------------------
st.sidebar.header("πŸ”§ Settings")
if st.sidebar.button("πŸ”„ Clear Chat"):
    st.session_state.messages = []
    st.session_state.thread_id = None
    st.session_state.image_url = None
    st.session_state.transcript = None
    st.rerun()

show_image = st.sidebar.checkbox("πŸ“– Show Document Image", value=True)
col1, col2 = st.columns([1, 2])

# ------------------ Image Panel ------------------
with col1:
    if show_image and st.session_state.image_url:
        st.image(st.session_state.image_url, caption="πŸ“‘ Extracted Page", use_container_width=True)

# ------------------ Chat + Mic Panel ------------------
with col2:
    for message in st.session_state.messages:
        st.chat_message(message["role"]).write(message["content"])

    st.subheader("πŸŽ™οΈ Ask with Your Voice")

    audio_bytes = audio_recorder(pause_threshold=3.0, energy_threshold=-1.0, sample_rate=44100)
    
    if audio_bytes:
        # Save temporary WAV file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmpfile:
            tmpfile.write(audio_bytes)
            tmp_path = tmpfile.name

        st.audio(tmp_path, format="audio/wav")

        with st.spinner("🧠 Transcribing..."):
            transcript = transcribe_audio(tmp_path, OPENAI_API_KEY)

        if transcript:
            st.success("πŸ“ Transcript: " + transcript)
            st.session_state.transcript = transcript

    # Submit Transcript to Assistant
    if st.session_state.transcript:
        if st.button("βœ… Send Transcript to Assistant"):
            user_input = st.session_state.transcript
            st.session_state.transcript = None  # reset

            st.session_state.messages.append({"role": "user", "content": user_input})
            st.chat_message("user").write(user_input)

            try:
                if st.session_state.thread_id is None:
                    thread = client.beta.threads.create()
                    st.session_state.thread_id = thread.id

                thread_id = st.session_state.thread_id
                client.beta.threads.messages.create(thread_id=thread_id, role="user", content=user_input)
                run = client.beta.threads.runs.create(thread_id=thread_id, assistant_id=ASSISTANT_ID)

                with st.spinner("πŸ€– Assistant is thinking..."):
                    while True:
                        run_status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id)
                        if run_status.status == "completed":
                            break
                        time.sleep(1)

                messages = client.beta.threads.messages.list(thread_id=thread_id)
                assistant_message = next(
                    (m.content[0].text.value for m in reversed(messages.data) if m.role == "assistant"), None
                )

                st.chat_message("assistant").write(assistant_message)
                st.session_state.messages.append({"role": "assistant", "content": assistant_message})

                # Extract GitHub image if available
                image_match = re.search(
                    r'https://raw\.githubusercontent\.com/AndrewLORTech/surgical-pathology-manual/main/[\w\-/]*\.png',
                    assistant_message
                )
                if image_match:
                    st.session_state.image_url = image_match.group(0)

            except Exception as e:
                st.error(f"❌ Error: {str(e)}")

    # Fallback text input
    if prompt := st.chat_input("πŸ’¬ Or type your question..."):
        st.session_state.messages.append({"role": "user", "content": prompt})
        st.chat_message("user").write(prompt)
        st.session_state.transcript = prompt  # Treat like voice input for now