File size: 1,348 Bytes
0e4149c
3e61f53
 
 
d1e8a6c
0e4149c
3e61f53
 
 
 
0e4149c
90fb249
 
 
 
 
 
d1e8a6c
0e4149c
 
d1e8a6c
 
 
 
 
 
 
 
3e61f53
d1e8a6c
3e61f53
d1e8a6c
 
 
90fb249
 
d1e8a6c
0e4149c
 
93ca441
0e4149c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image
import numpy as np

# 加载模型和特征提取器
model_name = "microsoft/beit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

# 加载 ImageNet 1000 类别名称
imagenet_labels = {
    idx: entry.strip() for idx, entry in enumerate(
        open("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt").readlines()
    )
}

# 定义分类函数
def classify_image(image):
    # 转换 PIL Image 为 numpy 数组
    if isinstance(image, Image.Image):
        image = np.array(image)

    # 进行特征提取
    inputs = feature_extractor(images=image, return_tensors="pt")

    # 预测类别
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()

    # 获取类别名称
    class_name = imagenet_labels.get(predicted_class_idx, f"Unknown Class (ID: {predicted_class_idx})")

    return f"Predicted class: {class_name} (ID: {predicted_class_idx})"

# 创建 Gradio 界面
demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")
demo.launch()