Questaaaa's picture
Update app.py
f9746a2 verified
raw
history blame
1.6 kB
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image
import numpy as np
# 加载模型和特征提取器
model_name = "microsoft/beit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# ImageNet 1000 类别名称(手动加载)
imagenet_labels = [
"tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark",
"electric ray", "stingray", "cock", "hen", "ostrich",
"brambling", "goldfinch", "house finch", "junco", "indigo bunting",
# ...... (省略 900 多个类别)
"sports car", "convertible", "minivan", "pickup", "SUV"
]
# 定义分类函数
def classify_image(image):
# 转换 PIL Image 为 numpy 数组
if isinstance(image, Image.Image):
image = np.array(image)
# 进行特征提取
inputs = feature_extractor(images=image, return_tensors="pt")
# 预测类别
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
# 获取类别名称
if predicted_class_idx < len(imagenet_labels):
class_name = imagenet_labels[predicted_class_idx]
else:
class_name = f"Unknown Class (ID: {predicted_class_idx})"
return f"Predicted class: {class_name} (ID: {predicted_class_idx})"
# 创建 Gradio 界面
demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")
demo.launch()