KeerthiVM commited on
Commit
87c2216
·
1 Parent(s): 154407c
Files changed (2) hide show
  1. SkinCancerDiagnosis.py +43 -2
  2. app.py +5 -3
SkinCancerDiagnosis.py CHANGED
@@ -139,6 +139,16 @@ class SkinDiseaseClassifier:
139
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
140
  ])
141
 
 
 
 
 
 
 
 
 
 
 
142
  def load_models(self):
143
  """Load all required models"""
144
  # Load binary models
@@ -179,6 +189,18 @@ class SkinDiseaseClassifier:
179
  self.meta_model.to(self.device)
180
  self.meta_model.eval()
181
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  def extract_image_features(self, image_tensor):
183
  """Extract features using ResNet"""
184
  with torch.no_grad():
@@ -197,12 +219,10 @@ class SkinDiseaseClassifier:
197
  def predict(self, image, top_k=3):
198
  """Make prediction for a single image"""
199
  if self.base_models is None or self.meta_model is None:
200
- # self.load_models()
201
  raise RuntimeError("Models not loaded - call load_models() first")
202
 
203
  # Load and preprocess image
204
  try:
205
- # image = Image.open(image_path).convert('RGB')
206
  image = image.convert('RGB')
207
  except:
208
  raise ValueError("Could not load image from path")
@@ -257,6 +277,27 @@ class SkinDiseaseClassifier:
257
  "all_probabilities": {name: float(prob) for name, prob in zip(self.class_names, probabilities)}
258
  }
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  def initialize_classifier():
261
  print("⚙️ Initializing skin disease classifier...")
262
  classifier = SkinDiseaseClassifier()
 
139
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
140
  ])
141
 
142
+ self.multilabel_class_names = [
143
+ "Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch",
144
+ "Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae",
145
+ "Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis",
146
+ "Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped",
147
+ "Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow",
148
+ "Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma",
149
+ "Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst"
150
+ ]
151
+
152
  def load_models(self):
153
  """Load all required models"""
154
  # Load binary models
 
189
  self.meta_model.to(self.device)
190
  self.meta_model.eval()
191
 
192
+ skincon_path = hf_hub_download(
193
+ repo_id="KeerthiVM/SkinCancerDiagnosis",
194
+ filename="skincon.pth"
195
+ )
196
+ self.skincon_model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=48,
197
+ hidden_dim=512)
198
+ state_dict = torch.load(skincon_path, map_location=device)
199
+ self.skincon_model.load_state_dict(state_dict, strict=False)
200
+
201
+ self.skincon_model.eval()
202
+
203
+
204
  def extract_image_features(self, image_tensor):
205
  """Extract features using ResNet"""
206
  with torch.no_grad():
 
219
  def predict(self, image, top_k=3):
220
  """Make prediction for a single image"""
221
  if self.base_models is None or self.meta_model is None:
 
222
  raise RuntimeError("Models not loaded - call load_models() first")
223
 
224
  # Load and preprocess image
225
  try:
 
226
  image = image.convert('RGB')
227
  except:
228
  raise ValueError("Could not load image from path")
 
277
  "all_probabilities": {name: float(prob) for name, prob in zip(self.class_names, probabilities)}
278
  }
279
 
280
+ def predict_skincon(self, image, top_k=3):
281
+ """Make prediction for a single image"""
282
+ if self.base_models is None or self.skincon_model is None:
283
+ raise RuntimeError("Models not loaded - call load_models() first")
284
+ self.skincon_model.eval()
285
+ try:
286
+ image = image.convert('RGB')
287
+ except:
288
+ raise ValueError("Could not load image from path")
289
+
290
+ image_tensor = self.transform(image).unsqueeze(0).to(self.device)
291
+ with torch.no_grad():
292
+ output_multi = self.skincon_model(image_tensor)
293
+ probs_multi = torch.sigmoid(output_multi).squeeze().numpy()
294
+ # print(f"Probabilities : {probs_multi}")
295
+ threshold = 0.5
296
+ predicted_labels_multi = [self.multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > threshold]
297
+ print("Predicted labels multi : ",predicted_labels_multi)
298
+ return predicted_labels_multi
299
+
300
+
301
  def initialize_classifier():
302
  print("⚙️ Initializing skin disease classifier...")
303
  classifier = SkinDiseaseClassifier()
app.py CHANGED
@@ -102,7 +102,8 @@ if "current_image" not in st.session_state:
102
  def run_inference(image):
103
  result = classifier.predict(image, top_k=1)
104
  predicted_label = result["top_predictions"][0][0]
105
- return predicted_label
 
106
 
107
 
108
  # === PDF Export ===
@@ -136,11 +137,12 @@ if uploaded_file is not None and uploaded_file != st.session_state.current_image
136
  image = Image.open(uploaded_file).convert("RGB")
137
  st.image(image, caption="Uploaded image", use_column_width=True)
138
  with st.spinner("Analyzing the image..."):
139
- predicted_label = run_inference(image)
140
 
 
141
  st.markdown(f" Most Likely Diagnosis : {predicted_label}")
142
 
143
- initial_query = f"What are my treatment options for {predicted_label}?"
144
  st.session_state.messages.append({"role": "user", "content": initial_query})
145
  with st.spinner("Retrieving medical information..."):
146
  response = get_reranked_response(initial_query, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
 
102
  def run_inference(image):
103
  result = classifier.predict(image, top_k=1)
104
  predicted_label = result["top_predictions"][0][0]
105
+ predicted_label_multi = classifier.predict_skincon(image, top_k=1)
106
+ return predicted_label, predicted_label_multi
107
 
108
 
109
  # === PDF Export ===
 
137
  image = Image.open(uploaded_file).convert("RGB")
138
  st.image(image, caption="Uploaded image", use_column_width=True)
139
  with st.spinner("Analyzing the image..."):
140
+ predicted_label, predicted_label_multi = run_inference(image)
141
 
142
+ st.markdown(f"🧾 **Skin Issues**: {', '.join(predicted_label_multi)}")
143
  st.markdown(f" Most Likely Diagnosis : {predicted_label}")
144
 
145
+ initial_query = f"What are my treatment options for {predicted_label} & {predicted_label_multi}?"
146
  st.session_state.messages.append({"role": "user", "content": initial_query})
147
  with st.spinner("Retrieving medical information..."):
148
  response = get_reranked_response(initial_query, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])