Spaces:
Running
Running
import numpy as np | |
import torch | |
import torchvision | |
from PIL import Image | |
from torch import nn | |
from torchvision import transforms as tr | |
from torchvision.models import vit_h_14 | |
import cv2 | |
class CosineSimilarity: | |
def __init__(self, vector='feature', threshold=0.8, mean_vec=[], device=None): | |
""" | |
Initialize the CosineSimilarity class. | |
Args: | |
vector (str): Type of vector to use ('feature' or 'image') | |
threshold (float): Threshold for determining outliers | |
mean_vec (numpy vector): Preloaded reference vector for comparison | |
device (str): Device to use for computation (default: 'mps' if available, else 'cuda' if available, else 'cpu') | |
""" | |
if device is None: | |
if torch.backends.mps.is_available(): | |
self.device = 'mps' | |
elif torch.cuda.is_available(): | |
self.device = 'cuda' | |
else: | |
self.device = 'cpu' | |
else: | |
self.device = device | |
self.vector = vector | |
self.threshold = threshold | |
self.model_instance = None | |
self.mean_vec = mean_vec | |
def model(self): | |
"""Initialize and return the ViT model.""" | |
if self.model_instance is None: | |
wt = torchvision.models.ViT_H_14_Weights.DEFAULT | |
self.model_instance = vit_h_14(weights=wt) | |
self.model_instance.heads = nn.Sequential(*list(self.model_instance.heads.children())[:-1]) | |
self.model_instance = self.model_instance.to(self.device) | |
return self.model_instance | |
def process_image(self, cv2_img): | |
""" | |
Process a cv2 image for the model. | |
Args: | |
cv2_img: OpenCV image (BGR format) | |
Returns: | |
Processed tensor | |
""" | |
# Convert BGR to RGB | |
rgb_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB) | |
# Convert to PIL Image | |
pil_img = Image.fromarray(rgb_img) | |
# A set of transformations to prepare the image in tensor format | |
transformations = tr.Compose([ | |
tr.ToTensor(), | |
tr.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
tr.Resize((518, 518)) | |
]) | |
# preparing the image | |
img_tensor = transformations(pil_img).float() | |
if self.vector == 'image': | |
img_tensor = img_tensor.flatten() | |
img_tensor = img_tensor.unsqueeze_(0) | |
if self.vector == 'feature': | |
img_tensor = img_tensor.to(self.device) | |
return img_tensor | |
def get_embeddings(self, ref_images, test_images): | |
""" | |
Get embeddings for reference and test images. | |
Args: | |
ref_images: List of cv2 reference images | |
test_images: List of cv2 test images | |
Returns: | |
Reference embedding, list of test embeddings | |
""" | |
model = self.model() | |
# Process test images | |
emb_test = [] | |
for img in test_images: | |
processed_img = self.process_image(img) | |
if self.vector == 'feature': | |
emb = model(processed_img).detach().cpu() | |
emb_test.append(emb) | |
else: # 'image' | |
emb_test.append(processed_img) | |
# This checks if a reference vector is loaded, if so the process of getting | |
# reference embeddings can be skipped for efficiency | |
if len(self.mean_vec) > 0: | |
emb_ref = torch.tensor(self.mean_vec) | |
# Process reference images if necessary | |
else: | |
if self.vector == 'feature': | |
# Standard method of getting reference embedding vector | |
emb_ref_list = [] | |
for img in ref_images: | |
processed_img = self.process_image(img) | |
emb = model(processed_img).detach().cpu() | |
emb_ref_list.append(emb) | |
# Average the reference embeddings | |
emb_ref = torch.mean(torch.stack(emb_ref_list), dim=0) | |
else: # 'image' | |
emb_ref_list = [] | |
for img in ref_images: | |
processed_img = self.process_image(img) | |
emb_ref_list.append(processed_img) | |
# Average the reference images | |
emb_ref = torch.mean(torch.stack(emb_ref_list), dim=0) | |
return emb_ref, emb_test | |
def find_outliers(self, ref_images, test_images): | |
""" | |
Find outliers in test images compared to reference images. | |
Args: | |
ref_images: List of cv2 reference images | |
test_images: List of cv2 test images | |
Returns: | |
mask: Boolean array where True indicates an outlier | |
scores: Similarity scores for each test image | |
""" | |
emb_ref, emb_test = self.get_embeddings(ref_images, test_images) | |
scores = [] | |
mask = [] | |
for i in range(len(emb_test)): | |
score = torch.nn.functional.cosine_similarity(emb_ref, emb_test[i]) | |
score_value = score.item() | |
scores.append(round(score_value, 4)) | |
# True if it's an outlier (below threshold) | |
mask.append(score_value <= self.threshold) | |
return np.array(mask), scores, emb_ref | |
def filter_outliers(self, ref_images, test_images): | |
""" | |
Filter out outliers from test images. | |
Args: | |
ref_images: List of cv2 reference images | |
test_images: List of cv2 test images | |
Returns: | |
filtered_images: List of non-outlier test images | |
outlier_mask: Boolean array where True indicates an outlier | |
scores: Similarity scores for each test image | |
""" | |
outlier_mask, scores, mean = self.find_outliers(ref_images, test_images) | |
# Filter out outliers (keep only non-outliers) | |
filtered_images = [img for i, img in enumerate(test_images) if not outlier_mask[i]] | |
return filtered_images, outlier_mask, scores, mean | |
def detect_outliers(ref_imgs, imgs, mean_vec=[]): | |
""" | |
Detects outliers in a set of test images, can use a reference vector | |
Args: | |
ref_images: List of cv2 reference images | |
images: List of cv2 test images | |
mean_vec: optional pre-computed reference vector | |
Returns: | |
filtered_images: List of non-outlier test images | |
mean: the reference vector used (if a new reference vector should be saved) | |
""" | |
similarity = CosineSimilarity(vector='feature', threshold=0.8, mean_vec=mean_vec) | |
# Get outlier mask, scores, and reference vector | |
outlier_mask, scores, mean_vector = similarity.find_outliers(ref_imgs, imgs) | |
# Filter out outliers | |
filtered_images = [img for i, img in enumerate(imgs) if not outlier_mask[i]] | |
return filtered_images, mean_vector |