Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,8 @@ import torchvision.transforms as transforms
|
|
9 |
# model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
|
10 |
# model.train()
|
11 |
|
|
|
|
|
12 |
import os
|
13 |
|
14 |
def print_bn():
|
@@ -45,7 +47,7 @@ def greet_backdoor(image):
|
|
45 |
image = transform_nor(image).unsqueeze(0)
|
46 |
print(image.shape)
|
47 |
output = model(image).squeeze()
|
48 |
-
return 'classified as
|
49 |
|
50 |
|
51 |
def greet(image):
|
|
|
9 |
# model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
|
10 |
# model.train()
|
11 |
|
12 |
+
id_label = {0:'airplane', 1:'automobile', 2:'bird', 3:'cat', 4:'deer', 5:'dog', 6:'frog', 7:'horse', 8:'ship', 9:'trunk'}
|
13 |
+
|
14 |
import os
|
15 |
|
16 |
def print_bn():
|
|
|
47 |
image = transform_nor(image).unsqueeze(0)
|
48 |
print(image.shape)
|
49 |
output = model(image).squeeze()
|
50 |
+
return 'classified as: ' + id_label[int(torch.argmax(output))]
|
51 |
|
52 |
|
53 |
def greet(image):
|