Spaces:
Sleeping
Sleeping
File size: 6,205 Bytes
e795ee7 ebeeb35 7f1ecff e795ee7 7f1ecff e795ee7 7f1ecff e795ee7 7f1ecff a242e19 e795ee7 7f1ecff e795ee7 ebeeb35 7f1ecff ebeeb35 7f1ecff e795ee7 7f1ecff e795ee7 7f1ecff e795ee7 7f1ecff e795ee7 7f1ecff e795ee7 7f1ecff e795ee7 7f1ecff e795ee7 7f1ecff e795ee7 7f1ecff e795ee7 7f1ecff e795ee7 7f1ecff e795ee7 7f1ecff e795ee7 7f1ecff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import streamlit as st
import os
import base64
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain.chains import LLMMathChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.utilities import WikipediaAPIWrapper
from langchain.agents.agent_types import AgentType
from langchain.agents import Tool, initialize_agent
from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
from groq import Groq
import open_clip
import torch
from PIL import Image
# Load environment variables
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
st.error("Groq API Key not found in .env file")
st.stop()
# Configure Streamlit
st.set_page_config(page_title="Medical Bot", page_icon="π¨βπ¬")
st.title("Medical Bot")
# Initialize LLM models
llm_text = ChatGroq(model="llama-3.3-70b-versatile", groq_api_key=groq_api_key)
llm_image = ChatGroq(model="llama-3.2-90b-vision-preview", groq_api_key=groq_api_key)
# Load BiomedCLIP model for image classification
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
def classify_image(image_path: str) -> str:
"""Classifies a medical image using BiomedCLIP."""
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device).eval()
image = preprocess_val(Image.open(image_path)).unsqueeze(0).to(device)
labels = ["MRI scan", "X-ray", "histopathology", "CT scan", "ultrasound", "medical chart"]
texts = tokenizer([f"this is a photo of {l}" for l in labels], context_length=256).to(device)
with torch.no_grad():
image_features, text_features, logit_scale = model(image, texts)
logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
sorted_indices = torch.argsort(logits, dim=-1, descending=True)
top_class = labels[sorted_indices[0][0].item()]
return f"The image is classified as {top_class}."
# Define tools
wikipedia_tool = Tool(name="Wikipedia", func=WikipediaAPIWrapper().run, description="A tool for searching information.")
math_chain = LLMMathChain.from_llm(llm=llm_text)
calculator = Tool(name="Calculator", func=math_chain.run, description="Solves mathematical problems.")
prompt_template = PromptTemplate(input_variables=["question"], template="""
You are a mathematical problem-solving assistant. Solve the question step by step.
Question: {question}
Answer:
""")
chain = LLMChain(llm=llm_text, prompt=prompt_template)
reasoning_tool = Tool(name="Reasoning Tool", func=chain.run, description="Answers logic-based questions.")
biomed_clip_tool = Tool(name="BiomedCLIP Image Classifier", func=classify_image, description="Classifies medical images.")
# Initialize agents
assistant_agent_text = initialize_agent(
tools=[wikipedia_tool, calculator, reasoning_tool],
llm=llm_text,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=False,
handle_parsing_errors=True
)
assistant_agent_image = initialize_agent(
tools=[biomed_clip_tool],
llm=llm_image,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=False,
handle_parsing_errors=True
)
# Streamlit session state for chat messages
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "content": "Welcome! How can I help you today?"}]
# Chat Interface
for msg in st.session_state.messages:
if msg["role"] == "user" and "image" in msg:
st.chat_message(msg["role"]).write(msg['content'])
st.image(msg["image"], caption='Uploaded Image', use_column_width=True)
else:
st.chat_message(msg["role"]).write(msg['content'])
st.sidebar.header("Navigation")
if st.sidebar.button("Text Question"):
st.session_state["section"] = "text"
if st.sidebar.button("Image Question"):
st.session_state["section"] = "image"
if "section" not in st.session_state:
st.session_state["section"] = "text"
def clean_response(response):
return response.split("```")[-1].strip() if "```" in response else response
if st.session_state["section"] == "text":
st.header("Text Question")
question = st.text_area("Your Question:")
if st.button("Get Answer"):
if question:
with st.spinner("Generating response..."):
st.session_state.messages.append({"role": "user", "content": question})
st.chat_message("user").write(question)
response = assistant_agent_text.run(question)
cleaned_response = clean_response(response)
st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
st.write('### Response:')
st.success(cleaned_response)
else:
st.warning("Please enter a question.")
elif st.session_state["section"] == "image":
st.header("Image Question")
question = st.text_area("Your Question:")
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if st.button("Get Answer"):
if question and uploaded_file:
with st.spinner("Generating response..."):
image_path = f"temp_{uploaded_file.name}"
with open(image_path, "wb") as f:
f.write(uploaded_file.read())
st.session_state.messages.append({"role": "user", "content": question, "image": image_path})
st.chat_message("user").write(question)
st.image(image_path, caption='Uploaded Image', use_column_width=True)
response = assistant_agent_image.run(image_path)
cleaned_response = clean_response(response)
st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
st.write('### Response:')
st.success(cleaned_response)
else:
st.warning("Please enter a question and upload an image.")
|