WwYc commited on
Commit
c935326
·
verified ·
1 Parent(s): 8727b39

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +2 -1
visualization.py CHANGED
@@ -54,7 +54,8 @@ def generate_visualization(original_image, class_index=None):
54
  return vis
55
 
56
 
57
- def print_top_classes(predictions, **kwargs):
 
58
  # Print Top-5 predictions
59
  prob = torch.softmax(predictions, dim=1)
60
  class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
 
54
  return vis
55
 
56
 
57
+ def print_top_classes(original_image, **kwargs):
58
+ predictions = model(original_image.unsqueeze(0))
59
  # Print Top-5 predictions
60
  prob = torch.softmax(predictions, dim=1)
61
  class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()