ClassCat commited on
Commit
c5efafb
Β·
1 Parent(s): 47371c2

add app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import ToTensor
6
+
7
+ # Define model
8
+ class ConvNet(nn.Module):
9
+ def __init__(self):
10
+ super(CNN, self).__init__()
11
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
12
+ self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
13
+ self.conv3 = nn.Conv2d(32,64, kernel_size=5)
14
+ self.fc1 = nn.Linear(3*3*64, 256)
15
+ self.fc2 = nn.Linear(256, 10)
16
+
17
+ def forward(self, x):
18
+ x = F.relu(self.conv1(x))
19
+ #x = F.dropout(x, p=0.5, training=self.training)
20
+ x = F.relu(F.max_pool2d(self.conv2(x), 2))
21
+ x = F.dropout(x, p=0.5, training=self.training)
22
+ x = F.relu(F.max_pool2d(self.conv3(x),2))
23
+ x = F.dropout(x, p=0.5, training=self.training)
24
+ x = x.view(-1,3*3*64 )
25
+ x = F.relu(self.fc1(x))
26
+ x = F.dropout(x, training=self.training)
27
+ logits = self.fc2(x)
28
+ return logits
29
+
30
+
31
+ model = ConvNet()
32
+ model.load_state_dict(
33
+ torch.load("weights/mnist_convnet_model.pth",
34
+ map_location=torch.device('cpu'))
35
+ )
36
+
37
+ model.eval()
38
+
39
+ import gradio as gr
40
+ from torchvision import transforms
41
+
42
+ def predict(image):
43
+ tsr_image = transforms.ToTensor()(image)
44
+
45
+ with torch.no_grad():
46
+ pred = model(tsr_image)
47
+ prob = torch.nn.functional.softmax(pred[0], dim=0)
48
+
49
+ confidences = {i: float(prob[i]) for i in range(10)}
50
+ return confidences
51
+
52
+
53
+ with gr.Blocks(css=".gradio-container {background:lightyellow;color:red;}", title="MNIST εˆ†ι‘žε™¨"
54
+ ) as demo:
55
+ gr.HTML('<div style="font-size:12pt; text-align:center; color:yellow;"MNIST εˆ†ι‘žε™¨</div>')
56
+
57
+ with gr.Row():
58
+ with gr.Tab("キャンバス"):
59
+ input_image1 = gr.Image(label="画像ε…₯εŠ›", source="canvas", type="pil", image_mode="L", shape=(28,28), invert_colors=True)
60
+ send_btn1 = gr.Button("ζŽ¨θ«–γ™γ‚‹")
61
+
62
+ with gr.Tab("画像フゑむル"):
63
+ input_image2 = gr.Image(label="画像ε…₯εŠ›", type="pil", image_mode="L", shape=(28, 28), invert_colors=True)
64
+ send_btn2 = gr.Button("ζŽ¨θ«–γ™γ‚‹")
65
+ gr.Examples(['examples/sample2.png', 'examples/sample4.png'], inputs=input_image2)
66
+
67
+ output_label=gr.Label(label="ζŽ¨θ«–η’ΊηŽ‡", num_top_classes=3)
68
+
69
+ send_btn1.click(fn=predict, inputs=input_image1, outputs=output_label)
70
+ send_btn2.click(fn=predict, inputs=input_image2, outputs=output_label)
71
+
72
+ # demo.queue(concurrency_count=3)
73
+ demo.launch()
74
+
75
+
76
+ ### EOF ###