File size: 6,205 Bytes
e795ee7
 
 
 
 
 
 
 
 
 
 
 
ebeeb35
7f1ecff
 
e795ee7
7f1ecff
e795ee7
 
 
 
 
 
7f1ecff
e795ee7
 
7f1ecff
 
a242e19
e795ee7
 
7f1ecff
 
 
e795ee7
ebeeb35
 
 
 
 
7f1ecff
ebeeb35
 
 
 
 
 
 
 
 
 
 
7f1ecff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e795ee7
7f1ecff
e795ee7
 
 
 
 
 
7f1ecff
 
 
 
 
 
 
 
 
e795ee7
7f1ecff
e795ee7
7f1ecff
e795ee7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f1ecff
e795ee7
 
 
 
 
 
 
 
 
7f1ecff
 
 
 
 
 
e795ee7
7f1ecff
e795ee7
 
 
7f1ecff
e795ee7
 
7f1ecff
e795ee7
7f1ecff
 
 
 
 
e795ee7
7f1ecff
 
 
 
 
 
 
e795ee7
7f1ecff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import streamlit as st
import os
import base64
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain.chains import LLMMathChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.utilities import WikipediaAPIWrapper
from langchain.agents.agent_types import AgentType
from langchain.agents import Tool, initialize_agent
from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
from groq import Groq
import open_clip
import torch
from PIL import Image

# Load environment variables
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
    st.error("Groq API Key not found in .env file")
    st.stop()

# Configure Streamlit
st.set_page_config(page_title="Medical Bot", page_icon="πŸ‘¨β€πŸ”¬")
st.title("Medical Bot")

# Initialize LLM models
llm_text = ChatGroq(model="llama-3.3-70b-versatile", groq_api_key=groq_api_key)
llm_image = ChatGroq(model="llama-3.2-90b-vision-preview", groq_api_key=groq_api_key)

# Load BiomedCLIP model for image classification
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

def classify_image(image_path: str) -> str:
    """Classifies a medical image using BiomedCLIP."""
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device).eval()
    
    image = preprocess_val(Image.open(image_path)).unsqueeze(0).to(device)
    labels = ["MRI scan", "X-ray", "histopathology", "CT scan", "ultrasound", "medical chart"]
    texts = tokenizer([f"this is a photo of {l}" for l in labels], context_length=256).to(device)
    
    with torch.no_grad():
        image_features, text_features, logit_scale = model(image, texts)
        logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
        sorted_indices = torch.argsort(logits, dim=-1, descending=True)
    
    top_class = labels[sorted_indices[0][0].item()]
    return f"The image is classified as {top_class}."

# Define tools
wikipedia_tool = Tool(name="Wikipedia", func=WikipediaAPIWrapper().run, description="A tool for searching information.")
math_chain = LLMMathChain.from_llm(llm=llm_text)
calculator = Tool(name="Calculator", func=math_chain.run, description="Solves mathematical problems.")

prompt_template = PromptTemplate(input_variables=["question"], template="""
You are a mathematical problem-solving assistant. Solve the question step by step.
Question: {question}
Answer:
""")
chain = LLMChain(llm=llm_text, prompt=prompt_template)
reasoning_tool = Tool(name="Reasoning Tool", func=chain.run, description="Answers logic-based questions.")

biomed_clip_tool = Tool(name="BiomedCLIP Image Classifier", func=classify_image, description="Classifies medical images.")

# Initialize agents
assistant_agent_text = initialize_agent(
    tools=[wikipedia_tool, calculator, reasoning_tool],
    llm=llm_text,
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=False,
    handle_parsing_errors=True
)

assistant_agent_image = initialize_agent(
    tools=[biomed_clip_tool],
    llm=llm_image,
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=False,
    handle_parsing_errors=True
)

# Streamlit session state for chat messages
if "messages" not in st.session_state:
    st.session_state["messages"] = [{"role": "assistant", "content": "Welcome! How can I help you today?"}]

# Chat Interface
for msg in st.session_state.messages:
    if msg["role"] == "user" and "image" in msg:
        st.chat_message(msg["role"]).write(msg['content'])
        st.image(msg["image"], caption='Uploaded Image', use_column_width=True)
    else:
        st.chat_message(msg["role"]).write(msg['content'])

st.sidebar.header("Navigation")
if st.sidebar.button("Text Question"):
    st.session_state["section"] = "text"
if st.sidebar.button("Image Question"):
    st.session_state["section"] = "image"
if "section" not in st.session_state:
    st.session_state["section"] = "text"

def clean_response(response):
    return response.split("```")[-1].strip() if "```" in response else response

if st.session_state["section"] == "text":
    st.header("Text Question")
    question = st.text_area("Your Question:")
    if st.button("Get Answer"):
        if question:
            with st.spinner("Generating response..."):
                st.session_state.messages.append({"role": "user", "content": question})
                st.chat_message("user").write(question)
                
                response = assistant_agent_text.run(question)
                cleaned_response = clean_response(response)
                st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
                st.write('### Response:')
                st.success(cleaned_response)
        else:
            st.warning("Please enter a question.")

elif st.session_state["section"] == "image":
    st.header("Image Question")
    question = st.text_area("Your Question:")
    uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
    if st.button("Get Answer"):
        if question and uploaded_file:
            with st.spinner("Generating response..."):
                image_path = f"temp_{uploaded_file.name}"
                with open(image_path, "wb") as f:
                    f.write(uploaded_file.read())
                
                st.session_state.messages.append({"role": "user", "content": question, "image": image_path})
                st.chat_message("user").write(question)
                st.image(image_path, caption='Uploaded Image', use_column_width=True)
                
                response = assistant_agent_image.run(image_path)
                cleaned_response = clean_response(response)
                st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
                st.write('### Response:')
                st.success(cleaned_response)
        else:
            st.warning("Please enter a question and upload an image.")