import gradio as gr import torch import torch.nn as nn from torchvision import transforms, models import pickle with open('class_names.pkl', 'rb') as f: class_names = pickle.load(f) # 加载训练好的模型 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = models.resnet50(weights=None) model.fc = nn.Sequential( nn.Dropout(0.2), nn.Linear(model.fc.in_features, len(class_names)) ) # 加载模型权重 model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True)) model = model.to(device) model.eval() # 定义与训练时相同的预处理流程 preprocess = transforms.Compose([ transforms.Resize((100, 100)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def predict_image(img): img = img.convert('RGB') # 应用预处理 input_tensor = preprocess(img) # 添加批次维度并移动到设备 input_batch = input_tensor.unsqueeze(0).to(device) # 预测 with torch.no_grad(): output = model(input_batch) # 计算概率 probabilities = torch.nn.functional.softmax(output[0], dim=0) # 获取前3个预测结果 top3_probs, top3_indices = torch.topk(probabilities, 3) results = { class_names[i]: p.item() for p, i in zip(top3_probs, top3_indices) } # 获取最佳预测结果 best_class = class_names[top3_indices[0]] best_conf = top3_probs[0].item() * 100 # 保存结果 with open('prediction_results.txt', 'a') as f: f.write(f"Image: {img}\n" f"Predicted: {best_class}\n" f"Confidence: {best_conf:.2f}%\n" f"Top 3: {results}\n" f"------------------------\n") return best_class, best_conf, results # 创建Gradio界面 def create_interface(): examples = [ "data/r0_0_100.jpg", "data/r0_18_100.jpg" ] with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🍎 水果识别系统") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="上传图像") gr.Examples(examples=examples, inputs=image_input) submit_btn = gr.Button("分类", variant="primary") with gr.Column(): best_pred = gr.Textbox(label="预测结果") confidence = gr.Textbox(label="置信度") full_results = gr.Label(label="Top 3", num_top_classes=3) # ‘分类’按钮点击事件 submit_btn.click( fn=predict_image, inputs=image_input, outputs=[best_pred, confidence, full_results] ) return demo if __name__ == "__main__": interface = create_interface() interface.launch(share=True)