MVTec_Website / prediction.py
pepperumo's picture
all
8826642 verified
import os
import torch
import dill
import pickle
import numpy as np
from PIL import Image
import cv2
from torchvision import transforms
# Define the image transformation
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
])
def decision_function(segm_map):
"""
Calculate anomaly score from segmentation map using mean of top 10 values.
"""
mean_top_10_values = []
for map in segm_map:
flattened_tensor = map.reshape(-1)
sorted_tensor, _ = torch.sort(flattened_tensor, descending=True)
mean_top_10_value = sorted_tensor[:10].mean()
mean_top_10_values.append(mean_top_10_value)
return torch.stack(mean_top_10_values)
def run_inference_autoencoder(image_path, model, backbone, threshold):
"""
Run inference on a single image using Autoencoder.
"""
# Load and preprocess the image
image = Image.open(image_path).convert('RGB')
test_image = transform(image).cuda().unsqueeze(0)
# Perform inference
with torch.no_grad():
features = backbone(test_image)
recon = model(features)
# Compute segmentation map and anomaly score
segm_map = ((features - recon) ** 2).mean(axis=(1))
y_score = decision_function(segm_map=segm_map)
is_anomaly = (y_score >= threshold).cpu().numpy().item()
# Create heatmap
heat_map = cv2.resize(segm_map.squeeze().cpu().numpy(), (224, 224))
return {
'original_image': test_image.squeeze().permute(1, 2, 0).cpu().numpy(),
'heat_map': heat_map,
'anomaly_score': y_score.item(),
'threshold': threshold,
'is_anomaly': is_anomaly,
'classification': 'NOK' if is_anomaly else 'OK'
}
def run_inference_knn(image_path, model, memory_bank, threshold):
"""
Run inference on a single image using KNN.
"""
# Load and preprocess the image
image = Image.open(image_path).convert('RGB')
test_image = transform(image).cuda().unsqueeze(0)
# Move memory bank to GPU **only once**
memory_bank = memory_bank.cuda()
# Extract features using the backbone
with torch.no_grad():
features = model(test_image)
# Compute distances (optimized)
distances = torch.cdist(features, memory_bank, p=2.0) # Batched distance calculation
dist_score, _ = torch.min(distances, dim=1) # Get the nearest neighbor
y_score = torch.max(dist_score) # Get the anomaly score
is_anomaly = (y_score >= threshold).cpu().item()
# Compute segmentation map
segm_map = dist_score.view(1, 1, 28, 28)
segm_map = torch.nn.functional.interpolate(segm_map, size=(224, 224), mode='bilinear').cpu().squeeze().numpy()
# Convert segmentation map to heatmap
heat_map = cv2.resize(segm_map, (224, 224))
return {
'original_image': test_image.squeeze().permute(1, 2, 0).cpu().numpy(),
'heat_map': heat_map,
'anomaly_score': y_score.item(),
'threshold': threshold,
'is_anomaly': is_anomaly,
'classification': 'NOK' if is_anomaly else 'OK',
}
def load_model_autoencoder(checkpoint_dir='models'):
"""
Load the saved models and evaluation metrics for Autoencoder.
"""
models_path = os.path.join(checkpoint_dir, 'models_autoencoder.dill')
backbone_path = os.path.join(checkpoint_dir, 'backbone_autoencoder.dill')
metrics_path = os.path.join(checkpoint_dir, 'evaluation_metrics_autoencoder.pkl')
# Ensure files exist before loading
if not os.path.exists(models_path):
raise FileNotFoundError(f"Autoencoder model file not found: {models_path}")
if not os.path.exists(backbone_path):
raise FileNotFoundError(f"Autoencoder backbone file not found: {backbone_path}")
if not os.path.exists(metrics_path):
raise FileNotFoundError(f"Autoencoder metrics file not found: {metrics_path}")
try:
with open(models_path, 'rb') as f:
models = dill.load(f)
with open(backbone_path, 'rb') as f:
backbone = dill.load(f)
with open(metrics_path, 'rb') as f:
evaluation_metrics = pickle.load(f)
return models, backbone, evaluation_metrics
except Exception as e:
raise RuntimeError(f"Error loading Autoencoder models: {e}")
def load_model_knn(checkpoint_dir='models'):
"""
Load the saved models and evaluation metrics for KNN.
"""
memory_bank_path = os.path.join(checkpoint_dir, 'memory_bank_selected.pkl')
backbone_path = os.path.join(checkpoint_dir, 'backbone_knn.dill') # Fixed incorrect backbone file
metrics_path = os.path.join(checkpoint_dir, 'evaluation_metrics_knn.pkl')
# Ensure files exist before loading
if not os.path.exists(memory_bank_path):
raise FileNotFoundError(f"KNN memory bank file not found: {memory_bank_path}")
if not os.path.exists(backbone_path):
raise FileNotFoundError(f"KNN backbone file not found: {backbone_path}")
if not os.path.exists(metrics_path):
raise FileNotFoundError(f"KNN metrics file not found: {metrics_path}")
try:
with open(memory_bank_path, 'rb') as f:
memory_bank = pickle.load(f)
with open(backbone_path, 'rb') as f:
backbone = dill.load(f)
with open(metrics_path, 'rb') as f:
evaluation_metrics = pickle.load(f)
return backbone, memory_bank, evaluation_metrics
except Exception as e:
raise RuntimeError(f"Error loading KNN models: {e}")