Spaces:
Sleeping
Sleeping
File size: 4,205 Bytes
620c260 b6fa136 620c260 b0fa5f2 620c260 db99b27 b0fa5f2 db99b27 b0fa5f2 db99b27 b0fa5f2 db99b27 b0fa5f2 db99b27 b0fa5f2 db99b27 b0fa5f2 db99b27 b0fa5f2 db99b27 b0fa5f2 db99b27 b0fa5f2 db99b27 b0fa5f2 db99b27 b0fa5f2 620c260 b6fa136 620c260 db99b27 620c260 db99b27 620c260 db99b27 620c260 db99b27 620c260 db99b27 620c260 db99b27 620c260 db99b27 620c260 b0fa5f2 620c260 db99b27 620c260 b0fa5f2 620c260 db99b27 fff6c76 db99b27 620c260 db99b27 620c260 db99b27 620c260 e42d91f 620c260 db99b27 620c260 db99b27 620c260 db99b27 620c260 db99b27 620c260 db99b27 620c260 db99b27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms, models
import pickle
from resnest.torch import resnest50
# Carregar nomes das classes originais
with open('class_names.pkl', 'rb') as f:
class_names_en = pickle.load(f)
# Imprimir as classes originais para debug
print("Classes originais encontradas:", class_names_en)
# Dicionário de tradução mais completo (incluindo variações)
class_names_pt = {
'apple': 'maçã',
'Apple': 'maçã',
'Apple 10': 'maçã', # adicionando variações
'banana': 'banana',
'Banana': 'banana',
'cherry': 'cereja',
'Cherry': 'cereja',
'chico': 'sapoti',
'grape': 'uva',
'Grape': 'uva',
'kiwi': 'kiwi',
'Kiwi': 'kiwi',
'mango': 'manga',
'Mango': 'manga',
'orange': 'laranja',
'Orange': 'laranja',
'pear': 'pera',
'Pear': 'pera',
'tomato': 'tomate',
'Tomato': 'tomate'
}
# Criar lista de nomes em português, usando o nome original se não houver tradução
class_names = []
for en in class_names_en:
# Remover números e espaços extras para normalizar
base_name = ''.join([i for i in en if not i.isdigit()]).strip()
translated = class_names_pt.get(base_name, class_names_pt.get(en, en))
class_names.append(translated)
print("Classes traduzidas:", class_names)
# Restante do código permanece igual...
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnest50(pretrained=None)
model.fc = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(model.fc.in_features, len(class_names))
)
# Carregar os pesos do modelo
model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True))
model = model.to(device)
model.eval()
# Definir o mesmo pré-processamento usado no treinamento
preprocess = transforms.Compose([
transforms.Resize((100, 100)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict_image(img):
img = img.convert('RGB')
# Aplicar pré-processamento
input_tensor = preprocess(img)
# Adicionar dimensão de batch e mover para o dispositivo
input_batch = input_tensor.unsqueeze(0).to(device)
# Fazer previsão
with torch.no_grad():
output = model(input_batch)
# Calcular probabilidades
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Obter as 3 melhores previsões
top3_probs, top3_indices = torch.topk(probabilities, 3)
results = {
class_names[i]: float(p)
for p, i in zip(top3_probs, top3_indices)
}
# Obter a melhor previsão
best_class = class_names[top3_indices[0]]
best_conf = float(top3_probs[0]) * 100
# Salvar resultados
with open('/tmp/prediction_results.txt', 'a') as f:
f.write(f"Imagem: {img}\n"
f"Previsão: {best_class}\n"
f"Confiança: {best_conf:.2f}%\n"
f"Top 3: {results}\n"
f"------------------------\n")
return best_class, f"{best_conf:.2f}%", results
# Criar interface Gradio
def create_interface():
examples = [
"r0_0_100.jpg",
"r0_18_100.jpg"
]
with gr.Blocks(title="Sistema de Classificação de Frutas", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🍎 Sistema de Reconhecimento de Frutas")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Envie uma imagem")
gr.Examples(examples=examples, inputs=image_input)
submit_btn = gr.Button("Classificar", variant="primary")
with gr.Column():
best_pred = gr.Textbox(label="Resultado da Previsão")
confidence = gr.Textbox(label="Nível de Confiança")
full_results = gr.Label(label="Top 3", num_top_classes=3)
# Evento de clique do botão 'Classificar'
submit_btn.click(
fn=predict_image,
inputs=image_input,
outputs=[best_pred, confidence, full_results]
)
return demo
if __name__ == "__main__":
interface = create_interface()
interface.launch(share=False) |