Spatial-aware / src /streamlit_app.py
noumanjavaid's picture
Update src/streamlit_app.py
5bb8fe7 verified
raw
history blame
26 kB
# -*- coding: utf-8 -*-
import streamlit as st
import os
import asyncio
import base64
import io
import threading
import queue # Standard library queue, not asyncio.Queue for thread-safe UI updates if needed
import traceback
import time # Keep time for potential future use (e.g., timestamps)
from dotenv import load_dotenv
# --- Import main libraries ---
import cv2
import pyaudio
import PIL.Image
import mss
from google import genai
from google.genai import types
# --- Configuration ---
load_dotenv()
# Audio configuration
FORMAT = pyaudio.paInt16
CHANNELS = 1
SEND_SAMPLE_RATE = 16000
RECEIVE_SAMPLE_RATE = 24000 # According to Gemini documentation
CHUNK_SIZE = 1024
AUDIO_QUEUE_MAXSIZE = 20 # Max audio chunks to buffer for playback
# Video configuration
VIDEO_FPS_LIMIT = 1 # Send 1 frame per second to the API
VIDEO_PREVIEW_RESIZE = (640, 480) # Size for Streamlit preview
VIDEO_API_RESIZE = (1024, 1024) # Max size to send to API (adjust if needed)
# Gemini model configuration
MODEL = "models/gemini-2.0-flash-live-001" # Ensure this is the correct model for live capabilities
DEFAULT_MODE = "camera" # Default video input mode
# System Prompt for the Medical Assistant
MEDICAL_ASSISTANT_SYSTEM_PROMPT = """You are an AI Medical Assistant. Your primary function is to analyze visual information from the user's camera or screen.
Your responsibilities are:
1. **Visual Observation and Description:** Carefully examine the images or video feed. Describe relevant details you observe.
2. **General Information (Non-Diagnostic):** Provide general information related to what is visually presented, if applicable. You are not a diagnostic tool.
3. **Safety and Disclaimer (CRITICAL):**
* You are an AI assistant, **NOT a medical doctor or a substitute for one.**
* **DO NOT provide medical diagnoses, treatment advice, or interpret medical results (e.g., X-rays, scans, lab reports).**
* When appropriate, and always if the user seems to be seeking diagnosis or treatment, explicitly state your limitations and **strongly advise the user to consult a qualified healthcare professional.**
* If you see something that *appears* visually concerning (e.g., an unusual skin lesion, signs of injury), you may gently suggest it might be wise to have it looked at by a professional, without speculating on what it is.
4. **Tone:** Maintain a helpful, empathetic, and calm tone.
5. **Interaction:** After this initial instruction, you can make a brief acknowledgment of your role (e.g., "I'm ready to assist by looking at what you show me. Please remember to consult a doctor for medical advice."). Then, focus on responding to the user's visual input and questions.
Example of a disclaimer you might use: "As an AI assistant, I can describe what I see, but I can't provide medical advice or diagnoses. For any health concerns, it's always best to speak with a doctor or other healthcare professional."
"""
# Initialize Streamlit state
def init_session_state():
if 'initialized' not in st.session_state:
st.session_state['initialized'] = False
if 'audio_loop' not in st.session_state:
st.session_state['audio_loop'] = None
if 'chat_messages' not in st.session_state:
st.session_state['chat_messages'] = []
if 'current_frame' not in st.session_state:
st.session_state['current_frame'] = None
if 'run_loop' not in st.session_state: # Flag to control the loop from Streamlit
st.session_state['run_loop'] = False
# Initialize all session state variables
init_session_state()
# Configure page
st.set_page_config(page_title="Real-time Medical Assistant", layout="wide")
# Initialize Gemini client
# Ensure API key is set in environment variables or .env file
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
st.error("GEMINI_API_KEY not found. Please set it in your environment variables or a .env file.")
st.stop()
client = genai.Client(
http_options={"api_version": "v1beta"},
api_key=GEMINI_API_KEY,
)
# Configure Gemini client and response settings
CONFIG = types.LiveConnectConfig(
response_modalities=["audio", "text"], # Ensure text is also enabled if you want to display AI text directly
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="Puck") # Or other preferred voice
)
),
# If the API supports an initial_prompt field in LiveConnectConfig, it would be ideal here.
# As of some versions, it might not be directly available, hence sending as first message.
)
pya = pyaudio.PyAudio()
class AudioLoop:
def __init__(self, video_mode=DEFAULT_MODE):
self.video_mode = video_mode
self.audio_in_queue = None # asyncio.Queue for audio playback
self.out_queue = None # asyncio.Queue for data to Gemini
self.session = None
# Tasks are managed by TaskGroup now
self.running = True # General flag to control async loops
self.audio_stream = None # PyAudio input stream
async def send_text_to_gemini(self, text_input): # Renamed from send_text to avoid confusion
if not text_input or not self.session or not self.running:
st.warning("Session not active or no text to send.")
return
try:
# User messages should typically end the turn for the AI to respond.
await self.session.send(input=text_input, end_of_turn=True)
# UI update for user message is handled in main Streamlit part
except Exception as e:
st.error(f"Error sending message to Gemini: {str(e)}")
traceback.print_exception(e)
def _get_frame(self, cap):
ret, frame = cap.read()
if not ret:
return None
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = PIL.Image.fromarray(frame_rgb)
preview_img = img.copy()
preview_img.thumbnail(VIDEO_PREVIEW_RESIZE)
api_img = img.copy()
api_img.thumbnail(VIDEO_API_RESIZE)
image_io = io.BytesIO()
api_img.save(image_io, format="jpeg")
image_io.seek(0)
image_bytes = image_io.read()
return {
"preview": preview_img,
"api": {
"mime_type": "image/jpeg",
"data": base64.b64encode(image_bytes).decode()
}
}
async def get_frames_from_camera(self): # Renamed for clarity
cap = None
try:
cap = await asyncio.to_thread(cv2.VideoCapture, 0)
if not cap.isOpened():
st.error("Could not open camera.") # This error needs to reach Streamlit UI
self.running = False # Stop the loop if camera fails
return
while self.running:
frame_data = await asyncio.to_thread(self._get_frame, cap)
if frame_data is None:
await asyncio.sleep(0.01) # Short sleep if frame read fails
continue
st.session_state['current_frame'] = frame_data["preview"]
if self.out_queue.full():
await self.out_queue.get() # Make space if full to avoid indefinite block
await self.out_queue.put(frame_data["api"])
await asyncio.sleep(1.0 / VIDEO_FPS_LIMIT)
except Exception as e:
st.error(f"Camera streaming error: {e}")
self.running = False
finally:
if cap:
await asyncio.to_thread(cap.release)
def _get_screen_frame(self): # Renamed for clarity
sct = mss.mss()
# Use the first monitor
monitor_number = 1
if len(sct.monitors) > 1: # sct.monitors[0] is all monitors, sct.monitors[1] is primary
monitor = sct.monitors[monitor_number]
else: # If only one monitor entry (all), just use it.
monitor = sct.monitors[0]
screenshot = sct.grab(monitor)
img = PIL.Image.frombytes("RGB", screenshot.size, screenshot.rgb)
preview_img = img.copy()
preview_img.thumbnail(VIDEO_PREVIEW_RESIZE)
api_img = img.copy()
api_img.thumbnail(VIDEO_API_RESIZE)
image_io = io.BytesIO()
api_img.save(image_io, format="jpeg")
image_io.seek(0)
image_bytes = image_io.read()
return {
"preview": preview_img,
"api": {
"mime_type": "image/jpeg",
"data": base64.b64encode(image_bytes).decode()
}
}
async def get_frames_from_screen(self): # Renamed for clarity
try:
while self.running:
frame_data = await asyncio.to_thread(self._get_screen_frame)
if frame_data is None:
await asyncio.sleep(0.01)
continue
st.session_state['current_frame'] = frame_data["preview"]
if self.out_queue.full():
await self.out_queue.get()
await self.out_queue.put(frame_data["api"])
await asyncio.sleep(1.0 / VIDEO_FPS_LIMIT)
except Exception as e:
st.error(f"Screen capture error: {e}")
self.running = False
async def send_realtime_media(self): # Renamed
try:
while self.running:
if not self.session:
await asyncio.sleep(0.1) # Wait for session to be established
continue
try:
msg = await asyncio.wait_for(self.out_queue.get(), timeout=0.5) # Timeout to prevent blocking indefinitely
if self.session and self.running: # Re-check session and running status
await self.session.send(input=msg) # No end_of_turn for continuous media
self.out_queue.task_done()
except asyncio.TimeoutError:
continue # No new media to send
except Exception as e:
if self.running: # Only log if we are supposed to be running
print(f"Error in send_realtime_media: {e}") # Log to console
# Consider if this error should stop the loop or be reported to UI
await asyncio.sleep(0.1) # Prevent tight loop on error
except asyncio.CancelledError:
print("send_realtime_media task cancelled.")
finally:
print("send_realtime_media task finished.")
async def listen_for_audio(self): # Renamed
self.audio_stream = None
try:
mic_info = await asyncio.to_thread(pya.get_default_input_device_info)
self.audio_stream = await asyncio.to_thread(
pya.open,
format=FORMAT,
channels=CHANNELS,
rate=SEND_SAMPLE_RATE,
input=True,
input_device_index=mic_info["index"],
frames_per_buffer=CHUNK_SIZE,
)
print("Microphone stream opened.")
while self.running:
try:
# exception_on_overflow=False helps avoid crashes on buffer overflows
data = await asyncio.to_thread(self.audio_stream.read, CHUNK_SIZE, exception_on_overflow=False)
if self.out_queue.full():
await self.out_queue.get() # Make space
await self.out_queue.put({"data": data, "mime_type": "audio/pcm"})
except IOError as e: # PyAudio specific IO errors
if e.errno == pyaudio.paInputOverflowed:
print("PyAudio Input overflowed. Skipping frame.") # Or log to a file/UI
else:
print(f"PyAudio read error: {e}")
self.running = False # Potentially stop on other IOErrors
break
except Exception as e:
print(f"Error in listen_for_audio: {e}")
await asyncio.sleep(0.01) # Prevent tight loop on error
except Exception as e:
st.error(f"Failed to open microphone: {e}") # This error needs to reach Streamlit UI
self.running = False
finally:
if self.audio_stream:
await asyncio.to_thread(self.audio_stream.stop_stream)
await asyncio.to_thread(self.audio_stream.close)
print("Microphone stream closed.")
async def receive_gemini_responses(self): # Renamed
try:
while self.running:
if not self.session:
await asyncio.sleep(0.1) # Wait for session
continue
try:
# Blocking receive, but should yield if self.running becomes false or session closes
turn = self.session.receive()
async for response in turn:
if not self.running: break # Exit if stop signal received during iteration
if data := response.data: # Audio data
if not self.audio_in_queue.full():
self.audio_in_queue.put_nowait(data)
else:
print("Playback audio queue full, discarding data.")
if text := response.text: # Text part of the response
# Queue this for the main thread to update Streamlit
st.session_state['chat_messages'].append({"role": "assistant", "content": text})
# Consider st.experimental_rerun() if immediate update is critical and safe
# For now, rely on Streamlit's natural refresh from chat_input or other interactions
# Handle turn completion logic if needed (e.g., clear audio queue for interruptions)
# For simplicity, current model might not need complex interruption handling here.
# If interruptions are implemented (e.g., user speaks while AI is speaking),
# you might want to clear self.audio_in_queue here.
except types.generation_types.StopCandidateException:
print("Gemini indicated end of response (StopCandidateException).") # Normal
except Exception as e:
if self.running:
print(f"Error receiving from Gemini: {e}")
await asyncio.sleep(0.1) # Prevent tight loop on error
except asyncio.CancelledError:
print("receive_gemini_responses task cancelled.")
finally:
print("receive_gemini_responses task finished.")
async def play_audio_responses(self): # Renamed
playback_stream = None
try:
playback_stream = await asyncio.to_thread(
pya.open,
format=FORMAT, # Assuming Gemini audio matches this, or adjust
channels=CHANNELS,
rate=RECEIVE_SAMPLE_RATE,
output=True,
)
print("Audio playback stream opened.")
while self.running:
try:
bytestream = await asyncio.wait_for(self.audio_in_queue.get(), timeout=0.5)
await asyncio.to_thread(playback_stream.write, bytestream)
self.audio_in_queue.task_done()
except asyncio.TimeoutError:
continue # No audio to play
except Exception as e:
print(f"Error playing audio: {e}")
await asyncio.sleep(0.01) # Prevent tight loop
except Exception as e:
st.error(f"Failed to open audio playback: {e}")
self.running = False
finally:
if playback_stream:
await asyncio.to_thread(playback_stream.stop_stream)
await asyncio.to_thread(playback_stream.close)
print("Audio playback stream closed.")
def stop_loop(self): # Renamed
print("Stop signal received for AudioLoop.")
self.running = False
# Queues can be an issue for graceful shutdown if tasks are blocked on put/get
# Put sentinel values or use timeouts in queue operations
if self.out_queue: # For send_realtime_media
self.out_queue.put_nowait(None) # Sentinel to unblock .get()
if self.audio_in_queue: # For play_audio_responses
self.audio_in_queue.put_nowait(None) # Sentinel
async def run(self):
st.session_state['run_loop'] = True # Indicate loop is running
self.running = True
print("AudioLoop starting...")
try:
# `client.aio.live.connect` is an async context manager
async with client.aio.live.connect(model=MODEL, config=CONFIG) as session:
self.session = session
print("Gemini session established.")
# Send the system prompt first.
try:
print("Sending system prompt to Gemini...")
# end_of_turn=False means this text is part of the initial context for the first actual user interaction.
await self.session.send(input=MEDICAL_ASSISTANT_SYSTEM_PROMPT, end_of_turn=False)
print("System prompt sent.")
except Exception as e:
st.error(f"Failed to send system prompt to Gemini: {str(e)}")
traceback.print_exception(e)
self.running = False # Stop if system prompt fails critical setup
return # Exit run method
# Initialize queues within the async context if they depend on loop specifics
self.audio_in_queue = asyncio.Queue(maxsize=AUDIO_QUEUE_MAXSIZE)
self.out_queue = asyncio.Queue(maxsize=10) # For outgoing media to Gemini API
async with asyncio.TaskGroup() as tg:
# Start all background tasks
print("Starting child tasks...")
tg.create_task(self.send_realtime_media(), name="send_realtime_media")
tg.create_task(self.listen_for_audio(), name="listen_for_audio")
if self.video_mode == "camera":
tg.create_task(self.get_frames_from_camera(), name="get_frames_from_camera")
elif self.video_mode == "screen":
tg.create_task(self.get_frames_from_screen(), name="get_frames_from_screen")
# If mode is "none", no video task is started.
tg.create_task(self.receive_gemini_responses(), name="receive_gemini_responses")
tg.create_task(self.play_audio_responses(), name="play_audio_responses")
print("All child tasks created.")
# TaskGroup will wait for all tasks to complete here.
# If self.running is set to False, tasks should ideally notice and exit.
print("TaskGroup finished.")
except asyncio.CancelledError:
print("AudioLoop.run() was cancelled.") # Usually from TaskGroup cancellation
except ExceptionGroup as eg: # From TaskGroup if child tasks fail
st.error(f"Error in async tasks: {eg.exceptions[0]}") # Show first error in UI
print(f"ExceptionGroup caught in AudioLoop.run(): {eg}")
for i, exc in enumerate(eg.exceptions):
print(f" Exception {i+1}/{len(eg.exceptions)} in TaskGroup: {type(exc).__name__}: {exc}")
traceback.print_exception(type(exc), exc, exc.__traceback__)
except Exception as e:
st.error(f"Critical error in session: {str(e)}")
print(f"Exception caught in AudioLoop.run(): {type(e).__name__}: {e}")
traceback.print_exception(e)
finally:
print("AudioLoop.run() finishing, cleaning up...")
self.running = False # Ensure all loops stop
st.session_state['run_loop'] = False # Signal that the loop has stopped
# `self.session` will be closed automatically by the `async with` block for `client.aio.live.connect`
self.session = None
# Other stream closures are handled in their respective task's finally blocks
print("AudioLoop finished.")
def main():
st.title("Gemini Live Medical Assistant")
with st.sidebar:
st.subheader("Settings")
video_mode_options = ["camera", "screen", "none"]
# Ensure default video mode is in options, find its index
default_video_index = video_mode_options.index(DEFAULT_MODE) if DEFAULT_MODE in video_mode_options else 0
video_mode = st.selectbox("Video Source", video_mode_options, index=default_video_index)
if not st.session_state.get('run_loop', False): # If loop is not running
if st.button("Start Session", key="start_session_button"):
st.session_state.chat_messages = [{ # Clear chat and add system message
"role": "system",
"content": (
"Medical Assistant activated. The AI has been instructed on its role to visually assist you. "
"Please remember, this AI cannot provide medical diagnoses or replace consultation with a healthcare professional."
)
}]
st.session_state.current_frame = None # Clear previous frame
audio_loop = AudioLoop(video_mode=video_mode)
st.session_state.audio_loop = audio_loop
# Run the asyncio event loop in a new thread
# daemon=True allows Streamlit to exit even if this thread is stuck (though it shouldn't be)
threading.Thread(target=lambda: asyncio.run(audio_loop.run()), daemon=True).start()
st.success("Session started. Initializing assistant...")
st.rerun() # Rerun to update button state and messages
else: # If loop is running
if st.button("Stop Session", key="stop_session_button"):
if st.session_state.audio_loop:
st.session_state.audio_loop.stop_loop() # Signal async tasks to stop
# Wait a moment for tasks to attempt cleanup (optional, can be tricky)
# time.sleep(1)
st.session_state.audio_loop = None
st.warning("Session stopping...")
st.rerun() # Rerun to update UI
# Main content area
col1, col2 = st.columns([2, 3]) # Adjust column ratio as needed
with col1:
st.subheader("Video Feed")
if st.session_state.get('run_loop', False) and st.session_state.get('current_frame') is not None:
st.image(st.session_state['current_frame'], caption="Live Feed" if video_mode != "none" else "Video Disabled", use_column_width=True)
elif video_mode != "none":
st.info("Video feed will appear here when the session starts.")
else:
st.info("Video input is disabled.")
with col2:
st.subheader("Chat with Medical Assistant")
chat_container = st.container() # For scrolling chat
with chat_container:
for msg in st.session_state.chat_messages:
with st.chat_message(msg["role"]):
st.write(msg["content"])
prompt = st.chat_input("Ask about what you're showing...", key="chat_input_box", disabled=not st.session_state.get('run_loop', False))
if prompt:
st.session_state.chat_messages.append({"role": "user", "content": prompt})
st.rerun() # Show user message immediately
if st.session_state.audio_loop:
# The text needs to be sent from within the asyncio loop or by scheduling it.
# A simple way is to call a method on audio_loop that uses asyncio.create_task or similar.
# For direct call from thread to asyncio loop, ensure it's thread-safe.
# A better way is to put the text into a queue that send_text_to_gemini reads from,
# or use asyncio.run_coroutine_threadsafe if the loop is known.
# Current send_text_to_gemini is an async method.
# We need to run it in the event loop of the audio_loop's thread.
loop = asyncio.get_event_loop_policy().get_event_loop() # Get current thread's loop (might not be the one)
if st.session_state.audio_loop.session: # Ensure session exists
# This is a simplified approach; proper thread-safe coroutine scheduling is more robust.
# Consider using asyncio.run_coroutine_threadsafe if audio_loop.run() exposes its loop.
asyncio.run(st.session_state.audio_loop.send_text_to_gemini(prompt))
else:
st.error("Session not fully active to send message.")
else:
st.error("Session is not active. Please start a session.")
# Rerun after processing to show potential AI response (if text part comes quickly)
# st.rerun() # This might be too frequent, rely on receive_gemini_responses to update chat
if __name__ == "__main__":
# Global PyAudio termination hook (optional, for very clean shutdowns)
# def cleanup_pyaudio():
# print("Terminating PyAudio globally.")
# pya.terminate()
# import atexit
# atexit.register(cleanup_pyaudio)
main()