DHEIVER commited on
Commit
db99b27
·
verified ·
1 Parent(s): fff6c76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -25
app.py CHANGED
@@ -5,10 +5,28 @@ from torchvision import transforms, models
5
  import pickle
6
  from resnest.torch import resnest50
7
 
 
8
  with open('class_names.pkl', 'rb') as f:
9
- class_names = pickle.load(f)
10
-
11
- # 加载训练好的模型
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
  model = resnest50(pretrained=None)
@@ -17,12 +35,12 @@ model.fc = nn.Sequential(
17
  nn.Linear(model.fc.in_features, len(class_names))
18
  )
19
 
20
- # 加载模型权重
21
  model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True))
22
  model = model.to(device)
23
  model.eval()
24
 
25
- # 定义与训练时相同的预处理流程
26
  preprocess = transforms.Compose([
27
  transforms.Resize((100, 100)),
28
  transforms.ToTensor(),
@@ -33,20 +51,20 @@ preprocess = transforms.Compose([
33
  def predict_image(img):
34
  img = img.convert('RGB')
35
 
36
- # 应用预处理
37
  input_tensor = preprocess(img)
38
 
39
- # 添加批次维度并移动到设备
40
  input_batch = input_tensor.unsqueeze(0).to(device)
41
 
42
- # 预测
43
  with torch.no_grad():
44
  output = model(input_batch)
45
 
46
- # 计算概率
47
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
48
 
49
- # 获取前3个预测结果
50
  top3_probs, top3_indices = torch.topk(probabilities, 3)
51
 
52
  results = {
@@ -54,42 +72,42 @@ def predict_image(img):
54
  for p, i in zip(top3_probs, top3_indices)
55
  }
56
 
57
- # 获取最佳预测结果
58
  best_class = class_names[top3_indices[0]]
59
  best_conf = top3_probs[0].item() * 100
60
 
61
- # 保存结果
62
  with open('/tmp/prediction_results.txt', 'a') as f:
63
- f.write(f"Image: {img}\n"
64
- f"Predicted: {best_class}\n"
65
- f"Confidence: {best_conf:.2f}%\n"
66
  f"Top 3: {results}\n"
67
  f"------------------------\n")
68
 
69
- return best_class, best_conf, results
70
 
71
- # 创建Gradio界面
72
  def create_interface():
73
  examples = [
74
  "r0_0_100.jpg",
75
  "r0_18_100.jpg"
76
  ]
77
 
78
- with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
79
- gr.Markdown("# 🍎 水果识别系统")
80
 
81
  with gr.Row():
82
  with gr.Column():
83
- image_input = gr.Image(type="pil", label="上传图像")
84
  gr.Examples(examples=examples, inputs=image_input)
85
- submit_btn = gr.Button("分类", variant="primary")
86
 
87
  with gr.Column():
88
- best_pred = gr.Textbox(label="预测结果")
89
- confidence = gr.Textbox(label="置信度")
90
  full_results = gr.Label(label="Top 3", num_top_classes=3)
91
 
92
- # ‘分类’按钮点击事件
93
  submit_btn.click(
94
  fn=predict_image,
95
  inputs=image_input,
@@ -101,4 +119,4 @@ def create_interface():
101
 
102
  if __name__ == "__main__":
103
  interface = create_interface()
104
- interface.launch(share=False)
 
5
  import pickle
6
  from resnest.torch import resnest50
7
 
8
+ # Carregar nomes das classes e criar mapeamento para português
9
  with open('class_names.pkl', 'rb') as f:
10
+ class_names_en = pickle.load(f)
11
+
12
+ # Dicionário de tradução das classes para português
13
+ class_names_pt = {
14
+ 'apple': 'maçã',
15
+ 'banana': 'banana',
16
+ 'cherry': 'cereja',
17
+ 'chico': 'sapoti', # nome em português para chico/fruit
18
+ 'grape': 'uva',
19
+ 'kiwi': 'kiwi',
20
+ 'mango': 'manga',
21
+ 'orange': 'laranja',
22
+ 'pear': 'pera',
23
+ 'tomato': 'tomate'
24
+ }
25
+
26
+ # Criar lista de nomes em português na mesma ordem que class_names_en
27
+ class_names = [class_names_pt[en] for en in class_names_en]
28
+
29
+ # Carregar o modelo treinado
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
32
  model = resnest50(pretrained=None)
 
35
  nn.Linear(model.fc.in_features, len(class_names))
36
  )
37
 
38
+ # Carregar os pesos do modelo
39
  model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True))
40
  model = model.to(device)
41
  model.eval()
42
 
43
+ # Definir o mesmo pré-processamento usado no treinamento
44
  preprocess = transforms.Compose([
45
  transforms.Resize((100, 100)),
46
  transforms.ToTensor(),
 
51
  def predict_image(img):
52
  img = img.convert('RGB')
53
 
54
+ # Aplicar pré-processamento
55
  input_tensor = preprocess(img)
56
 
57
+ # Adicionar dimensão de batch e mover para o dispositivo
58
  input_batch = input_tensor.unsqueeze(0).to(device)
59
 
60
+ # Fazer previsão
61
  with torch.no_grad():
62
  output = model(input_batch)
63
 
64
+ # Calcular probabilidades
65
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
66
 
67
+ # Obter as 3 melhores previsões
68
  top3_probs, top3_indices = torch.topk(probabilities, 3)
69
 
70
  results = {
 
72
  for p, i in zip(top3_probs, top3_indices)
73
  }
74
 
75
+ # Obter a melhor previsão
76
  best_class = class_names[top3_indices[0]]
77
  best_conf = top3_probs[0].item() * 100
78
 
79
+ # Salvar resultados
80
  with open('/tmp/prediction_results.txt', 'a') as f:
81
+ f.write(f"Imagem: {img}\n"
82
+ f"Previsão: {best_class}\n"
83
+ f"Confiança: {best_conf:.2f}%\n"
84
  f"Top 3: {results}\n"
85
  f"------------------------\n")
86
 
87
+ return best_class, f"{best_conf:.2f}%", results
88
 
89
+ # Criar interface Gradio
90
  def create_interface():
91
  examples = [
92
  "r0_0_100.jpg",
93
  "r0_18_100.jpg"
94
  ]
95
 
96
+ with gr.Blocks(title="Sistema de Classificação de Frutas", theme=gr.themes.Soft()) as demo:
97
+ gr.Markdown("# 🍎 Sistema de Reconhecimento de Frutas")
98
 
99
  with gr.Row():
100
  with gr.Column():
101
+ image_input = gr.Image(type="pil", label="Envie uma imagem")
102
  gr.Examples(examples=examples, inputs=image_input)
103
+ submit_btn = gr.Button("Classificar", variant="primary")
104
 
105
  with gr.Column():
106
+ best_pred = gr.Textbox(label="Resultado da Previsão")
107
+ confidence = gr.Textbox(label="Nível de Confiança")
108
  full_results = gr.Label(label="Top 3", num_top_classes=3)
109
 
110
+ # Evento de clique do botão 'Classificar'
111
  submit_btn.click(
112
  fn=predict_image,
113
  inputs=image_input,
 
119
 
120
  if __name__ == "__main__":
121
  interface = create_interface()
122
+ interface.launch(share=False)