Questaaaa commited on
Commit
a98b34c
·
verified ·
1 Parent(s): f9746a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -13
app.py CHANGED
@@ -9,14 +9,8 @@ model_name = "microsoft/beit-base-patch16-224"
9
  model = AutoModelForImageClassification.from_pretrained(model_name)
10
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
 
12
- # ImageNet 1000 类别名称(手动加载)
13
- imagenet_labels = [
14
- "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark",
15
- "electric ray", "stingray", "cock", "hen", "ostrich",
16
- "brambling", "goldfinch", "house finch", "junco", "indigo bunting",
17
- # ...... (省略 900 多个类别)
18
- "sports car", "convertible", "minivan", "pickup", "SUV"
19
- ]
20
 
21
  # 定义分类函数
22
  def classify_image(image):
@@ -34,13 +28,11 @@ def classify_image(image):
34
  predicted_class_idx = logits.argmax(-1).item()
35
 
36
  # 获取类别名称
37
- if predicted_class_idx < len(imagenet_labels):
38
- class_name = imagenet_labels[predicted_class_idx]
39
- else:
40
- class_name = f"Unknown Class (ID: {predicted_class_idx})"
41
-
42
  return f"Predicted class: {class_name} (ID: {predicted_class_idx})"
43
 
44
  # 创建 Gradio 界面
45
  demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")
46
  demo.launch()
 
 
9
  model = AutoModelForImageClassification.from_pretrained(model_name)
10
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
 
12
+ # 获取模型内置的类别标签
13
+ labels = model.config.id2label
 
 
 
 
 
 
14
 
15
  # 定义分类函数
16
  def classify_image(image):
 
28
  predicted_class_idx = logits.argmax(-1).item()
29
 
30
  # 获取类别名称
31
+ class_name = labels.get(predicted_class_idx, f"Unknown Class (ID: {predicted_class_idx})")
32
+
 
 
 
33
  return f"Predicted class: {class_name} (ID: {predicted_class_idx})"
34
 
35
  # 创建 Gradio 界面
36
  demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")
37
  demo.launch()
38
+