KeerthiVM commited on
Commit
50e8a0a
·
verified ·
1 Parent(s): 6f8eac5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +454 -0
app.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import transforms
6
+ from torchvision.models import vit_b_16, ViT_B_16_Weights
7
+ import pandas as pd
8
+ import io
9
+ import os
10
+ import base64
11
+ from fpdf import FPDF
12
+ from sqlalchemy import create_engine
13
+ from langchain.chains import RetrievalQA
14
+ from langchain.prompts import PromptTemplate
15
+ from qdrant_client import QdrantClient
16
+ from qdrant_client.http.models import Distance, VectorParams
17
+ from sentence_transformers import SentenceTransformer
18
+ from langchain_community.vectorstores.pgvector import PGVector
19
+ from langchain_postgres import PGVector
20
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
21
+ from langchain_community.vectorstores import Qdrant
22
+ from langchain_community.embeddings import HuggingFaceEmbeddings
23
+ from langchain_community.embeddings import SentenceTransformerEmbeddings
24
+ import streamlit as st
25
+ import torchvision.transforms as transforms
26
+ import torch
27
+ import torch.nn as nn
28
+ import numpy as np
29
+ from PIL import Image
30
+ import torch.nn.functional as F
31
+ from evo_vit import EvoViTModel
32
+ import io
33
+ import os
34
+ from fpdf import FPDF
35
+ from openai import OpenAI
36
+ from sqlalchemy import create_engine
37
+ from langchain_community.vectorstores.pgvector import PGVector
38
+ from langchain_postgres import PGVector
39
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
40
+ from langchain.chains import RetrievalQA
41
+ from langchain.prompts import PromptTemplate
42
+ from torchvision.models import resnet50
43
+ import nest_asyncio
44
+ from sentence_transformers import SentenceTransformer
45
+
46
+ # model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
47
+
48
+ nest_asyncio.apply()
49
+ device='cuda' if torch.cuda.is_available() else 'cpu'
50
+
51
+ st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
52
+
53
+ # os.environ["PGVECTOR_CONNECTION_STRING"] = "postgresql+psycopg2://postgres:postgres@localhost:5432/VectorDB"
54
+
55
+ # === Model Selection ===
56
+ available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
57
+ st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models)
58
+
59
+ # === Qdrant DB Setup ===
60
+ qdrant_client = QdrantClient(
61
+ url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
62
+ api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
63
+ )
64
+ collection_name = "ks_collection_1.5BE"
65
+ # embedding_model = SentenceTransformer("D:\DR\RAG\gte-Qwen2-1.5B-instruct", trust_remote_code=True)
66
+ # embedding_model.max_seq_length = 8192
67
+ # local_embedding = SentenceTransformerEmbeddings(model=embedding_model)
68
+
69
+ model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
70
+
71
+
72
+ local_embedding = HuggingFaceEmbeddings(
73
+ model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
74
+ model_kwargs={"trust_remote_code": True, "device": "cuda" if torch.cuda.is_available() else "cpu"}
75
+ )
76
+ print(" Qwen2-1.5B local embedding model loaded.")
77
+
78
+ vector_store = Qdrant(
79
+ client=qdrant_client,
80
+ collection_name=collection_name,
81
+ embeddings=local_embedding
82
+ )
83
+ retriever = vector_store.as_retriever()
84
+
85
+ '''
86
+ # === Init LLM and Vector DB ===
87
+
88
+ CONNECTION_STRING = "postgresql+psycopg2://postgres:postgres@localhost:5432/VectorDB"
89
+ engine = create_engine(CONNECTION_STRING)
90
+ embedding_model = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
91
+ '''
92
+ # Dynamically initialize LLM based on selection
93
+ OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
94
+ selected_model = st.session_state["selected_model"]
95
+ if "OpenAI" in selected_model:
96
+ llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=OPENAI_API_KEY)
97
+ elif "LLaMA" in selected_model:
98
+ st.warning("LLaMA integration is not implemented yet.")
99
+ st.stop()
100
+ elif "Gemini" in selected_model:
101
+ st.warning("Gemini integration is not implemented yet.")
102
+ st.stop()
103
+ else:
104
+ st.error("Unsupported model selected.")
105
+ st.stop()
106
+
107
+ '''
108
+ vector_store = PGVector.from_existing_index(
109
+ embedding=embedding_model,
110
+ connection=engine,
111
+ collection_name="documents"
112
+ )
113
+ '''
114
+ # retriever = vector_store.as_retriever()
115
+
116
+ AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
117
+ You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
118
+
119
+ Guidelines:
120
+ 1. Symptoms - Explain in simple terms with proper medical definitions.
121
+ 2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
122
+ 3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
123
+ 4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
124
+ 5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
125
+
126
+ Query: {question}
127
+ Relevant Information: {context}
128
+ Answer:
129
+ """
130
+ prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
131
+
132
+ rag_chain = RetrievalQA.from_chain_type(
133
+ llm=llm,
134
+ retriever=retriever,
135
+ chain_type="stuff",
136
+ chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
137
+ )
138
+
139
+
140
+ ''''
141
+ Load My models
142
+ '''
143
+
144
+ def load_model(model_path):
145
+ model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=2, hidden_dim=512)
146
+ model.classifier = nn.Linear(512, 1)
147
+ state_dict = torch.load(model_path, map_location=device)
148
+ new_state_dict = {}
149
+ for key, value in state_dict.items():
150
+ if key.startswith("backbone."):
151
+ new_key = key[len("backbone."):]
152
+ else:
153
+ new_key = key
154
+ new_state_dict[new_key] = value
155
+
156
+ if "classifier.weight" in new_state_dict:
157
+ original_weight = new_state_dict["classifier.weight"]
158
+ new_state_dict["classifier.weight"] = original_weight[0:1, :]
159
+ if "classifier.bias" in new_state_dict:
160
+ original_bias = new_state_dict["classifier.bias"]
161
+ new_state_dict["classifier.bias"] = original_bias[0:1]
162
+ model.load_state_dict(new_state_dict, strict=False)
163
+ model.to(device)
164
+ model.eval()
165
+ return model
166
+
167
+ def load_binary_models():
168
+ base_models = []
169
+ class_models_mapping = {
170
+ "Acne and Rosacea Photos": 'santhosh/10fold_model_acne.pth',
171
+ "Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions": 'santhosh/5fold_model_actinic.pth',
172
+ "Atopic Dermatitis Photos": 'keerthi/Atopic/best_global_model_5fold.pth',
173
+ "Bullous Disease Photos": 'santhosh/10fold_model_bullous.pth',
174
+ "Cellulitis Impetigo and other Bacterial Infections": 'santhosh/10fold_model_cellulitis.pth',
175
+ "Eczema Photos": 'santhosh/5fold_model_eczema.pth',
176
+ "ExanthemsandDrugEruptions": 'santhosh/10fold_model_exantherms.pth',
177
+ "Hair Loss Photos Alopecia and other Hair Diseases": 'keerthi/HairLoss/best_global_model_5fold.pth',
178
+ "Herpes HPV and other STDs Photos": 'keerthi/Herpes/best_global_model_5fold.pth',
179
+ "Light Diseases and Disorders of Pigmentation": 'santhosh/5fold_model_light.pth',
180
+ "Lupus and other Connective Tissue diseases": 'keerthi/Lupus/best_global_model_5fold.pth',
181
+ "Melanoma Skin Cancer Nevi and Moles": 'keerthi/Melanoma/best_global_model_10fold.pth',
182
+ "Nail Fungus and other Nail Disease": 'santhosh/5fold_model_nail.pth',
183
+ "Poison Ivy Photos and other Contact Dermatitis": 'santhosh/5fold_model_poison.pth',
184
+ "Psoriasis pictures Lichen Planus and related diseases": 'santhosh/10fold_model_psoriasis.pth',
185
+ "Scabies Lyme Disease and other Infestations and Bites": 'santhosh/5fold_model_scabies.pth',
186
+ "Seborrheic Keratoses and other Benign Tumors": 'santhosh/10fold_model_seboh.pth',
187
+ "Systemic Disease": 'keerthi/Systemic/best_global_model_5fold.pth',
188
+ "Tinea Ringworm Candidiasis and other Fungal Infections": 'santhosh/10fold_model_tinea.pth',
189
+ "Urticaria Hives": 'keerthi/Urticaria/best_global_model_10fold.pth',
190
+ "Vascular Tumors": 'keerthi/Vascular/best_global_model_5fold.pth',
191
+ "Vasculitis Photos": 'keerthi/Vasculitis/best_global_model_10fold.pth',
192
+ "Warts Molluscum and other Viral Infections": 'santhosh/10fold_model_warts.pth'
193
+ }
194
+ for class_name, rel_path in class_models_mapping.items():
195
+ model_path = os.path.join("best_models_overall", rel_path)
196
+ model = load_model(model_path)
197
+ base_models.append(model)
198
+ return base_models
199
+
200
+
201
+ class DynamicCNN(nn.Module):
202
+ def __init__(self, input_channels, fc_layers, num_classes, dropout_rate=0.3):
203
+ super(DynamicCNN, self).__init__()
204
+ fc_layers_list = []
205
+ in_dim = input_channels
206
+
207
+ for fc_dim in fc_layers:
208
+ fc_layers_list.append(nn.Linear(in_dim, fc_dim))
209
+ fc_layers_list.append(nn.BatchNorm1d(fc_dim))
210
+ fc_layers_list.append(nn.ReLU())
211
+ fc_layers_list.append(nn.Dropout(dropout_rate))
212
+ in_dim = fc_dim
213
+
214
+ fc_layers_list.append(nn.Linear(in_dim, num_classes))
215
+ self.fc = nn.Sequential(*fc_layers_list)
216
+
217
+ def forward(self, x):
218
+ x = self.fc(x)
219
+ return x
220
+
221
+
222
+ class SkinDiseaseClassifier:
223
+ def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
224
+ self.device = torch.device(device)
225
+ self.class_names = [
226
+ "Acne and Rosacea Photos",
227
+ "Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions",
228
+ "Atopic Dermatitis Photos",
229
+ "Bullous Disease Photos",
230
+ "Cellulitis Impetigo and other Bacterial Infections",
231
+ "Eczema Photos",
232
+ "ExanthemsandDrugEruptions",
233
+ "Hair Loss Photos Alopecia and other Hair Diseases",
234
+ "Herpes HPV and other STDs Photos",
235
+ "Light Diseases and Disorders of Pigmentation",
236
+ "Lupus and other Connective Tissue diseases",
237
+ "Melanoma Skin Cancer Nevi and Moles",
238
+ "Nail Fungus and other Nail Disease",
239
+ "Poison Ivy Photos and other Contact Dermatitis",
240
+ "Psoriasis pictures Lichen Planus and related diseases",
241
+ "Scabies Lyme Disease and other Infestations and Bites",
242
+ "Seborrheic Keratoses and other Benign Tumors",
243
+ "Systemic Disease",
244
+ "Tinea Ringworm Candidiasis and other Fungal Infections",
245
+ "Urticaria Hives",
246
+ "Vascular Tumors",
247
+ "Vasculitis Photos",
248
+ "Warts Molluscum and other Viral Infections"
249
+ ]
250
+
251
+ # Initialize models (they'll be loaded when needed)
252
+ self.base_models = None
253
+ self.meta_model = None
254
+ self.resnet_feature_extractor = None
255
+
256
+ # Image transformations
257
+ self.transform = transforms.Compose([
258
+ transforms.Resize((224, 224)),
259
+ transforms.ToTensor(),
260
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
261
+ ])
262
+
263
+ def load_models(self):
264
+ """Load all required models"""
265
+ # Load binary models
266
+ self.base_models = load_binary_models()
267
+ for model in self.base_models:
268
+ model.to(self.device)
269
+ model.eval()
270
+
271
+ # Load ResNet feature extractor
272
+ model = resnet50(pretrained=True)
273
+ layers = [model.layer1, model.layer2, model.layer3, model.layer4]
274
+ self.resnet_feature_extractor = nn.Sequential(
275
+ model.conv1, model.bn1, model.relu, model.maxpool, *layers
276
+ )
277
+ self.resnet_feature_extractor.to(self.device)
278
+ self.resnet_feature_extractor.eval()
279
+
280
+ # Load meta model
281
+ meta_model_path = 'best_meta_model_two_layer_version4.pth'
282
+ checkpoint = torch.load(meta_model_path, map_location=self.device)
283
+ correct_input_size = checkpoint['state_dict']['fc.0.weight'].shape[1]
284
+ input_size = 23 + 4 + 7168 # Adjust based on your actual feature size
285
+ fc_layers = [1024, 512, 256] # Use whatever was in your best model
286
+
287
+ self.meta_model = DynamicCNN(
288
+ input_channels=correct_input_size,
289
+ fc_layers=fc_layers,
290
+ num_classes=23,
291
+ dropout_rate=0.5
292
+ )
293
+
294
+ self.meta_model.load_state_dict(checkpoint['state_dict'])
295
+ self.meta_model.to(self.device)
296
+ self.meta_model.eval()
297
+
298
+ def extract_image_features(self, image_tensor):
299
+ """Extract features using ResNet"""
300
+ with torch.no_grad():
301
+ features = []
302
+ x = image_tensor
303
+ for layer in self.resnet_feature_extractor.children():
304
+ x = layer(x)
305
+ if isinstance(layer, nn.Sequential): # For residual blocks
306
+ # features.append(F.adaptive_avg_pool2d(x, (1, 1)).flatten(1))
307
+ # features = torch.cat(features, dim=1)
308
+ pooled = F.adaptive_avg_pool2d(x, (1, 1)).flatten(1)
309
+ features.append(pooled)
310
+ features = torch.cat(features, dim=1)
311
+ return features.cpu().numpy()
312
+
313
+ def predict(self, image, top_k=3):
314
+ """Make prediction for a single image"""
315
+ if self.base_models is None or self.meta_model is None:
316
+ self.load_models()
317
+
318
+ # Load and preprocess image
319
+ try:
320
+ # image = Image.open(image_path).convert('RGB')
321
+ image = image.convert('RGB')
322
+ except:
323
+ raise ValueError("Could not load image from path")
324
+
325
+ image_tensor = self.transform(image).unsqueeze(0).to(self.device)
326
+
327
+ # Extract features
328
+ with torch.no_grad():
329
+ # Get probabilities from each binary model
330
+ binary_probs = []
331
+ for model in self.base_models:
332
+ outputs = model(image_tensor)
333
+ probs = torch.sigmoid(outputs).squeeze(1)
334
+ binary_probs.append(probs)
335
+
336
+ binary_features = torch.stack(binary_probs, dim=1)
337
+
338
+ # Get image features
339
+ image_features = self.extract_image_features(image_tensor)
340
+ image_features = torch.from_numpy(image_features).float().to(self.device)
341
+
342
+ # Calculate probability statistics
343
+ top3_probs = torch.topk(binary_features, 3, dim=1).values
344
+ prob_stats = torch.stack([
345
+ binary_features.mean(dim=1, keepdim=True),
346
+ binary_features.std(dim=1, keepdim=True),
347
+ top3_probs.mean(dim=1, keepdim=True),
348
+ (top3_probs[:, 0] - top3_probs[:, 2]).unsqueeze(1) # Confidence gap
349
+ ], dim=1).squeeze(-1)
350
+
351
+ # Combine all features
352
+ combined_features = torch.cat([
353
+ binary_features,
354
+ image_features,
355
+ prob_stats
356
+ ], dim=1)
357
+
358
+ # Make prediction with meta-model
359
+ with torch.no_grad():
360
+ outputs = self.meta_model(combined_features)
361
+ probabilities = torch.softmax(outputs, dim=1).squeeze().cpu().numpy()
362
+
363
+ # Get top predictions
364
+ top_indices = np.argsort(probabilities)[-top_k:][::-1]
365
+ top_predictions = [
366
+ (self.class_names[i], float(probabilities[i]))
367
+ for i in top_indices
368
+ ]
369
+
370
+ return {
371
+ "top_predictions": top_predictions,
372
+ "all_probabilities": {name: float(prob) for name, prob in zip(self.class_names, probabilities)}
373
+ }
374
+
375
+
376
+ classifier = SkinDiseaseClassifier()
377
+
378
+
379
+ # === Session Init ===
380
+ if "messages" not in st.session_state:
381
+ st.session_state.messages = []
382
+
383
+
384
+ # === Image Processing Function ===
385
+ def run_inference(image):
386
+
387
+ result = classifier.predict(image, top_k=1)
388
+
389
+ # Display results
390
+ print("Top Predictions:")
391
+ for name, prob in result["top_predictions"]:
392
+ print(f"{name}: {prob:.2%}")
393
+
394
+ print("\nYou can also access all probabilities:")
395
+ print(result["all_probabilities"])
396
+
397
+ predicted_label = result["top_predictions"][0][0]
398
+ return predicted_label
399
+
400
+
401
+ # === PDF Export ===
402
+ def export_chat_to_pdf(messages):
403
+ pdf = FPDF()
404
+ pdf.add_page()
405
+ pdf.set_font("Arial", size=12)
406
+ for msg in messages:
407
+ role = "You" if msg["role"] == "user" else "AI"
408
+ pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
409
+ buf = io.BytesIO()
410
+ pdf.output(buf)
411
+ buf.seek(0)
412
+ return buf
413
+
414
+
415
+ # === App UI ===
416
+
417
+ st.title("🧬 DermBOT — Skin AI Assistant")
418
+ st.caption(f"🧠 Using model: {selected_model}")
419
+ uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"])
420
+
421
+ if uploaded_file:
422
+ st.image(uploaded_file, caption="Uploaded image", use_column_width=True)
423
+ image = Image.open(uploaded_file).convert("RGB")
424
+
425
+ predicted_label = run_inference(image)
426
+
427
+ # Show predictions clearly to the user
428
+ st.markdown(f" Most Likely Diagnosis : {predicted_label}")
429
+
430
+ query = f"What are my treatment options for {predicted_label}?"
431
+ st.session_state.messages.append({"role": "user", "content": query})
432
+
433
+ with st.spinner("Analyzing the image and retrieving response..."):
434
+ response = rag_chain.invoke(query)
435
+ st.session_state.messages.append({"role": "assistant", "content": response['result']})
436
+
437
+ with st.chat_message("assistant"):
438
+ st.markdown(response['result'])
439
+
440
+ # === Chat Interface ===
441
+ if prompt := st.chat_input("Ask a follow-up..."):
442
+ st.session_state.messages.append({"role": "user", "content": prompt})
443
+ with st.chat_message("user"):
444
+ st.markdown(prompt)
445
+
446
+ response = llm.invoke([{"role": m["role"], "content": m["content"]} for m in st.session_state.messages])
447
+ st.session_state.messages.append({"role": "assistant", "content": response.content})
448
+ with st.chat_message("assistant"):
449
+ st.markdown(response.content)
450
+
451
+ # === PDF Button ===
452
+ if st.button("📄 Download Chat as PDF"):
453
+ pdf_file = export_chat_to_pdf(st.session_state.messages)
454
+ st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf")