|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
|
|
|
|
class Autoencoder(nn.Module): |
|
def __init__(self): |
|
super(Autoencoder, self).__init__() |
|
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1) |
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) |
|
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) |
|
self.fc1 = nn.Linear(128 * 8 * 8, 32) |
|
self.fc2 = nn.Linear(32, 128 * 8 * 8) |
|
self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) |
|
self.conv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1) |
|
self.conv6 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1) |
|
|
|
def encode(self, x): |
|
z = torch.tanh(self.conv1(x)) |
|
z = torch.tanh(self.conv2(z)) |
|
z = torch.tanh(self.conv3(z)) |
|
z = z.view(z.size(0), -1) |
|
z = torch.tanh(self.fc1(z)) |
|
return z |
|
|
|
def decode(self, x): |
|
z = torch.tanh(self.fc2(x)) |
|
z = z.view(z.size(0), 128, 8, 8) |
|
z = torch.tanh(self.conv4(z)) |
|
z = torch.tanh(self.conv5(z)) |
|
z = torch.sigmoid(self.conv6(z)) |
|
return z |
|
|
|
def forward(self, x): |
|
return self.decode(self.encode(x)) |
|
|
|
|
|
model = Autoencoder() |
|
model.load_state_dict(torch.load("autoencoder.pth", map_location=torch.device("cpu"))) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Grayscale(), |
|
transforms.Resize((64, 64)), |
|
transforms.ToTensor() |
|
]) |
|
|
|
|
|
THRESHOLD = 0.01 |
|
|
|
|
|
def detectar_anomalia(imagen): |
|
img_tensor = transform(imagen).unsqueeze(0) |
|
with torch.no_grad(): |
|
reconstruida = model(img_tensor) |
|
|
|
mse = torch.mean((img_tensor - reconstruida) ** 2).item() |
|
resultado = "An贸mala" if mse > THRESHOLD else "Normal" |
|
return resultado |
|
|
|
|
|
demo = gr.Interface( |
|
fn=detectar_anomalia, |
|
inputs=gr.Image(type="pil", label="Sube una imagen para analizar"), |
|
outputs=gr.Label(label="Resultado"), |
|
examples=["anomalous.png", "normal.png"], |
|
title="Detecci贸n de Anomal铆as con Autoencoder (PyTorch)", |
|
description="Este Space utiliza un autoencoder entrenado con PyTorch para detectar anomal铆as en im谩genes de textiles.", |
|
) |
|
|
|
demo.launch() |
|
|
|
|