SkinCancerDiagnosis / SkinCancerDiagnosis.py
KeerthiVM's picture
Fix added
a3f66c1
raw
history blame contribute delete
13.7 kB
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torch.nn.functional as F
from evo_vit import EvoViTModel
import io
import os
from fpdf import FPDF
from torchvision.models import resnet50
import nest_asyncio
from huggingface_hub import hf_hub_download
device='cuda' if torch.cuda.is_available() else 'cpu'
def load_model(repo_id, filename):
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
)
model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=2, hidden_dim=512)
model.classifier = nn.Linear(512, 1)
state_dict = torch.load(model_path, map_location=device)
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith("backbone."):
new_key = key[len("backbone."):]
else:
new_key = key
new_state_dict[new_key] = value
if "classifier.weight" in new_state_dict:
original_weight = new_state_dict["classifier.weight"]
new_state_dict["classifier.weight"] = original_weight[0:1, :]
if "classifier.bias" in new_state_dict:
original_bias = new_state_dict["classifier.bias"]
new_state_dict["classifier.bias"] = original_bias[0:1]
model.load_state_dict(new_state_dict, strict=False)
model.to(device)
model.eval()
return model
def load_binary_models():
base_models = []
class_models_mapping = {
"Acne and Rosacea Photos": 'santhosh/10fold_model_acne.pth',
"Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions": 'santhosh/5fold_model_actinic.pth',
"Atopic Dermatitis Photos": 'keerthi/Atopic/best_global_model_5fold.pth',
"Bullous Disease Photos": 'santhosh/10fold_model_bullous.pth',
"Cellulitis Impetigo and other Bacterial Infections": 'santhosh/10fold_model_cellulitis.pth',
"Eczema Photos": 'santhosh/5fold_model_eczema.pth',
"ExanthemsandDrugEruptions": 'santhosh/10fold_model_exantherms.pth',
"Hair Loss Photos Alopecia and other Hair Diseases": 'keerthi/HairLoss/best_global_model_5fold.pth',
"Herpes HPV and other STDs Photos": 'keerthi/Herpes/best_global_model_5fold.pth',
"Light Diseases and Disorders of Pigmentation": 'santhosh/5fold_model_light.pth',
"Lupus and other Connective Tissue diseases": 'keerthi/Lupus/best_global_model_5fold.pth',
"Melanoma Skin Cancer Nevi and Moles": 'keerthi/Melanoma/best_global_model_10fold.pth',
"Nail Fungus and other Nail Disease": 'santhosh/5fold_model_nail.pth',
"Poison Ivy Photos and other Contact Dermatitis": 'santhosh/5fold_model_poison.pth',
"Psoriasis pictures Lichen Planus and related diseases": 'santhosh/10fold_model_psoriasis.pth',
"Scabies Lyme Disease and other Infestations and Bites": 'santhosh/5fold_model_scabies.pth',
"Seborrheic Keratoses and other Benign Tumors": 'santhosh/10fold_model_seboh.pth',
"Systemic Disease": 'keerthi/Systemic/best_global_model_5fold.pth',
"Tinea Ringworm Candidiasis and other Fungal Infections": 'santhosh/10fold_model_tinea.pth',
"Urticaria Hives": 'keerthi/Urticaria/best_global_model_10fold.pth',
"Vascular Tumors": 'keerthi/Vascular/best_global_model_5fold.pth',
"Vasculitis Photos": 'keerthi/Vasculitis/best_global_model_10fold.pth',
"Warts Molluscum and other Viral Infections": 'santhosh/10fold_model_warts.pth'
}
repo_id = "KeerthiVM/SkinCancerDiagnosis" # Your Hugging Face repo
for class_name, filename in class_models_mapping.items():
# model_path = os.path.join("best_models_overall", rel_path)
model = load_model(repo_id, filename)
base_models.append(model)
return base_models
class DynamicCNN(nn.Module):
def __init__(self, input_channels, fc_layers, num_classes, dropout_rate=0.3):
super(DynamicCNN, self).__init__()
fc_layers_list = []
in_dim = input_channels
for fc_dim in fc_layers:
fc_layers_list.append(nn.Linear(in_dim, fc_dim))
fc_layers_list.append(nn.BatchNorm1d(fc_dim))
fc_layers_list.append(nn.ReLU())
fc_layers_list.append(nn.Dropout(dropout_rate))
in_dim = fc_dim
fc_layers_list.append(nn.Linear(in_dim, num_classes))
self.fc = nn.Sequential(*fc_layers_list)
def forward(self, x):
x = self.fc(x)
return x
class SkinDiseaseClassifier:
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.device = torch.device(device)
self.class_names = [
"Acne and Rosacea Photos",
"Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions",
"Atopic Dermatitis Photos",
"Bullous Disease Photos",
"Cellulitis Impetigo and other Bacterial Infections",
"Eczema Photos",
"ExanthemsandDrugEruptions",
"Hair Loss Photos Alopecia and other Hair Diseases",
"Herpes HPV and other STDs Photos",
"Light Diseases and Disorders of Pigmentation",
"Lupus and other Connective Tissue diseases",
"Melanoma Skin Cancer Nevi and Moles",
"Nail Fungus and other Nail Disease",
"Poison Ivy Photos and other Contact Dermatitis",
"Psoriasis pictures Lichen Planus and related diseases",
"Scabies Lyme Disease and other Infestations and Bites",
"Seborrheic Keratoses and other Benign Tumors",
"Systemic Disease",
"Tinea Ringworm Candidiasis and other Fungal Infections",
"Urticaria Hives",
"Vascular Tumors",
"Vasculitis Photos",
"Warts Molluscum and other Viral Infections"
]
# Initialize models (they'll be loaded when needed)
self.base_models = None
self.meta_model = None
self.resnet_feature_extractor = None
self.skincon_model = None
# Image transformations
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.multilabel_class_names = [
"Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch",
"Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae",
"Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis",
"Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped",
"Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow",
"Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma",
"Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst"
]
def load_models(self):
"""Load all required models"""
# Load binary models
self.base_models = load_binary_models()
for model in self.base_models:
model.to(self.device)
model.eval()
# Load ResNet feature extractor
model = resnet50(pretrained=True)
layers = [model.layer1, model.layer2, model.layer3, model.layer4]
self.resnet_feature_extractor = nn.Sequential(
model.conv1, model.bn1, model.relu, model.maxpool, *layers
)
self.resnet_feature_extractor.to(self.device)
self.resnet_feature_extractor.eval()
# Load meta model
print("=== Loading model with weights_only=False ===")
meta_model_path = hf_hub_download(
repo_id="KeerthiVM/SkinCancerDiagnosis",
filename="best_meta_model_two_layer_version4.pth"
)
checkpoint = torch.load(meta_model_path, map_location=self.device, weights_only=False)
correct_input_size = checkpoint['state_dict']['fc.0.weight'].shape[1]
input_size = 23 + 4 + 7168 # Adjust based on your actual feature size
fc_layers = [1024, 512, 256] # Use whatever was in your best model
self.meta_model = DynamicCNN(
input_channels=correct_input_size,
fc_layers=fc_layers,
num_classes=23,
dropout_rate=0.5
)
self.meta_model.load_state_dict(checkpoint['state_dict'])
self.meta_model.to(self.device)
self.meta_model.eval()
skincon_path = hf_hub_download(
repo_id="KeerthiVM/SkinCancerDiagnosis",
filename="skincon.pth"
)
self.skincon_model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=48, hidden_dim=512)
# self.skincon_model = load_model("KeerthiVM/SkinCancerDiagnosis", "skincon.pth")
# self.skincon_model.classifier = nn.Linear(512, 48)
# self.skincon_model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=48,
# hidden_dim=512)
state_dict = torch.load(skincon_path, map_location=device)
self.skincon_model.load_state_dict(state_dict, strict=False)
# self.skincon_model.eval()
def extract_image_features(self, image_tensor):
"""Extract features using ResNet"""
with torch.no_grad():
features = []
x = image_tensor
for layer in self.resnet_feature_extractor.children():
x = layer(x)
if isinstance(layer, nn.Sequential): # For residual blocks
# features.append(F.adaptive_avg_pool2d(x, (1, 1)).flatten(1))
# features = torch.cat(features, dim=1)
pooled = F.adaptive_avg_pool2d(x, (1, 1)).flatten(1)
features.append(pooled)
features = torch.cat(features, dim=1)
return features.cpu().numpy()
def predict(self, image, top_k=3):
"""Make prediction for a single image"""
if self.base_models is None or self.meta_model is None:
raise RuntimeError("Models not loaded - call load_models() first")
# Load and preprocess image
try:
image = image.convert('RGB')
except:
raise ValueError("Could not load image from path")
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Extract features
with torch.no_grad():
# Get probabilities from each binary model
binary_probs = []
for model in self.base_models:
outputs = model(image_tensor)
probs = torch.sigmoid(outputs).squeeze(1)
binary_probs.append(probs)
binary_features = torch.stack(binary_probs, dim=1)
# Get image features
image_features = self.extract_image_features(image_tensor)
image_features = torch.from_numpy(image_features).float().to(self.device)
# Calculate probability statistics
top3_probs = torch.topk(binary_features, 3, dim=1).values
prob_stats = torch.stack([
binary_features.mean(dim=1, keepdim=True),
binary_features.std(dim=1, keepdim=True),
top3_probs.mean(dim=1, keepdim=True),
(top3_probs[:, 0] - top3_probs[:, 2]).unsqueeze(1) # Confidence gap
], dim=1).squeeze(-1)
# Combine all features
combined_features = torch.cat([
binary_features,
image_features,
prob_stats
], dim=1)
# Make prediction with meta-model
with torch.no_grad():
outputs = self.meta_model(combined_features)
probabilities = torch.softmax(outputs, dim=1).squeeze().cpu().numpy()
# Get top predictions
top_indices = np.argsort(probabilities)[-top_k:][::-1]
top_predictions = [
(self.class_names[i], float(probabilities[i]))
for i in top_indices
]
return {
"top_predictions": top_predictions,
"all_probabilities": {name: float(prob) for name, prob in zip(self.class_names, probabilities)}
}
def predict_skincon(self, image, top_k=3):
"""Make prediction for a single image"""
if self.base_models is None or self.skincon_model is None:
raise RuntimeError("Models not loaded - call load_models() first")
self.skincon_model.eval()
try:
image = image.convert('RGB')
except:
raise ValueError("Could not load image from path")
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
output_multi = self.skincon_model(image_tensor)
probs_multi = torch.sigmoid(output_multi).squeeze().numpy()
print(f"Probabilities : {probs_multi}")
threshold = 0.5
predicted_labels_multi = [self.multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > threshold]
print("Predicted labels multi : ",predicted_labels_multi)
return predicted_labels_multi
def initialize_classifier():
print("⚙️ Initializing skin disease classifier...")
classifier = SkinDiseaseClassifier()
classifier.load_models()
dummy_img = Image.new('RGB', (224, 224))
classifier.predict(dummy_img)
print("⚙️ Initialization successful")
return classifier