sashavor commited on
Commit
27938aa
·
1 Parent(s): 00290a4

new version!

Browse files
Files changed (1) hide show
  1. app.py +46 -9
app.py CHANGED
@@ -5,8 +5,9 @@ from collections import OrderedDict
5
  from random import sample
6
  import csv
7
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
 
8
 
9
- extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
10
  model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
11
 
12
 
@@ -18,34 +19,70 @@ description="Try guessing the category of each image displayed, from the options
18
  classdict = OrderedDict()
19
  for line in open('LOC_synset_mapping.txt', 'r').readlines():
20
  try:
21
- classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','')
22
  except:
23
  continue
24
-
25
  imagedict={}
26
  with open('image_labels.csv', 'r') as csv_file:
27
  reader = csv.DictReader(csv_file)
28
  for row in reader:
29
  imagedict[row['image_name']] = row['image_label']
 
 
30
 
31
  def model_classify(im):
32
  inputs = feature_extractor(images=im, return_tensors="pt")
33
  outputs = model(**inputs)
34
  logits = outputs.logits
35
  predicted_class_idx = logits.argmax(-1).item()
36
- return("Predicted class:", model.config.id2label[predicted_class_idx])
 
37
 
38
- def check_answer(im):
 
 
 
 
 
 
 
 
 
39
 
40
- return {'cat': 0.3, 'dog': 0.7}
 
 
 
41
 
 
 
 
 
 
 
42
 
43
  with gr.Blocks() as demo:
 
 
 
 
 
44
  with gr.Row():
45
  with gr.Column():
46
- im = Image.open('images/'+sample(imagedict.keys(),1)[0]+'.jpg')
47
- image = gr.Image(im,shape=(224, 224))
48
- radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category")
 
 
 
 
 
49
 
 
 
 
 
 
50
 
51
  demo.launch()
 
5
  from random import sample
6
  import csv
7
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
8
+ import random
9
 
10
+ feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
11
  model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
12
 
13
 
 
19
  classdict = OrderedDict()
20
  for line in open('LOC_synset_mapping.txt', 'r').readlines():
21
  try:
22
+ classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','').split(',')[0]
23
  except:
24
  continue
25
+ classes = list(classdict.values())
26
  imagedict={}
27
  with open('image_labels.csv', 'r') as csv_file:
28
  reader = csv.DictReader(csv_file)
29
  for row in reader:
30
  imagedict[row['image_name']] = row['image_label']
31
+ images= list(imagedict.keys())
32
+ labels = list(set(imagedict.values()))
33
 
34
  def model_classify(im):
35
  inputs = feature_extractor(images=im, return_tensors="pt")
36
  outputs = model(**inputs)
37
  logits = outputs.logits
38
  predicted_class_idx = logits.argmax(-1).item()
39
+ return model.config.id2label[predicted_class_idx]
40
+
41
 
42
+ def random_image():
43
+ imname = random.choice(images)
44
+ im = Image.open('images/'+ imname +'.jpg')
45
+ label = str(imagedict[imname])
46
+ labels.remove(label)
47
+ options = sample(labels,3)
48
+ options.append(label)
49
+ random.shuffle(options)
50
+ options = [classes[int(i)] for i in options]
51
+ return im, label, gr.Radio.update(choices=options), None
52
 
53
+ def check_score(pred, truth, current_score):
54
+ if pred == classes[int(truth)]:
55
+ return current_score + 1, f"Your score is {current_score}"
56
+ return current_score, f"Your score is {current_score}"
57
 
58
+ def compare_score(userclass, prediction):
59
+ print(userclass)
60
+ print(prediction)
61
+ if userclass == str(prediction).split(',')[0]:
62
+ return "Great! You and the model agree on the category"
63
+ return "You and the model disagree"
64
 
65
  with gr.Blocks() as demo:
66
+ user_score = gr.State(0)
67
+ model_score = gr.State(0)
68
+ image_label = gr.State()
69
+ prediction = gr.State()
70
+
71
  with gr.Row():
72
  with gr.Column():
73
+ image = gr.Image(shape=(448, 448))
74
+ radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
75
+ with gr.Column():
76
+ prediction = gr.Label(label="Model Prediction")
77
+ score = gr.Label(label="Your Score")
78
+ message = gr.Text()
79
+
80
+ btn = gr.Button("Next image")
81
 
82
+ demo.load(random_image, None, [image, image_label, radio, prediction])
83
+ radio.change(model_classify, image, prediction)
84
+ radio.change(check_score, [radio, image_label, user_score], [user_score, score])
85
+ radio.change(compare_score, [radio, prediction], message)
86
+ btn.click(random_image, None, [image, image_label, radio, prediction])
87
 
88
  demo.launch()