Arpit-Bansal commited on
Commit
bfc585e
·
1 Parent(s): 8568b2c

push files

Browse files
Files changed (6) hide show
  1. .gitignore +3 -0
  2. Dockerfile +14 -0
  3. __init__.py +0 -0
  4. main.py +42 -0
  5. model.py +183 -0
  6. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv/
2
+ models/
3
+ __pycache__/
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12
2
+ RUN useradd -m -u 1000 user
3
+ USER user
4
+ WORKDIR /code
5
+ RUN chown -R user:user /code
6
+ ENV HOME=/home/user
7
+ ENV PATH=/home/user/.local/bin:$PATH
8
+ WORKDIR $HOME/app
9
+ COPY ./requirements.txt ./
10
+ RUN pip install --no-cache-dir -r ./requirements.txt
11
+ COPY --chown=user . $HOME/app
12
+ # COPY --chown=user:user . /code
13
+ # COPY . .
14
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
__init__.py ADDED
File without changes
main.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import load_model, classify
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ from typing import List
5
+ import numpy as np
6
+ import uvicorn
7
+ from typing import List
8
+ app = FastAPI()
9
+
10
+ class InputData(BaseModel):
11
+ features: List[float]
12
+
13
+ # class InputData(BaseModel):
14
+ # features: List[float]
15
+
16
+ # @field_validator('features')
17
+ # def check_features_length(cls, v):
18
+ # if len(v) != 384:
19
+ # raise ValueError('Features must be a list of length 384')
20
+ # return v
21
+
22
+ global model
23
+
24
+ model = load_model()
25
+
26
+ @app.post("/classify")
27
+ async def classify_data(data: InputData):
28
+ try:
29
+ # Convert input to numpy array for model
30
+ features = np.array(data.features)
31
+
32
+ # Get prediction using the imported classify function
33
+ prediction, confidence = classify(model, features)
34
+
35
+ return {"prediction": prediction, "confidence": confidence}
36
+ except Exception as e:
37
+ raise HTTPException(status_code=500, detail=f"Error during classification: {str(e)}")
38
+
39
+ if __name__ == "__main__":
40
+ # Load the model at startup
41
+ load_model()
42
+ uvicorn.run(app, host="0.0.0.0", port=8000)
model.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # import h5py
3
+ import numpy as np
4
+ # import pandas as pd
5
+ # from sklearn.model_selection import train_test_split
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ # from torch.utils.data import Dataset, DataLoader
10
+ from monai.networks.nets import SegResNet
11
+ from huggingface_hub import hf_hub_download
12
+ # from tqdm.notebook import tqdm, trange
13
+
14
+ # class EmbeddingsDataset(Dataset):
15
+ # """Helper class to load and work with the stored embeddings"""
16
+
17
+ # def __init__(self, embeddings_path, metadata_path, transform=None):
18
+ # """
19
+ # Initialize the dataset
20
+
21
+ # Args:
22
+ # embeddings_path: Path to the directory containing H5 embedding files
23
+ # metadata_path: Path to the directory containing metadata files
24
+ # transform: Optional transforms to apply to the data
25
+ # """
26
+ # self.embeddings_path = embeddings_path
27
+ # self.metadata_path = metadata_path
28
+ # self.transform = transform
29
+ # self.master_metadata = pd.read_parquet(os.path.join(metadata_path, "master_metadata.parquet"))
30
+ # # Limit to data with labels
31
+ # self.metadata = self.master_metadata.dropna(subset=['label'])
32
+
33
+ # def __len__(self):
34
+ # return len(self.metadata)
35
+
36
+ # def __getitem__(self, idx):
37
+ # """Get embedding and label for a specific index"""
38
+ # row = self.metadata.iloc[idx]
39
+ # batch_name = row['embedding_batch']
40
+ # embedding_index = row['embedding_index']
41
+ # label = row['label']
42
+
43
+ # # Load the embedding
44
+ # h5_path = os.path.join(self.embeddings_path, f"{batch_name}.h5")
45
+ # with h5py.File(h5_path, 'r') as h5f:
46
+ # embedding = h5f['embeddings'][embedding_index]
47
+
48
+ # # Convert to PyTorch tensor
49
+ # embedding = torch.tensor(embedding, dtype=torch.float32)
50
+
51
+ # # Reshape for CNN input - we expect embeddings of shape (384,)
52
+ # # Reshape to (1, 384, 1, 1) for network input
53
+ # embedding = embedding.view(1, 384, 1)
54
+
55
+ # # Convert label to tensor (0=negative, 1=positive)
56
+ # label = torch.tensor(label, dtype=torch.long)
57
+
58
+ # if self.transform:
59
+ # embedding = self.transform(embedding)
60
+
61
+ # return embedding, label
62
+
63
+ # def get_embedding(self, file_id):
64
+ # """Get embedding for a specific file ID"""
65
+ # # Find the file in metadata
66
+ # file_info = self.master_metadata[self.master_metadata['file_id'] == file_id]
67
+
68
+ # if len(file_info) == 0:
69
+ # raise ValueError(f"File ID {file_id} not found in metadata")
70
+
71
+ # # Get the batch and index
72
+ # batch_name = file_info['embedding_batch'].iloc[0]
73
+ # embedding_index = file_info['embedding_index'].iloc[0]
74
+
75
+ # # Load the embedding
76
+ # h5_path = os.path.join(self.embeddings_path, f"{batch_name}.h5")
77
+ # with h5py.File(h5_path, 'r') as h5f:
78
+ # embedding = h5f['embeddings'][embedding_index]
79
+
80
+ # return embedding, file_info['label'].iloc[0] if 'label' in file_info.columns else None
81
+
82
+ class SelfSupervisedHead(nn.Module):
83
+ """Self-supervised learning head for cancer classification
84
+
85
+ Since no coordinates or bounding boxes are available, this head focuses on
86
+ learning from the entire embedding through self-supervision.
87
+ """
88
+ def __init__(self, in_channels, num_classes=2):
89
+ super(SelfSupervisedHead, self).__init__()
90
+ self.conv = nn.Conv2d(in_channels, 128, kernel_size=1)
91
+ self.bn = nn.BatchNorm2d(128)
92
+ self.relu = nn.ReLU(inplace=True)
93
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
94
+
95
+ # Self-supervised projector (MLP)
96
+ self.projector = nn.Sequential(
97
+ nn.Linear(128, 256),
98
+ nn.BatchNorm1d(256),
99
+ nn.ReLU(inplace=True),
100
+ nn.Linear(256, 128)
101
+ )
102
+
103
+ # Classification layer
104
+ self.fc = nn.Linear(128, num_classes)
105
+
106
+ def forward(self, x):
107
+ x = self.conv(x)
108
+ x = self.bn(x)
109
+ x = self.relu(x)
110
+ x = self.global_pool(x)
111
+ x = x.view(x.size(0), -1)
112
+
113
+ # Apply projector for self-supervised learning
114
+ features = self.projector(x)
115
+
116
+ # Classification output
117
+ output = self.fc(features)
118
+ return output, features
119
+
120
+ class SelfSupervisedCancerModel(nn.Module):
121
+ """SegResNet with self-supervised learning head for cancer classification"""
122
+ def __init__(self, num_classes=2):
123
+ super(SelfSupervisedCancerModel, self).__init__()
124
+ # Initialize SegResNet as backbone
125
+ # Modified to work with 1-channel input and small input size
126
+ self.backbone = SegResNet(
127
+ spatial_dims=2,
128
+ in_channels=1,
129
+ out_channels=2,
130
+ blocks_down=[3, 4, 23, 3],
131
+ blocks_up=[3, 6, 3],
132
+ upsample_mode="deconv",
133
+ init_filters=32,
134
+ )
135
+
136
+ # We know from the structure that the final conv layer outputs 2 channels
137
+ # Look at the print of self.backbone.conv_final showing Conv2d(8, 2, ...)
138
+ backbone_out_channels = 2
139
+
140
+ # Replace classifier with our self-supervised head
141
+ self.ssl_head = SelfSupervisedHead(backbone_out_channels, num_classes)
142
+
143
+ # Remove original classifier if needed
144
+ if hasattr(self.backbone, 'class_layers'):
145
+ self.backbone.class_layers = nn.Identity()
146
+
147
+ def forward(self, x, return_features=False):
148
+ # Run through backbone
149
+ features = self.backbone(x)
150
+
151
+ # Apply self-supervised head
152
+ output, proj_features = self.ssl_head(features)
153
+
154
+ if return_features:
155
+ return output, proj_features
156
+ return output
157
+
158
+ def load_model():
159
+ path = hf_hub_download(repo_id="Arpit-Bansal/Medical-Diagnosing-models", filename="cancer_detector_model.pth",
160
+ )
161
+ model = SelfSupervisedCancerModel(num_classes=2)
162
+ state_dict = torch.load(path, map_location=torch.device('cpu'))
163
+
164
+ model.load_state_dict(state_dict=state_dict)
165
+
166
+ return model.eval()
167
+
168
+
169
+ def classify(model, embedding):
170
+ """Classify a single embedding using the trained model"""
171
+ # Ensure the model is in evaluation
172
+ embedding_tensor = torch.tensor(embedding, dtype=torch.float32).view(1, 1, 384, 1)
173
+
174
+ with torch.no_grad():
175
+ output = model(embedding_tensor)
176
+ probs = torch.softmax(output, dim=1)
177
+ predicted_class = torch.argmax(probs, dim=1).item()
178
+ confidence = probs[0, predicted_class].item()
179
+ prediction = "positive" if predicted_class == 1 else "negative"
180
+
181
+ return prediction, confidence
182
+
183
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.7.0
2
+ monai==1.4.0
3
+ fastapi[all]
4
+ huggingface-hub==0.30.2