yqcyqc's picture
Update app.py
b6fa136 verified
raw
history blame
2.9 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms, models
import pickle
from resnest.torch import resnest50
with open('class_names.pkl', 'rb') as f:
class_names = pickle.load(f)
# 加载训练好的模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnest50(pretrained=None)
model.fc = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(model.fc.in_features, len(class_names))
)
# 加载模型权重
model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True))
model = model.to(device)
model.eval()
# 定义与训练时相同的预处理流程
preprocess = transforms.Compose([
transforms.Resize((100, 100)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict_image(img):
img = img.convert('RGB')
# 应用预处理
input_tensor = preprocess(img)
# 添加批次维度并移动到设备
input_batch = input_tensor.unsqueeze(0).to(device)
# 预测
with torch.no_grad():
output = model(input_batch)
# 计算概率
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# 获取前3个预测结果
top3_probs, top3_indices = torch.topk(probabilities, 3)
results = {
class_names[i]: p.item()
for p, i in zip(top3_probs, top3_indices)
}
# 获取最佳预测结果
best_class = class_names[top3_indices[0]]
best_conf = top3_probs[0].item() * 100
# 保存结果
with open('prediction_results.txt', 'a') as f:
f.write(f"Image: {img}\n"
f"Predicted: {best_class}\n"
f"Confidence: {best_conf:.2f}%\n"
f"Top 3: {results}\n"
f"------------------------\n")
return best_class, best_conf, results
# 创建Gradio界面
def create_interface():
examples = [
"r0_0_100.jpg",
"r0_18_100.jpg"
]
with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🍎 水果识别系统")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="上传图像")
gr.Examples(examples=examples, inputs=image_input)
submit_btn = gr.Button("分类", variant="primary")
with gr.Column():
best_pred = gr.Textbox(label="预测结果")
confidence = gr.Textbox(label="置信度")
full_results = gr.Label(label="Top 3", num_top_classes=3)
# ‘分类’按钮点击事件
submit_btn.click(
fn=predict_image,
inputs=image_input,
outputs=[best_pred, confidence, full_results]
)
return demo
if __name__ == "__main__":
interface = create_interface()
interface.launch(share=False)