Spaces:
Build error
Build error
File size: 3,328 Bytes
f37cfad 0490df4 27938aa 0490df4 27938aa 0490df4 f37cfad 27938aa f37cfad 27938aa f37cfad 27938aa f37cfad a0a2e9f f37cfad 27938aa a0a2e9f 27938aa dc0eec0 820d8a7 a0a2e9f 27938aa a0a2e9f f37cfad 27938aa f37cfad 27938aa a0a2e9f 27938aa f37cfad 9e1d8e2 a0a2e9f 27938aa 9e1d8e2 27938aa f37cfad 27938aa a0a2e9f 9e1d8e2 27938aa f37cfad dc0eec0 f37cfad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import gradio as gr
from datasets import load_dataset
from PIL import Image
from collections import OrderedDict
from random import sample
import csv
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import random
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
title="ImageNet Roulette"
description="Try guessing the category of each image displayed, from the options provided below.\
After 10 guesses, we will show you your accuracy!\
"
classdict = OrderedDict()
for line in open('LOC_synset_mapping.txt', 'r').readlines():
try:
classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','').split(',')[0]
except:
continue
classes = list(classdict.values())
imagedict={}
with open('image_labels.csv', 'r') as csv_file:
reader = csv.DictReader(csv_file)
for row in reader:
imagedict[row['image_name']] = row['image_label']
images= list(imagedict.keys())
labels = list(set(imagedict.values()))
def model_classify(radio, im):
if radio is not None:
inputs = feature_extractor(images=im, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return model.config.id2label[predicted_class_idx]
def random_image():
imname = random.choice(images)
im = Image.open('images/'+ imname +'.jpg')
label = str(imagedict[imname])
print(label)
labels.remove(label)
options = sample(labels,3)
options.append(label)
random.shuffle(options)
options = [classes[int(i)] for i in options]
return im, label, gr.Radio.update(value=None, choices=options), None
def check_score(pred, truth, current_score, total_score):
if pred == classes[int(truth)]:
total_score +=1
return current_score + 1, f"Your score is {current_score+1} out of {total_score}"
else:
total_score +=1
return current_score, f"Your score is {current_score} out of {total_score}"
def compare_score(userclass, prediction):
if userclass == str(prediction).split(',')[0]:
return "Great! You and the model agree on the category"
return "You and the model disagree"
with gr.Blocks() as demo:
user_score = gr.State(0)
model_score = gr.State(0)
image_label = gr.State()
prediction = gr.State()
total_score = gr.State(0)
with gr.Row():
with gr.Column(min_width= 900):
image = gr.Image(shape=(600, 600))
radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
with gr.Column():
prediction = gr.Label(label="Model Prediction")
score = gr.Label(label="Your Score")
#message = gr.Text()
btn = gr.Button("Next image")
demo.load(random_image, None, [image, image_label, radio, prediction])
radio.change(model_classify, [radio, image], prediction)
radio.change(check_score, [radio, image_label, user_score, total_score], [user_score, score])
#radio.change(compare_score, [radio, prediction], message)
btn.click(random_image, None, [image, image_label, radio, prediction])
demo.launch()
|