JeanCGuerrero's picture
Update app.py
f2c9943 verified
raw
history blame
2.49 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
# Modelo autoencoder
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
# Encoder
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1) # 64x64 -> 32x32
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) # 32x32 -> 16x16
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) # 16x16 -> 8x8
self.fc1 = nn.Linear(128 * 8 * 8, 32) # Espacio latente
# Decoder
self.fc2 = nn.Linear(32, 128 * 8 * 8)
self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) # 8x8 -> 16x16
self.conv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1) # 16x16 -> 32x32
self.conv6 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1) # 32x32 -> 64x64
def encode(self, x):
z = torch.tanh(self.conv1(x))
z = torch.tanh(self.conv2(z))
z = torch.tanh(self.conv3(z))
z = z.view(z.size(0), -1)
z = torch.tanh(self.fc1(z))
return z
def decode(self, x):
z = torch.tanh(self.fc2(x))
z = z.view(z.size(0), 128, 8, 8)
z = torch.tanh(self.conv4(z))
z = torch.tanh(self.conv5(z))
z = torch.sigmoid(self.conv6(z))
return z
def forward(self, x):
return self.decode(self.encode(x))
# Cargar el modelo
model = Autoencoder()
model.load_state_dict(torch.load("autoencoder.pth", map_location=torch.device("cpu")))
model.eval()
# Transformaci贸n de entrada
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((64, 64)),
transforms.ToTensor()
])
# Funci贸n de inferencia
def detectar_anomalia(imagen):
with torch.no_grad():
img_tensor = transform(imagen).unsqueeze(0) # A帽adir batch
reconstruida = model(img_tensor).squeeze(0).squeeze(0)
return reconstruida.numpy() # Convertir a numpy para visualizaci贸n
# Interfaz de Gradio
interface = gr.Interface(
fn=detectar_anomalia,
inputs=gr.Image(type="pil"),
outputs=[gr.Image(type="numpy"), gr.Text()],
title="Detecci贸n de Anomal铆as con Autoencoder",
description="Sube una imagen para detectar anomal铆as usando un autoencoder entrenado."
)
interface.launch()