File size: 1,595 Bytes
0e4149c
3e61f53
 
 
d1e8a6c
0e4149c
3e61f53
 
 
 
0e4149c
f9746a2
 
 
 
 
 
 
 
d1e8a6c
0e4149c
 
d1e8a6c
 
 
 
 
 
 
 
3e61f53
d1e8a6c
3e61f53
d1e8a6c
 
 
f9746a2
 
 
 
 
d1e8a6c
0e4149c
 
93ca441
0e4149c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image
import numpy as np

# 加载模型和特征提取器
model_name = "microsoft/beit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

# ImageNet 1000 类别名称(手动加载)
imagenet_labels = [
    "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", 
    "electric ray", "stingray", "cock", "hen", "ostrich",
    "brambling", "goldfinch", "house finch", "junco", "indigo bunting",
    # ...... (省略 900 多个类别)
    "sports car", "convertible", "minivan", "pickup", "SUV"
]

# 定义分类函数
def classify_image(image):
    # 转换 PIL Image 为 numpy 数组
    if isinstance(image, Image.Image):
        image = np.array(image)

    # 进行特征提取
    inputs = feature_extractor(images=image, return_tensors="pt")

    # 预测类别
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()

    # 获取类别名称
    if predicted_class_idx < len(imagenet_labels):
        class_name = imagenet_labels[predicted_class_idx]
    else:
        class_name = f"Unknown Class (ID: {predicted_class_idx})"
    
    return f"Predicted class: {class_name} (ID: {predicted_class_idx})"

# 创建 Gradio 界面
demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")
demo.launch()