Questaaaa commited on
Commit
a7b8099
·
verified ·
1 Parent(s): d1e8a6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -3,17 +3,20 @@ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
  import torch
4
  from PIL import Image
5
  import numpy as np
6
- import json
7
- import requests
8
 
9
  # 加载模型和特征提取器
10
  model_name = "microsoft/beit-base-patch16-224"
11
  model = AutoModelForImageClassification.from_pretrained(model_name)
12
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
13
 
14
- # 获取 ImageNet 类别映射
15
- LABELS_URL = "https://storage.googleapis.com/bit_models/imagenet21k_wordnet_id_map.json"
16
- imagenet_classes = requests.get(LABELS_URL).json()
 
 
 
 
 
17
 
18
  # 定义分类函数
19
  def classify_image(image):
@@ -31,7 +34,7 @@ def classify_image(image):
31
  predicted_class_idx = logits.argmax(-1).item()
32
 
33
  # 获取类别名称
34
- class_name = imagenet_classes.get(str(predicted_class_idx), "Unknown")
35
 
36
  return f"Predicted class: {class_name} (ID: {predicted_class_idx})"
37
 
 
3
  import torch
4
  from PIL import Image
5
  import numpy as np
 
 
6
 
7
  # 加载模型和特征提取器
8
  model_name = "microsoft/beit-base-patch16-224"
9
  model = AutoModelForImageClassification.from_pretrained(model_name)
10
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
 
12
+ # ImageNet 1000 类别名称(从 Hugging Face 官方下载)
13
+ imagenet_classes = [
14
+ "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", # 0-4
15
+ "electric ray", "stingray", "cock", "hen", "ostrich", # 5-9
16
+ "brambling", "goldfinch", "house finch", "junco", "indigo bunting", # 10-14
17
+ # 省略中间 900 多个类别...
18
+ "sports car", "convertible", "minivan", "pickup", "SUV" # 817-821(汽车类)
19
+ ]
20
 
21
  # 定义分类函数
22
  def classify_image(image):
 
34
  predicted_class_idx = logits.argmax(-1).item()
35
 
36
  # 获取类别名称
37
+ class_name = imagenet_classes[predicted_class_idx] if predicted_class_idx < len(imagenet_classes) else f"Unknown Class (ID: {predicted_class_idx})"
38
 
39
  return f"Predicted class: {class_name} (ID: {predicted_class_idx})"
40