File size: 2,876 Bytes
620c260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms, models
import pickle


with open('class_names.pkl', 'rb') as f:
    class_names = pickle.load(f)

# 加载训练好的模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.resnet50(weights=None)
model.fc = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(model.fc.in_features, len(class_names))
)

# 加载模型权重
model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True))
model = model.to(device)
model.eval()

# 定义与训练时相同的预处理流程
preprocess = transforms.Compose([
    transforms.Resize((100, 100)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


def predict_image(img):
    img = img.convert('RGB')

    # 应用预处理
    input_tensor = preprocess(img)

    # 添加批次维度并移动到设备
    input_batch = input_tensor.unsqueeze(0).to(device)

    # 预测
    with torch.no_grad():
        output = model(input_batch)

    # 计算概率
    probabilities = torch.nn.functional.softmax(output[0], dim=0)

    # 获取前3个预测结果
    top3_probs, top3_indices = torch.topk(probabilities, 3)

    results = {
        class_names[i]: p.item()
        for p, i in zip(top3_probs, top3_indices)
    }

    # 获取最佳预测结果
    best_class = class_names[top3_indices[0]]
    best_conf = top3_probs[0].item() * 100

    # 保存结果
    with open('prediction_results.txt', 'a') as f:
        f.write(f"Image: {img}\n"
                f"Predicted: {best_class}\n"
                f"Confidence: {best_conf:.2f}%\n"
                f"Top 3: {results}\n"
                f"------------------------\n")

    return best_class, best_conf, results

# 创建Gradio界面
def create_interface():
    examples = [
        "data/r0_0_100.jpg",
        "data/r0_18_100.jpg"
    ]

    with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
        gr.Markdown("# 🍎 水果识别系统")

        with gr.Row():
            with gr.Column():
                image_input = gr.Image(type="pil", label="上传图像")
                gr.Examples(examples=examples, inputs=image_input)
                submit_btn = gr.Button("分类", variant="primary")

            with gr.Column():
                best_pred = gr.Textbox(label="预测结果")
                confidence = gr.Textbox(label="置信度")
                full_results = gr.Label(label="Top 3", num_top_classes=3)

        # ‘分类’按钮点击事件
        submit_btn.click(
            fn=predict_image,
            inputs=image_input,
            outputs=[best_pred, confidence, full_results]
        )

    return demo


if __name__ == "__main__":
    interface = create_interface()
    interface.launch(share=True)