Questaaaa's picture
Update app.py
3e61f53 verified
raw
history blame
793 Bytes
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image
# 加载模型和特征提取器
model_name = "microsoft/beit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# 定义分类函数
def classify_image(image):
image = feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**image)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()
return f"Predicted class: {predicted_class}"
# 创建 Gradio 界面
demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")
demo.launch()