Spaces:
Sleeping
Sleeping
import torchvision.transforms as transforms | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from PIL import Image | |
import torch.nn.functional as F | |
from evo_vit import EvoViTModel | |
import io | |
import os | |
from fpdf import FPDF | |
from torchvision.models import resnet50 | |
import nest_asyncio | |
from huggingface_hub import hf_hub_download | |
device='cuda' if torch.cuda.is_available() else 'cpu' | |
def load_model(repo_id, filename): | |
model_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, | |
) | |
model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=2, hidden_dim=512) | |
model.classifier = nn.Linear(512, 1) | |
state_dict = torch.load(model_path, map_location=device) | |
new_state_dict = {} | |
for key, value in state_dict.items(): | |
if key.startswith("backbone."): | |
new_key = key[len("backbone."):] | |
else: | |
new_key = key | |
new_state_dict[new_key] = value | |
if "classifier.weight" in new_state_dict: | |
original_weight = new_state_dict["classifier.weight"] | |
new_state_dict["classifier.weight"] = original_weight[0:1, :] | |
if "classifier.bias" in new_state_dict: | |
original_bias = new_state_dict["classifier.bias"] | |
new_state_dict["classifier.bias"] = original_bias[0:1] | |
model.load_state_dict(new_state_dict, strict=False) | |
model.to(device) | |
model.eval() | |
return model | |
def load_binary_models(): | |
base_models = [] | |
class_models_mapping = { | |
"Acne and Rosacea Photos": 'santhosh/10fold_model_acne.pth', | |
"Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions": 'santhosh/5fold_model_actinic.pth', | |
"Atopic Dermatitis Photos": 'keerthi/Atopic/best_global_model_5fold.pth', | |
"Bullous Disease Photos": 'santhosh/10fold_model_bullous.pth', | |
"Cellulitis Impetigo and other Bacterial Infections": 'santhosh/10fold_model_cellulitis.pth', | |
"Eczema Photos": 'santhosh/5fold_model_eczema.pth', | |
"ExanthemsandDrugEruptions": 'santhosh/10fold_model_exantherms.pth', | |
"Hair Loss Photos Alopecia and other Hair Diseases": 'keerthi/HairLoss/best_global_model_5fold.pth', | |
"Herpes HPV and other STDs Photos": 'keerthi/Herpes/best_global_model_5fold.pth', | |
"Light Diseases and Disorders of Pigmentation": 'santhosh/5fold_model_light.pth', | |
"Lupus and other Connective Tissue diseases": 'keerthi/Lupus/best_global_model_5fold.pth', | |
"Melanoma Skin Cancer Nevi and Moles": 'keerthi/Melanoma/best_global_model_10fold.pth', | |
"Nail Fungus and other Nail Disease": 'santhosh/5fold_model_nail.pth', | |
"Poison Ivy Photos and other Contact Dermatitis": 'santhosh/5fold_model_poison.pth', | |
"Psoriasis pictures Lichen Planus and related diseases": 'santhosh/10fold_model_psoriasis.pth', | |
"Scabies Lyme Disease and other Infestations and Bites": 'santhosh/5fold_model_scabies.pth', | |
"Seborrheic Keratoses and other Benign Tumors": 'santhosh/10fold_model_seboh.pth', | |
"Systemic Disease": 'keerthi/Systemic/best_global_model_5fold.pth', | |
"Tinea Ringworm Candidiasis and other Fungal Infections": 'santhosh/10fold_model_tinea.pth', | |
"Urticaria Hives": 'keerthi/Urticaria/best_global_model_10fold.pth', | |
"Vascular Tumors": 'keerthi/Vascular/best_global_model_5fold.pth', | |
"Vasculitis Photos": 'keerthi/Vasculitis/best_global_model_10fold.pth', | |
"Warts Molluscum and other Viral Infections": 'santhosh/10fold_model_warts.pth' | |
} | |
repo_id = "KeerthiVM/SkinCancerDiagnosis" # Your Hugging Face repo | |
for class_name, filename in class_models_mapping.items(): | |
# model_path = os.path.join("best_models_overall", rel_path) | |
model = load_model(repo_id, filename) | |
base_models.append(model) | |
return base_models | |
class DynamicCNN(nn.Module): | |
def __init__(self, input_channels, fc_layers, num_classes, dropout_rate=0.3): | |
super(DynamicCNN, self).__init__() | |
fc_layers_list = [] | |
in_dim = input_channels | |
for fc_dim in fc_layers: | |
fc_layers_list.append(nn.Linear(in_dim, fc_dim)) | |
fc_layers_list.append(nn.BatchNorm1d(fc_dim)) | |
fc_layers_list.append(nn.ReLU()) | |
fc_layers_list.append(nn.Dropout(dropout_rate)) | |
in_dim = fc_dim | |
fc_layers_list.append(nn.Linear(in_dim, num_classes)) | |
self.fc = nn.Sequential(*fc_layers_list) | |
def forward(self, x): | |
x = self.fc(x) | |
return x | |
class SkinDiseaseClassifier: | |
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'): | |
self.device = torch.device(device) | |
self.class_names = [ | |
"Acne and Rosacea Photos", | |
"Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions", | |
"Atopic Dermatitis Photos", | |
"Bullous Disease Photos", | |
"Cellulitis Impetigo and other Bacterial Infections", | |
"Eczema Photos", | |
"ExanthemsandDrugEruptions", | |
"Hair Loss Photos Alopecia and other Hair Diseases", | |
"Herpes HPV and other STDs Photos", | |
"Light Diseases and Disorders of Pigmentation", | |
"Lupus and other Connective Tissue diseases", | |
"Melanoma Skin Cancer Nevi and Moles", | |
"Nail Fungus and other Nail Disease", | |
"Poison Ivy Photos and other Contact Dermatitis", | |
"Psoriasis pictures Lichen Planus and related diseases", | |
"Scabies Lyme Disease and other Infestations and Bites", | |
"Seborrheic Keratoses and other Benign Tumors", | |
"Systemic Disease", | |
"Tinea Ringworm Candidiasis and other Fungal Infections", | |
"Urticaria Hives", | |
"Vascular Tumors", | |
"Vasculitis Photos", | |
"Warts Molluscum and other Viral Infections" | |
] | |
# Initialize models (they'll be loaded when needed) | |
self.base_models = None | |
self.meta_model = None | |
self.resnet_feature_extractor = None | |
self.skincon_model = None | |
# Image transformations | |
self.transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
self.multilabel_class_names = [ | |
"Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch", | |
"Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae", | |
"Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis", | |
"Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped", | |
"Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow", | |
"Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma", | |
"Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst" | |
] | |
def load_models(self): | |
"""Load all required models""" | |
# Load binary models | |
self.base_models = load_binary_models() | |
for model in self.base_models: | |
model.to(self.device) | |
model.eval() | |
# Load ResNet feature extractor | |
model = resnet50(pretrained=True) | |
layers = [model.layer1, model.layer2, model.layer3, model.layer4] | |
self.resnet_feature_extractor = nn.Sequential( | |
model.conv1, model.bn1, model.relu, model.maxpool, *layers | |
) | |
self.resnet_feature_extractor.to(self.device) | |
self.resnet_feature_extractor.eval() | |
# Load meta model | |
print("=== Loading model with weights_only=False ===") | |
meta_model_path = hf_hub_download( | |
repo_id="KeerthiVM/SkinCancerDiagnosis", | |
filename="best_meta_model_two_layer_version4.pth" | |
) | |
checkpoint = torch.load(meta_model_path, map_location=self.device, weights_only=False) | |
correct_input_size = checkpoint['state_dict']['fc.0.weight'].shape[1] | |
input_size = 23 + 4 + 7168 # Adjust based on your actual feature size | |
fc_layers = [1024, 512, 256] # Use whatever was in your best model | |
self.meta_model = DynamicCNN( | |
input_channels=correct_input_size, | |
fc_layers=fc_layers, | |
num_classes=23, | |
dropout_rate=0.5 | |
) | |
self.meta_model.load_state_dict(checkpoint['state_dict']) | |
self.meta_model.to(self.device) | |
self.meta_model.eval() | |
skincon_path = hf_hub_download( | |
repo_id="KeerthiVM/SkinCancerDiagnosis", | |
filename="skincon.pth" | |
) | |
self.skincon_model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=48, hidden_dim=512) | |
# self.skincon_model = load_model("KeerthiVM/SkinCancerDiagnosis", "skincon.pth") | |
# self.skincon_model.classifier = nn.Linear(512, 48) | |
# self.skincon_model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=48, | |
# hidden_dim=512) | |
state_dict = torch.load(skincon_path, map_location=device) | |
self.skincon_model.load_state_dict(state_dict, strict=False) | |
# self.skincon_model.eval() | |
def extract_image_features(self, image_tensor): | |
"""Extract features using ResNet""" | |
with torch.no_grad(): | |
features = [] | |
x = image_tensor | |
for layer in self.resnet_feature_extractor.children(): | |
x = layer(x) | |
if isinstance(layer, nn.Sequential): # For residual blocks | |
# features.append(F.adaptive_avg_pool2d(x, (1, 1)).flatten(1)) | |
# features = torch.cat(features, dim=1) | |
pooled = F.adaptive_avg_pool2d(x, (1, 1)).flatten(1) | |
features.append(pooled) | |
features = torch.cat(features, dim=1) | |
return features.cpu().numpy() | |
def predict(self, image, top_k=3): | |
"""Make prediction for a single image""" | |
if self.base_models is None or self.meta_model is None: | |
raise RuntimeError("Models not loaded - call load_models() first") | |
# Load and preprocess image | |
try: | |
image = image.convert('RGB') | |
except: | |
raise ValueError("Could not load image from path") | |
image_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
# Extract features | |
with torch.no_grad(): | |
# Get probabilities from each binary model | |
binary_probs = [] | |
for model in self.base_models: | |
outputs = model(image_tensor) | |
probs = torch.sigmoid(outputs).squeeze(1) | |
binary_probs.append(probs) | |
binary_features = torch.stack(binary_probs, dim=1) | |
# Get image features | |
image_features = self.extract_image_features(image_tensor) | |
image_features = torch.from_numpy(image_features).float().to(self.device) | |
# Calculate probability statistics | |
top3_probs = torch.topk(binary_features, 3, dim=1).values | |
prob_stats = torch.stack([ | |
binary_features.mean(dim=1, keepdim=True), | |
binary_features.std(dim=1, keepdim=True), | |
top3_probs.mean(dim=1, keepdim=True), | |
(top3_probs[:, 0] - top3_probs[:, 2]).unsqueeze(1) # Confidence gap | |
], dim=1).squeeze(-1) | |
# Combine all features | |
combined_features = torch.cat([ | |
binary_features, | |
image_features, | |
prob_stats | |
], dim=1) | |
# Make prediction with meta-model | |
with torch.no_grad(): | |
outputs = self.meta_model(combined_features) | |
probabilities = torch.softmax(outputs, dim=1).squeeze().cpu().numpy() | |
# Get top predictions | |
top_indices = np.argsort(probabilities)[-top_k:][::-1] | |
top_predictions = [ | |
(self.class_names[i], float(probabilities[i])) | |
for i in top_indices | |
] | |
return { | |
"top_predictions": top_predictions, | |
"all_probabilities": {name: float(prob) for name, prob in zip(self.class_names, probabilities)} | |
} | |
def predict_skincon(self, image, top_k=3): | |
"""Make prediction for a single image""" | |
if self.base_models is None or self.skincon_model is None: | |
raise RuntimeError("Models not loaded - call load_models() first") | |
self.skincon_model.eval() | |
try: | |
image = image.convert('RGB') | |
except: | |
raise ValueError("Could not load image from path") | |
image_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
output_multi = self.skincon_model(image_tensor) | |
probs_multi = torch.sigmoid(output_multi).squeeze().numpy() | |
print(f"Probabilities : {probs_multi}") | |
threshold = 0.5 | |
predicted_labels_multi = [self.multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > threshold] | |
print("Predicted labels multi : ",predicted_labels_multi) | |
return predicted_labels_multi | |
def initialize_classifier(): | |
print("⚙️ Initializing skin disease classifier...") | |
classifier = SkinDiseaseClassifier() | |
classifier.load_models() | |
dummy_img = Image.new('RGB', (224, 224)) | |
classifier.predict(dummy_img) | |
print("⚙️ Initialization successful") | |
return classifier |