JeanCGuerrero's picture
Create app.py
4d5be2f verified
raw
history blame
1.67 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
# Modelo autoencoder
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
self.fc1 = nn.Linear(64 * 16 * 16, 16)
self.fc2 = nn.Linear(16, 64 * 16 * 16)
self.conv3 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
self.conv4 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1)
def encode(self, x):
z = torch.tanh(self.conv1(x))
z = torch.tanh(self.conv2(z))
z = z.view(z.size(0), -1)
z = torch.tanh(self.fc1(z))
return z
def decode(self, x):
z = torch.tanh(self.fc2(x))
z = z.view(z.size(0), 64, 16, 16)
z = torch.tanh(self.conv3(z))
z = torch.sigmoid(self.conv4(z))
return z
def forward(self, x):
return self.decode(self.encode(x))
# Cargar el modelo
model = Autoencoder()
model.load_state_dict(torch.load("autoencoder.pth", map_location=torch.device("cpu")))
model.eval()
# Transformaci贸n de entrada
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((64, 64)),
transforms.ToTensor()
])
# Funci贸n de inferencia
def detectar_anomalia(imagen):
with torch.no_grad():
img_tensor = transform(imagen).unsqueeze(0) # A帽adir batch
reconstruida = model(img_tensor).squeeze(0).squeeze(_