Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms, models | |
import pickle | |
from resnest.torch import resnest50 | |
# Carregar nomes das classes originais | |
with open('class_names.pkl', 'rb') as f: | |
class_names_en = pickle.load(f) | |
# Imprimir as classes originais para debug | |
print("Classes originais encontradas:", class_names_en) | |
# Dicionário de tradução mais completo (incluindo variações) | |
class_names_pt = { | |
'apple': 'maçã', | |
'Apple': 'maçã', | |
'Apple 10': 'maçã', # adicionando variações | |
'banana': 'banana', | |
'Banana': 'banana', | |
'cherry': 'cereja', | |
'Cherry': 'cereja', | |
'chico': 'sapoti', | |
'grape': 'uva', | |
'Grape': 'uva', | |
'kiwi': 'kiwi', | |
'Kiwi': 'kiwi', | |
'mango': 'manga', | |
'Mango': 'manga', | |
'orange': 'laranja', | |
'Orange': 'laranja', | |
'pear': 'pera', | |
'Pear': 'pera', | |
'tomato': 'tomate', | |
'Tomato': 'tomate' | |
} | |
# Criar lista de nomes em português, usando o nome original se não houver tradução | |
class_names = [] | |
for en in class_names_en: | |
# Remover números e espaços extras para normalizar | |
base_name = ''.join([i for i in en if not i.isdigit()]).strip() | |
translated = class_names_pt.get(base_name, class_names_pt.get(en, en)) | |
class_names.append(translated) | |
print("Classes traduzidas:", class_names) | |
# Restante do código permanece igual... | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = resnest50(pretrained=None) | |
model.fc = nn.Sequential( | |
nn.Dropout(0.2), | |
nn.Linear(model.fc.in_features, len(class_names)) | |
) | |
# Carregar os pesos do modelo | |
model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True)) | |
model = model.to(device) | |
model.eval() | |
# Definir o mesmo pré-processamento usado no treinamento | |
preprocess = transforms.Compose([ | |
transforms.Resize((100, 100)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
def predict_image(img): | |
img = img.convert('RGB') | |
# Aplicar pré-processamento | |
input_tensor = preprocess(img) | |
# Adicionar dimensão de batch e mover para o dispositivo | |
input_batch = input_tensor.unsqueeze(0).to(device) | |
# Fazer previsão | |
with torch.no_grad(): | |
output = model(input_batch) | |
# Calcular probabilidades | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
# Obter as 3 melhores previsões | |
top3_probs, top3_indices = torch.topk(probabilities, 3) | |
results = { | |
class_names[i]: float(p) | |
for p, i in zip(top3_probs, top3_indices) | |
} | |
# Obter a melhor previsão | |
best_class = class_names[top3_indices[0]] | |
best_conf = float(top3_probs[0]) * 100 | |
# Salvar resultados | |
with open('/tmp/prediction_results.txt', 'a') as f: | |
f.write(f"Imagem: {img}\n" | |
f"Previsão: {best_class}\n" | |
f"Confiança: {best_conf:.2f}%\n" | |
f"Top 3: {results}\n" | |
f"------------------------\n") | |
return best_class, f"{best_conf:.2f}%", results | |
# Criar interface Gradio | |
def create_interface(): | |
examples = [ | |
"r0_0_100.jpg", | |
"r0_18_100.jpg" | |
] | |
with gr.Blocks(title="Sistema de Classificação de Frutas", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 🍎 Sistema de Reconhecimento de Frutas") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Envie uma imagem") | |
gr.Examples(examples=examples, inputs=image_input) | |
submit_btn = gr.Button("Classificar", variant="primary") | |
with gr.Column(): | |
best_pred = gr.Textbox(label="Resultado da Previsão") | |
confidence = gr.Textbox(label="Nível de Confiança") | |
full_results = gr.Label(label="Top 3", num_top_classes=3) | |
# Evento de clique do botão 'Classificar' | |
submit_btn.click( | |
fn=predict_image, | |
inputs=image_input, | |
outputs=[best_pred, confidence, full_results] | |
) | |
return demo | |
if __name__ == "__main__": | |
interface = create_interface() | |
interface.launch(share=False) |