sashavor commited on
Commit
b444b38
·
1 Parent(s): 2e4f4c8

trying to add comparison between model and user class

Browse files
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -11,11 +11,6 @@ feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch1
11
  model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
12
 
13
 
14
- title="ImageNet Roulette"
15
- description="Try guessing the category of each image displayed, from the options provided below.\
16
- After 10 guesses, we will show you your accuracy!\
17
- "
18
-
19
  classdict = OrderedDict()
20
  for line in open('LOC_synset_mapping.txt', 'r').readlines():
21
  try:
@@ -37,15 +32,14 @@ def model_classify(radio, im):
37
  outputs = model(**inputs)
38
  logits = outputs.logits
39
  predicted_class_idx = logits.argmax(-1).item()
40
- return model.config.id2label[predicted_class_idx], True
41
  else:
42
- return None, False
43
 
44
  def random_image():
45
  imname = random.choice(images)
46
  im = Image.open('images/'+ imname +'.jpg')
47
  label = str(imagedict[imname])
48
- print(label)
49
  labels.remove(label)
50
  options = sample(labels,3)
51
  options.append(label)
@@ -55,9 +49,9 @@ def random_image():
55
 
56
  def check_score(pred, truth, current_score, total_score, has_guessed):
57
  if not(has_guessed):
58
- if pred == "lion":
59
  total_score +=1
60
- return current_score + 1, f"Your score is {current_score+1} out of {total_score}", total_score
61
  else:
62
  if pred is not None:
63
  total_score +=1
@@ -67,34 +61,42 @@ def check_score(pred, truth, current_score, total_score, has_guessed):
67
 
68
 
69
 
70
- def compare_score(userclass, prediction):
71
- if userclass == str(prediction).split(',')[0]:
72
- return "Great! You and the model agree on the category"
73
- return "You and the model disagree"
 
 
 
 
 
74
 
75
  with gr.Blocks() as demo:
76
  user_score = gr.State(0)
77
  model_score = gr.State(0)
78
  image_label = gr.State()
79
- prediction = gr.State()
80
  total_score = gr.State(0)
81
  has_guessed = gr.State(False)
82
 
 
 
83
  with gr.Row():
 
84
  with gr.Column(min_width= 900):
85
  image = gr.Image(shape=(600, 600))
86
  radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
87
  with gr.Column():
88
- prediction = gr.Label(label="The AI predicts:")
89
  score = gr.Label(label="Your Score")
90
- #message = gr.Text()
91
 
92
  btn = gr.Button("Next image")
93
 
94
  demo.load(random_image, None, [image, image_label, radio, prediction])
95
- radio.change(model_classify, [radio, image], [prediction, has_guessed])
96
  radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score])
97
- #radio.change(compare_score, [radio, prediction], message)
98
  btn.click(random_image, None, [image, image_label, radio, prediction])
99
  btn.click(lambda :False, None, has_guessed)
100
 
 
11
  model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
12
 
13
 
 
 
 
 
 
14
  classdict = OrderedDict()
15
  for line in open('LOC_synset_mapping.txt', 'r').readlines():
16
  try:
 
32
  outputs = model(**inputs)
33
  logits = outputs.logits
34
  predicted_class_idx = logits.argmax(-1).item()
35
+ return model.config.id2label[predicted_class_idx], predicted_class_idx, True
36
  else:
37
+ return None, None, False
38
 
39
  def random_image():
40
  imname = random.choice(images)
41
  im = Image.open('images/'+ imname +'.jpg')
42
  label = str(imagedict[imname])
 
43
  labels.remove(label)
44
  options = sample(labels,3)
45
  options.append(label)
 
49
 
50
  def check_score(pred, truth, current_score, total_score, has_guessed):
51
  if not(has_guessed):
52
+ if pred == classes[int(truth)]:
53
  total_score +=1
54
+ return current_score + 1, f"Your score is {current_score+1} out of {total_score}!", total_score
55
  else:
56
  if pred is not None:
57
  total_score +=1
 
61
 
62
 
63
 
64
+ def compare_score(userclass, model_class, truth, has_guessed):
65
+ print(model_class)
66
+ prediction= classes[int(model_class)]
67
+ if userclass == classes[int(truth)] == prediction:
68
+ return "Great! You and the model both got the correct answer"
69
+ elif userclass == classes[int(truth)]:
70
+ return "Great! You guessed it right"
71
+ elif prediction == classes[int(truth)]:
72
+ return "The AI model got it right this time, try again!"
73
 
74
  with gr.Blocks() as demo:
75
  user_score = gr.State(0)
76
  model_score = gr.State(0)
77
  image_label = gr.State()
78
+ model_class = gr.State()
79
  total_score = gr.State(0)
80
  has_guessed = gr.State(False)
81
 
82
+ gr.Markdown("# ImageNet Quiz")
83
+ gr.Markdown("Try guessing the category of each image displayed, from the options provided below.")
84
  with gr.Row():
85
+
86
  with gr.Column(min_width= 900):
87
  image = gr.Image(shape=(600, 600))
88
  radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
89
  with gr.Column():
90
+ prediction = gr.Label(label="The AI model predicts:")
91
  score = gr.Label(label="Your Score")
92
+ message = gr.Text(label="Who guessed it right?")
93
 
94
  btn = gr.Button("Next image")
95
 
96
  demo.load(random_image, None, [image, image_label, radio, prediction])
97
+ radio.change(model_classify, [radio, image], [prediction, model_class, has_guessed])
98
  radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score])
99
+ radio.change(compare_score, [radio, prediction, image_label, has_guessed], message)
100
  btn.click(random_image, None, [image, image_label, radio, prediction])
101
  btn.click(lambda :False, None, has_guessed)
102