DHEIVER's picture
Update app.py
b0fa5f2 verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms, models
import pickle
from resnest.torch import resnest50
# Carregar nomes das classes originais
with open('class_names.pkl', 'rb') as f:
class_names_en = pickle.load(f)
# Imprimir as classes originais para debug
print("Classes originais encontradas:", class_names_en)
# Dicionário de tradução mais completo (incluindo variações)
class_names_pt = {
'apple': 'maçã',
'Apple': 'maçã',
'Apple 10': 'maçã', # adicionando variações
'banana': 'banana',
'Banana': 'banana',
'cherry': 'cereja',
'Cherry': 'cereja',
'chico': 'sapoti',
'grape': 'uva',
'Grape': 'uva',
'kiwi': 'kiwi',
'Kiwi': 'kiwi',
'mango': 'manga',
'Mango': 'manga',
'orange': 'laranja',
'Orange': 'laranja',
'pear': 'pera',
'Pear': 'pera',
'tomato': 'tomate',
'Tomato': 'tomate'
}
# Criar lista de nomes em português, usando o nome original se não houver tradução
class_names = []
for en in class_names_en:
# Remover números e espaços extras para normalizar
base_name = ''.join([i for i in en if not i.isdigit()]).strip()
translated = class_names_pt.get(base_name, class_names_pt.get(en, en))
class_names.append(translated)
print("Classes traduzidas:", class_names)
# Restante do código permanece igual...
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnest50(pretrained=None)
model.fc = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(model.fc.in_features, len(class_names))
)
# Carregar os pesos do modelo
model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True))
model = model.to(device)
model.eval()
# Definir o mesmo pré-processamento usado no treinamento
preprocess = transforms.Compose([
transforms.Resize((100, 100)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict_image(img):
img = img.convert('RGB')
# Aplicar pré-processamento
input_tensor = preprocess(img)
# Adicionar dimensão de batch e mover para o dispositivo
input_batch = input_tensor.unsqueeze(0).to(device)
# Fazer previsão
with torch.no_grad():
output = model(input_batch)
# Calcular probabilidades
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Obter as 3 melhores previsões
top3_probs, top3_indices = torch.topk(probabilities, 3)
results = {
class_names[i]: float(p)
for p, i in zip(top3_probs, top3_indices)
}
# Obter a melhor previsão
best_class = class_names[top3_indices[0]]
best_conf = float(top3_probs[0]) * 100
# Salvar resultados
with open('/tmp/prediction_results.txt', 'a') as f:
f.write(f"Imagem: {img}\n"
f"Previsão: {best_class}\n"
f"Confiança: {best_conf:.2f}%\n"
f"Top 3: {results}\n"
f"------------------------\n")
return best_class, f"{best_conf:.2f}%", results
# Criar interface Gradio
def create_interface():
examples = [
"r0_0_100.jpg",
"r0_18_100.jpg"
]
with gr.Blocks(title="Sistema de Classificação de Frutas", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🍎 Sistema de Reconhecimento de Frutas")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Envie uma imagem")
gr.Examples(examples=examples, inputs=image_input)
submit_btn = gr.Button("Classificar", variant="primary")
with gr.Column():
best_pred = gr.Textbox(label="Resultado da Previsão")
confidence = gr.Textbox(label="Nível de Confiança")
full_results = gr.Label(label="Top 3", num_top_classes=3)
# Evento de clique do botão 'Classificar'
submit_btn.click(
fn=predict_image,
inputs=image_input,
outputs=[best_pred, confidence, full_results]
)
return demo
if __name__ == "__main__":
interface = create_interface()
interface.launch(share=False)