Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torchvision import transforms
|
6 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
7 |
+
import pandas as pd
|
8 |
+
import io
|
9 |
+
import os
|
10 |
+
import base64
|
11 |
+
from fpdf import FPDF
|
12 |
+
from sqlalchemy import create_engine
|
13 |
+
from langchain.chains import RetrievalQA
|
14 |
+
from langchain.prompts import PromptTemplate
|
15 |
+
from qdrant_client import QdrantClient
|
16 |
+
from qdrant_client.http.models import Distance, VectorParams
|
17 |
+
from sentence_transformers import SentenceTransformer
|
18 |
+
from langchain_community.vectorstores.pgvector import PGVector
|
19 |
+
from langchain_postgres import PGVector
|
20 |
+
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
21 |
+
from langchain_community.vectorstores import Qdrant
|
22 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
23 |
+
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
24 |
+
import streamlit as st
|
25 |
+
import torchvision.transforms as transforms
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
import numpy as np
|
29 |
+
from PIL import Image
|
30 |
+
import torch.nn.functional as F
|
31 |
+
from evo_vit import EvoViTModel
|
32 |
+
import io
|
33 |
+
import os
|
34 |
+
from fpdf import FPDF
|
35 |
+
from openai import OpenAI
|
36 |
+
from sqlalchemy import create_engine
|
37 |
+
from langchain_community.vectorstores.pgvector import PGVector
|
38 |
+
from langchain_postgres import PGVector
|
39 |
+
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
40 |
+
from langchain.chains import RetrievalQA
|
41 |
+
from langchain.prompts import PromptTemplate
|
42 |
+
from torchvision.models import resnet50
|
43 |
+
import nest_asyncio
|
44 |
+
from sentence_transformers import SentenceTransformer
|
45 |
+
|
46 |
+
# model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
|
47 |
+
|
48 |
+
nest_asyncio.apply()
|
49 |
+
device='cuda' if torch.cuda.is_available() else 'cpu'
|
50 |
+
|
51 |
+
st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
|
52 |
+
|
53 |
+
# os.environ["PGVECTOR_CONNECTION_STRING"] = "postgresql+psycopg2://postgres:postgres@localhost:5432/VectorDB"
|
54 |
+
|
55 |
+
# === Model Selection ===
|
56 |
+
available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
|
57 |
+
st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models)
|
58 |
+
|
59 |
+
# === Qdrant DB Setup ===
|
60 |
+
qdrant_client = QdrantClient(
|
61 |
+
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
|
62 |
+
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
|
63 |
+
)
|
64 |
+
collection_name = "ks_collection_1.5BE"
|
65 |
+
# embedding_model = SentenceTransformer("D:\DR\RAG\gte-Qwen2-1.5B-instruct", trust_remote_code=True)
|
66 |
+
# embedding_model.max_seq_length = 8192
|
67 |
+
# local_embedding = SentenceTransformerEmbeddings(model=embedding_model)
|
68 |
+
|
69 |
+
model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
|
70 |
+
|
71 |
+
|
72 |
+
local_embedding = HuggingFaceEmbeddings(
|
73 |
+
model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
74 |
+
model_kwargs={"trust_remote_code": True, "device": "cuda" if torch.cuda.is_available() else "cpu"}
|
75 |
+
)
|
76 |
+
print(" Qwen2-1.5B local embedding model loaded.")
|
77 |
+
|
78 |
+
vector_store = Qdrant(
|
79 |
+
client=qdrant_client,
|
80 |
+
collection_name=collection_name,
|
81 |
+
embeddings=local_embedding
|
82 |
+
)
|
83 |
+
retriever = vector_store.as_retriever()
|
84 |
+
|
85 |
+
'''
|
86 |
+
# === Init LLM and Vector DB ===
|
87 |
+
|
88 |
+
CONNECTION_STRING = "postgresql+psycopg2://postgres:postgres@localhost:5432/VectorDB"
|
89 |
+
engine = create_engine(CONNECTION_STRING)
|
90 |
+
embedding_model = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
|
91 |
+
'''
|
92 |
+
# Dynamically initialize LLM based on selection
|
93 |
+
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
|
94 |
+
selected_model = st.session_state["selected_model"]
|
95 |
+
if "OpenAI" in selected_model:
|
96 |
+
llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=OPENAI_API_KEY)
|
97 |
+
elif "LLaMA" in selected_model:
|
98 |
+
st.warning("LLaMA integration is not implemented yet.")
|
99 |
+
st.stop()
|
100 |
+
elif "Gemini" in selected_model:
|
101 |
+
st.warning("Gemini integration is not implemented yet.")
|
102 |
+
st.stop()
|
103 |
+
else:
|
104 |
+
st.error("Unsupported model selected.")
|
105 |
+
st.stop()
|
106 |
+
|
107 |
+
'''
|
108 |
+
vector_store = PGVector.from_existing_index(
|
109 |
+
embedding=embedding_model,
|
110 |
+
connection=engine,
|
111 |
+
collection_name="documents"
|
112 |
+
)
|
113 |
+
'''
|
114 |
+
# retriever = vector_store.as_retriever()
|
115 |
+
|
116 |
+
AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
|
117 |
+
You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
|
118 |
+
|
119 |
+
Guidelines:
|
120 |
+
1. Symptoms - Explain in simple terms with proper medical definitions.
|
121 |
+
2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
|
122 |
+
3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
|
123 |
+
4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
|
124 |
+
5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
|
125 |
+
|
126 |
+
Query: {question}
|
127 |
+
Relevant Information: {context}
|
128 |
+
Answer:
|
129 |
+
"""
|
130 |
+
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
|
131 |
+
|
132 |
+
rag_chain = RetrievalQA.from_chain_type(
|
133 |
+
llm=llm,
|
134 |
+
retriever=retriever,
|
135 |
+
chain_type="stuff",
|
136 |
+
chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
''''
|
141 |
+
Load My models
|
142 |
+
'''
|
143 |
+
|
144 |
+
def load_model(model_path):
|
145 |
+
model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=2, hidden_dim=512)
|
146 |
+
model.classifier = nn.Linear(512, 1)
|
147 |
+
state_dict = torch.load(model_path, map_location=device)
|
148 |
+
new_state_dict = {}
|
149 |
+
for key, value in state_dict.items():
|
150 |
+
if key.startswith("backbone."):
|
151 |
+
new_key = key[len("backbone."):]
|
152 |
+
else:
|
153 |
+
new_key = key
|
154 |
+
new_state_dict[new_key] = value
|
155 |
+
|
156 |
+
if "classifier.weight" in new_state_dict:
|
157 |
+
original_weight = new_state_dict["classifier.weight"]
|
158 |
+
new_state_dict["classifier.weight"] = original_weight[0:1, :]
|
159 |
+
if "classifier.bias" in new_state_dict:
|
160 |
+
original_bias = new_state_dict["classifier.bias"]
|
161 |
+
new_state_dict["classifier.bias"] = original_bias[0:1]
|
162 |
+
model.load_state_dict(new_state_dict, strict=False)
|
163 |
+
model.to(device)
|
164 |
+
model.eval()
|
165 |
+
return model
|
166 |
+
|
167 |
+
def load_binary_models():
|
168 |
+
base_models = []
|
169 |
+
class_models_mapping = {
|
170 |
+
"Acne and Rosacea Photos": 'santhosh/10fold_model_acne.pth',
|
171 |
+
"Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions": 'santhosh/5fold_model_actinic.pth',
|
172 |
+
"Atopic Dermatitis Photos": 'keerthi/Atopic/best_global_model_5fold.pth',
|
173 |
+
"Bullous Disease Photos": 'santhosh/10fold_model_bullous.pth',
|
174 |
+
"Cellulitis Impetigo and other Bacterial Infections": 'santhosh/10fold_model_cellulitis.pth',
|
175 |
+
"Eczema Photos": 'santhosh/5fold_model_eczema.pth',
|
176 |
+
"ExanthemsandDrugEruptions": 'santhosh/10fold_model_exantherms.pth',
|
177 |
+
"Hair Loss Photos Alopecia and other Hair Diseases": 'keerthi/HairLoss/best_global_model_5fold.pth',
|
178 |
+
"Herpes HPV and other STDs Photos": 'keerthi/Herpes/best_global_model_5fold.pth',
|
179 |
+
"Light Diseases and Disorders of Pigmentation": 'santhosh/5fold_model_light.pth',
|
180 |
+
"Lupus and other Connective Tissue diseases": 'keerthi/Lupus/best_global_model_5fold.pth',
|
181 |
+
"Melanoma Skin Cancer Nevi and Moles": 'keerthi/Melanoma/best_global_model_10fold.pth',
|
182 |
+
"Nail Fungus and other Nail Disease": 'santhosh/5fold_model_nail.pth',
|
183 |
+
"Poison Ivy Photos and other Contact Dermatitis": 'santhosh/5fold_model_poison.pth',
|
184 |
+
"Psoriasis pictures Lichen Planus and related diseases": 'santhosh/10fold_model_psoriasis.pth',
|
185 |
+
"Scabies Lyme Disease and other Infestations and Bites": 'santhosh/5fold_model_scabies.pth',
|
186 |
+
"Seborrheic Keratoses and other Benign Tumors": 'santhosh/10fold_model_seboh.pth',
|
187 |
+
"Systemic Disease": 'keerthi/Systemic/best_global_model_5fold.pth',
|
188 |
+
"Tinea Ringworm Candidiasis and other Fungal Infections": 'santhosh/10fold_model_tinea.pth',
|
189 |
+
"Urticaria Hives": 'keerthi/Urticaria/best_global_model_10fold.pth',
|
190 |
+
"Vascular Tumors": 'keerthi/Vascular/best_global_model_5fold.pth',
|
191 |
+
"Vasculitis Photos": 'keerthi/Vasculitis/best_global_model_10fold.pth',
|
192 |
+
"Warts Molluscum and other Viral Infections": 'santhosh/10fold_model_warts.pth'
|
193 |
+
}
|
194 |
+
for class_name, rel_path in class_models_mapping.items():
|
195 |
+
model_path = os.path.join("best_models_overall", rel_path)
|
196 |
+
model = load_model(model_path)
|
197 |
+
base_models.append(model)
|
198 |
+
return base_models
|
199 |
+
|
200 |
+
|
201 |
+
class DynamicCNN(nn.Module):
|
202 |
+
def __init__(self, input_channels, fc_layers, num_classes, dropout_rate=0.3):
|
203 |
+
super(DynamicCNN, self).__init__()
|
204 |
+
fc_layers_list = []
|
205 |
+
in_dim = input_channels
|
206 |
+
|
207 |
+
for fc_dim in fc_layers:
|
208 |
+
fc_layers_list.append(nn.Linear(in_dim, fc_dim))
|
209 |
+
fc_layers_list.append(nn.BatchNorm1d(fc_dim))
|
210 |
+
fc_layers_list.append(nn.ReLU())
|
211 |
+
fc_layers_list.append(nn.Dropout(dropout_rate))
|
212 |
+
in_dim = fc_dim
|
213 |
+
|
214 |
+
fc_layers_list.append(nn.Linear(in_dim, num_classes))
|
215 |
+
self.fc = nn.Sequential(*fc_layers_list)
|
216 |
+
|
217 |
+
def forward(self, x):
|
218 |
+
x = self.fc(x)
|
219 |
+
return x
|
220 |
+
|
221 |
+
|
222 |
+
class SkinDiseaseClassifier:
|
223 |
+
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
224 |
+
self.device = torch.device(device)
|
225 |
+
self.class_names = [
|
226 |
+
"Acne and Rosacea Photos",
|
227 |
+
"Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions",
|
228 |
+
"Atopic Dermatitis Photos",
|
229 |
+
"Bullous Disease Photos",
|
230 |
+
"Cellulitis Impetigo and other Bacterial Infections",
|
231 |
+
"Eczema Photos",
|
232 |
+
"ExanthemsandDrugEruptions",
|
233 |
+
"Hair Loss Photos Alopecia and other Hair Diseases",
|
234 |
+
"Herpes HPV and other STDs Photos",
|
235 |
+
"Light Diseases and Disorders of Pigmentation",
|
236 |
+
"Lupus and other Connective Tissue diseases",
|
237 |
+
"Melanoma Skin Cancer Nevi and Moles",
|
238 |
+
"Nail Fungus and other Nail Disease",
|
239 |
+
"Poison Ivy Photos and other Contact Dermatitis",
|
240 |
+
"Psoriasis pictures Lichen Planus and related diseases",
|
241 |
+
"Scabies Lyme Disease and other Infestations and Bites",
|
242 |
+
"Seborrheic Keratoses and other Benign Tumors",
|
243 |
+
"Systemic Disease",
|
244 |
+
"Tinea Ringworm Candidiasis and other Fungal Infections",
|
245 |
+
"Urticaria Hives",
|
246 |
+
"Vascular Tumors",
|
247 |
+
"Vasculitis Photos",
|
248 |
+
"Warts Molluscum and other Viral Infections"
|
249 |
+
]
|
250 |
+
|
251 |
+
# Initialize models (they'll be loaded when needed)
|
252 |
+
self.base_models = None
|
253 |
+
self.meta_model = None
|
254 |
+
self.resnet_feature_extractor = None
|
255 |
+
|
256 |
+
# Image transformations
|
257 |
+
self.transform = transforms.Compose([
|
258 |
+
transforms.Resize((224, 224)),
|
259 |
+
transforms.ToTensor(),
|
260 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
261 |
+
])
|
262 |
+
|
263 |
+
def load_models(self):
|
264 |
+
"""Load all required models"""
|
265 |
+
# Load binary models
|
266 |
+
self.base_models = load_binary_models()
|
267 |
+
for model in self.base_models:
|
268 |
+
model.to(self.device)
|
269 |
+
model.eval()
|
270 |
+
|
271 |
+
# Load ResNet feature extractor
|
272 |
+
model = resnet50(pretrained=True)
|
273 |
+
layers = [model.layer1, model.layer2, model.layer3, model.layer4]
|
274 |
+
self.resnet_feature_extractor = nn.Sequential(
|
275 |
+
model.conv1, model.bn1, model.relu, model.maxpool, *layers
|
276 |
+
)
|
277 |
+
self.resnet_feature_extractor.to(self.device)
|
278 |
+
self.resnet_feature_extractor.eval()
|
279 |
+
|
280 |
+
# Load meta model
|
281 |
+
meta_model_path = 'best_meta_model_two_layer_version4.pth'
|
282 |
+
checkpoint = torch.load(meta_model_path, map_location=self.device)
|
283 |
+
correct_input_size = checkpoint['state_dict']['fc.0.weight'].shape[1]
|
284 |
+
input_size = 23 + 4 + 7168 # Adjust based on your actual feature size
|
285 |
+
fc_layers = [1024, 512, 256] # Use whatever was in your best model
|
286 |
+
|
287 |
+
self.meta_model = DynamicCNN(
|
288 |
+
input_channels=correct_input_size,
|
289 |
+
fc_layers=fc_layers,
|
290 |
+
num_classes=23,
|
291 |
+
dropout_rate=0.5
|
292 |
+
)
|
293 |
+
|
294 |
+
self.meta_model.load_state_dict(checkpoint['state_dict'])
|
295 |
+
self.meta_model.to(self.device)
|
296 |
+
self.meta_model.eval()
|
297 |
+
|
298 |
+
def extract_image_features(self, image_tensor):
|
299 |
+
"""Extract features using ResNet"""
|
300 |
+
with torch.no_grad():
|
301 |
+
features = []
|
302 |
+
x = image_tensor
|
303 |
+
for layer in self.resnet_feature_extractor.children():
|
304 |
+
x = layer(x)
|
305 |
+
if isinstance(layer, nn.Sequential): # For residual blocks
|
306 |
+
# features.append(F.adaptive_avg_pool2d(x, (1, 1)).flatten(1))
|
307 |
+
# features = torch.cat(features, dim=1)
|
308 |
+
pooled = F.adaptive_avg_pool2d(x, (1, 1)).flatten(1)
|
309 |
+
features.append(pooled)
|
310 |
+
features = torch.cat(features, dim=1)
|
311 |
+
return features.cpu().numpy()
|
312 |
+
|
313 |
+
def predict(self, image, top_k=3):
|
314 |
+
"""Make prediction for a single image"""
|
315 |
+
if self.base_models is None or self.meta_model is None:
|
316 |
+
self.load_models()
|
317 |
+
|
318 |
+
# Load and preprocess image
|
319 |
+
try:
|
320 |
+
# image = Image.open(image_path).convert('RGB')
|
321 |
+
image = image.convert('RGB')
|
322 |
+
except:
|
323 |
+
raise ValueError("Could not load image from path")
|
324 |
+
|
325 |
+
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
326 |
+
|
327 |
+
# Extract features
|
328 |
+
with torch.no_grad():
|
329 |
+
# Get probabilities from each binary model
|
330 |
+
binary_probs = []
|
331 |
+
for model in self.base_models:
|
332 |
+
outputs = model(image_tensor)
|
333 |
+
probs = torch.sigmoid(outputs).squeeze(1)
|
334 |
+
binary_probs.append(probs)
|
335 |
+
|
336 |
+
binary_features = torch.stack(binary_probs, dim=1)
|
337 |
+
|
338 |
+
# Get image features
|
339 |
+
image_features = self.extract_image_features(image_tensor)
|
340 |
+
image_features = torch.from_numpy(image_features).float().to(self.device)
|
341 |
+
|
342 |
+
# Calculate probability statistics
|
343 |
+
top3_probs = torch.topk(binary_features, 3, dim=1).values
|
344 |
+
prob_stats = torch.stack([
|
345 |
+
binary_features.mean(dim=1, keepdim=True),
|
346 |
+
binary_features.std(dim=1, keepdim=True),
|
347 |
+
top3_probs.mean(dim=1, keepdim=True),
|
348 |
+
(top3_probs[:, 0] - top3_probs[:, 2]).unsqueeze(1) # Confidence gap
|
349 |
+
], dim=1).squeeze(-1)
|
350 |
+
|
351 |
+
# Combine all features
|
352 |
+
combined_features = torch.cat([
|
353 |
+
binary_features,
|
354 |
+
image_features,
|
355 |
+
prob_stats
|
356 |
+
], dim=1)
|
357 |
+
|
358 |
+
# Make prediction with meta-model
|
359 |
+
with torch.no_grad():
|
360 |
+
outputs = self.meta_model(combined_features)
|
361 |
+
probabilities = torch.softmax(outputs, dim=1).squeeze().cpu().numpy()
|
362 |
+
|
363 |
+
# Get top predictions
|
364 |
+
top_indices = np.argsort(probabilities)[-top_k:][::-1]
|
365 |
+
top_predictions = [
|
366 |
+
(self.class_names[i], float(probabilities[i]))
|
367 |
+
for i in top_indices
|
368 |
+
]
|
369 |
+
|
370 |
+
return {
|
371 |
+
"top_predictions": top_predictions,
|
372 |
+
"all_probabilities": {name: float(prob) for name, prob in zip(self.class_names, probabilities)}
|
373 |
+
}
|
374 |
+
|
375 |
+
|
376 |
+
classifier = SkinDiseaseClassifier()
|
377 |
+
|
378 |
+
|
379 |
+
# === Session Init ===
|
380 |
+
if "messages" not in st.session_state:
|
381 |
+
st.session_state.messages = []
|
382 |
+
|
383 |
+
|
384 |
+
# === Image Processing Function ===
|
385 |
+
def run_inference(image):
|
386 |
+
|
387 |
+
result = classifier.predict(image, top_k=1)
|
388 |
+
|
389 |
+
# Display results
|
390 |
+
print("Top Predictions:")
|
391 |
+
for name, prob in result["top_predictions"]:
|
392 |
+
print(f"{name}: {prob:.2%}")
|
393 |
+
|
394 |
+
print("\nYou can also access all probabilities:")
|
395 |
+
print(result["all_probabilities"])
|
396 |
+
|
397 |
+
predicted_label = result["top_predictions"][0][0]
|
398 |
+
return predicted_label
|
399 |
+
|
400 |
+
|
401 |
+
# === PDF Export ===
|
402 |
+
def export_chat_to_pdf(messages):
|
403 |
+
pdf = FPDF()
|
404 |
+
pdf.add_page()
|
405 |
+
pdf.set_font("Arial", size=12)
|
406 |
+
for msg in messages:
|
407 |
+
role = "You" if msg["role"] == "user" else "AI"
|
408 |
+
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
|
409 |
+
buf = io.BytesIO()
|
410 |
+
pdf.output(buf)
|
411 |
+
buf.seek(0)
|
412 |
+
return buf
|
413 |
+
|
414 |
+
|
415 |
+
# === App UI ===
|
416 |
+
|
417 |
+
st.title("🧬 DermBOT — Skin AI Assistant")
|
418 |
+
st.caption(f"🧠 Using model: {selected_model}")
|
419 |
+
uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"])
|
420 |
+
|
421 |
+
if uploaded_file:
|
422 |
+
st.image(uploaded_file, caption="Uploaded image", use_column_width=True)
|
423 |
+
image = Image.open(uploaded_file).convert("RGB")
|
424 |
+
|
425 |
+
predicted_label = run_inference(image)
|
426 |
+
|
427 |
+
# Show predictions clearly to the user
|
428 |
+
st.markdown(f" Most Likely Diagnosis : {predicted_label}")
|
429 |
+
|
430 |
+
query = f"What are my treatment options for {predicted_label}?"
|
431 |
+
st.session_state.messages.append({"role": "user", "content": query})
|
432 |
+
|
433 |
+
with st.spinner("Analyzing the image and retrieving response..."):
|
434 |
+
response = rag_chain.invoke(query)
|
435 |
+
st.session_state.messages.append({"role": "assistant", "content": response['result']})
|
436 |
+
|
437 |
+
with st.chat_message("assistant"):
|
438 |
+
st.markdown(response['result'])
|
439 |
+
|
440 |
+
# === Chat Interface ===
|
441 |
+
if prompt := st.chat_input("Ask a follow-up..."):
|
442 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
443 |
+
with st.chat_message("user"):
|
444 |
+
st.markdown(prompt)
|
445 |
+
|
446 |
+
response = llm.invoke([{"role": m["role"], "content": m["content"]} for m in st.session_state.messages])
|
447 |
+
st.session_state.messages.append({"role": "assistant", "content": response.content})
|
448 |
+
with st.chat_message("assistant"):
|
449 |
+
st.markdown(response.content)
|
450 |
+
|
451 |
+
# === PDF Button ===
|
452 |
+
if st.button("📄 Download Chat as PDF"):
|
453 |
+
pdf_file = export_chat_to_pdf(st.session_state.messages)
|
454 |
+
st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf")
|