dorsar commited on
Commit
a92c35a
·
1 Parent(s): c42cd15

architecture

Browse files
Files changed (1) hide show
  1. architecture.py +178 -0
architecture.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import datasets, transforms, models
5
+ from torch.utils.data import DataLoader
6
+ import os
7
+ import copy
8
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
9
+ from torchvision.models import resnet50, ResNet50_Weights
10
+ import ssl
11
+ ssl._create_default_https_context = ssl._create_unverified_context
12
+
13
+ # data transformations with augmentation
14
+ train_transform = transforms.Compose([
15
+ transforms.RandomResizedCrop(224),
16
+ transforms.RandomHorizontalFlip(),
17
+ transforms.RandomRotation(10),
18
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
21
+ ])
22
+
23
+ val_test_transform = transforms.Compose([
24
+ transforms.Resize(256),
25
+ transforms.CenterCrop(224),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
28
+ ])
29
+
30
+
31
+ class ResNetLungCancer(nn.Module):
32
+ def __init__(self, num_classes, use_pretrained=True):
33
+ super(ResNetLungCancer, self).__init__()
34
+ if use_pretrained:
35
+ weights = ResNet50_Weights.IMAGENET1K_V1
36
+ else:
37
+ weights = None
38
+ self.resnet = resnet50(weights=weights)
39
+ num_ftrs = self.resnet.fc.in_features
40
+ self.resnet.fc = nn.Identity() # remove the final fully connected layer
41
+ self.fc = nn.Sequential(
42
+ nn.Linear(num_ftrs, 256),
43
+ nn.ReLU(),
44
+ nn.Dropout(0.5),
45
+ nn.Linear(256, num_classes)
46
+ )
47
+
48
+ def forward(self, x):
49
+ x = self.resnet(x)
50
+ return self.fc(x)
51
+
52
+
53
+ # train function
54
+ def train_model(model, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs=50, device='cuda'):
55
+ best_model_wts = copy.deepcopy(model.state_dict())
56
+ best_acc = 0.0
57
+
58
+ for epoch in range(num_epochs):
59
+ print(f'Epoch {epoch}/{num_epochs - 1}')
60
+ print('-' * 10)
61
+
62
+ for phase in ['train', 'valid']:
63
+ if phase == 'train':
64
+ model.train()
65
+ dataloader = train_loader
66
+ else:
67
+ model.eval()
68
+ dataloader = valid_loader
69
+
70
+ running_loss = 0.0
71
+ running_corrects = 0
72
+
73
+ for inputs, labels in dataloader:
74
+ inputs = inputs.to(device)
75
+ labels = labels.to(device)
76
+
77
+ optimizer.zero_grad()
78
+
79
+ with torch.set_grad_enabled(phase == 'train'):
80
+ outputs = model(inputs)
81
+ _, preds = torch.max(outputs, 1)
82
+ loss = criterion(outputs, labels)
83
+
84
+ if phase == 'train':
85
+ loss.backward()
86
+ optimizer.step()
87
+
88
+ running_loss += loss.item() * inputs.size(0)
89
+ running_corrects += torch.sum(preds == labels.data)
90
+
91
+ epoch_loss = running_loss / len(dataloader.dataset)
92
+ epoch_acc = running_corrects.double() / len(dataloader.dataset)
93
+
94
+ print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
95
+
96
+ if phase == 'valid':
97
+ scheduler.step(epoch_acc)
98
+ current_lr = optimizer.param_groups[0]['lr']
99
+ print(f'Learning rate: {current_lr}')
100
+ if epoch_acc > best_acc:
101
+ best_acc = epoch_acc
102
+ best_model_wts = copy.deepcopy(model.state_dict())
103
+
104
+ print()
105
+
106
+ print(f'Best val Acc: {best_acc:.4f}')
107
+ model.load_state_dict(best_model_wts)
108
+ return model
109
+
110
+ # eval the model
111
+ def evaluate_model(model, test_loader, device='cuda'):
112
+ model.eval()
113
+ running_corrects = 0
114
+
115
+ with torch.no_grad():
116
+ for inputs, labels in test_loader:
117
+ inputs = inputs.to(device)
118
+ labels = labels.to(device)
119
+
120
+ outputs = model(inputs)
121
+ _, preds = torch.max(outputs, 1)
122
+ running_corrects += torch.sum(preds == labels.data)
123
+
124
+ test_acc = running_corrects.double() / len(test_loader.dataset)
125
+ print(f'Test Acc: {test_acc:.4f}')
126
+
127
+ if __name__ == "__main__":
128
+ # device
129
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
130
+ print(f"Using device: {device}")
131
+
132
+ # data
133
+ data_dir = 'Processed_Data'
134
+ train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_transform)
135
+ valid_dataset = datasets.ImageFolder(os.path.join(data_dir, 'valid'), transform=val_test_transform)
136
+ test_dataset = datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=val_test_transform)
137
+
138
+ # dataloaders
139
+ batch_size = 32
140
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
141
+ valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
142
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
143
+
144
+ print(f"Number of training images: {len(train_dataset)}")
145
+ print(f"Number of validation images: {len(valid_dataset)}")
146
+ print(f"Number of test images: {len(test_dataset)}")
147
+
148
+ # initialize model, loss, and optimizer
149
+ num_classes = len(train_dataset.classes)
150
+ model = ResNetLungCancer(num_classes)
151
+ model = model.to(device)
152
+
153
+ criterion = nn.CrossEntropyLoss()
154
+
155
+ pretrained_params = list(model.resnet.parameters())
156
+ new_params = list(model.fc.parameters())
157
+
158
+ optimizer = optim.Adam([
159
+ {'params': pretrained_params, 'lr': 1e-5},
160
+ {'params': new_params, 'lr': 1e-4}
161
+ ], weight_decay=1e-6)
162
+
163
+ scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=7)
164
+
165
+ # train the model
166
+ trained_model = train_model(model, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs=50, device=device)
167
+
168
+ # eval the model
169
+ evaluate_model(trained_model, test_loader, device=device)
170
+
171
+ # save the model weights
172
+ torch.save(trained_model.state_dict(), 'lung_cancer_detection_model.pth')
173
+
174
+ # save the model in ONNX format
175
+ dummy_input = torch.randn(1, 3, 224, 224).to(device)
176
+ torch.onnx.export(trained_model, dummy_input, "lung_cancer_detection_model.onnx", input_names=['input'], output_names=['output'])
177
+
178
+ print("Training completed. Model saved.")