Spaces:
Build error
Build error
sashavor
commited on
Commit
·
b444b38
1
Parent(s):
2e4f4c8
trying to add comparison between model and user class
Browse files
app.py
CHANGED
@@ -11,11 +11,6 @@ feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch1
|
|
11 |
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
|
12 |
|
13 |
|
14 |
-
title="ImageNet Roulette"
|
15 |
-
description="Try guessing the category of each image displayed, from the options provided below.\
|
16 |
-
After 10 guesses, we will show you your accuracy!\
|
17 |
-
"
|
18 |
-
|
19 |
classdict = OrderedDict()
|
20 |
for line in open('LOC_synset_mapping.txt', 'r').readlines():
|
21 |
try:
|
@@ -37,15 +32,14 @@ def model_classify(radio, im):
|
|
37 |
outputs = model(**inputs)
|
38 |
logits = outputs.logits
|
39 |
predicted_class_idx = logits.argmax(-1).item()
|
40 |
-
return model.config.id2label[predicted_class_idx], True
|
41 |
else:
|
42 |
-
return None, False
|
43 |
|
44 |
def random_image():
|
45 |
imname = random.choice(images)
|
46 |
im = Image.open('images/'+ imname +'.jpg')
|
47 |
label = str(imagedict[imname])
|
48 |
-
print(label)
|
49 |
labels.remove(label)
|
50 |
options = sample(labels,3)
|
51 |
options.append(label)
|
@@ -55,9 +49,9 @@ def random_image():
|
|
55 |
|
56 |
def check_score(pred, truth, current_score, total_score, has_guessed):
|
57 |
if not(has_guessed):
|
58 |
-
if pred ==
|
59 |
total_score +=1
|
60 |
-
return current_score + 1, f"Your score is {current_score+1} out of {total_score}", total_score
|
61 |
else:
|
62 |
if pred is not None:
|
63 |
total_score +=1
|
@@ -67,34 +61,42 @@ def check_score(pred, truth, current_score, total_score, has_guessed):
|
|
67 |
|
68 |
|
69 |
|
70 |
-
def compare_score(userclass,
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
with gr.Blocks() as demo:
|
76 |
user_score = gr.State(0)
|
77 |
model_score = gr.State(0)
|
78 |
image_label = gr.State()
|
79 |
-
|
80 |
total_score = gr.State(0)
|
81 |
has_guessed = gr.State(False)
|
82 |
|
|
|
|
|
83 |
with gr.Row():
|
|
|
84 |
with gr.Column(min_width= 900):
|
85 |
image = gr.Image(shape=(600, 600))
|
86 |
radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
|
87 |
with gr.Column():
|
88 |
-
prediction = gr.Label(label="The AI predicts:")
|
89 |
score = gr.Label(label="Your Score")
|
90 |
-
|
91 |
|
92 |
btn = gr.Button("Next image")
|
93 |
|
94 |
demo.load(random_image, None, [image, image_label, radio, prediction])
|
95 |
-
radio.change(model_classify, [radio, image], [prediction, has_guessed])
|
96 |
radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score])
|
97 |
-
|
98 |
btn.click(random_image, None, [image, image_label, radio, prediction])
|
99 |
btn.click(lambda :False, None, has_guessed)
|
100 |
|
|
|
11 |
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
|
|
14 |
classdict = OrderedDict()
|
15 |
for line in open('LOC_synset_mapping.txt', 'r').readlines():
|
16 |
try:
|
|
|
32 |
outputs = model(**inputs)
|
33 |
logits = outputs.logits
|
34 |
predicted_class_idx = logits.argmax(-1).item()
|
35 |
+
return model.config.id2label[predicted_class_idx], predicted_class_idx, True
|
36 |
else:
|
37 |
+
return None, None, False
|
38 |
|
39 |
def random_image():
|
40 |
imname = random.choice(images)
|
41 |
im = Image.open('images/'+ imname +'.jpg')
|
42 |
label = str(imagedict[imname])
|
|
|
43 |
labels.remove(label)
|
44 |
options = sample(labels,3)
|
45 |
options.append(label)
|
|
|
49 |
|
50 |
def check_score(pred, truth, current_score, total_score, has_guessed):
|
51 |
if not(has_guessed):
|
52 |
+
if pred == classes[int(truth)]:
|
53 |
total_score +=1
|
54 |
+
return current_score + 1, f"Your score is {current_score+1} out of {total_score}!", total_score
|
55 |
else:
|
56 |
if pred is not None:
|
57 |
total_score +=1
|
|
|
61 |
|
62 |
|
63 |
|
64 |
+
def compare_score(userclass, model_class, truth, has_guessed):
|
65 |
+
print(model_class)
|
66 |
+
prediction= classes[int(model_class)]
|
67 |
+
if userclass == classes[int(truth)] == prediction:
|
68 |
+
return "Great! You and the model both got the correct answer"
|
69 |
+
elif userclass == classes[int(truth)]:
|
70 |
+
return "Great! You guessed it right"
|
71 |
+
elif prediction == classes[int(truth)]:
|
72 |
+
return "The AI model got it right this time, try again!"
|
73 |
|
74 |
with gr.Blocks() as demo:
|
75 |
user_score = gr.State(0)
|
76 |
model_score = gr.State(0)
|
77 |
image_label = gr.State()
|
78 |
+
model_class = gr.State()
|
79 |
total_score = gr.State(0)
|
80 |
has_guessed = gr.State(False)
|
81 |
|
82 |
+
gr.Markdown("# ImageNet Quiz")
|
83 |
+
gr.Markdown("Try guessing the category of each image displayed, from the options provided below.")
|
84 |
with gr.Row():
|
85 |
+
|
86 |
with gr.Column(min_width= 900):
|
87 |
image = gr.Image(shape=(600, 600))
|
88 |
radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
|
89 |
with gr.Column():
|
90 |
+
prediction = gr.Label(label="The AI model predicts:")
|
91 |
score = gr.Label(label="Your Score")
|
92 |
+
message = gr.Text(label="Who guessed it right?")
|
93 |
|
94 |
btn = gr.Button("Next image")
|
95 |
|
96 |
demo.load(random_image, None, [image, image_label, radio, prediction])
|
97 |
+
radio.change(model_classify, [radio, image], [prediction, model_class, has_guessed])
|
98 |
radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score])
|
99 |
+
radio.change(compare_score, [radio, prediction, image_label, has_guessed], message)
|
100 |
btn.click(random_image, None, [image, image_label, radio, prediction])
|
101 |
btn.click(lambda :False, None, has_guessed)
|
102 |
|