noumanjavaid commited on
Commit
a4d774b
·
verified ·
1 Parent(s): ed44aa7

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +47 -30
src/streamlit_app.py CHANGED
@@ -47,12 +47,10 @@ VIDEO_FPS_TO_GEMINI = 2
47
  VIDEO_API_RESIZE = (1024, 1024)
48
 
49
  # !!! IMPORTANT: Verify this model name is correct for the Live API !!!
50
- # This is a common point of failure for ConnectionClosedError.
51
  MODEL_NAME = "models/gemini-2.0-flash-live-001"
52
  logging.info(f"Using Gemini Model: {MODEL_NAME}")
53
 
54
  MEDICAL_ASSISTANT_SYSTEM_PROMPT = """You are an AI Medical Assistant. Your primary function is to analyze visual information from the user's camera or screen and respond via voice.
55
-
56
  Your responsibilities are:
57
  1. **Visual Observation and Description:** Carefully examine the images or video feed. Describe relevant details you observe.
58
  2. **General Information (Non-Diagnostic):** Provide general information related to what is visually presented, if applicable. You are not a diagnostic tool.
@@ -63,7 +61,6 @@ Your responsibilities are:
63
  * If you see something that *appears* visually concerning (e.g., an unusual skin lesion, signs of injury), you may gently suggest it might be wise to have it looked at by a professional, without speculating on what it is.
64
  4. **Tone:** Maintain a helpful, empathetic, and calm tone.
65
  5. **Interaction:** After this initial instruction, you can make a brief acknowledgment of your role (e.g., "I'm ready to assist by looking at what you show me. Please remember to consult a doctor for medical advice."). Then, focus on responding to the user's visual input and questions.
66
-
67
  Example of a disclaimer you might use: "As an AI assistant, I can describe what I see, but I can't provide medical advice or diagnoses. For any health concerns, it's always best to speak with a doctor or other healthcare professional."
68
  """
69
 
@@ -101,16 +98,18 @@ else:
101
  logging.critical("GEMINI_API_KEY not found.")
102
  st.stop()
103
 
104
- # Gemini LiveConnectConfig - HIGHLY SIMPLIFIED FOR DEBUGGING ConnectionClosedError
105
- # Start with the absolute minimum. If this connects, incrementally add back features.
106
- # If this still fails, the issue is likely MODEL_NAME or API Key/Project permissions.
107
  LIVE_CONNECT_CONFIG = types.LiveConnectConfig(
108
- response_modalities=["audio"], # Start with text only
109
  speech_config=types.SpeechConfig(
110
  voice_config=types.VoiceConfig(
111
- prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="Zephyr")
112
- ),
113
- )
 
 
 
114
 
115
  # --- Backend Gemini Interaction Loop ---
116
  class GeminiInteractionLoop:
@@ -126,6 +125,8 @@ class GeminiInteractionLoop:
126
  return
127
  try:
128
  logging.info(f"Sending text to Gemini: '{user_text[:50]}...'")
 
 
129
  await self.gemini_session.send(input=user_text, end_of_turn=True)
130
  except Exception as e:
131
  logging.error(f"Error sending text message to Gemini: {e}", exc_info=True)
@@ -153,7 +154,10 @@ class GeminiInteractionLoop:
153
  media_data = await get_media_from_queues()
154
  if media_data is None and not self.is_running: break # Sentinel and stop signal
155
  if media_data and self.gemini_session and self.is_running:
156
- try: await self.gemini_session.send(input=media_data)
 
 
 
157
  except Exception as e: logging.error(f"Error sending media chunk to Gemini: {e}", exc_info=True)
158
  elif not media_data: await asyncio.sleep(0.05) # No data, yield
159
  except asyncio.CancelledError: logging.info("Task cancelled: stream_media_to_gemini.")
@@ -176,7 +180,6 @@ class GeminiInteractionLoop:
176
  logging.info(f"Gemini text response: {text_response[:100]}")
177
  if 'chat_messages' not in st.session_state: st.session_state.chat_messages = []
178
  st.session_state.chat_messages = st.session_state.chat_messages + [{"role": "assistant", "content": text_response}]
179
- # Consider using st.rerun() via a thread-safe mechanism if immediate UI update is critical
180
  except types.generation_types.StopCandidateException: logging.info("Gemini response stream ended normally.")
181
  except Exception as e:
182
  if self.is_running: logging.error(f"Error receiving from Gemini: {e}", exc_info=True)
@@ -221,8 +224,8 @@ class GeminiInteractionLoop:
221
  def signal_stop(self):
222
  logging.info("Signal to stop GeminiInteractionLoop received.")
223
  self.is_running = False
224
- for q_name, q_obj_ref in [("video_q", video_frames_to_gemini_q),
225
- ("audio_in_q", audio_chunks_to_gemini_q),
226
  ("audio_out_q", audio_from_gemini_playback_q)]:
227
  if q_obj_ref:
228
  try: q_obj_ref.put_nowait(None)
@@ -255,7 +258,7 @@ class GeminiInteractionLoop:
255
  logging.error(f"Failed to send system prompt: {e}", exc_info=True)
256
  self.is_running = False; return
257
 
258
- # Python 3.9 does not have asyncio.TaskGroup, so manage tasks individually
259
  tasks = []
260
  try:
261
  logging.info("Creating async tasks for Gemini interaction...")
@@ -263,12 +266,23 @@ class GeminiInteractionLoop:
263
  tasks.append(asyncio.create_task(self.process_gemini_responses(), name="process_gemini_responses"))
264
  tasks.append(asyncio.create_task(self.play_gemini_audio(), name="play_gemini_audio"))
265
  logging.info("All Gemini interaction tasks created.")
266
- await asyncio.gather(*tasks) # Wait for all tasks to complete
267
- except Exception as e_gather: # Catch errors from tasks gathered
268
- logging.error(f"Error during asyncio.gather: {e_gather}", exc_info=True)
 
 
 
 
 
 
 
 
 
 
269
  for task in tasks:
270
- if not task.done(): task.cancel() # Cancel pending tasks
271
- await asyncio.gather(*tasks, return_exceptions=True) # Wait for cancellations
 
272
  logging.info("Gemini interaction tasks finished or cancelled.")
273
 
274
  except asyncio.CancelledError: logging.info("GeminiInteractionLoop.run_main_loop() was cancelled.")
@@ -276,15 +290,22 @@ class GeminiInteractionLoop:
276
  logging.error(f"Exception in GeminiInteractionLoop run_main_loop: {type(e).__name__}: {e}", exc_info=True)
277
  finally:
278
  logging.info("GeminiInteractionLoop.run_main_loop() finishing...")
279
- self.is_running = False # Ensure flag is set for all tasks
280
- self.signal_stop() # Send sentinels again to be sure
 
 
 
 
 
 
 
281
  self.gemini_session = None
282
- # Clear global queues by setting them to None
283
  video_frames_to_gemini_q = None
284
  audio_chunks_to_gemini_q = None
285
  audio_from_gemini_playback_q = None
286
  logging.info("GeminiInteractionLoop finished and global queues set to None.")
287
 
 
288
  # --- WebRTC Media Processors ---
289
  class VideoProcessor(VideoProcessorBase):
290
  def __init__(self):
@@ -311,8 +332,7 @@ class VideoProcessor(VideoProcessorBase):
311
  video_frames_to_gemini_q.put_nowait(api_data)
312
  except Exception as e: logging.error(f"Error processing/queueing video frame: {e}", exc_info=True)
313
 
314
- async def recv(self
315
- , frame):
316
  img_bgr = frame.to_ndarray(format="bgr24")
317
  try:
318
  loop = asyncio.get_running_loop()
@@ -325,11 +345,8 @@ class AudioProcessor(AudioProcessorBase):
325
  if audio_chunks_to_gemini_q is None: return
326
  for frame in audio_frames:
327
  audio_data = frame.planes[0].to_bytes()
328
- # Note: Ensure this mime_type and the actual audio data format (sample rate, channels, bit depth)
329
- # are compatible with what the Gemini Live API expects for PCM audio.
330
  mime_type = f"audio/L16;rate={frame.sample_rate};channels={frame.layout.channels}"
331
- api_data = {"data"
332
- : audio_data, "mime_type": mime_type}
333
  try:
334
  if audio_chunks_to_gemini_q.full():
335
  try: await asyncio.wait_for(audio_chunks_to_gemini_q.get(), timeout=0.01)
@@ -414,7 +431,7 @@ def run_streamlit_app():
414
 
415
  if webrtc_ctx.state.playing:
416
  st.caption("WebRTC connected. Streaming your camera and microphone.")
417
- elif st.session_state.gemini_session_active: # Check if session is supposed to be active
418
  st.caption("WebRTC attempting to connect. Ensure camera/microphone permissions are granted in your browser.")
419
  if hasattr(webrtc_ctx.state, 'error') and webrtc_ctx.state.error:
420
  st.error(f"WebRTC Connection Error: {webrtc_ctx.state.error}")
 
47
  VIDEO_API_RESIZE = (1024, 1024)
48
 
49
  # !!! IMPORTANT: Verify this model name is correct for the Live API !!!
 
50
  MODEL_NAME = "models/gemini-2.0-flash-live-001"
51
  logging.info(f"Using Gemini Model: {MODEL_NAME}")
52
 
53
  MEDICAL_ASSISTANT_SYSTEM_PROMPT = """You are an AI Medical Assistant. Your primary function is to analyze visual information from the user's camera or screen and respond via voice.
 
54
  Your responsibilities are:
55
  1. **Visual Observation and Description:** Carefully examine the images or video feed. Describe relevant details you observe.
56
  2. **General Information (Non-Diagnostic):** Provide general information related to what is visually presented, if applicable. You are not a diagnostic tool.
 
61
  * If you see something that *appears* visually concerning (e.g., an unusual skin lesion, signs of injury), you may gently suggest it might be wise to have it looked at by a professional, without speculating on what it is.
62
  4. **Tone:** Maintain a helpful, empathetic, and calm tone.
63
  5. **Interaction:** After this initial instruction, you can make a brief acknowledgment of your role (e.g., "I'm ready to assist by looking at what you show me. Please remember to consult a doctor for medical advice."). Then, focus on responding to the user's visual input and questions.
 
64
  Example of a disclaimer you might use: "As an AI assistant, I can describe what I see, but I can't provide medical advice or diagnoses. For any health concerns, it's always best to speak with a doctor or other healthcare professional."
65
  """
66
 
 
98
  logging.critical("GEMINI_API_KEY not found.")
99
  st.stop()
100
 
101
+ # Gemini LiveConnectConfig - Using audio response and Puck voice as in your latest code
102
+ # Ensure this configuration is valid for your API key and model.
 
103
  LIVE_CONNECT_CONFIG = types.LiveConnectConfig(
104
+ response_modalities=["audio"], # Requesting audio response
105
  speech_config=types.SpeechConfig(
106
  voice_config=types.VoiceConfig(
107
+ prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="Puck") # Using Puck voice
108
+ )
109
+ ) # <---------------------------------- CORRECTED: Added missing closing parenthesis
110
+ )
111
+ logging.info(f"Attempting connection with LiveConnectConfig: {LIVE_CONNECT_CONFIG}")
112
+
113
 
114
  # --- Backend Gemini Interaction Loop ---
115
  class GeminiInteractionLoop:
 
125
  return
126
  try:
127
  logging.info(f"Sending text to Gemini: '{user_text[:50]}...'")
128
+ # Use the specific method as suggested by the deprecation warning if possible
129
+ # For now, keeping session.send as it was working functionally
130
  await self.gemini_session.send(input=user_text, end_of_turn=True)
131
  except Exception as e:
132
  logging.error(f"Error sending text message to Gemini: {e}", exc_info=True)
 
154
  media_data = await get_media_from_queues()
155
  if media_data is None and not self.is_running: break # Sentinel and stop signal
156
  if media_data and self.gemini_session and self.is_running:
157
+ try:
158
+ # Use the specific method as suggested by the deprecation warning if possible
159
+ # For now, keeping session.send as it was working functionally
160
+ await self.gemini_session.send(input=media_data)
161
  except Exception as e: logging.error(f"Error sending media chunk to Gemini: {e}", exc_info=True)
162
  elif not media_data: await asyncio.sleep(0.05) # No data, yield
163
  except asyncio.CancelledError: logging.info("Task cancelled: stream_media_to_gemini.")
 
180
  logging.info(f"Gemini text response: {text_response[:100]}")
181
  if 'chat_messages' not in st.session_state: st.session_state.chat_messages = []
182
  st.session_state.chat_messages = st.session_state.chat_messages + [{"role": "assistant", "content": text_response}]
 
183
  except types.generation_types.StopCandidateException: logging.info("Gemini response stream ended normally.")
184
  except Exception as e:
185
  if self.is_running: logging.error(f"Error receiving from Gemini: {e}", exc_info=True)
 
224
  def signal_stop(self):
225
  logging.info("Signal to stop GeminiInteractionLoop received.")
226
  self.is_running = False
227
+ for q_name, q_obj_ref in [("video_q", video_frames_to_gemini_q),
228
+ ("audio_in_q", audio_chunks_to_gemini_q),
229
  ("audio_out_q", audio_from_gemini_playback_q)]:
230
  if q_obj_ref:
231
  try: q_obj_ref.put_nowait(None)
 
258
  logging.error(f"Failed to send system prompt: {e}", exc_info=True)
259
  self.is_running = False; return
260
 
261
+ # Using asyncio.gather for Python 3.9 compatibility
262
  tasks = []
263
  try:
264
  logging.info("Creating async tasks for Gemini interaction...")
 
266
  tasks.append(asyncio.create_task(self.process_gemini_responses(), name="process_gemini_responses"))
267
  tasks.append(asyncio.create_task(self.play_gemini_audio(), name="play_gemini_audio"))
268
  logging.info("All Gemini interaction tasks created.")
269
+ # Wait for tasks to complete or raise an exception
270
+ done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
271
+ # Check results of completed tasks for errors
272
+ for future in done:
273
+ try:
274
+ future.result() # Raise exception if task failed
275
+ except Exception as task_exc:
276
+ logging.error(f"Task {future.get_name()} failed: {task_exc}", exc_info=True)
277
+ # Optionally cancel remaining tasks if one fails critically
278
+ for p_task in pending: p_task.cancel()
279
+ # If loop completes normally (e.g., user stops), pending tasks will be handled by finally block
280
+ except Exception as e_gather: # Catch errors during task creation/gathering
281
+ logging.error(f"Error during task management: {e_gather}", exc_info=True)
282
  for task in tasks:
283
+ if not task.done(): task.cancel()
284
+ # Wait for cancellations to complete
285
+ await asyncio.gather(*tasks, return_exceptions=True)
286
  logging.info("Gemini interaction tasks finished or cancelled.")
287
 
288
  except asyncio.CancelledError: logging.info("GeminiInteractionLoop.run_main_loop() was cancelled.")
 
290
  logging.error(f"Exception in GeminiInteractionLoop run_main_loop: {type(e).__name__}: {e}", exc_info=True)
291
  finally:
292
  logging.info("GeminiInteractionLoop.run_main_loop() finishing...")
293
+ self.is_running = False
294
+ self.signal_stop() # Ensure sentinels are sent
295
+ # Clean up any remaining tasks (important if gather didn't complete)
296
+ # current_tasks = [t for t in asyncio.all_tasks(self.async_event_loop) if t is not asyncio.current_task()]
297
+ # if current_tasks:
298
+ # logging.info(f"Cancelling {len(current_tasks)} remaining tasks...")
299
+ # for task in current_tasks: task.cancel()
300
+ # await asyncio.gather(*current_tasks, return_exceptions=True)
301
+
302
  self.gemini_session = None
 
303
  video_frames_to_gemini_q = None
304
  audio_chunks_to_gemini_q = None
305
  audio_from_gemini_playback_q = None
306
  logging.info("GeminiInteractionLoop finished and global queues set to None.")
307
 
308
+
309
  # --- WebRTC Media Processors ---
310
  class VideoProcessor(VideoProcessorBase):
311
  def __init__(self):
 
332
  video_frames_to_gemini_q.put_nowait(api_data)
333
  except Exception as e: logging.error(f"Error processing/queueing video frame: {e}", exc_info=True)
334
 
335
+ async def recv(self, frame):
 
336
  img_bgr = frame.to_ndarray(format="bgr24")
337
  try:
338
  loop = asyncio.get_running_loop()
 
345
  if audio_chunks_to_gemini_q is None: return
346
  for frame in audio_frames:
347
  audio_data = frame.planes[0].to_bytes()
 
 
348
  mime_type = f"audio/L16;rate={frame.sample_rate};channels={frame.layout.channels}"
349
+ api_data = {"data": audio_data, "mime_type": mime_type}
 
350
  try:
351
  if audio_chunks_to_gemini_q.full():
352
  try: await asyncio.wait_for(audio_chunks_to_gemini_q.get(), timeout=0.01)
 
431
 
432
  if webrtc_ctx.state.playing:
433
  st.caption("WebRTC connected. Streaming your camera and microphone.")
434
+ elif st.session_state.gemini_session_active:
435
  st.caption("WebRTC attempting to connect. Ensure camera/microphone permissions are granted in your browser.")
436
  if hasattr(webrtc_ctx.state, 'error') and webrtc_ctx.state.error:
437
  st.error(f"WebRTC Connection Error: {webrtc_ctx.state.error}")