Spaces:
Sleeping
Sleeping
import os | |
# import h5py | |
import numpy as np | |
# import pandas as pd | |
# from sklearn.model_selection import train_test_split | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# from torch.utils.data import Dataset, DataLoader | |
from monai.networks.nets import SegResNet | |
from huggingface_hub import hf_hub_download | |
# from tqdm.notebook import tqdm, trange | |
# class EmbeddingsDataset(Dataset): | |
# """Helper class to load and work with the stored embeddings""" | |
# def __init__(self, embeddings_path, metadata_path, transform=None): | |
# """ | |
# Initialize the dataset | |
# Args: | |
# embeddings_path: Path to the directory containing H5 embedding files | |
# metadata_path: Path to the directory containing metadata files | |
# transform: Optional transforms to apply to the data | |
# """ | |
# self.embeddings_path = embeddings_path | |
# self.metadata_path = metadata_path | |
# self.transform = transform | |
# self.master_metadata = pd.read_parquet(os.path.join(metadata_path, "master_metadata.parquet")) | |
# # Limit to data with labels | |
# self.metadata = self.master_metadata.dropna(subset=['label']) | |
# def __len__(self): | |
# return len(self.metadata) | |
# def __getitem__(self, idx): | |
# """Get embedding and label for a specific index""" | |
# row = self.metadata.iloc[idx] | |
# batch_name = row['embedding_batch'] | |
# embedding_index = row['embedding_index'] | |
# label = row['label'] | |
# # Load the embedding | |
# h5_path = os.path.join(self.embeddings_path, f"{batch_name}.h5") | |
# with h5py.File(h5_path, 'r') as h5f: | |
# embedding = h5f['embeddings'][embedding_index] | |
# # Convert to PyTorch tensor | |
# embedding = torch.tensor(embedding, dtype=torch.float32) | |
# # Reshape for CNN input - we expect embeddings of shape (384,) | |
# # Reshape to (1, 384, 1, 1) for network input | |
# embedding = embedding.view(1, 384, 1) | |
# # Convert label to tensor (0=negative, 1=positive) | |
# label = torch.tensor(label, dtype=torch.long) | |
# if self.transform: | |
# embedding = self.transform(embedding) | |
# return embedding, label | |
# def get_embedding(self, file_id): | |
# """Get embedding for a specific file ID""" | |
# # Find the file in metadata | |
# file_info = self.master_metadata[self.master_metadata['file_id'] == file_id] | |
# if len(file_info) == 0: | |
# raise ValueError(f"File ID {file_id} not found in metadata") | |
# # Get the batch and index | |
# batch_name = file_info['embedding_batch'].iloc[0] | |
# embedding_index = file_info['embedding_index'].iloc[0] | |
# # Load the embedding | |
# h5_path = os.path.join(self.embeddings_path, f"{batch_name}.h5") | |
# with h5py.File(h5_path, 'r') as h5f: | |
# embedding = h5f['embeddings'][embedding_index] | |
# return embedding, file_info['label'].iloc[0] if 'label' in file_info.columns else None | |
class SelfSupervisedHead(nn.Module): | |
"""Self-supervised learning head for cancer classification | |
Since no coordinates or bounding boxes are available, this head focuses on | |
learning from the entire embedding through self-supervision. | |
""" | |
def __init__(self, in_channels, num_classes=2): | |
super(SelfSupervisedHead, self).__init__() | |
self.conv = nn.Conv2d(in_channels, 128, kernel_size=1) | |
self.bn = nn.BatchNorm2d(128) | |
self.relu = nn.ReLU(inplace=True) | |
self.global_pool = nn.AdaptiveAvgPool2d(1) | |
# Self-supervised projector (MLP) | |
self.projector = nn.Sequential( | |
nn.Linear(128, 256), | |
nn.BatchNorm1d(256), | |
nn.ReLU(inplace=True), | |
nn.Linear(256, 128) | |
) | |
# Classification layer | |
self.fc = nn.Linear(128, num_classes) | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.bn(x) | |
x = self.relu(x) | |
x = self.global_pool(x) | |
x = x.view(x.size(0), -1) | |
# Apply projector for self-supervised learning | |
features = self.projector(x) | |
# Classification output | |
output = self.fc(features) | |
return output, features | |
class SelfSupervisedCancerModel(nn.Module): | |
"""SegResNet with self-supervised learning head for cancer classification""" | |
def __init__(self, num_classes=2): | |
super(SelfSupervisedCancerModel, self).__init__() | |
# Initialize SegResNet as backbone | |
# Modified to work with 1-channel input and small input size | |
self.backbone = SegResNet( | |
spatial_dims=2, | |
in_channels=1, | |
out_channels=2, | |
blocks_down=[3, 4, 23, 3], | |
blocks_up=[3, 6, 3], | |
upsample_mode="deconv", | |
init_filters=32, | |
) | |
# We know from the structure that the final conv layer outputs 2 channels | |
# Look at the print of self.backbone.conv_final showing Conv2d(8, 2, ...) | |
backbone_out_channels = 2 | |
# Replace classifier with our self-supervised head | |
self.ssl_head = SelfSupervisedHead(backbone_out_channels, num_classes) | |
# Remove original classifier if needed | |
if hasattr(self.backbone, 'class_layers'): | |
self.backbone.class_layers = nn.Identity() | |
def forward(self, x, return_features=False): | |
# Run through backbone | |
features = self.backbone(x) | |
# Apply self-supervised head | |
output, proj_features = self.ssl_head(features) | |
if return_features: | |
return output, proj_features | |
return output | |
def load_model(): | |
path = hf_hub_download(repo_id="Arpit-Bansal/Medical-Diagnosing-models", filename="cancer_detector_model.pth", | |
) | |
model = SelfSupervisedCancerModel(num_classes=2) | |
state_dict = torch.load(path, map_location=torch.device('cpu')) | |
model.load_state_dict(state_dict=state_dict) | |
return model.eval() | |
def classify(model, embedding): | |
"""Classify a single embedding using the trained model""" | |
# Ensure the model is in evaluation | |
embedding_tensor = torch.tensor(embedding, dtype=torch.float32).view(1, 1, 384, 1) | |
with torch.no_grad(): | |
output = model(embedding_tensor) | |
probs = torch.softmax(output, dim=1) | |
predicted_class = torch.argmax(probs, dim=1).item() | |
confidence = probs[0, predicted_class].item() | |
prediction = "positive" if predicted_class == 1 else "negative" | |
return prediction, confidence | |