File size: 6,858 Bytes
40aaca9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import numpy as np
import torch
import torchvision
from PIL import Image
from torch import nn
from torchvision import transforms as tr
from torchvision.models import vit_h_14
import cv2

class CosineSimilarity:
    def __init__(self, vector='feature', threshold=0.8, mean_vec=[], device=None):
        """
        Initialize the CosineSimilarity class.

        Args:
            vector (str): Type of vector to use ('feature' or 'image')
            threshold (float): Threshold for determining outliers
            mean_vec (numpy vector): Preloaded reference vector for comparison
            device (str): Device to use for computation (default: 'mps' if available, else 'cuda' if available, else 'cpu')
        """
        if device is None:
            if torch.backends.mps.is_available():
                self.device = 'mps'
            elif torch.cuda.is_available():
                self.device = 'cuda'
            else:
                self.device = 'cpu'
        else:
            self.device = device

        self.vector = vector
        self.threshold = threshold
        self.model_instance = None
        self.mean_vec = mean_vec

    def model(self):
        """Initialize and return the ViT model."""
        if self.model_instance is None:
            wt = torchvision.models.ViT_H_14_Weights.DEFAULT
            self.model_instance = vit_h_14(weights=wt)
            self.model_instance.heads = nn.Sequential(*list(self.model_instance.heads.children())[:-1])
            self.model_instance = self.model_instance.to(self.device)
        return self.model_instance

    def process_image(self, cv2_img):
        """
        Process a cv2 image for the model.

        Args:
            cv2_img: OpenCV image (BGR format)

        Returns:
            Processed tensor
        """
        # Convert BGR to RGB
        rgb_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
        # Convert to PIL Image
        pil_img = Image.fromarray(rgb_img)

        # A set of transformations to prepare the image in tensor format
        transformations = tr.Compose([
            tr.ToTensor(),
            tr.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            tr.Resize((518, 518))
        ])

        # preparing the image
        img_tensor = transformations(pil_img).float()

        if self.vector == 'image':
            img_tensor = img_tensor.flatten()

        img_tensor = img_tensor.unsqueeze_(0)

        if self.vector == 'feature':
            img_tensor = img_tensor.to(self.device)

        return img_tensor

    def get_embeddings(self, ref_images, test_images):
        """
        Get embeddings for reference and test images.

        Args:
            ref_images: List of cv2 reference images
            test_images: List of cv2 test images

        Returns:
            Reference embedding, list of test embeddings
        """
        model = self.model()

        # Process test images
        emb_test = []
        for img in test_images:
            processed_img = self.process_image(img)
            if self.vector == 'feature':
                emb = model(processed_img).detach().cpu()
                emb_test.append(emb)
            else:  # 'image'
                emb_test.append(processed_img)

        # This checks if a reference vector is loaded, if so the process of getting
        # reference embeddings can be skipped for efficiency
        if len(self.mean_vec) > 0:
            emb_ref = torch.tensor(self.mean_vec)

        # Process reference images if necessary
        else:
            if self.vector == 'feature':
                # Standard method of getting reference embedding vector
                emb_ref_list = []
                for img in ref_images:
                    processed_img = self.process_image(img)
                    emb = model(processed_img).detach().cpu()
                    emb_ref_list.append(emb)

                # Average the reference embeddings
                emb_ref = torch.mean(torch.stack(emb_ref_list), dim=0)
            
            else:  # 'image'
                emb_ref_list = []
                for img in ref_images:
                    processed_img = self.process_image(img)
                    emb_ref_list.append(processed_img)

                # Average the reference images
                emb_ref = torch.mean(torch.stack(emb_ref_list), dim=0)

        return emb_ref, emb_test

    def find_outliers(self, ref_images, test_images):
        """
        Find outliers in test images compared to reference images.

        Args:
            ref_images: List of cv2 reference images
            test_images: List of cv2 test images

        Returns:
            mask: Boolean array where True indicates an outlier
            scores: Similarity scores for each test image
        """
        emb_ref, emb_test = self.get_embeddings(ref_images, test_images)

        scores = []
        mask = []

        for i in range(len(emb_test)):
            score = torch.nn.functional.cosine_similarity(emb_ref, emb_test[i])
            score_value = score.item()
            scores.append(round(score_value, 4))
            # True if it's an outlier (below threshold)
            mask.append(score_value <= self.threshold)

        return np.array(mask), scores, emb_ref

    def filter_outliers(self, ref_images, test_images):
        """
        Filter out outliers from test images.

        Args:
            ref_images: List of cv2 reference images
            test_images: List of cv2 test images

        Returns:
            filtered_images: List of non-outlier test images
            outlier_mask: Boolean array where True indicates an outlier
            scores: Similarity scores for each test image
        """
        outlier_mask, scores, mean = self.find_outliers(ref_images, test_images)

        # Filter out outliers (keep only non-outliers)
        filtered_images = [img for i, img in enumerate(test_images) if not outlier_mask[i]]

        return filtered_images, outlier_mask, scores, mean

def detect_outliers(ref_imgs, imgs, mean_vec=[]):
    """
    Detects outliers in a set of test images, can use a reference vector

    Args:
        ref_images: List of cv2 reference images
        images: List of cv2 test images
        mean_vec: optional pre-computed reference vector

    Returns:
        filtered_images: List of non-outlier test images
        mean: the reference vector used (if a new reference vector should be saved)
    """
    
    similarity = CosineSimilarity(vector='feature', threshold=0.8, mean_vec=mean_vec)

    # Get outlier mask, scores, and reference vector
    outlier_mask, scores, mean_vector = similarity.find_outliers(ref_imgs, imgs)

    # Filter out outliers
    filtered_images = [img for i, img in enumerate(imgs) if not outlier_mask[i]]

    return filtered_images, mean_vector