CostalSegment / pipeline /ImgOutlier.py
AveMujica's picture
init
40aaca9
raw
history blame contribute delete
6.86 kB
import numpy as np
import torch
import torchvision
from PIL import Image
from torch import nn
from torchvision import transforms as tr
from torchvision.models import vit_h_14
import cv2
class CosineSimilarity:
def __init__(self, vector='feature', threshold=0.8, mean_vec=[], device=None):
"""
Initialize the CosineSimilarity class.
Args:
vector (str): Type of vector to use ('feature' or 'image')
threshold (float): Threshold for determining outliers
mean_vec (numpy vector): Preloaded reference vector for comparison
device (str): Device to use for computation (default: 'mps' if available, else 'cuda' if available, else 'cpu')
"""
if device is None:
if torch.backends.mps.is_available():
self.device = 'mps'
elif torch.cuda.is_available():
self.device = 'cuda'
else:
self.device = 'cpu'
else:
self.device = device
self.vector = vector
self.threshold = threshold
self.model_instance = None
self.mean_vec = mean_vec
def model(self):
"""Initialize and return the ViT model."""
if self.model_instance is None:
wt = torchvision.models.ViT_H_14_Weights.DEFAULT
self.model_instance = vit_h_14(weights=wt)
self.model_instance.heads = nn.Sequential(*list(self.model_instance.heads.children())[:-1])
self.model_instance = self.model_instance.to(self.device)
return self.model_instance
def process_image(self, cv2_img):
"""
Process a cv2 image for the model.
Args:
cv2_img: OpenCV image (BGR format)
Returns:
Processed tensor
"""
# Convert BGR to RGB
rgb_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
# Convert to PIL Image
pil_img = Image.fromarray(rgb_img)
# A set of transformations to prepare the image in tensor format
transformations = tr.Compose([
tr.ToTensor(),
tr.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
tr.Resize((518, 518))
])
# preparing the image
img_tensor = transformations(pil_img).float()
if self.vector == 'image':
img_tensor = img_tensor.flatten()
img_tensor = img_tensor.unsqueeze_(0)
if self.vector == 'feature':
img_tensor = img_tensor.to(self.device)
return img_tensor
def get_embeddings(self, ref_images, test_images):
"""
Get embeddings for reference and test images.
Args:
ref_images: List of cv2 reference images
test_images: List of cv2 test images
Returns:
Reference embedding, list of test embeddings
"""
model = self.model()
# Process test images
emb_test = []
for img in test_images:
processed_img = self.process_image(img)
if self.vector == 'feature':
emb = model(processed_img).detach().cpu()
emb_test.append(emb)
else: # 'image'
emb_test.append(processed_img)
# This checks if a reference vector is loaded, if so the process of getting
# reference embeddings can be skipped for efficiency
if len(self.mean_vec) > 0:
emb_ref = torch.tensor(self.mean_vec)
# Process reference images if necessary
else:
if self.vector == 'feature':
# Standard method of getting reference embedding vector
emb_ref_list = []
for img in ref_images:
processed_img = self.process_image(img)
emb = model(processed_img).detach().cpu()
emb_ref_list.append(emb)
# Average the reference embeddings
emb_ref = torch.mean(torch.stack(emb_ref_list), dim=0)
else: # 'image'
emb_ref_list = []
for img in ref_images:
processed_img = self.process_image(img)
emb_ref_list.append(processed_img)
# Average the reference images
emb_ref = torch.mean(torch.stack(emb_ref_list), dim=0)
return emb_ref, emb_test
def find_outliers(self, ref_images, test_images):
"""
Find outliers in test images compared to reference images.
Args:
ref_images: List of cv2 reference images
test_images: List of cv2 test images
Returns:
mask: Boolean array where True indicates an outlier
scores: Similarity scores for each test image
"""
emb_ref, emb_test = self.get_embeddings(ref_images, test_images)
scores = []
mask = []
for i in range(len(emb_test)):
score = torch.nn.functional.cosine_similarity(emb_ref, emb_test[i])
score_value = score.item()
scores.append(round(score_value, 4))
# True if it's an outlier (below threshold)
mask.append(score_value <= self.threshold)
return np.array(mask), scores, emb_ref
def filter_outliers(self, ref_images, test_images):
"""
Filter out outliers from test images.
Args:
ref_images: List of cv2 reference images
test_images: List of cv2 test images
Returns:
filtered_images: List of non-outlier test images
outlier_mask: Boolean array where True indicates an outlier
scores: Similarity scores for each test image
"""
outlier_mask, scores, mean = self.find_outliers(ref_images, test_images)
# Filter out outliers (keep only non-outliers)
filtered_images = [img for i, img in enumerate(test_images) if not outlier_mask[i]]
return filtered_images, outlier_mask, scores, mean
def detect_outliers(ref_imgs, imgs, mean_vec=[]):
"""
Detects outliers in a set of test images, can use a reference vector
Args:
ref_images: List of cv2 reference images
images: List of cv2 test images
mean_vec: optional pre-computed reference vector
Returns:
filtered_images: List of non-outlier test images
mean: the reference vector used (if a new reference vector should be saved)
"""
similarity = CosineSimilarity(vector='feature', threshold=0.8, mean_vec=mean_vec)
# Get outlier mask, scores, and reference vector
outlier_mask, scores, mean_vector = similarity.find_outliers(ref_imgs, imgs)
# Filter out outliers
filtered_images = [img for i, img in enumerate(imgs) if not outlier_mask[i]]
return filtered_images, mean_vector