File size: 5,688 Bytes
8826642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import torch
import dill
import pickle
import numpy as np
from PIL import Image
import cv2
from torchvision import transforms

# Define the image transformation
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])

def decision_function(segm_map):
    """

    Calculate anomaly score from segmentation map using mean of top 10 values.

    """
    mean_top_10_values = []
    for map in segm_map:
        flattened_tensor = map.reshape(-1)
        sorted_tensor, _ = torch.sort(flattened_tensor, descending=True)
        mean_top_10_value = sorted_tensor[:10].mean()
        mean_top_10_values.append(mean_top_10_value)

    return torch.stack(mean_top_10_values)

def run_inference_autoencoder(image_path, model, backbone, threshold):
    """

    Run inference on a single image using Autoencoder.

    """
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    test_image = transform(image).cuda().unsqueeze(0)

    # Perform inference
    with torch.no_grad():
        features = backbone(test_image)
        recon = model(features)

        # Compute segmentation map and anomaly score
        segm_map = ((features - recon) ** 2).mean(axis=(1))
        y_score = decision_function(segm_map=segm_map)
        is_anomaly = (y_score >= threshold).cpu().numpy().item()

        # Create heatmap
        heat_map = cv2.resize(segm_map.squeeze().cpu().numpy(), (224, 224))

    return {
        'original_image': test_image.squeeze().permute(1, 2, 0).cpu().numpy(),
        'heat_map': heat_map,
        'anomaly_score': y_score.item(),
        'threshold': threshold,
        'is_anomaly': is_anomaly,
        'classification': 'NOK' if is_anomaly else 'OK'
    }

def run_inference_knn(image_path, model, memory_bank, threshold):
    """

    Run inference on a single image using KNN.

    """
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    test_image = transform(image).cuda().unsqueeze(0)

    # Move memory bank to GPU **only once**
    memory_bank = memory_bank.cuda()

    # Extract features using the backbone
    with torch.no_grad():
        features = model(test_image)

    # Compute distances (optimized)
    distances = torch.cdist(features, memory_bank, p=2.0)  # Batched distance calculation
    dist_score, _ = torch.min(distances, dim=1)  # Get the nearest neighbor
    y_score = torch.max(dist_score)  # Get the anomaly score

    is_anomaly = (y_score >= threshold).cpu().item()

    # Compute segmentation map
    segm_map = dist_score.view(1, 1, 28, 28)
    segm_map = torch.nn.functional.interpolate(segm_map, size=(224, 224), mode='bilinear').cpu().squeeze().numpy()
    
    # Convert segmentation map to heatmap
    heat_map = cv2.resize(segm_map, (224, 224))

    return {
        'original_image': test_image.squeeze().permute(1, 2, 0).cpu().numpy(),
        'heat_map': heat_map,
        'anomaly_score': y_score.item(),
        'threshold': threshold,
        'is_anomaly': is_anomaly,
        'classification': 'NOK' if is_anomaly else 'OK',
    }


def load_model_autoencoder(checkpoint_dir='models'):
    """

    Load the saved models and evaluation metrics for Autoencoder.

    """
    models_path = os.path.join(checkpoint_dir, 'models_autoencoder.dill')
    backbone_path = os.path.join(checkpoint_dir, 'backbone_autoencoder.dill')
    metrics_path = os.path.join(checkpoint_dir, 'evaluation_metrics_autoencoder.pkl')

    # Ensure files exist before loading
    if not os.path.exists(models_path):
        raise FileNotFoundError(f"Autoencoder model file not found: {models_path}")
    if not os.path.exists(backbone_path):
        raise FileNotFoundError(f"Autoencoder backbone file not found: {backbone_path}")
    if not os.path.exists(metrics_path):
        raise FileNotFoundError(f"Autoencoder metrics file not found: {metrics_path}")

    try:
        with open(models_path, 'rb') as f:
            models = dill.load(f)

        with open(backbone_path, 'rb') as f:
            backbone = dill.load(f)

        with open(metrics_path, 'rb') as f:
            evaluation_metrics = pickle.load(f)

        return models, backbone, evaluation_metrics

    except Exception as e:
        raise RuntimeError(f"Error loading Autoencoder models: {e}")

def load_model_knn(checkpoint_dir='models'):
    """

    Load the saved models and evaluation metrics for KNN.

    """
    memory_bank_path = os.path.join(checkpoint_dir, 'memory_bank_selected.pkl')
    backbone_path = os.path.join(checkpoint_dir, 'backbone_knn.dill')  # Fixed incorrect backbone file
    metrics_path = os.path.join(checkpoint_dir, 'evaluation_metrics_knn.pkl')

    # Ensure files exist before loading
    if not os.path.exists(memory_bank_path):
        raise FileNotFoundError(f"KNN memory bank file not found: {memory_bank_path}")
    if not os.path.exists(backbone_path):
        raise FileNotFoundError(f"KNN backbone file not found: {backbone_path}")
    if not os.path.exists(metrics_path):
        raise FileNotFoundError(f"KNN metrics file not found: {metrics_path}")

    try:
        with open(memory_bank_path, 'rb') as f:
            memory_bank = pickle.load(f)

        with open(backbone_path, 'rb') as f:
            backbone = dill.load(f)

        with open(metrics_path, 'rb') as f:
            evaluation_metrics = pickle.load(f)

        return backbone, memory_bank, evaluation_metrics

    except Exception as e:
        raise RuntimeError(f"Error loading KNN models: {e}")