File size: 13,700 Bytes
79cab30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd5144d
79cab30
 
 
 
 
 
 
 
87c2216
 
 
 
 
 
 
 
 
 
79cab30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87c2216
 
 
 
a3f66c1
87c2216
a3f66c1
 
c540b50
 
a3f66c1
 
c540b50
 
87c2216
 
79cab30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87c2216
 
 
 
 
 
 
 
 
 
 
 
 
 
fd5144d
87c2216
 
 
 
 
 
79cab30
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torch.nn.functional as F
from evo_vit import EvoViTModel
import io
import os
from fpdf import FPDF
from torchvision.models import resnet50
import nest_asyncio
from huggingface_hub import hf_hub_download

device='cuda' if torch.cuda.is_available() else 'cpu'

def load_model(repo_id, filename):
    model_path = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
    )
    model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=2, hidden_dim=512)
    model.classifier = nn.Linear(512, 1)
    state_dict = torch.load(model_path, map_location=device)
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith("backbone."):
            new_key = key[len("backbone."):]
        else:
            new_key = key
        new_state_dict[new_key] = value

    if "classifier.weight" in new_state_dict:
        original_weight = new_state_dict["classifier.weight"]
        new_state_dict["classifier.weight"] = original_weight[0:1, :]
    if "classifier.bias" in new_state_dict:
        original_bias = new_state_dict["classifier.bias"]
        new_state_dict["classifier.bias"] = original_bias[0:1]
    model.load_state_dict(new_state_dict, strict=False)
    model.to(device)
    model.eval()
    return model

def load_binary_models():
    base_models = []
    class_models_mapping = {
        "Acne and Rosacea Photos": 'santhosh/10fold_model_acne.pth',
        "Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions": 'santhosh/5fold_model_actinic.pth',
        "Atopic Dermatitis Photos": 'keerthi/Atopic/best_global_model_5fold.pth',
        "Bullous Disease Photos": 'santhosh/10fold_model_bullous.pth',
        "Cellulitis Impetigo and other Bacterial Infections": 'santhosh/10fold_model_cellulitis.pth',
        "Eczema Photos": 'santhosh/5fold_model_eczema.pth',
        "ExanthemsandDrugEruptions": 'santhosh/10fold_model_exantherms.pth',
        "Hair Loss Photos Alopecia and other Hair Diseases": 'keerthi/HairLoss/best_global_model_5fold.pth',
        "Herpes HPV and other STDs Photos": 'keerthi/Herpes/best_global_model_5fold.pth',
        "Light Diseases and Disorders of Pigmentation": 'santhosh/5fold_model_light.pth',
        "Lupus and other Connective Tissue diseases": 'keerthi/Lupus/best_global_model_5fold.pth',
        "Melanoma Skin Cancer Nevi and Moles": 'keerthi/Melanoma/best_global_model_10fold.pth',
        "Nail Fungus and other Nail Disease": 'santhosh/5fold_model_nail.pth',
        "Poison Ivy Photos and other Contact Dermatitis": 'santhosh/5fold_model_poison.pth',
        "Psoriasis pictures Lichen Planus and related diseases": 'santhosh/10fold_model_psoriasis.pth',
        "Scabies Lyme Disease and other Infestations and Bites": 'santhosh/5fold_model_scabies.pth',
        "Seborrheic Keratoses and other Benign Tumors": 'santhosh/10fold_model_seboh.pth',
        "Systemic Disease": 'keerthi/Systemic/best_global_model_5fold.pth',
        "Tinea Ringworm Candidiasis and other Fungal Infections": 'santhosh/10fold_model_tinea.pth',
        "Urticaria Hives": 'keerthi/Urticaria/best_global_model_10fold.pth',
        "Vascular Tumors": 'keerthi/Vascular/best_global_model_5fold.pth',
        "Vasculitis Photos": 'keerthi/Vasculitis/best_global_model_10fold.pth',
        "Warts Molluscum and other Viral Infections": 'santhosh/10fold_model_warts.pth'
    }
    repo_id = "KeerthiVM/SkinCancerDiagnosis"  # Your Hugging Face repo

    for class_name, filename in class_models_mapping.items():
        # model_path = os.path.join("best_models_overall", rel_path)
        model = load_model(repo_id, filename)
        base_models.append(model)
    return base_models


class DynamicCNN(nn.Module):
    def __init__(self, input_channels, fc_layers, num_classes, dropout_rate=0.3):
        super(DynamicCNN, self).__init__()
        fc_layers_list = []
        in_dim = input_channels

        for fc_dim in fc_layers:
            fc_layers_list.append(nn.Linear(in_dim, fc_dim))
            fc_layers_list.append(nn.BatchNorm1d(fc_dim))
            fc_layers_list.append(nn.ReLU())
            fc_layers_list.append(nn.Dropout(dropout_rate))
            in_dim = fc_dim

        fc_layers_list.append(nn.Linear(in_dim, num_classes))
        self.fc = nn.Sequential(*fc_layers_list)

    def forward(self, x):
        x = self.fc(x)
        return x


class SkinDiseaseClassifier:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = torch.device(device)
        self.class_names = [
            "Acne and Rosacea Photos",
            "Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions",
            "Atopic Dermatitis Photos",
            "Bullous Disease Photos",
            "Cellulitis Impetigo and other Bacterial Infections",
            "Eczema Photos",
            "ExanthemsandDrugEruptions",
            "Hair Loss Photos Alopecia and other Hair Diseases",
            "Herpes HPV and other STDs Photos",
            "Light Diseases and Disorders of Pigmentation",
            "Lupus and other Connective Tissue diseases",
            "Melanoma Skin Cancer Nevi and Moles",
            "Nail Fungus and other Nail Disease",
            "Poison Ivy Photos and other Contact Dermatitis",
            "Psoriasis pictures Lichen Planus and related diseases",
            "Scabies Lyme Disease and other Infestations and Bites",
            "Seborrheic Keratoses and other Benign Tumors",
            "Systemic Disease",
            "Tinea Ringworm Candidiasis and other Fungal Infections",
            "Urticaria Hives",
            "Vascular Tumors",
            "Vasculitis Photos",
            "Warts Molluscum and other Viral Infections"
        ]

        # Initialize models (they'll be loaded when needed)
        self.base_models = None
        self.meta_model = None
        self.resnet_feature_extractor = None
        self.skincon_model = None

        # Image transformations
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.multilabel_class_names = [
            "Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch",
            "Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae",
            "Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis",
            "Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped",
            "Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow",
            "Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma",
            "Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst"
        ]

    def load_models(self):
        """Load all required models"""
        # Load binary models
        self.base_models = load_binary_models()
        for model in self.base_models:
            model.to(self.device)
            model.eval()

        # Load ResNet feature extractor
        model = resnet50(pretrained=True)
        layers = [model.layer1, model.layer2, model.layer3, model.layer4]
        self.resnet_feature_extractor = nn.Sequential(
            model.conv1, model.bn1, model.relu, model.maxpool, *layers
        )
        self.resnet_feature_extractor.to(self.device)
        self.resnet_feature_extractor.eval()

        # Load meta model
        print("=== Loading model with weights_only=False ===")
        meta_model_path = hf_hub_download(
            repo_id="KeerthiVM/SkinCancerDiagnosis",
            filename="best_meta_model_two_layer_version4.pth"
        )
        checkpoint = torch.load(meta_model_path, map_location=self.device, weights_only=False)

        correct_input_size = checkpoint['state_dict']['fc.0.weight'].shape[1]
        input_size = 23 + 4 + 7168  # Adjust based on your actual feature size
        fc_layers = [1024, 512, 256]  # Use whatever was in your best model

        self.meta_model = DynamicCNN(
            input_channels=correct_input_size,
            fc_layers=fc_layers,
            num_classes=23,
            dropout_rate=0.5
        )

        self.meta_model.load_state_dict(checkpoint['state_dict'])
        self.meta_model.to(self.device)
        self.meta_model.eval()

        skincon_path = hf_hub_download(
            repo_id="KeerthiVM/SkinCancerDiagnosis",
            filename="skincon.pth"
        )
        self.skincon_model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=48, hidden_dim=512)

        # self.skincon_model  = load_model("KeerthiVM/SkinCancerDiagnosis", "skincon.pth")
        # self.skincon_model.classifier = nn.Linear(512, 48)
        # self.skincon_model = EvoViTModel(img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_classes=48,
        #                                  hidden_dim=512)
        state_dict = torch.load(skincon_path, map_location=device)
        self.skincon_model.load_state_dict(state_dict, strict=False)

        # self.skincon_model.eval()


    def extract_image_features(self, image_tensor):
        """Extract features using ResNet"""
        with torch.no_grad():
            features = []
            x = image_tensor
            for layer in self.resnet_feature_extractor.children():
                x = layer(x)
                if isinstance(layer, nn.Sequential):  # For residual blocks
                    # features.append(F.adaptive_avg_pool2d(x, (1, 1)).flatten(1))
                    # features = torch.cat(features, dim=1)
                    pooled = F.adaptive_avg_pool2d(x, (1, 1)).flatten(1)
                    features.append(pooled)
            features = torch.cat(features, dim=1)
        return features.cpu().numpy()

    def predict(self, image, top_k=3):
        """Make prediction for a single image"""
        if self.base_models is None or self.meta_model is None:
            raise RuntimeError("Models not loaded - call load_models() first")

        # Load and preprocess image
        try:
            image = image.convert('RGB')
        except:
            raise ValueError("Could not load image from path")

        image_tensor = self.transform(image).unsqueeze(0).to(self.device)

        # Extract features
        with torch.no_grad():
            # Get probabilities from each binary model
            binary_probs = []
            for model in self.base_models:
                outputs = model(image_tensor)
                probs = torch.sigmoid(outputs).squeeze(1)
                binary_probs.append(probs)

            binary_features = torch.stack(binary_probs, dim=1)

            # Get image features
            image_features = self.extract_image_features(image_tensor)
            image_features = torch.from_numpy(image_features).float().to(self.device)

            # Calculate probability statistics
            top3_probs = torch.topk(binary_features, 3, dim=1).values
            prob_stats = torch.stack([
                binary_features.mean(dim=1, keepdim=True),
                binary_features.std(dim=1, keepdim=True),
                top3_probs.mean(dim=1, keepdim=True),
                (top3_probs[:, 0] - top3_probs[:, 2]).unsqueeze(1)  # Confidence gap
            ], dim=1).squeeze(-1)

            # Combine all features
            combined_features = torch.cat([
                binary_features,
                image_features,
                prob_stats
            ], dim=1)

        # Make prediction with meta-model
        with torch.no_grad():
            outputs = self.meta_model(combined_features)
            probabilities = torch.softmax(outputs, dim=1).squeeze().cpu().numpy()

        # Get top predictions
        top_indices = np.argsort(probabilities)[-top_k:][::-1]
        top_predictions = [
            (self.class_names[i], float(probabilities[i]))
            for i in top_indices
        ]

        return {
            "top_predictions": top_predictions,
            "all_probabilities": {name: float(prob) for name, prob in zip(self.class_names, probabilities)}
        }

    def predict_skincon(self, image, top_k=3):
        """Make prediction for a single image"""
        if self.base_models is None or self.skincon_model is None:
            raise RuntimeError("Models not loaded - call load_models() first")
        self.skincon_model.eval()
        try:
            image = image.convert('RGB')
        except:
            raise ValueError("Could not load image from path")

        image_tensor = self.transform(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            output_multi = self.skincon_model(image_tensor)
            probs_multi = torch.sigmoid(output_multi).squeeze().numpy()
            print(f"Probabilities : {probs_multi}")
            threshold = 0.5
            predicted_labels_multi = [self.multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > threshold]
        print("Predicted labels multi : ",predicted_labels_multi)
        return predicted_labels_multi


def initialize_classifier():
    print("⚙️ Initializing skin disease classifier...")
    classifier = SkinDiseaseClassifier()
    classifier.load_models()
    dummy_img = Image.new('RGB', (224, 224))
    classifier.predict(dummy_img)

    print("⚙️ Initialization successful")
    return classifier