digits / app.py
etweedy's picture
Upload 2 files
fac144d
raw
history blame
1.02 kB
import torch
from torch import nn
import gradio as gr
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1,16,5,stride=1,padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(16,32,5,1,2),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.out = nn.Linear(32*7*7,10)
def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
x = x.view(-1,32*7*7)
return self.out(x)
model = CNN()
model.load_state_dict(torch.load('mnist2.pkl',map_location=torch.device('cpu')))
model.eval()
def predict(img):
x = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
with torch.no_grad():
pred = model(x)[0]
return int(pred.argmax())
gr.Interface(fn=predict,
inputs="sketchpad",
outputs="label",
).launch()