IAMTFRMZA commited on
Commit
1c296f6
Β·
verified Β·
1 Parent(s): 7d12d16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -102
app.py CHANGED
@@ -4,36 +4,31 @@ import time
4
  import re
5
  import requests
6
  import tempfile
7
- from openai import OpenAI
8
- from streamlit_webrtc import webrtc_streamer, WebRtcMode
9
  import av
10
  import numpy as np
11
- import wave
 
12
 
13
- # ------------------ Configuration ------------------
14
  st.set_page_config(page_title="Document AI Assistant", layout="wide")
15
  st.title("πŸ“„ Document AI Assistant")
16
  st.caption("Chat with an AI Assistant on your medical/pathology documents")
17
 
18
- # ------------------ Secrets ------------------
19
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
20
  ASSISTANT_ID = os.environ.get("ASSISTANT_ID")
21
 
22
  if not OPENAI_API_KEY or not ASSISTANT_ID:
23
- st.error("❌ Missing secrets. Please set both OPENAI_API_KEY and ASSISTANT_ID in your Hugging Face Space settings.")
24
  st.stop()
25
 
26
  client = OpenAI(api_key=OPENAI_API_KEY)
27
 
28
- # ------------------ Session State ------------------
29
- if "messages" not in st.session_state:
30
- st.session_state.messages = []
31
- if "thread_id" not in st.session_state:
32
- st.session_state.thread_id = None
33
- if "image_url" not in st.session_state:
34
- st.session_state.image_url = None
35
- if "audio_buffer" not in st.session_state:
36
- st.session_state.audio_buffer = []
37
 
38
  # ------------------ Whisper Transcription ------------------
39
  def transcribe_audio(file_path, api_key):
@@ -46,16 +41,7 @@ def transcribe_audio(file_path, api_key):
46
  )
47
  return response.json().get("text", None)
48
 
49
- # ------------------ Audio Recorder ------------------
50
- class AudioProcessor:
51
- def __init__(self):
52
- self.frames = []
53
-
54
- def recv(self, frame):
55
- audio = frame.to_ndarray()
56
- self.frames.append(audio)
57
- return av.AudioFrame.from_ndarray(audio, layout="mono")
58
-
59
  def save_wav(frames, path, rate=48000):
60
  audio_data = np.concatenate(frames)
61
  with wave.open(path, 'wb') as wf:
@@ -64,106 +50,106 @@ def save_wav(frames, path, rate=48000):
64
  wf.setframerate(rate)
65
  wf.writeframes(audio_data.tobytes())
66
 
67
- # ------------------ Sidebar & Image Panel ------------------
68
  st.sidebar.header("πŸ”§ Settings")
69
  if st.sidebar.button("πŸ”„ Clear Chat"):
70
  st.session_state.messages = []
71
  st.session_state.thread_id = None
72
  st.session_state.image_url = None
 
 
73
  st.rerun()
74
 
75
  show_image = st.sidebar.checkbox("πŸ“– Show Document Image", value=True)
76
  col1, col2 = st.columns([1, 2])
77
 
 
78
  with col1:
79
  if show_image and st.session_state.image_url:
80
  st.image(st.session_state.image_url, caption="πŸ“‘ Extracted Page", use_container_width=True)
81
 
82
- # ------------------ Chat & Voice Panel ------------------
83
  with col2:
84
  for message in st.session_state.messages:
85
  st.chat_message(message["role"]).write(message["content"])
86
 
87
- # 🎀 Real-time voice recorder
88
- st.subheader("πŸŽ™οΈ Ask with your voice")
89
- audio_ctx = webrtc_streamer(
90
- key="speech",
91
- mode=WebRtcMode.SENDONLY,
92
- in_audio_enabled=True,
93
- audio_receiver_size=256
94
- )
95
-
96
- if audio_ctx.audio_receiver:
97
- audio_processor = AudioProcessor()
98
- result = audio_ctx.audio_receiver.recv()
99
- audio_data = result.to_ndarray()
100
- st.session_state.audio_buffer.append(audio_data)
101
-
102
- # ⏱️ Auto stop after ~3 seconds
103
- if len(st.session_state.audio_buffer) > 30:
104
- tmp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
105
- save_wav(st.session_state.audio_buffer, tmp_path)
106
- st.session_state.audio_buffer = []
107
-
108
- with st.spinner("🧠 Transcribing..."):
109
- transcript = transcribe_audio(tmp_path, OPENAI_API_KEY)
110
-
111
- if transcript:
112
- st.success("πŸ“ " + transcript)
113
- st.session_state.messages.append({"role": "user", "content": transcript})
114
- st.chat_message("user").write(transcript)
115
- prompt = transcript
116
-
117
- try:
118
- if st.session_state.thread_id is None:
119
- thread = client.beta.threads.create()
120
- st.session_state.thread_id = thread.id
121
-
122
- thread_id = st.session_state.thread_id
123
-
124
- client.beta.threads.messages.create(
125
- thread_id=thread_id,
126
- role="user",
127
- content=prompt
128
- )
129
-
130
- run = client.beta.threads.runs.create(
131
- thread_id=thread_id,
132
- assistant_id=ASSISTANT_ID
133
- )
134
-
135
- with st.spinner("Assistant is thinking..."):
136
- while True:
137
- run_status = client.beta.threads.runs.retrieve(
138
- thread_id=thread_id,
139
- run_id=run.id
140
- )
141
- if run_status.status == "completed":
142
- break
143
- time.sleep(1)
144
-
145
- messages = client.beta.threads.messages.list(thread_id=thread_id)
146
- assistant_message = None
147
- for message in reversed(messages.data):
148
- if message.role == "assistant":
149
- assistant_message = message.content[0].text.value
150
  break
 
 
 
 
 
 
151
 
152
- st.chat_message("assistant").write(assistant_message)
153
- st.session_state.messages.append({"role": "assistant", "content": assistant_message})
154
 
155
- image_match = re.search(
156
- r'https://raw\.githubusercontent\.com/AndrewLORTech/surgical-pathology-manual/main/[\w\-/]*\.png',
157
- assistant_message
158
- )
159
- if image_match:
160
- st.session_state.image_url = image_match.group(0)
161
 
162
- except Exception as e:
163
- st.error(f"❌ Error: {str(e)}")
164
 
165
- # Fallback text input
166
  if prompt := st.chat_input("πŸ’¬ Or type your question..."):
167
  st.session_state.messages.append({"role": "user", "content": prompt})
168
  st.chat_message("user").write(prompt)
169
- # You can add assistant logic here if you want it to run immediately
 
4
  import re
5
  import requests
6
  import tempfile
7
+ import wave
 
8
  import av
9
  import numpy as np
10
+ from openai import OpenAI
11
+ from streamlit_webrtc import webrtc_streamer, WebRtcMode
12
 
13
+ # ------------------ Page Config ------------------
14
  st.set_page_config(page_title="Document AI Assistant", layout="wide")
15
  st.title("πŸ“„ Document AI Assistant")
16
  st.caption("Chat with an AI Assistant on your medical/pathology documents")
17
 
18
+ # ------------------ Load Secrets ------------------
19
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
20
  ASSISTANT_ID = os.environ.get("ASSISTANT_ID")
21
 
22
  if not OPENAI_API_KEY or not ASSISTANT_ID:
23
+ st.error("❌ Missing secrets. Please set both OPENAI_API_KEY and ASSISTANT_ID in Hugging Face Space settings.")
24
  st.stop()
25
 
26
  client = OpenAI(api_key=OPENAI_API_KEY)
27
 
28
+ # ------------------ Session State Init ------------------
29
+ for key in ["messages", "thread_id", "image_url", "audio_buffer", "transcript"]:
30
+ if key not in st.session_state:
31
+ st.session_state[key] = [] if key == "messages" or key == "audio_buffer" else None
 
 
 
 
 
32
 
33
  # ------------------ Whisper Transcription ------------------
34
  def transcribe_audio(file_path, api_key):
 
41
  )
42
  return response.json().get("text", None)
43
 
44
+ # ------------------ Audio Save Helper ------------------
 
 
 
 
 
 
 
 
 
45
  def save_wav(frames, path, rate=48000):
46
  audio_data = np.concatenate(frames)
47
  with wave.open(path, 'wb') as wf:
 
50
  wf.setframerate(rate)
51
  wf.writeframes(audio_data.tobytes())
52
 
53
+ # ------------------ Sidebar Controls ------------------
54
  st.sidebar.header("πŸ”§ Settings")
55
  if st.sidebar.button("πŸ”„ Clear Chat"):
56
  st.session_state.messages = []
57
  st.session_state.thread_id = None
58
  st.session_state.image_url = None
59
+ st.session_state.transcript = None
60
+ st.session_state.audio_buffer = []
61
  st.rerun()
62
 
63
  show_image = st.sidebar.checkbox("πŸ“– Show Document Image", value=True)
64
  col1, col2 = st.columns([1, 2])
65
 
66
+ # ------------------ Image Panel ------------------
67
  with col1:
68
  if show_image and st.session_state.image_url:
69
  st.image(st.session_state.image_url, caption="πŸ“‘ Extracted Page", use_container_width=True)
70
 
71
+ # ------------------ Chat + Voice Panel ------------------
72
  with col2:
73
  for message in st.session_state.messages:
74
  st.chat_message(message["role"]).write(message["content"])
75
 
76
+ st.subheader("πŸŽ™οΈ Real-time Voice Input")
77
+ is_recording = st.checkbox("🎀 Start Recording")
78
+
79
+ if is_recording:
80
+ audio_ctx = webrtc_streamer(key="voice", mode=WebRtcMode.SENDONLY)
81
+
82
+ if audio_ctx.audio_receiver:
83
+ try:
84
+ audio_frames = []
85
+ while True:
86
+ result = audio_ctx.audio_receiver.recv()
87
+ audio_data = result.to_ndarray()
88
+ audio_frames.append(audio_data)
89
+ if len(audio_frames) > 30:
90
+ break
91
+
92
+ tmp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
93
+ save_wav(audio_frames, tmp_path)
94
+ st.audio(tmp_path, format="audio/wav")
95
+
96
+ with st.spinner("🧠 Transcribing..."):
97
+ transcript = transcribe_audio(tmp_path, OPENAI_API_KEY)
98
+
99
+ if transcript:
100
+ st.session_state.transcript = transcript
101
+ st.success("πŸ“ Transcript: " + transcript)
102
+ with open(tmp_path, "rb") as f:
103
+ st.download_button("⬇️ Download Audio", f, file_name="recording.wav", mime="audio/wav")
104
+
105
+ except Exception as e:
106
+ st.error(f"Recording failed: {str(e)}")
107
+
108
+ # Confirm & send transcript
109
+ if st.session_state.transcript:
110
+ if st.button("βœ… Send Transcript to Assistant"):
111
+ user_input = st.session_state.transcript
112
+ st.session_state.transcript = None # reset
113
+
114
+ st.session_state.messages.append({"role": "user", "content": user_input})
115
+ st.chat_message("user").write(user_input)
116
+
117
+ try:
118
+ if st.session_state.thread_id is None:
119
+ thread = client.beta.threads.create()
120
+ st.session_state.thread_id = thread.id
121
+
122
+ thread_id = st.session_state.thread_id
123
+ client.beta.threads.messages.create(thread_id=thread_id, role="user", content=user_input)
124
+ run = client.beta.threads.runs.create(thread_id=thread_id, assistant_id=ASSISTANT_ID)
125
+
126
+ with st.spinner("πŸ€– Assistant is thinking..."):
127
+ while True:
128
+ run_status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id)
129
+ if run_status.status == "completed":
 
 
 
 
 
 
 
 
 
130
  break
131
+ time.sleep(1)
132
+
133
+ messages = client.beta.threads.messages.list(thread_id=thread_id)
134
+ assistant_message = next(
135
+ (m.content[0].text.value for m in reversed(messages.data) if m.role == "assistant"), None
136
+ )
137
 
138
+ st.chat_message("assistant").write(assistant_message)
139
+ st.session_state.messages.append({"role": "assistant", "content": assistant_message})
140
 
141
+ image_match = re.search(
142
+ r'https://raw\.githubusercontent\.com/AndrewLORTech/surgical-pathology-manual/main/[\w\-/]*\.png',
143
+ assistant_message
144
+ )
145
+ if image_match:
146
+ st.session_state.image_url = image_match.group(0)
147
 
148
+ except Exception as e:
149
+ st.error(f"❌ Error: {str(e)}")
150
 
151
+ # Text input fallback
152
  if prompt := st.chat_input("πŸ’¬ Or type your question..."):
153
  st.session_state.messages.append({"role": "user", "content": prompt})
154
  st.chat_message("user").write(prompt)
155
+ # Same logic could be duplicated here or modularized