digits / app.py
etweedy's picture
Upload app.py
cbb6fd7
raw
history blame
1.92 kB
import torch
from torch import nn
import gradio as gr
# Define the custom CNN model class that was trained on the MNIST data
class CNN(nn.Module):
"""
A custom CNN class. The network has: (1) a convolution layer with 1 input channel and 16 output channels with ReLU activation and 2x2 max-pooling, (2) a second convolution layer with 16 input channels and 32 output channels with ReLU activation and 2x2 max-pooling, and (3) a linear output layer with 10 outputs.
"""
def __init__(self):
super(CNN,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1,16,5,stride=1,padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(16,32,5,1,2),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.out = nn.Linear(32*7*7,10)
# Forward propogation method
def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
x = x.view(-1,32*7*7)
return self.out(x)
# Initialize an instance and load in the saved state_dict for the trained model
model = CNN()
model.load_state_dict(torch.load('mnist2.pkl',map_location=torch.device('cpu')))
model.eval()
# Prediction function
def predict(img):
x = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
with torch.no_grad():
pred = model(x)[0]
return int(pred.argmax())
# Define and launch gradio interfact with sketchopad input and classification label output
title = "Guess that digit"
description = "Draw your favorite base-10 digit (0-9) and click submit - I'll try to guess what you drew! I do a bit better if you're not too messy and your digit is fairly centered."
gr.Interface(fn=predict,
inputs="sketchpad",
outputs="label",
title = title,
description = description,
).launch()