import numpy as np import torch import torchvision from PIL import Image from torch import nn from torchvision import transforms as tr from torchvision.models import vit_h_14 import cv2 class CosineSimilarity: def __init__(self, vector='feature', threshold=0.8, mean_vec=[], device=None): """ Initialize the CosineSimilarity class. Args: vector (str): Type of vector to use ('feature' or 'image') threshold (float): Threshold for determining outliers mean_vec (numpy vector): Preloaded reference vector for comparison device (str): Device to use for computation (default: 'mps' if available, else 'cuda' if available, else 'cpu') """ if device is None: if torch.backends.mps.is_available(): self.device = 'mps' elif torch.cuda.is_available(): self.device = 'cuda' else: self.device = 'cpu' else: self.device = device self.vector = vector self.threshold = threshold self.model_instance = None self.mean_vec = mean_vec def model(self): """Initialize and return the ViT model.""" if self.model_instance is None: wt = torchvision.models.ViT_H_14_Weights.DEFAULT self.model_instance = vit_h_14(weights=wt) self.model_instance.heads = nn.Sequential(*list(self.model_instance.heads.children())[:-1]) self.model_instance = self.model_instance.to(self.device) return self.model_instance def process_image(self, cv2_img): """ Process a cv2 image for the model. Args: cv2_img: OpenCV image (BGR format) Returns: Processed tensor """ # Convert BGR to RGB rgb_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB) # Convert to PIL Image pil_img = Image.fromarray(rgb_img) # A set of transformations to prepare the image in tensor format transformations = tr.Compose([ tr.ToTensor(), tr.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), tr.Resize((518, 518)) ]) # preparing the image img_tensor = transformations(pil_img).float() if self.vector == 'image': img_tensor = img_tensor.flatten() img_tensor = img_tensor.unsqueeze_(0) if self.vector == 'feature': img_tensor = img_tensor.to(self.device) return img_tensor def get_embeddings(self, ref_images, test_images): """ Get embeddings for reference and test images. Args: ref_images: List of cv2 reference images test_images: List of cv2 test images Returns: Reference embedding, list of test embeddings """ model = self.model() # Process test images emb_test = [] for img in test_images: processed_img = self.process_image(img) if self.vector == 'feature': emb = model(processed_img).detach().cpu() emb_test.append(emb) else: # 'image' emb_test.append(processed_img) # This checks if a reference vector is loaded, if so the process of getting # reference embeddings can be skipped for efficiency if len(self.mean_vec) > 0: emb_ref = torch.tensor(self.mean_vec) # Process reference images if necessary else: if self.vector == 'feature': # Standard method of getting reference embedding vector emb_ref_list = [] for img in ref_images: processed_img = self.process_image(img) emb = model(processed_img).detach().cpu() emb_ref_list.append(emb) # Average the reference embeddings emb_ref = torch.mean(torch.stack(emb_ref_list), dim=0) else: # 'image' emb_ref_list = [] for img in ref_images: processed_img = self.process_image(img) emb_ref_list.append(processed_img) # Average the reference images emb_ref = torch.mean(torch.stack(emb_ref_list), dim=0) return emb_ref, emb_test def find_outliers(self, ref_images, test_images): """ Find outliers in test images compared to reference images. Args: ref_images: List of cv2 reference images test_images: List of cv2 test images Returns: mask: Boolean array where True indicates an outlier scores: Similarity scores for each test image """ emb_ref, emb_test = self.get_embeddings(ref_images, test_images) scores = [] mask = [] for i in range(len(emb_test)): score = torch.nn.functional.cosine_similarity(emb_ref, emb_test[i]) score_value = score.item() scores.append(round(score_value, 4)) # True if it's an outlier (below threshold) mask.append(score_value <= self.threshold) return np.array(mask), scores, emb_ref def filter_outliers(self, ref_images, test_images): """ Filter out outliers from test images. Args: ref_images: List of cv2 reference images test_images: List of cv2 test images Returns: filtered_images: List of non-outlier test images outlier_mask: Boolean array where True indicates an outlier scores: Similarity scores for each test image """ outlier_mask, scores, mean = self.find_outliers(ref_images, test_images) # Filter out outliers (keep only non-outliers) filtered_images = [img for i, img in enumerate(test_images) if not outlier_mask[i]] return filtered_images, outlier_mask, scores, mean def detect_outliers(ref_imgs, imgs, mean_vec=[]): """ Detects outliers in a set of test images, can use a reference vector Args: ref_images: List of cv2 reference images images: List of cv2 test images mean_vec: optional pre-computed reference vector Returns: filtered_images: List of non-outlier test images mean: the reference vector used (if a new reference vector should be saved) """ similarity = CosineSimilarity(vector='feature', threshold=0.8, mean_vec=mean_vec) # Get outlier mask, scores, and reference vector outlier_mask, scores, mean_vector = similarity.find_outliers(ref_imgs, imgs) # Filter out outliers filtered_images = [img for i, img in enumerate(imgs) if not outlier_mask[i]] return filtered_images, mean_vector