Questaaaa's picture
Update app.py
90fb249 verified
raw
history blame
1.35 kB
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image
import numpy as np
# 加载模型和特征提取器
model_name = "microsoft/beit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# 加载 ImageNet 1000 类别名称
imagenet_labels = {
idx: entry.strip() for idx, entry in enumerate(
open("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt").readlines()
)
}
# 定义分类函数
def classify_image(image):
# 转换 PIL Image 为 numpy 数组
if isinstance(image, Image.Image):
image = np.array(image)
# 进行特征提取
inputs = feature_extractor(images=image, return_tensors="pt")
# 预测类别
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
# 获取类别名称
class_name = imagenet_labels.get(predicted_class_idx, f"Unknown Class (ID: {predicted_class_idx})")
return f"Predicted class: {class_name} (ID: {predicted_class_idx})"
# 创建 Gradio 界面
demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")
demo.launch()