File size: 6,788 Bytes
bfc585e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
# import h5py
import numpy as np
# import pandas as pd
# from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch.utils.data import Dataset, DataLoader
from monai.networks.nets import SegResNet
from huggingface_hub import hf_hub_download
# from tqdm.notebook import tqdm, trange

# class EmbeddingsDataset(Dataset):
#     """Helper class to load and work with the stored embeddings"""
    
#     def __init__(self, embeddings_path, metadata_path, transform=None):
#         """
#         Initialize the dataset
        
#         Args:
#             embeddings_path: Path to the directory containing H5 embedding files
#             metadata_path: Path to the directory containing metadata files
#             transform: Optional transforms to apply to the data
#         """
#         self.embeddings_path = embeddings_path
#         self.metadata_path = metadata_path
#         self.transform = transform
#         self.master_metadata = pd.read_parquet(os.path.join(metadata_path, "master_metadata.parquet"))
#         # Limit to data with labels
#         self.metadata = self.master_metadata.dropna(subset=['label'])
        
#     def __len__(self):
#         return len(self.metadata)
    
#     def __getitem__(self, idx):
#         """Get embedding and label for a specific index"""
#         row = self.metadata.iloc[idx]
#         batch_name = row['embedding_batch']
#         embedding_index = row['embedding_index']
#         label = row['label']
        
#         # Load the embedding
#         h5_path = os.path.join(self.embeddings_path, f"{batch_name}.h5")
#         with h5py.File(h5_path, 'r') as h5f:
#             embedding = h5f['embeddings'][embedding_index]
        
#         # Convert to PyTorch tensor
#         embedding = torch.tensor(embedding, dtype=torch.float32)
        
#         # Reshape for CNN input - we expect embeddings of shape (384,)
#         # Reshape to (1, 384, 1, 1) for network input
#         embedding = embedding.view(1, 384, 1)
        
#         # Convert label to tensor (0=negative, 1=positive)
#         label = torch.tensor(label, dtype=torch.long)
        
#         if self.transform:
#             embedding = self.transform(embedding)
            
#         return embedding, label
    
#     def get_embedding(self, file_id):
#         """Get embedding for a specific file ID"""
#         # Find the file in metadata
#         file_info = self.master_metadata[self.master_metadata['file_id'] == file_id]
        
#         if len(file_info) == 0:
#             raise ValueError(f"File ID {file_id} not found in metadata")
        
#         # Get the batch and index
#         batch_name = file_info['embedding_batch'].iloc[0]
#         embedding_index = file_info['embedding_index'].iloc[0]
        
#         # Load the embedding
#         h5_path = os.path.join(self.embeddings_path, f"{batch_name}.h5")
#         with h5py.File(h5_path, 'r') as h5f:
#             embedding = h5f['embeddings'][embedding_index]
            
#         return embedding, file_info['label'].iloc[0] if 'label' in file_info.columns else None

class SelfSupervisedHead(nn.Module):
    """Self-supervised learning head for cancer classification
    
    Since no coordinates or bounding boxes are available, this head focuses on
    learning from the entire embedding through self-supervision.
    """
    def __init__(self, in_channels, num_classes=2):
        super(SelfSupervisedHead, self).__init__()
        self.conv = nn.Conv2d(in_channels, 128, kernel_size=1)
        self.bn = nn.BatchNorm2d(128)
        self.relu = nn.ReLU(inplace=True)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Self-supervised projector (MLP)
        self.projector = nn.Sequential(
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128)
        )
        
        # Classification layer
        self.fc = nn.Linear(128, num_classes)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        
        # Apply projector for self-supervised learning
        features = self.projector(x)
        
        # Classification output
        output = self.fc(features)
        return output, features

class SelfSupervisedCancerModel(nn.Module):
    """SegResNet with self-supervised learning head for cancer classification"""
    def __init__(self, num_classes=2):
        super(SelfSupervisedCancerModel, self).__init__()
        # Initialize SegResNet as backbone
        # Modified to work with 1-channel input and small input size
        self.backbone = SegResNet(
            spatial_dims=2,
            in_channels=1,
            out_channels=2,
            blocks_down=[3, 4, 23, 3],
            blocks_up=[3, 6, 3],
            upsample_mode="deconv",
            init_filters=32,
        )
        
        # We know from the structure that the final conv layer outputs 2 channels
        # Look at the print of self.backbone.conv_final showing Conv2d(8, 2, ...)
        backbone_out_channels = 2
        
        # Replace classifier with our self-supervised head
        self.ssl_head = SelfSupervisedHead(backbone_out_channels, num_classes)
        
        # Remove original classifier if needed
        if hasattr(self.backbone, 'class_layers'):
            self.backbone.class_layers = nn.Identity()
        
    def forward(self, x, return_features=False):
        # Run through backbone
        features = self.backbone(x)
        
        # Apply self-supervised head
        output, proj_features = self.ssl_head(features)
        
        if return_features:
            return output, proj_features
        return output

def load_model():
    path = hf_hub_download(repo_id="Arpit-Bansal/Medical-Diagnosing-models", filename="cancer_detector_model.pth",
                    )
    model = SelfSupervisedCancerModel(num_classes=2)
    state_dict = torch.load(path, map_location=torch.device('cpu'))

    model.load_state_dict(state_dict=state_dict)

    return model.eval()


def classify(model, embedding):
    """Classify a single embedding using the trained model"""
    # Ensure the model is in evaluation
    embedding_tensor = torch.tensor(embedding, dtype=torch.float32).view(1, 1, 384, 1)

    with torch.no_grad():
        output = model(embedding_tensor)
        probs = torch.softmax(output, dim=1)
        predicted_class = torch.argmax(probs, dim=1).item()
        confidence = probs[0, predicted_class].item()
    prediction = "positive" if predicted_class == 1 else "negative"
    
    return prediction, confidence