Spaces:
Build error
Build error
sashavor
commited on
Commit
·
27938aa
1
Parent(s):
00290a4
new version!
Browse files
app.py
CHANGED
@@ -5,8 +5,9 @@ from collections import OrderedDict
|
|
5 |
from random import sample
|
6 |
import csv
|
7 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
|
|
8 |
|
9 |
-
|
10 |
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
|
11 |
|
12 |
|
@@ -18,34 +19,70 @@ description="Try guessing the category of each image displayed, from the options
|
|
18 |
classdict = OrderedDict()
|
19 |
for line in open('LOC_synset_mapping.txt', 'r').readlines():
|
20 |
try:
|
21 |
-
classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','')
|
22 |
except:
|
23 |
continue
|
24 |
-
|
25 |
imagedict={}
|
26 |
with open('image_labels.csv', 'r') as csv_file:
|
27 |
reader = csv.DictReader(csv_file)
|
28 |
for row in reader:
|
29 |
imagedict[row['image_name']] = row['image_label']
|
|
|
|
|
30 |
|
31 |
def model_classify(im):
|
32 |
inputs = feature_extractor(images=im, return_tensors="pt")
|
33 |
outputs = model(**inputs)
|
34 |
logits = outputs.logits
|
35 |
predicted_class_idx = logits.argmax(-1).item()
|
36 |
-
return
|
|
|
37 |
|
38 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
|
|
|
|
|
|
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
44 |
with gr.Row():
|
45 |
with gr.Column():
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
49 |
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
demo.launch()
|
|
|
5 |
from random import sample
|
6 |
import csv
|
7 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
8 |
+
import random
|
9 |
|
10 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
|
11 |
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
|
12 |
|
13 |
|
|
|
19 |
classdict = OrderedDict()
|
20 |
for line in open('LOC_synset_mapping.txt', 'r').readlines():
|
21 |
try:
|
22 |
+
classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','').split(',')[0]
|
23 |
except:
|
24 |
continue
|
25 |
+
classes = list(classdict.values())
|
26 |
imagedict={}
|
27 |
with open('image_labels.csv', 'r') as csv_file:
|
28 |
reader = csv.DictReader(csv_file)
|
29 |
for row in reader:
|
30 |
imagedict[row['image_name']] = row['image_label']
|
31 |
+
images= list(imagedict.keys())
|
32 |
+
labels = list(set(imagedict.values()))
|
33 |
|
34 |
def model_classify(im):
|
35 |
inputs = feature_extractor(images=im, return_tensors="pt")
|
36 |
outputs = model(**inputs)
|
37 |
logits = outputs.logits
|
38 |
predicted_class_idx = logits.argmax(-1).item()
|
39 |
+
return model.config.id2label[predicted_class_idx]
|
40 |
+
|
41 |
|
42 |
+
def random_image():
|
43 |
+
imname = random.choice(images)
|
44 |
+
im = Image.open('images/'+ imname +'.jpg')
|
45 |
+
label = str(imagedict[imname])
|
46 |
+
labels.remove(label)
|
47 |
+
options = sample(labels,3)
|
48 |
+
options.append(label)
|
49 |
+
random.shuffle(options)
|
50 |
+
options = [classes[int(i)] for i in options]
|
51 |
+
return im, label, gr.Radio.update(choices=options), None
|
52 |
|
53 |
+
def check_score(pred, truth, current_score):
|
54 |
+
if pred == classes[int(truth)]:
|
55 |
+
return current_score + 1, f"Your score is {current_score}"
|
56 |
+
return current_score, f"Your score is {current_score}"
|
57 |
|
58 |
+
def compare_score(userclass, prediction):
|
59 |
+
print(userclass)
|
60 |
+
print(prediction)
|
61 |
+
if userclass == str(prediction).split(',')[0]:
|
62 |
+
return "Great! You and the model agree on the category"
|
63 |
+
return "You and the model disagree"
|
64 |
|
65 |
with gr.Blocks() as demo:
|
66 |
+
user_score = gr.State(0)
|
67 |
+
model_score = gr.State(0)
|
68 |
+
image_label = gr.State()
|
69 |
+
prediction = gr.State()
|
70 |
+
|
71 |
with gr.Row():
|
72 |
with gr.Column():
|
73 |
+
image = gr.Image(shape=(448, 448))
|
74 |
+
radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
|
75 |
+
with gr.Column():
|
76 |
+
prediction = gr.Label(label="Model Prediction")
|
77 |
+
score = gr.Label(label="Your Score")
|
78 |
+
message = gr.Text()
|
79 |
+
|
80 |
+
btn = gr.Button("Next image")
|
81 |
|
82 |
+
demo.load(random_image, None, [image, image_label, radio, prediction])
|
83 |
+
radio.change(model_classify, image, prediction)
|
84 |
+
radio.change(check_score, [radio, image_label, user_score], [user_score, score])
|
85 |
+
radio.change(compare_score, [radio, prediction], message)
|
86 |
+
btn.click(random_image, None, [image, image_label, radio, prediction])
|
87 |
|
88 |
demo.launch()
|