Spaces:
Sleeping
Sleeping
architecture
Browse files- architecture.py +178 -0
architecture.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.optim as optim
|
4 |
+
from torchvision import datasets, transforms, models
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import os
|
7 |
+
import copy
|
8 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
9 |
+
from torchvision.models import resnet50, ResNet50_Weights
|
10 |
+
import ssl
|
11 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
12 |
+
|
13 |
+
# data transformations with augmentation
|
14 |
+
train_transform = transforms.Compose([
|
15 |
+
transforms.RandomResizedCrop(224),
|
16 |
+
transforms.RandomHorizontalFlip(),
|
17 |
+
transforms.RandomRotation(10),
|
18 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
19 |
+
transforms.ToTensor(),
|
20 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
21 |
+
])
|
22 |
+
|
23 |
+
val_test_transform = transforms.Compose([
|
24 |
+
transforms.Resize(256),
|
25 |
+
transforms.CenterCrop(224),
|
26 |
+
transforms.ToTensor(),
|
27 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
28 |
+
])
|
29 |
+
|
30 |
+
|
31 |
+
class ResNetLungCancer(nn.Module):
|
32 |
+
def __init__(self, num_classes, use_pretrained=True):
|
33 |
+
super(ResNetLungCancer, self).__init__()
|
34 |
+
if use_pretrained:
|
35 |
+
weights = ResNet50_Weights.IMAGENET1K_V1
|
36 |
+
else:
|
37 |
+
weights = None
|
38 |
+
self.resnet = resnet50(weights=weights)
|
39 |
+
num_ftrs = self.resnet.fc.in_features
|
40 |
+
self.resnet.fc = nn.Identity() # remove the final fully connected layer
|
41 |
+
self.fc = nn.Sequential(
|
42 |
+
nn.Linear(num_ftrs, 256),
|
43 |
+
nn.ReLU(),
|
44 |
+
nn.Dropout(0.5),
|
45 |
+
nn.Linear(256, num_classes)
|
46 |
+
)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
x = self.resnet(x)
|
50 |
+
return self.fc(x)
|
51 |
+
|
52 |
+
|
53 |
+
# train function
|
54 |
+
def train_model(model, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs=50, device='cuda'):
|
55 |
+
best_model_wts = copy.deepcopy(model.state_dict())
|
56 |
+
best_acc = 0.0
|
57 |
+
|
58 |
+
for epoch in range(num_epochs):
|
59 |
+
print(f'Epoch {epoch}/{num_epochs - 1}')
|
60 |
+
print('-' * 10)
|
61 |
+
|
62 |
+
for phase in ['train', 'valid']:
|
63 |
+
if phase == 'train':
|
64 |
+
model.train()
|
65 |
+
dataloader = train_loader
|
66 |
+
else:
|
67 |
+
model.eval()
|
68 |
+
dataloader = valid_loader
|
69 |
+
|
70 |
+
running_loss = 0.0
|
71 |
+
running_corrects = 0
|
72 |
+
|
73 |
+
for inputs, labels in dataloader:
|
74 |
+
inputs = inputs.to(device)
|
75 |
+
labels = labels.to(device)
|
76 |
+
|
77 |
+
optimizer.zero_grad()
|
78 |
+
|
79 |
+
with torch.set_grad_enabled(phase == 'train'):
|
80 |
+
outputs = model(inputs)
|
81 |
+
_, preds = torch.max(outputs, 1)
|
82 |
+
loss = criterion(outputs, labels)
|
83 |
+
|
84 |
+
if phase == 'train':
|
85 |
+
loss.backward()
|
86 |
+
optimizer.step()
|
87 |
+
|
88 |
+
running_loss += loss.item() * inputs.size(0)
|
89 |
+
running_corrects += torch.sum(preds == labels.data)
|
90 |
+
|
91 |
+
epoch_loss = running_loss / len(dataloader.dataset)
|
92 |
+
epoch_acc = running_corrects.double() / len(dataloader.dataset)
|
93 |
+
|
94 |
+
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
|
95 |
+
|
96 |
+
if phase == 'valid':
|
97 |
+
scheduler.step(epoch_acc)
|
98 |
+
current_lr = optimizer.param_groups[0]['lr']
|
99 |
+
print(f'Learning rate: {current_lr}')
|
100 |
+
if epoch_acc > best_acc:
|
101 |
+
best_acc = epoch_acc
|
102 |
+
best_model_wts = copy.deepcopy(model.state_dict())
|
103 |
+
|
104 |
+
print()
|
105 |
+
|
106 |
+
print(f'Best val Acc: {best_acc:.4f}')
|
107 |
+
model.load_state_dict(best_model_wts)
|
108 |
+
return model
|
109 |
+
|
110 |
+
# eval the model
|
111 |
+
def evaluate_model(model, test_loader, device='cuda'):
|
112 |
+
model.eval()
|
113 |
+
running_corrects = 0
|
114 |
+
|
115 |
+
with torch.no_grad():
|
116 |
+
for inputs, labels in test_loader:
|
117 |
+
inputs = inputs.to(device)
|
118 |
+
labels = labels.to(device)
|
119 |
+
|
120 |
+
outputs = model(inputs)
|
121 |
+
_, preds = torch.max(outputs, 1)
|
122 |
+
running_corrects += torch.sum(preds == labels.data)
|
123 |
+
|
124 |
+
test_acc = running_corrects.double() / len(test_loader.dataset)
|
125 |
+
print(f'Test Acc: {test_acc:.4f}')
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
# device
|
129 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
130 |
+
print(f"Using device: {device}")
|
131 |
+
|
132 |
+
# data
|
133 |
+
data_dir = 'Processed_Data'
|
134 |
+
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_transform)
|
135 |
+
valid_dataset = datasets.ImageFolder(os.path.join(data_dir, 'valid'), transform=val_test_transform)
|
136 |
+
test_dataset = datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=val_test_transform)
|
137 |
+
|
138 |
+
# dataloaders
|
139 |
+
batch_size = 32
|
140 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
141 |
+
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
142 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
143 |
+
|
144 |
+
print(f"Number of training images: {len(train_dataset)}")
|
145 |
+
print(f"Number of validation images: {len(valid_dataset)}")
|
146 |
+
print(f"Number of test images: {len(test_dataset)}")
|
147 |
+
|
148 |
+
# initialize model, loss, and optimizer
|
149 |
+
num_classes = len(train_dataset.classes)
|
150 |
+
model = ResNetLungCancer(num_classes)
|
151 |
+
model = model.to(device)
|
152 |
+
|
153 |
+
criterion = nn.CrossEntropyLoss()
|
154 |
+
|
155 |
+
pretrained_params = list(model.resnet.parameters())
|
156 |
+
new_params = list(model.fc.parameters())
|
157 |
+
|
158 |
+
optimizer = optim.Adam([
|
159 |
+
{'params': pretrained_params, 'lr': 1e-5},
|
160 |
+
{'params': new_params, 'lr': 1e-4}
|
161 |
+
], weight_decay=1e-6)
|
162 |
+
|
163 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=7)
|
164 |
+
|
165 |
+
# train the model
|
166 |
+
trained_model = train_model(model, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs=50, device=device)
|
167 |
+
|
168 |
+
# eval the model
|
169 |
+
evaluate_model(trained_model, test_loader, device=device)
|
170 |
+
|
171 |
+
# save the model weights
|
172 |
+
torch.save(trained_model.state_dict(), 'lung_cancer_detection_model.pth')
|
173 |
+
|
174 |
+
# save the model in ONNX format
|
175 |
+
dummy_input = torch.randn(1, 3, 224, 224).to(device)
|
176 |
+
torch.onnx.export(trained_model, dummy_input, "lung_cancer_detection_model.onnx", input_names=['input'], output_names=['output'])
|
177 |
+
|
178 |
+
print("Training completed. Model saved.")
|