Questaaaa commited on
Commit
d1e8a6c
·
verified ·
1 Parent(s): 141eaa1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -4
app.py CHANGED
@@ -2,20 +2,38 @@ import gradio as gr
2
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
  import torch
4
  from PIL import Image
 
 
 
5
 
6
  # 加载模型和特征提取器
7
  model_name = "microsoft/beit-base-patch16-224"
8
  model = AutoModelForImageClassification.from_pretrained(model_name)
9
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
10
 
 
 
 
 
11
  # 定义分类函数
12
  def classify_image(image):
13
- image = feature_extractor(images=image, return_tensors="pt")
 
 
 
 
 
 
 
14
  with torch.no_grad():
15
- outputs = model(**image)
16
  logits = outputs.logits
17
- predicted_class = logits.argmax(-1).item()
18
- return f"Predicted class: {predicted_class}"
 
 
 
 
19
 
20
  # 创建 Gradio 界面
21
  demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")
 
2
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
  import torch
4
  from PIL import Image
5
+ import numpy as np
6
+ import json
7
+ import requests
8
 
9
  # 加载模型和特征提取器
10
  model_name = "microsoft/beit-base-patch16-224"
11
  model = AutoModelForImageClassification.from_pretrained(model_name)
12
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
13
 
14
+ # 获取 ImageNet 类别映射
15
+ LABELS_URL = "https://storage.googleapis.com/bit_models/imagenet21k_wordnet_id_map.json"
16
+ imagenet_classes = requests.get(LABELS_URL).json()
17
+
18
  # 定义分类函数
19
  def classify_image(image):
20
+ # 转换 PIL Image 为 numpy 数组
21
+ if isinstance(image, Image.Image):
22
+ image = np.array(image)
23
+
24
+ # 进行特征提取
25
+ inputs = feature_extractor(images=image, return_tensors="pt")
26
+
27
+ # 预测类别
28
  with torch.no_grad():
29
+ outputs = model(**inputs)
30
  logits = outputs.logits
31
+ predicted_class_idx = logits.argmax(-1).item()
32
+
33
+ # 获取类别名称
34
+ class_name = imagenet_classes.get(str(predicted_class_idx), "Unknown")
35
+
36
+ return f"Predicted class: {class_name} (ID: {predicted_class_idx})"
37
 
38
  # 创建 Gradio 界面
39
  demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")