Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms, models | |
import pickle | |
with open('class_names.pkl', 'rb') as f: | |
class_names = pickle.load(f) | |
# 加载训练好的模型 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = models.resnet50(weights=None) | |
model.fc = nn.Sequential( | |
nn.Dropout(0.2), | |
nn.Linear(model.fc.in_features, len(class_names)) | |
) | |
# 加载模型权重 | |
model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True)) | |
model = model.to(device) | |
model.eval() | |
# 定义与训练时相同的预处理流程 | |
preprocess = transforms.Compose([ | |
transforms.Resize((100, 100)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
def predict_image(img): | |
img = img.convert('RGB') | |
# 应用预处理 | |
input_tensor = preprocess(img) | |
# 添加批次维度并移动到设备 | |
input_batch = input_tensor.unsqueeze(0).to(device) | |
# 预测 | |
with torch.no_grad(): | |
output = model(input_batch) | |
# 计算概率 | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
# 获取前3个预测结果 | |
top3_probs, top3_indices = torch.topk(probabilities, 3) | |
results = { | |
class_names[i]: p.item() | |
for p, i in zip(top3_probs, top3_indices) | |
} | |
# 获取最佳预测结果 | |
best_class = class_names[top3_indices[0]] | |
best_conf = top3_probs[0].item() * 100 | |
# 保存结果 | |
with open('prediction_results.txt', 'a') as f: | |
f.write(f"Image: {img}\n" | |
f"Predicted: {best_class}\n" | |
f"Confidence: {best_conf:.2f}%\n" | |
f"Top 3: {results}\n" | |
f"------------------------\n") | |
return best_class, best_conf, results | |
# 创建Gradio界面 | |
def create_interface(): | |
examples = [ | |
"data/r0_0_100.jpg", | |
"data/r0_18_100.jpg" | |
] | |
with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 🍎 水果识别系统") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="上传图像") | |
gr.Examples(examples=examples, inputs=image_input) | |
submit_btn = gr.Button("分类", variant="primary") | |
with gr.Column(): | |
best_pred = gr.Textbox(label="预测结果") | |
confidence = gr.Textbox(label="置信度") | |
full_results = gr.Label(label="Top 3", num_top_classes=3) | |
# ‘分类’按钮点击事件 | |
submit_btn.click( | |
fn=predict_image, | |
inputs=image_input, | |
outputs=[best_pred, confidence, full_results] | |
) | |
return demo | |
if __name__ == "__main__": | |
interface = create_interface() | |
interface.launch(share=True) | |