JeanCGuerrero commited on
Commit
824f72f
verified
1 Parent(s): 8b8a289

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -22
app.py CHANGED
@@ -5,20 +5,18 @@ import numpy as np
5
  from PIL import Image
6
  import torchvision.transforms as transforms
7
 
8
- # Modelo autoencoder
9
  class Autoencoder(nn.Module):
10
  def __init__(self):
11
  super(Autoencoder, self).__init__()
12
- # Encoder
13
- self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1) # 64x64 -> 32x32
14
- self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) # 32x32 -> 16x16
15
- self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) # 16x16 -> 8x8
16
- self.fc1 = nn.Linear(128 * 8 * 8, 32) # Espacio latente
17
- # Decoder
18
  self.fc2 = nn.Linear(32, 128 * 8 * 8)
19
- self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) # 8x8 -> 16x16
20
- self.conv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1) # 16x16 -> 32x32
21
- self.conv6 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1) # 32x32 -> 64x64
22
 
23
  def encode(self, x):
24
  z = torch.tanh(self.conv1(x))
@@ -44,29 +42,35 @@ model = Autoencoder()
44
  model.load_state_dict(torch.load("autoencoder.pth", map_location=torch.device("cpu")))
45
  model.eval()
46
 
47
- # Transformaci贸n de entrada
48
  transform = transforms.Compose([
49
  transforms.Grayscale(),
50
  transforms.Resize((64, 64)),
51
  transforms.ToTensor()
52
  ])
53
 
54
- # Funci贸n de inferencia
 
 
 
55
  def detectar_anomalia(imagen):
 
56
  with torch.no_grad():
57
- img_tensor = transform(imagen).unsqueeze(0) # A帽adir batch
58
- reconstruida = model(img_tensor).squeeze(0).squeeze(0)
59
- return reconstruida.numpy() # Convertir a numpy para visualizaci贸n
60
 
 
 
 
61
 
62
- # Interfaz de Gradio
63
- interface = gr.Interface(
64
  fn=detectar_anomalia,
65
- inputs=gr.Image(type="pil"),
66
- outputs=[gr.Image(type="numpy"), gr.Text()],
67
- title="Detecci贸n de Anomal铆as con Autoencoder",
68
- description="Sube una imagen para detectar anomal铆as usando un autoencoder entrenado."
 
69
  )
70
 
71
- interface.launch()
72
 
 
5
  from PIL import Image
6
  import torchvision.transforms as transforms
7
 
8
+ # Modelo Autoencoder
9
  class Autoencoder(nn.Module):
10
  def __init__(self):
11
  super(Autoencoder, self).__init__()
12
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
13
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
14
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
15
+ self.fc1 = nn.Linear(128 * 8 * 8, 32)
 
 
16
  self.fc2 = nn.Linear(32, 128 * 8 * 8)
17
+ self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
18
+ self.conv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
19
+ self.conv6 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1)
20
 
21
  def encode(self, x):
22
  z = torch.tanh(self.conv1(x))
 
42
  model.load_state_dict(torch.load("autoencoder.pth", map_location=torch.device("cpu")))
43
  model.eval()
44
 
45
+ # Transformaci贸n
46
  transform = transforms.Compose([
47
  transforms.Grayscale(),
48
  transforms.Resize((64, 64)),
49
  transforms.ToTensor()
50
  ])
51
 
52
+ # Umbral de error (ajustable)
53
+ THRESHOLD = 0.01
54
+
55
+ # Funci贸n de predicci贸n
56
  def detectar_anomalia(imagen):
57
+ img_tensor = transform(imagen).unsqueeze(0)
58
  with torch.no_grad():
59
+ reconstruida = model(img_tensor)
 
 
60
 
61
+ mse = torch.mean((img_tensor - reconstruida) ** 2).item()
62
+ resultado = "An贸mala" if mse > THRESHOLD else "Normal"
63
+ return resultado
64
 
65
+ # Interfaz Gradio
66
+ demo = gr.Interface(
67
  fn=detectar_anomalia,
68
+ inputs=gr.Image(type="pil", label="Sube una imagen para analizar"),
69
+ outputs=gr.Label(label="Resultado"),
70
+ examples=["anomalous.png", "normal.png"],
71
+ title="馃搶 Detecci贸n de Anomal铆as con Autoencoder (PyTorch)",
72
+ description="Este Space utiliza un autoencoder entrenado con PyTorch para detectar anomal铆as en im谩genes de textiles industriales.",
73
  )
74
 
75
+ demo.launch()
76