chrisjay's picture
updated app.py to not train
32e3ac6
raw
history blame contribute delete
6.32 kB
import os
import torch
import gradio as gr
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# This is just to show an interface where one draws a number and gets prediction.
n_epochs = 10
batch_size_train = 128
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10
random_seed = 1
TRAIN_CUTOFF = 10
MODEL_PATH = 'weights'
os.makedirs(MODEL_PATH,exist_ok=True)
METRIC_PATH = os.path.join(MODEL_PATH,'metrics.json')
MODEL_WEIGHTS_PATH = os.path.join(MODEL_PATH,'mnist_model.pth')
OPTIMIZER_PATH = os.path.join(MODEL_PATH,'optimizer.pth')
REPOSITORY_DIR = "data"
LOCAL_DIR = 'data_local'
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_REPO = 'mnist-adversarial-model'
HF_DATASET ="mnist-adversarial-dataset"
DATASET_REPO_URL = f"https://huggingface.co/datasets/chrisjay/{HF_DATASET}"
MODEL_REPO_URL = f"https://huggingface.co/model/chrisjay/{MODEL_REPO}"
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
TRAIN_TRANSFORM = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])
# Source: https://nextjournal.com/gkoehler/pytorch-mnist
class MNIST_Model(nn.Module):
def __init__(self):
super(MNIST_Model, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('files/', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=(0.1307,), std=(0.3081,))
])),
batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('files/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size_test, shuffle=True)
def train(epoch,network,optimizer,train_loader):
train_losses=[]
network.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = network(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
train_losses.append(loss.item())
torch.save(network.state_dict(), MODEL_WEIGHTS_PATH)
torch.save(optimizer.state_dict(), OPTIMIZER_PATH)
def test():
test_losses=[]
network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = network(data)
test_loss += F.nll_loss(output, target, size_average=False).item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
test_losses.append(test_loss)
acc = 100. * correct / len(test_loader.dataset)
acc = acc.item()
test_metric = '〽Current test metric -> Avg. loss: `{:.4f}`, Accuracy: `{:.0f}%`\n'.format(
test_loss,acc)
print(test_metric)
return test_metric,acc
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
network = MNIST_Model() #Initialize the model with random weights
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
momentum=momentum)
model_state_dict = MODEL_WEIGHTS_PATH
optimizer_state_dict = OPTIMIZER_PATH
if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
network_state_dict = torch.load(model_state_dict)
network.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load(optimizer_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
# Train
#for epoch in range(n_epochs):
# train(epoch,network,optimizer,train_loader)
# test()
def image_classifier(inp):
"""
It takes an image as input and returns a dictionary of class labels and their corresponding
confidence scores.
:param inp: the image to be classified
:return: A dictionary of the class index and the confidence value.
"""
input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0)
#pred_number = prediction.data.max(1, keepdim=True)[1]
sorted_prediction = torch.sort(prediction,descending=True)
confidences={}
for s,v in zip(sorted_prediction.indices.numpy().tolist(),sorted_prediction.values.numpy().tolist()):
confidences.update({s:v})
return confidences
def main():
block = gr.Blocks()
with block:
with gr.Row():
image_input =gr.inputs.Image(source="canvas",shape=(28,28),invert_colors=True,image_mode="L",type="pil")
label_output = gr.outputs.Label(num_top_classes=10)
image_input.change(image_classifier,inputs = [image_input],outputs=[label_output])
block.launch()
if __name__ == "__main__":
main()