turtlegraphics commited on
Commit
739f38e
·
verified ·
1 Parent(s): 1d99f23

Switching to mobilenet_v2

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -2,18 +2,20 @@
2
  # gradio demo
3
  #
4
  import gradio as gr
5
- from transformers import ViTFeatureExtractor, ViTModel
6
 
7
- feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
8
- model = ViTModel.from_pretrained('akahana/vit-base-cats-vs-dogs')
9
 
10
  title = "Sandbox"
11
  description = "Place to try various models"
12
 
13
  def classify(image):
14
- inputs = feature_extractor(images=image, return_tensors="pt")
15
  outputs = model(**inputs)
16
- return str(outputs)
 
 
17
 
18
  demo = gr.Interface(fn=classify, inputs="image", outputs="text")
19
 
 
2
  # gradio demo
3
  #
4
  import gradio as gr
5
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
6
 
7
+ preprocessor = AutoImageProcessor.from_pretrained("google/mobilenet_v2_1.0_224")
8
+ model = AutoModelForImageClassification.from_pretrained("google/mobilenet_v2_1.0_224")
9
 
10
  title = "Sandbox"
11
  description = "Place to try various models"
12
 
13
  def classify(image):
14
+ inputs = preprocessor(images=image, return_tensors="pt")
15
  outputs = model(**inputs)
16
+ logits = outputs.logits
17
+ predicted_class_idx = logits.argmax(-1).item()
18
+ return model.config.id2label[predicted_class_idx]
19
 
20
  demo = gr.Interface(fn=classify, inputs="image", outputs="text")
21