KeerthiVM commited on
Commit
79cab30
·
1 Parent(s): 1b1c00e
Files changed (3) hide show
  1. SkinCancerDiagnosis.py +268 -0
  2. app.py +4 -344
  3. rag_pipeline.py +59 -0
SkinCancerDiagnosis.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as transforms
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch.nn.functional as F
7
+ from evo_vit import EvoViTModel
8
+ import io
9
+ import os
10
+ from fpdf import FPDF
11
+ from torchvision.models import resnet50
12
+ import nest_asyncio
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ device='cuda' if torch.cuda.is_available() else 'cpu'
16
+
17
+ def load_model(repo_id, filename):
18
+ model_path = hf_hub_download(
19
+ repo_id=repo_id,
20
+ filename=filename,
21
+ )
22
+ model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=2, hidden_dim=512)
23
+ model.classifier = nn.Linear(512, 1)
24
+ state_dict = torch.load(model_path, map_location=device)
25
+ new_state_dict = {}
26
+ for key, value in state_dict.items():
27
+ if key.startswith("backbone."):
28
+ new_key = key[len("backbone."):]
29
+ else:
30
+ new_key = key
31
+ new_state_dict[new_key] = value
32
+
33
+ if "classifier.weight" in new_state_dict:
34
+ original_weight = new_state_dict["classifier.weight"]
35
+ new_state_dict["classifier.weight"] = original_weight[0:1, :]
36
+ if "classifier.bias" in new_state_dict:
37
+ original_bias = new_state_dict["classifier.bias"]
38
+ new_state_dict["classifier.bias"] = original_bias[0:1]
39
+ model.load_state_dict(new_state_dict, strict=False)
40
+ model.to(device)
41
+ model.eval()
42
+ return model
43
+
44
+ def load_binary_models():
45
+ base_models = []
46
+ class_models_mapping = {
47
+ "Acne and Rosacea Photos": 'santhosh/10fold_model_acne.pth',
48
+ "Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions": 'santhosh/5fold_model_actinic.pth',
49
+ "Atopic Dermatitis Photos": 'keerthi/Atopic/best_global_model_5fold.pth',
50
+ "Bullous Disease Photos": 'santhosh/10fold_model_bullous.pth',
51
+ "Cellulitis Impetigo and other Bacterial Infections": 'santhosh/10fold_model_cellulitis.pth',
52
+ "Eczema Photos": 'santhosh/5fold_model_eczema.pth',
53
+ "ExanthemsandDrugEruptions": 'santhosh/10fold_model_exantherms.pth',
54
+ "Hair Loss Photos Alopecia and other Hair Diseases": 'keerthi/HairLoss/best_global_model_5fold.pth',
55
+ "Herpes HPV and other STDs Photos": 'keerthi/Herpes/best_global_model_5fold.pth',
56
+ "Light Diseases and Disorders of Pigmentation": 'santhosh/5fold_model_light.pth',
57
+ "Lupus and other Connective Tissue diseases": 'keerthi/Lupus/best_global_model_5fold.pth',
58
+ "Melanoma Skin Cancer Nevi and Moles": 'keerthi/Melanoma/best_global_model_10fold.pth',
59
+ "Nail Fungus and other Nail Disease": 'santhosh/5fold_model_nail.pth',
60
+ "Poison Ivy Photos and other Contact Dermatitis": 'santhosh/5fold_model_poison.pth',
61
+ "Psoriasis pictures Lichen Planus and related diseases": 'santhosh/10fold_model_psoriasis.pth',
62
+ "Scabies Lyme Disease and other Infestations and Bites": 'santhosh/5fold_model_scabies.pth',
63
+ "Seborrheic Keratoses and other Benign Tumors": 'santhosh/10fold_model_seboh.pth',
64
+ "Systemic Disease": 'keerthi/Systemic/best_global_model_5fold.pth',
65
+ "Tinea Ringworm Candidiasis and other Fungal Infections": 'santhosh/10fold_model_tinea.pth',
66
+ "Urticaria Hives": 'keerthi/Urticaria/best_global_model_10fold.pth',
67
+ "Vascular Tumors": 'keerthi/Vascular/best_global_model_5fold.pth',
68
+ "Vasculitis Photos": 'keerthi/Vasculitis/best_global_model_10fold.pth',
69
+ "Warts Molluscum and other Viral Infections": 'santhosh/10fold_model_warts.pth'
70
+ }
71
+ repo_id = "KeerthiVM/SkinCancerDiagnosis" # Your Hugging Face repo
72
+
73
+ for class_name, filename in class_models_mapping.items():
74
+ # model_path = os.path.join("best_models_overall", rel_path)
75
+ model = load_model(repo_id, filename)
76
+ base_models.append(model)
77
+ return base_models
78
+
79
+
80
+ class DynamicCNN(nn.Module):
81
+ def __init__(self, input_channels, fc_layers, num_classes, dropout_rate=0.3):
82
+ super(DynamicCNN, self).__init__()
83
+ fc_layers_list = []
84
+ in_dim = input_channels
85
+
86
+ for fc_dim in fc_layers:
87
+ fc_layers_list.append(nn.Linear(in_dim, fc_dim))
88
+ fc_layers_list.append(nn.BatchNorm1d(fc_dim))
89
+ fc_layers_list.append(nn.ReLU())
90
+ fc_layers_list.append(nn.Dropout(dropout_rate))
91
+ in_dim = fc_dim
92
+
93
+ fc_layers_list.append(nn.Linear(in_dim, num_classes))
94
+ self.fc = nn.Sequential(*fc_layers_list)
95
+
96
+ def forward(self, x):
97
+ x = self.fc(x)
98
+ return x
99
+
100
+
101
+ class SkinDiseaseClassifier:
102
+ def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
103
+ self.device = torch.device(device)
104
+ self.class_names = [
105
+ "Acne and Rosacea Photos",
106
+ "Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions",
107
+ "Atopic Dermatitis Photos",
108
+ "Bullous Disease Photos",
109
+ "Cellulitis Impetigo and other Bacterial Infections",
110
+ "Eczema Photos",
111
+ "ExanthemsandDrugEruptions",
112
+ "Hair Loss Photos Alopecia and other Hair Diseases",
113
+ "Herpes HPV and other STDs Photos",
114
+ "Light Diseases and Disorders of Pigmentation",
115
+ "Lupus and other Connective Tissue diseases",
116
+ "Melanoma Skin Cancer Nevi and Moles",
117
+ "Nail Fungus and other Nail Disease",
118
+ "Poison Ivy Photos and other Contact Dermatitis",
119
+ "Psoriasis pictures Lichen Planus and related diseases",
120
+ "Scabies Lyme Disease and other Infestations and Bites",
121
+ "Seborrheic Keratoses and other Benign Tumors",
122
+ "Systemic Disease",
123
+ "Tinea Ringworm Candidiasis and other Fungal Infections",
124
+ "Urticaria Hives",
125
+ "Vascular Tumors",
126
+ "Vasculitis Photos",
127
+ "Warts Molluscum and other Viral Infections"
128
+ ]
129
+
130
+ # Initialize models (they'll be loaded when needed)
131
+ self.base_models = None
132
+ self.meta_model = None
133
+ self.resnet_feature_extractor = None
134
+
135
+ # Image transformations
136
+ self.transform = transforms.Compose([
137
+ transforms.Resize((224, 224)),
138
+ transforms.ToTensor(),
139
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
140
+ ])
141
+
142
+ def load_models(self):
143
+ """Load all required models"""
144
+ # Load binary models
145
+ self.base_models = load_binary_models()
146
+ for model in self.base_models:
147
+ model.to(self.device)
148
+ model.eval()
149
+
150
+ # Load ResNet feature extractor
151
+ model = resnet50(pretrained=True)
152
+ layers = [model.layer1, model.layer2, model.layer3, model.layer4]
153
+ self.resnet_feature_extractor = nn.Sequential(
154
+ model.conv1, model.bn1, model.relu, model.maxpool, *layers
155
+ )
156
+ self.resnet_feature_extractor.to(self.device)
157
+ self.resnet_feature_extractor.eval()
158
+
159
+ # Load meta model
160
+ print("=== Loading model with weights_only=False ===")
161
+ meta_model_path = hf_hub_download(
162
+ repo_id="KeerthiVM/SkinCancerDiagnosis",
163
+ filename="best_meta_model_two_layer_version4.pth"
164
+ )
165
+ checkpoint = torch.load(meta_model_path, map_location=self.device, weights_only=False)
166
+
167
+ correct_input_size = checkpoint['state_dict']['fc.0.weight'].shape[1]
168
+ input_size = 23 + 4 + 7168 # Adjust based on your actual feature size
169
+ fc_layers = [1024, 512, 256] # Use whatever was in your best model
170
+
171
+ self.meta_model = DynamicCNN(
172
+ input_channels=correct_input_size,
173
+ fc_layers=fc_layers,
174
+ num_classes=23,
175
+ dropout_rate=0.5
176
+ )
177
+
178
+ self.meta_model.load_state_dict(checkpoint['state_dict'])
179
+ self.meta_model.to(self.device)
180
+ self.meta_model.eval()
181
+
182
+ def extract_image_features(self, image_tensor):
183
+ """Extract features using ResNet"""
184
+ with torch.no_grad():
185
+ features = []
186
+ x = image_tensor
187
+ for layer in self.resnet_feature_extractor.children():
188
+ x = layer(x)
189
+ if isinstance(layer, nn.Sequential): # For residual blocks
190
+ # features.append(F.adaptive_avg_pool2d(x, (1, 1)).flatten(1))
191
+ # features = torch.cat(features, dim=1)
192
+ pooled = F.adaptive_avg_pool2d(x, (1, 1)).flatten(1)
193
+ features.append(pooled)
194
+ features = torch.cat(features, dim=1)
195
+ return features.cpu().numpy()
196
+
197
+ def predict(self, image, top_k=3):
198
+ """Make prediction for a single image"""
199
+ if self.base_models is None or self.meta_model is None:
200
+ # self.load_models()
201
+ raise RuntimeError("Models not loaded - call load_models() first")
202
+
203
+ # Load and preprocess image
204
+ try:
205
+ # image = Image.open(image_path).convert('RGB')
206
+ image = image.convert('RGB')
207
+ except:
208
+ raise ValueError("Could not load image from path")
209
+
210
+ image_tensor = self.transform(image).unsqueeze(0).to(self.device)
211
+
212
+ # Extract features
213
+ with torch.no_grad():
214
+ # Get probabilities from each binary model
215
+ binary_probs = []
216
+ for model in self.base_models:
217
+ outputs = model(image_tensor)
218
+ probs = torch.sigmoid(outputs).squeeze(1)
219
+ binary_probs.append(probs)
220
+
221
+ binary_features = torch.stack(binary_probs, dim=1)
222
+
223
+ # Get image features
224
+ image_features = self.extract_image_features(image_tensor)
225
+ image_features = torch.from_numpy(image_features).float().to(self.device)
226
+
227
+ # Calculate probability statistics
228
+ top3_probs = torch.topk(binary_features, 3, dim=1).values
229
+ prob_stats = torch.stack([
230
+ binary_features.mean(dim=1, keepdim=True),
231
+ binary_features.std(dim=1, keepdim=True),
232
+ top3_probs.mean(dim=1, keepdim=True),
233
+ (top3_probs[:, 0] - top3_probs[:, 2]).unsqueeze(1) # Confidence gap
234
+ ], dim=1).squeeze(-1)
235
+
236
+ # Combine all features
237
+ combined_features = torch.cat([
238
+ binary_features,
239
+ image_features,
240
+ prob_stats
241
+ ], dim=1)
242
+
243
+ # Make prediction with meta-model
244
+ with torch.no_grad():
245
+ outputs = self.meta_model(combined_features)
246
+ probabilities = torch.softmax(outputs, dim=1).squeeze().cpu().numpy()
247
+
248
+ # Get top predictions
249
+ top_indices = np.argsort(probabilities)[-top_k:][::-1]
250
+ top_predictions = [
251
+ (self.class_names[i], float(probabilities[i]))
252
+ for i in top_indices
253
+ ]
254
+
255
+ return {
256
+ "top_predictions": top_predictions,
257
+ "all_probabilities": {name: float(prob) for name, prob in zip(self.class_names, probabilities)}
258
+ }
259
+
260
+ def initialize_classifier():
261
+ print("⚙️ Initializing skin disease classifier...")
262
+ classifier = SkinDiseaseClassifier()
263
+ classifier.load_models()
264
+ dummy_img = Image.new('RGB', (224, 224))
265
+ classifier.predict(dummy_img)
266
+
267
+ print("⚙️ Initialization successful")
268
+ return classifier
app.py CHANGED
@@ -1,7 +1,3 @@
1
- from qdrant_client import QdrantClient
2
- from langchain_qdrant import Qdrant
3
- from langchain_community.embeddings import HuggingFaceEmbeddings
4
- from langchain_community.embeddings import SentenceTransformerEmbeddings
5
  import streamlit as st
6
  import torchvision.transforms as transforms
7
  import torch
@@ -13,60 +9,21 @@ from evo_vit import EvoViTModel
13
  import io
14
  import os
15
  from fpdf import FPDF
16
- from langchain_openai import OpenAIEmbeddings, ChatOpenAI
17
- from langchain.chains import RetrievalQA
18
- from langchain.prompts import PromptTemplate
19
  from torchvision.models import resnet50
20
  import nest_asyncio
21
- from sentence_transformers import SentenceTransformer
22
  from huggingface_hub import hf_hub_download
23
-
24
- # model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
 
25
 
26
  nest_asyncio.apply()
27
  device='cuda' if torch.cuda.is_available() else 'cpu'
28
-
29
  st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
30
 
31
- # os.environ["PGVECTOR_CONNECTION_STRING"] = "postgresql+psycopg2://postgres:postgres@localhost:5432/VectorDB"
32
-
33
  # === Model Selection ===
34
  available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
35
  st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models)
36
 
37
- # === Qdrant DB Setup ===
38
- qdrant_client = QdrantClient(
39
- url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
40
- api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
41
- )
42
- collection_name = "ks_collection_1.5BE"
43
- # embedding_model = SentenceTransformer("D:\DR\RAG\gte-Qwen2-1.5B-instruct", trust_remote_code=True)
44
- # embedding_model.max_seq_length = 8192
45
- # local_embedding = SentenceTransformerEmbeddings(model=embedding_model)
46
-
47
- model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
48
-
49
-
50
- local_embedding = HuggingFaceEmbeddings(
51
- model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
52
- model_kwargs={"trust_remote_code": True, "device": "cuda" if torch.cuda.is_available() else "cpu"}
53
- )
54
- print(" Qwen2-1.5B local embedding model loaded.")
55
-
56
- vector_store = Qdrant(
57
- client=qdrant_client,
58
- collection_name=collection_name,
59
- embeddings=local_embedding
60
- )
61
- retriever = vector_store.as_retriever()
62
-
63
- '''
64
- # === Init LLM and Vector DB ===
65
-
66
- CONNECTION_STRING = "postgresql+psycopg2://postgres:postgres@localhost:5432/VectorDB"
67
- engine = create_engine(CONNECTION_STRING)
68
- embedding_model = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
69
- '''
70
  # Dynamically initialize LLM based on selection
71
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
72
  selected_model = st.session_state["selected_model"]
@@ -82,304 +39,7 @@ else:
82
  st.error("Unsupported model selected.")
83
  st.stop()
84
 
85
- '''
86
- vector_store = PGVector.from_existing_index(
87
- embedding=embedding_model,
88
- connection=engine,
89
- collection_name="documents"
90
- )
91
- '''
92
- # retriever = vector_store.as_retriever()
93
-
94
- AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
95
- You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
96
-
97
- Guidelines:
98
- 1. Symptoms - Explain in simple terms with proper medical definitions.
99
- 2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
100
- 3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
101
- 4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
102
- 5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
103
-
104
- Query: {question}
105
- Relevant Information: {context}
106
- Answer:
107
- """
108
- prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
109
-
110
- rag_chain = RetrievalQA.from_chain_type(
111
- llm=llm,
112
- retriever=retriever,
113
- chain_type="stuff",
114
- chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
115
- )
116
-
117
-
118
- ''''
119
- Load My models
120
- '''
121
-
122
-
123
- def load_model(repo_id, filename):
124
- model_path = hf_hub_download(
125
- repo_id=repo_id,
126
- filename=filename,
127
- )
128
- model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=2, hidden_dim=512)
129
- model.classifier = nn.Linear(512, 1)
130
- state_dict = torch.load(model_path, map_location=device)
131
- new_state_dict = {}
132
- for key, value in state_dict.items():
133
- if key.startswith("backbone."):
134
- new_key = key[len("backbone."):]
135
- else:
136
- new_key = key
137
- new_state_dict[new_key] = value
138
-
139
- if "classifier.weight" in new_state_dict:
140
- original_weight = new_state_dict["classifier.weight"]
141
- new_state_dict["classifier.weight"] = original_weight[0:1, :]
142
- if "classifier.bias" in new_state_dict:
143
- original_bias = new_state_dict["classifier.bias"]
144
- new_state_dict["classifier.bias"] = original_bias[0:1]
145
- model.load_state_dict(new_state_dict, strict=False)
146
- model.to(device)
147
- model.eval()
148
- return model
149
-
150
- def load_binary_models():
151
- base_models = []
152
- class_models_mapping = {
153
- "Acne and Rosacea Photos": 'santhosh/10fold_model_acne.pth',
154
- "Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions": 'santhosh/5fold_model_actinic.pth',
155
- "Atopic Dermatitis Photos": 'keerthi/Atopic/best_global_model_5fold.pth',
156
- "Bullous Disease Photos": 'santhosh/10fold_model_bullous.pth',
157
- "Cellulitis Impetigo and other Bacterial Infections": 'santhosh/10fold_model_cellulitis.pth',
158
- "Eczema Photos": 'santhosh/5fold_model_eczema.pth',
159
- "ExanthemsandDrugEruptions": 'santhosh/10fold_model_exantherms.pth',
160
- "Hair Loss Photos Alopecia and other Hair Diseases": 'keerthi/HairLoss/best_global_model_5fold.pth',
161
- "Herpes HPV and other STDs Photos": 'keerthi/Herpes/best_global_model_5fold.pth',
162
- "Light Diseases and Disorders of Pigmentation": 'santhosh/5fold_model_light.pth',
163
- "Lupus and other Connective Tissue diseases": 'keerthi/Lupus/best_global_model_5fold.pth',
164
- "Melanoma Skin Cancer Nevi and Moles": 'keerthi/Melanoma/best_global_model_10fold.pth',
165
- "Nail Fungus and other Nail Disease": 'santhosh/5fold_model_nail.pth',
166
- "Poison Ivy Photos and other Contact Dermatitis": 'santhosh/5fold_model_poison.pth',
167
- "Psoriasis pictures Lichen Planus and related diseases": 'santhosh/10fold_model_psoriasis.pth',
168
- "Scabies Lyme Disease and other Infestations and Bites": 'santhosh/5fold_model_scabies.pth',
169
- "Seborrheic Keratoses and other Benign Tumors": 'santhosh/10fold_model_seboh.pth',
170
- "Systemic Disease": 'keerthi/Systemic/best_global_model_5fold.pth',
171
- "Tinea Ringworm Candidiasis and other Fungal Infections": 'santhosh/10fold_model_tinea.pth',
172
- "Urticaria Hives": 'keerthi/Urticaria/best_global_model_10fold.pth',
173
- "Vascular Tumors": 'keerthi/Vascular/best_global_model_5fold.pth',
174
- "Vasculitis Photos": 'keerthi/Vasculitis/best_global_model_10fold.pth',
175
- "Warts Molluscum and other Viral Infections": 'santhosh/10fold_model_warts.pth'
176
- }
177
- repo_id = "KeerthiVM/SkinCancerDiagnosis" # Your Hugging Face repo
178
-
179
- for class_name, filename in class_models_mapping.items():
180
- # model_path = os.path.join("best_models_overall", rel_path)
181
- model = load_model(repo_id, filename)
182
- base_models.append(model)
183
- return base_models
184
-
185
-
186
- class DynamicCNN(nn.Module):
187
- def __init__(self, input_channels, fc_layers, num_classes, dropout_rate=0.3):
188
- super(DynamicCNN, self).__init__()
189
- fc_layers_list = []
190
- in_dim = input_channels
191
-
192
- for fc_dim in fc_layers:
193
- fc_layers_list.append(nn.Linear(in_dim, fc_dim))
194
- fc_layers_list.append(nn.BatchNorm1d(fc_dim))
195
- fc_layers_list.append(nn.ReLU())
196
- fc_layers_list.append(nn.Dropout(dropout_rate))
197
- in_dim = fc_dim
198
-
199
- fc_layers_list.append(nn.Linear(in_dim, num_classes))
200
- self.fc = nn.Sequential(*fc_layers_list)
201
-
202
- def forward(self, x):
203
- x = self.fc(x)
204
- return x
205
-
206
-
207
- class SkinDiseaseClassifier:
208
- def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
209
- self.device = torch.device(device)
210
- self.class_names = [
211
- "Acne and Rosacea Photos",
212
- "Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions",
213
- "Atopic Dermatitis Photos",
214
- "Bullous Disease Photos",
215
- "Cellulitis Impetigo and other Bacterial Infections",
216
- "Eczema Photos",
217
- "ExanthemsandDrugEruptions",
218
- "Hair Loss Photos Alopecia and other Hair Diseases",
219
- "Herpes HPV and other STDs Photos",
220
- "Light Diseases and Disorders of Pigmentation",
221
- "Lupus and other Connective Tissue diseases",
222
- "Melanoma Skin Cancer Nevi and Moles",
223
- "Nail Fungus and other Nail Disease",
224
- "Poison Ivy Photos and other Contact Dermatitis",
225
- "Psoriasis pictures Lichen Planus and related diseases",
226
- "Scabies Lyme Disease and other Infestations and Bites",
227
- "Seborrheic Keratoses and other Benign Tumors",
228
- "Systemic Disease",
229
- "Tinea Ringworm Candidiasis and other Fungal Infections",
230
- "Urticaria Hives",
231
- "Vascular Tumors",
232
- "Vasculitis Photos",
233
- "Warts Molluscum and other Viral Infections"
234
- ]
235
-
236
- # Initialize models (they'll be loaded when needed)
237
- self.base_models = None
238
- self.meta_model = None
239
- self.resnet_feature_extractor = None
240
-
241
- # Image transformations
242
- self.transform = transforms.Compose([
243
- transforms.Resize((224, 224)),
244
- transforms.ToTensor(),
245
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
246
- ])
247
-
248
- def load_models(self):
249
- """Load all required models"""
250
- # Load binary models
251
- self.base_models = load_binary_models()
252
- for model in self.base_models:
253
- model.to(self.device)
254
- model.eval()
255
-
256
- # Load ResNet feature extractor
257
- model = resnet50(pretrained=True)
258
- layers = [model.layer1, model.layer2, model.layer3, model.layer4]
259
- self.resnet_feature_extractor = nn.Sequential(
260
- model.conv1, model.bn1, model.relu, model.maxpool, *layers
261
- )
262
- self.resnet_feature_extractor.to(self.device)
263
- self.resnet_feature_extractor.eval()
264
-
265
- # Load meta model
266
- # meta_model_path = 'best_meta_model_two_layer_version4.pth'
267
- # checkpoint = torch.load(meta_model_path, map_location=self.device)
268
- print("=== Loading model with weights_only=False ===")
269
- meta_model_path = hf_hub_download(
270
- repo_id="KeerthiVM/SkinCancerDiagnosis",
271
- filename="best_meta_model_two_layer_version4.pth"
272
- )
273
- checkpoint = torch.load(meta_model_path, map_location=self.device, weights_only=False)
274
-
275
- correct_input_size = checkpoint['state_dict']['fc.0.weight'].shape[1]
276
- input_size = 23 + 4 + 7168 # Adjust based on your actual feature size
277
- fc_layers = [1024, 512, 256] # Use whatever was in your best model
278
-
279
- self.meta_model = DynamicCNN(
280
- input_channels=correct_input_size,
281
- fc_layers=fc_layers,
282
- num_classes=23,
283
- dropout_rate=0.5
284
- )
285
-
286
- self.meta_model.load_state_dict(checkpoint['state_dict'])
287
- self.meta_model.to(self.device)
288
- self.meta_model.eval()
289
-
290
- def extract_image_features(self, image_tensor):
291
- """Extract features using ResNet"""
292
- with torch.no_grad():
293
- features = []
294
- x = image_tensor
295
- for layer in self.resnet_feature_extractor.children():
296
- x = layer(x)
297
- if isinstance(layer, nn.Sequential): # For residual blocks
298
- # features.append(F.adaptive_avg_pool2d(x, (1, 1)).flatten(1))
299
- # features = torch.cat(features, dim=1)
300
- pooled = F.adaptive_avg_pool2d(x, (1, 1)).flatten(1)
301
- features.append(pooled)
302
- features = torch.cat(features, dim=1)
303
- return features.cpu().numpy()
304
-
305
- def predict(self, image, top_k=3):
306
- """Make prediction for a single image"""
307
- if self.base_models is None or self.meta_model is None:
308
- # self.load_models()
309
- raise RuntimeError("Models not loaded - call load_models() first")
310
-
311
- # Load and preprocess image
312
- try:
313
- # image = Image.open(image_path).convert('RGB')
314
- image = image.convert('RGB')
315
- except:
316
- raise ValueError("Could not load image from path")
317
-
318
- image_tensor = self.transform(image).unsqueeze(0).to(self.device)
319
-
320
- # Extract features
321
- with torch.no_grad():
322
- # Get probabilities from each binary model
323
- binary_probs = []
324
- for model in self.base_models:
325
- outputs = model(image_tensor)
326
- probs = torch.sigmoid(outputs).squeeze(1)
327
- binary_probs.append(probs)
328
-
329
- binary_features = torch.stack(binary_probs, dim=1)
330
-
331
- # Get image features
332
- image_features = self.extract_image_features(image_tensor)
333
- image_features = torch.from_numpy(image_features).float().to(self.device)
334
-
335
- # Calculate probability statistics
336
- top3_probs = torch.topk(binary_features, 3, dim=1).values
337
- prob_stats = torch.stack([
338
- binary_features.mean(dim=1, keepdim=True),
339
- binary_features.std(dim=1, keepdim=True),
340
- top3_probs.mean(dim=1, keepdim=True),
341
- (top3_probs[:, 0] - top3_probs[:, 2]).unsqueeze(1) # Confidence gap
342
- ], dim=1).squeeze(-1)
343
-
344
- # Combine all features
345
- combined_features = torch.cat([
346
- binary_features,
347
- image_features,
348
- prob_stats
349
- ], dim=1)
350
-
351
- # Make prediction with meta-model
352
- with torch.no_grad():
353
- outputs = self.meta_model(combined_features)
354
- probabilities = torch.softmax(outputs, dim=1).squeeze().cpu().numpy()
355
-
356
- # Get top predictions
357
- top_indices = np.argsort(probabilities)[-top_k:][::-1]
358
- top_predictions = [
359
- (self.class_names[i], float(probabilities[i]))
360
- for i in top_indices
361
- ]
362
-
363
- return {
364
- "top_predictions": top_predictions,
365
- "all_probabilities": {name: float(prob) for name, prob in zip(self.class_names, probabilities)}
366
- }
367
-
368
-
369
- @st.cache_resource
370
- def initialize_classifier():
371
- print("⚙️ Initializing skin disease classifier...")
372
- classifier = SkinDiseaseClassifier()
373
- classifier.load_models()
374
- dummy_img = Image.new('RGB', (224, 224))
375
- classifier.predict(dummy_img)
376
-
377
- print("⚙️ Initialization successful")
378
- return classifier
379
 
380
- # print("⚙️ Initializing skin disease classifier...")
381
- # classifier = initialize_classifier()
382
- # print("⚙️ Initializing classifier done...")
383
  with st.spinner("Loading AI models (one-time operation)..."):
384
  classifier = initialize_classifier()
385
  st.success("Models loaded successfully!")
@@ -441,7 +101,7 @@ if uploaded_file:
441
  st.session_state.messages.append({"role": "user", "content": query})
442
 
443
  with st.spinner("Analyzing the image and retrieving response..."):
444
- response = rag_chain.invoke(query)
445
  st.session_state.messages.append({"role": "assistant", "content": response['result']})
446
 
447
  with st.chat_message("assistant"):
 
 
 
 
 
1
  import streamlit as st
2
  import torchvision.transforms as transforms
3
  import torch
 
9
  import io
10
  import os
11
  from fpdf import FPDF
 
 
 
12
  from torchvision.models import resnet50
13
  import nest_asyncio
 
14
  from huggingface_hub import hf_hub_download
15
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
16
+ from SkinCancerDiagnosis import initialize_classifier
17
+ from rag_pipeline import invoke_rag_chain
18
 
19
  nest_asyncio.apply()
20
  device='cuda' if torch.cuda.is_available() else 'cpu'
 
21
  st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
22
 
 
 
23
  # === Model Selection ===
24
  available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
25
  st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Dynamically initialize LLM based on selection
28
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
29
  selected_model = st.session_state["selected_model"]
 
39
  st.error("Unsupported model selected.")
40
  st.stop()
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
 
 
 
43
  with st.spinner("Loading AI models (one-time operation)..."):
44
  classifier = initialize_classifier()
45
  st.success("Models loaded successfully!")
 
101
  st.session_state.messages.append({"role": "user", "content": query})
102
 
103
  with st.spinner("Analyzing the image and retrieving response..."):
104
+ response = invoke_rag_chain(llm).invoke(query)
105
  st.session_state.messages.append({"role": "assistant", "content": response['result']})
106
 
107
  with st.chat_message("assistant"):
rag_pipeline.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import RetrievalQA
2
+ from langchain.prompts import PromptTemplate
3
+ from sentence_transformers import SentenceTransformer
4
+ from qdrant_client import QdrantClient
5
+ from langchain_qdrant import Qdrant
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain_community.embeddings import SentenceTransformerEmbeddings
8
+ import os
9
+ import torch
10
+
11
+ def invoke_rag_chain(llm):
12
+ # === Qdrant DB Setup ===
13
+ qdrant_client = QdrantClient(
14
+ url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
15
+ api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
16
+ )
17
+ collection_name = "ks_collection_1.5BE"
18
+
19
+ model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
20
+
21
+
22
+ local_embedding = HuggingFaceEmbeddings(
23
+ model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
24
+ model_kwargs={"trust_remote_code": True, "device": "cuda" if torch.cuda.is_available() else "cpu"}
25
+ )
26
+ print(" Qwen2-1.5B local embedding model loaded.")
27
+
28
+ vector_store = Qdrant(
29
+ client=qdrant_client,
30
+ collection_name=collection_name,
31
+ embeddings=local_embedding
32
+ )
33
+ retriever = vector_store.as_retriever()
34
+
35
+
36
+ AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
37
+ You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
38
+
39
+ Guidelines:
40
+ 1. Symptoms - Explain in simple terms with proper medical definitions.
41
+ 2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
42
+ 3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
43
+ 4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
44
+ 5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
45
+
46
+ Query: {question}
47
+ Relevant Information: {context}
48
+ Answer:
49
+ """
50
+ prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
51
+
52
+ rag_chain = RetrievalQA.from_chain_type(
53
+ llm=llm,
54
+ retriever=retriever,
55
+ chain_type="stuff",
56
+ chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
57
+ )
58
+
59
+ return rag_chain