eaglelandsonce commited on
Commit
741a2c4
·
verified ·
1 Parent(s): 003dc9d

Create 20_ResNet2.py

Browse files
Files changed (1) hide show
  1. pages/20_ResNet2.py +128 -0
pages/20_ResNet2.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import matplotlib.pyplot as plt
6
+ import torchvision.transforms as transforms
7
+ from torchvision.datasets import CIFAR10
8
+ from torch.utils.data import DataLoader
9
+
10
+ # Define the ResNet model
11
+ class BasicBlock(nn.Module):
12
+ expansion = 1
13
+
14
+ def __init__(self, in_planes, planes, stride=1):
15
+ super(BasicBlock, self).__init__()
16
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
17
+ self.bn1 = nn.BatchNorm2d(planes)
18
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
19
+ self.bn2 = nn.BatchNorm2d(planes)
20
+
21
+ self.shortcut = nn.Sequential()
22
+ if stride != 1 or in_planes != self.expansion * planes:
23
+ self.shortcut = nn.Sequential(
24
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
25
+ nn.BatchNorm2d(self.expansion * planes)
26
+ )
27
+
28
+ def forward(self, x):
29
+ identity = x
30
+ out = F.relu(self.bn1(self.conv1(x)))
31
+ out = self.bn2(self.conv2(out))
32
+ out += self.shortcut(identity)
33
+ out = F.relu(out)
34
+ return out
35
+
36
+ class ResNet(nn.Module):
37
+ def __init__(self, block, num_blocks, num_classes=10):
38
+ super(ResNet, self).__init__()
39
+ self.in_planes = 64
40
+
41
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
42
+ self.bn1 = nn.BatchNorm2d(64)
43
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
44
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
45
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
46
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
47
+ self.linear = nn.Linear(512 * block.expansion, num_classes)
48
+
49
+ def _make_layer(self, block, planes, num_blocks, stride):
50
+ strides = [stride] + [1] * (num_blocks - 1)
51
+ layers = []
52
+ for stride in strides:
53
+ layers.append(block(self.in_planes, planes, stride))
54
+ self.in_planes = planes * block.expansion
55
+ return nn.Sequential(*layers)
56
+
57
+ def forward(self, x):
58
+ out = F.relu(self.bn1(self.conv1(x)))
59
+ out = self.layer1(out)
60
+ out = self.layer2(out)
61
+ out = self.layer3(out)
62
+ out = self.layer4(out)
63
+ out = F.avg_pool2d(out, 4)
64
+ out = out.view(out.size(0), -1)
65
+ out = self.linear(out)
66
+ return out
67
+
68
+ def ResNet18():
69
+ return ResNet(BasicBlock, [2, 2, 2, 2])
70
+
71
+ # Define a function to load CIFAR-10 dataset
72
+ def load_data():
73
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
74
+ train_set = CIFAR10(root='./data', train=True, download=True, transform=transform)
75
+ train_loader = DataLoader(train_set, batch_size=100, shuffle=True, num_workers=2)
76
+ return train_loader
77
+
78
+ # Streamlit Interface
79
+ st.title('ResNet with Streamlit')
80
+ st.write("This is an example of integrating a ResNet model with Streamlit.")
81
+
82
+ # Load data button
83
+ if st.button('Load Data'):
84
+ st.write("Loading CIFAR-10 data...")
85
+ train_loader = load_data()
86
+ st.write("Data loaded successfully!")
87
+
88
+ # Initialize and test the model
89
+ if st.button('Initialize and Test ResNet18'):
90
+ net = ResNet18()
91
+ sample_input = torch.randn(1, 3, 32, 32)
92
+ output = net(sample_input)
93
+ st.write("Output size: ", output.size())
94
+
95
+ # Train the model (for demonstration, we'll just do one epoch)
96
+ if st.button('Train ResNet18'):
97
+ st.write("Training ResNet18 on CIFAR-10...")
98
+ net = ResNet18()
99
+ train_loader = load_data()
100
+ criterion = nn.CrossEntropyLoss()
101
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
102
+
103
+ net.train()
104
+ for epoch in range(1): # Single epoch for demonstration
105
+ running_loss = 0.0
106
+ for i, data in enumerate(train_loader, 0):
107
+ inputs, labels = data
108
+ optimizer.zero_grad()
109
+ outputs = net(inputs)
110
+ loss = criterion(outputs, labels)
111
+ loss.backward()
112
+ optimizer.step()
113
+ running_loss += loss.item()
114
+ if i % 100 == 99: # Print every 100 mini-batches
115
+ st.write(f'Epoch [{epoch + 1}], Step [{i + 1}], Loss: {running_loss / 100:.4f}')
116
+ running_loss = 0.0
117
+
118
+ st.write("Training complete!")
119
+
120
+ # Plotting example (dummy plot for demonstration)
121
+ if st.button('Show Plot'):
122
+ st.write("Displaying a sample plot...")
123
+ fig, ax = plt.subplots()
124
+ ax.plot([1, 2, 3, 4], [1, 4, 2, 3])
125
+ st.pyplot(fig)
126
+
127
+ # To run the Streamlit app, use the command below in your terminal:
128
+ # streamlit run your_script_name.py