import os # import h5py import numpy as np # import pandas as pd # from sklearn.model_selection import train_test_split import torch import torch.nn as nn import torch.nn.functional as F # from torch.utils.data import Dataset, DataLoader from monai.networks.nets import SegResNet from huggingface_hub import hf_hub_download # from tqdm.notebook import tqdm, trange # class EmbeddingsDataset(Dataset): # """Helper class to load and work with the stored embeddings""" # def __init__(self, embeddings_path, metadata_path, transform=None): # """ # Initialize the dataset # Args: # embeddings_path: Path to the directory containing H5 embedding files # metadata_path: Path to the directory containing metadata files # transform: Optional transforms to apply to the data # """ # self.embeddings_path = embeddings_path # self.metadata_path = metadata_path # self.transform = transform # self.master_metadata = pd.read_parquet(os.path.join(metadata_path, "master_metadata.parquet")) # # Limit to data with labels # self.metadata = self.master_metadata.dropna(subset=['label']) # def __len__(self): # return len(self.metadata) # def __getitem__(self, idx): # """Get embedding and label for a specific index""" # row = self.metadata.iloc[idx] # batch_name = row['embedding_batch'] # embedding_index = row['embedding_index'] # label = row['label'] # # Load the embedding # h5_path = os.path.join(self.embeddings_path, f"{batch_name}.h5") # with h5py.File(h5_path, 'r') as h5f: # embedding = h5f['embeddings'][embedding_index] # # Convert to PyTorch tensor # embedding = torch.tensor(embedding, dtype=torch.float32) # # Reshape for CNN input - we expect embeddings of shape (384,) # # Reshape to (1, 384, 1, 1) for network input # embedding = embedding.view(1, 384, 1) # # Convert label to tensor (0=negative, 1=positive) # label = torch.tensor(label, dtype=torch.long) # if self.transform: # embedding = self.transform(embedding) # return embedding, label # def get_embedding(self, file_id): # """Get embedding for a specific file ID""" # # Find the file in metadata # file_info = self.master_metadata[self.master_metadata['file_id'] == file_id] # if len(file_info) == 0: # raise ValueError(f"File ID {file_id} not found in metadata") # # Get the batch and index # batch_name = file_info['embedding_batch'].iloc[0] # embedding_index = file_info['embedding_index'].iloc[0] # # Load the embedding # h5_path = os.path.join(self.embeddings_path, f"{batch_name}.h5") # with h5py.File(h5_path, 'r') as h5f: # embedding = h5f['embeddings'][embedding_index] # return embedding, file_info['label'].iloc[0] if 'label' in file_info.columns else None class SelfSupervisedHead(nn.Module): """Self-supervised learning head for cancer classification Since no coordinates or bounding boxes are available, this head focuses on learning from the entire embedding through self-supervision. """ def __init__(self, in_channels, num_classes=2): super(SelfSupervisedHead, self).__init__() self.conv = nn.Conv2d(in_channels, 128, kernel_size=1) self.bn = nn.BatchNorm2d(128) self.relu = nn.ReLU(inplace=True) self.global_pool = nn.AdaptiveAvgPool2d(1) # Self-supervised projector (MLP) self.projector = nn.Sequential( nn.Linear(128, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Linear(256, 128) ) # Classification layer self.fc = nn.Linear(128, num_classes) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) x = self.global_pool(x) x = x.view(x.size(0), -1) # Apply projector for self-supervised learning features = self.projector(x) # Classification output output = self.fc(features) return output, features class SelfSupervisedCancerModel(nn.Module): """SegResNet with self-supervised learning head for cancer classification""" def __init__(self, num_classes=2): super(SelfSupervisedCancerModel, self).__init__() # Initialize SegResNet as backbone # Modified to work with 1-channel input and small input size self.backbone = SegResNet( spatial_dims=2, in_channels=1, out_channels=2, blocks_down=[3, 4, 23, 3], blocks_up=[3, 6, 3], upsample_mode="deconv", init_filters=32, ) # We know from the structure that the final conv layer outputs 2 channels # Look at the print of self.backbone.conv_final showing Conv2d(8, 2, ...) backbone_out_channels = 2 # Replace classifier with our self-supervised head self.ssl_head = SelfSupervisedHead(backbone_out_channels, num_classes) # Remove original classifier if needed if hasattr(self.backbone, 'class_layers'): self.backbone.class_layers = nn.Identity() def forward(self, x, return_features=False): # Run through backbone features = self.backbone(x) # Apply self-supervised head output, proj_features = self.ssl_head(features) if return_features: return output, proj_features return output def load_model(): path = hf_hub_download(repo_id="Arpit-Bansal/Medical-Diagnosing-models", filename="cancer_detector_model.pth", ) model = SelfSupervisedCancerModel(num_classes=2) state_dict = torch.load(path, map_location=torch.device('cpu')) model.load_state_dict(state_dict=state_dict) return model.eval() def classify(model, embedding): """Classify a single embedding using the trained model""" # Ensure the model is in evaluation embedding_tensor = torch.tensor(embedding, dtype=torch.float32).view(1, 1, 384, 1) with torch.no_grad(): output = model(embedding_tensor) probs = torch.softmax(output, dim=1) predicted_class = torch.argmax(probs, dim=1).item() confidence = probs[0, predicted_class].item() prediction = "positive" if predicted_class == 1 else "negative" return prediction, confidence