JeanCGuerrero's picture
Update app.py
49bc36c verified
raw
history blame
2.5 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
# Modelo Autoencoder
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
self.fc1 = nn.Linear(128 * 8 * 8, 32)
self.fc2 = nn.Linear(32, 128 * 8 * 8)
self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
self.conv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
self.conv6 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1)
def encode(self, x):
z = torch.tanh(self.conv1(x))
z = torch.tanh(self.conv2(z))
z = torch.tanh(self.conv3(z))
z = z.view(z.size(0), -1)
z = torch.tanh(self.fc1(z))
return z
def decode(self, x):
z = torch.tanh(self.fc2(x))
z = z.view(z.size(0), 128, 8, 8)
z = torch.tanh(self.conv4(z))
z = torch.tanh(self.conv5(z))
z = torch.sigmoid(self.conv6(z))
return z
def forward(self, x):
return self.decode(self.encode(x))
# Cargar el modelo
model = Autoencoder()
model.load_state_dict(torch.load("autoencoder.pth", map_location=torch.device("cpu")))
model.eval()
# Transformaci贸n
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((64, 64)),
transforms.ToTensor()
])
# Umbral de error (ajustable)
THRESHOLD = 0.01
# Funci贸n de predicci贸n
def detectar_anomalia(imagen):
img_tensor = transform(imagen).unsqueeze(0)
with torch.no_grad():
reconstruida = model(img_tensor)
mse = torch.mean((img_tensor - reconstruida) ** 2).item()
resultado = "An贸mala" if mse > THRESHOLD else "Normal"
return resultado
# Interfaz Gradio
demo = gr.Interface(
fn=detectar_anomalia,
inputs=gr.Image(type="pil", label="Sube una imagen para analizar"),
outputs=gr.Label(label="Resultado"),
examples=["anomalous.png", "normal.png"],
title="Detecci贸n de Anomal铆as con Autoencoder (PyTorch)",
description="Este Space utiliza un autoencoder entrenado con PyTorch para detectar anomal铆as en im谩genes de textiles.",
)
demo.launch()