Questaaaa's picture
Update app.py
a98b34c verified
raw
history blame contribute delete
1.19 kB
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image
import numpy as np
# 加载模型和特征提取器
model_name = "microsoft/beit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# 获取模型内置的类别标签
labels = model.config.id2label
# 定义分类函数
def classify_image(image):
# 转换 PIL Image 为 numpy 数组
if isinstance(image, Image.Image):
image = np.array(image)
# 进行特征提取
inputs = feature_extractor(images=image, return_tensors="pt")
# 预测类别
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
# 获取类别名称
class_name = labels.get(predicted_class_idx, f"Unknown Class (ID: {predicted_class_idx})")
return f"Predicted class: {class_name} (ID: {predicted_class_idx})"
# 创建 Gradio 界面
demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")
demo.launch()