Spaces:
Sleeping
Sleeping
File size: 2,876 Bytes
620c260 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms, models
import pickle
with open('class_names.pkl', 'rb') as f:
class_names = pickle.load(f)
# 加载训练好的模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet50(weights=None)
model.fc = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(model.fc.in_features, len(class_names))
)
# 加载模型权重
model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True))
model = model.to(device)
model.eval()
# 定义与训练时相同的预处理流程
preprocess = transforms.Compose([
transforms.Resize((100, 100)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict_image(img):
img = img.convert('RGB')
# 应用预处理
input_tensor = preprocess(img)
# 添加批次维度并移动到设备
input_batch = input_tensor.unsqueeze(0).to(device)
# 预测
with torch.no_grad():
output = model(input_batch)
# 计算概率
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# 获取前3个预测结果
top3_probs, top3_indices = torch.topk(probabilities, 3)
results = {
class_names[i]: p.item()
for p, i in zip(top3_probs, top3_indices)
}
# 获取最佳预测结果
best_class = class_names[top3_indices[0]]
best_conf = top3_probs[0].item() * 100
# 保存结果
with open('prediction_results.txt', 'a') as f:
f.write(f"Image: {img}\n"
f"Predicted: {best_class}\n"
f"Confidence: {best_conf:.2f}%\n"
f"Top 3: {results}\n"
f"------------------------\n")
return best_class, best_conf, results
# 创建Gradio界面
def create_interface():
examples = [
"data/r0_0_100.jpg",
"data/r0_18_100.jpg"
]
with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🍎 水果识别系统")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="上传图像")
gr.Examples(examples=examples, inputs=image_input)
submit_btn = gr.Button("分类", variant="primary")
with gr.Column():
best_pred = gr.Textbox(label="预测结果")
confidence = gr.Textbox(label="置信度")
full_results = gr.Label(label="Top 3", num_top_classes=3)
# ‘分类’按钮点击事件
submit_btn.click(
fn=predict_image,
inputs=image_input,
outputs=[best_pred, confidence, full_results]
)
return demo
if __name__ == "__main__":
interface = create_interface()
interface.launch(share=True)
|