File size: 4,205 Bytes
620c260
 
 
 
 
b6fa136
620c260
b0fa5f2
620c260
db99b27
 
b0fa5f2
 
 
 
db99b27
 
b0fa5f2
 
db99b27
b0fa5f2
db99b27
b0fa5f2
 
db99b27
b0fa5f2
db99b27
b0fa5f2
db99b27
b0fa5f2
db99b27
b0fa5f2
db99b27
b0fa5f2
 
 
db99b27
 
b0fa5f2
 
 
 
 
 
 
 
 
db99b27
b0fa5f2
620c260
 
b6fa136
620c260
 
 
 
 
db99b27
620c260
 
 
 
db99b27
620c260
 
 
 
 
 
 
 
 
 
db99b27
620c260
 
db99b27
620c260
 
db99b27
620c260
 
 
db99b27
620c260
 
db99b27
620c260
 
 
b0fa5f2
620c260
 
 
db99b27
620c260
b0fa5f2
620c260
db99b27
fff6c76
db99b27
 
 
620c260
 
 
db99b27
620c260
db99b27
620c260
 
e42d91f
 
620c260
 
db99b27
 
620c260
 
 
db99b27
620c260
db99b27
620c260
 
db99b27
 
620c260
 
db99b27
620c260
 
 
 
 
 
 
 
 
 
 
db99b27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms, models
import pickle
from resnest.torch import resnest50

# Carregar nomes das classes originais
with open('class_names.pkl', 'rb') as f:
    class_names_en = pickle.load(f)

# Imprimir as classes originais para debug
print("Classes originais encontradas:", class_names_en)

# Dicionário de tradução mais completo (incluindo variações)
class_names_pt = {
    'apple': 'maçã',
    'Apple': 'maçã',
    'Apple 10': 'maçã',  # adicionando variações
    'banana': 'banana',
    'Banana': 'banana',
    'cherry': 'cereja',
    'Cherry': 'cereja',
    'chico': 'sapoti',
    'grape': 'uva',
    'Grape': 'uva',
    'kiwi': 'kiwi',
    'Kiwi': 'kiwi',
    'mango': 'manga',
    'Mango': 'manga',
    'orange': 'laranja',
    'Orange': 'laranja',
    'pear': 'pera',
    'Pear': 'pera',
    'tomato': 'tomate',
    'Tomato': 'tomate'
}

# Criar lista de nomes em português, usando o nome original se não houver tradução
class_names = []
for en in class_names_en:
    # Remover números e espaços extras para normalizar
    base_name = ''.join([i for i in en if not i.isdigit()]).strip()
    translated = class_names_pt.get(base_name, class_names_pt.get(en, en))
    class_names.append(translated)

print("Classes traduzidas:", class_names)

# Restante do código permanece igual...
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = resnest50(pretrained=None)
model.fc = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(model.fc.in_features, len(class_names))
)

# Carregar os pesos do modelo
model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True))
model = model.to(device)
model.eval()

# Definir o mesmo pré-processamento usado no treinamento
preprocess = transforms.Compose([
    transforms.Resize((100, 100)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


def predict_image(img):
    img = img.convert('RGB')

    # Aplicar pré-processamento
    input_tensor = preprocess(img)

    # Adicionar dimensão de batch e mover para o dispositivo
    input_batch = input_tensor.unsqueeze(0).to(device)

    # Fazer previsão
    with torch.no_grad():
        output = model(input_batch)

    # Calcular probabilidades
    probabilities = torch.nn.functional.softmax(output[0], dim=0)

    # Obter as 3 melhores previsões
    top3_probs, top3_indices = torch.topk(probabilities, 3)

    results = {
        class_names[i]: float(p)
        for p, i in zip(top3_probs, top3_indices)
    }

    # Obter a melhor previsão
    best_class = class_names[top3_indices[0]]
    best_conf = float(top3_probs[0]) * 100

    # Salvar resultados
    with open('/tmp/prediction_results.txt', 'a') as f:
        f.write(f"Imagem: {img}\n"
                f"Previsão: {best_class}\n"
                f"Confiança: {best_conf:.2f}%\n"
                f"Top 3: {results}\n"
                f"------------------------\n")

    return best_class, f"{best_conf:.2f}%", results

# Criar interface Gradio
def create_interface():
    examples = [
        "r0_0_100.jpg",
        "r0_18_100.jpg"
    ]

    with gr.Blocks(title="Sistema de Classificação de Frutas", theme=gr.themes.Soft()) as demo:
        gr.Markdown("# 🍎 Sistema de Reconhecimento de Frutas")

        with gr.Row():
            with gr.Column():
                image_input = gr.Image(type="pil", label="Envie uma imagem")
                gr.Examples(examples=examples, inputs=image_input)
                submit_btn = gr.Button("Classificar", variant="primary")

            with gr.Column():
                best_pred = gr.Textbox(label="Resultado da Previsão")
                confidence = gr.Textbox(label="Nível de Confiança")
                full_results = gr.Label(label="Top 3", num_top_classes=3)

        # Evento de clique do botão 'Classificar'
        submit_btn.click(
            fn=predict_image,
            inputs=image_input,
            outputs=[best_pred, confidence, full_results]
        )

    return demo


if __name__ == "__main__":
    interface = create_interface()
    interface.launch(share=False)