jiang20 commited on
Commit
90a0539
·
1 Parent(s): b63fd51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -9,6 +9,8 @@ import torchvision.transforms as transforms
9
  # model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
10
  # model.train()
11
 
 
 
12
  import os
13
 
14
  def print_bn():
@@ -45,7 +47,7 @@ def greet_backdoor(image):
45
  image = transform_nor(image).unsqueeze(0)
46
  print(image.shape)
47
  output = model(image).squeeze()
48
- return 'classified as label: ' + str(int(torch.argmax(output)))
49
 
50
 
51
  def greet(image):
 
9
  # model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
10
  # model.train()
11
 
12
+ id_label = {0:'airplane', 1:'automobile', 2:'bird', 3:'cat', 4:'deer', 5:'dog', 6:'frog', 7:'horse', 8:'ship', 9:'trunk'}
13
+
14
  import os
15
 
16
  def print_bn():
 
47
  image = transform_nor(image).unsqueeze(0)
48
  print(image.shape)
49
  output = model(image).squeeze()
50
+ return 'classified as: ' + id_label[int(torch.argmax(output))]
51
 
52
 
53
  def greet(image):