CostalSegment / pipeline /normalization.py
AveMujica's picture
init
40aaca9
raw
history blame contribute delete
2.6 kB
from transformers import AutoImageProcessor, AutoModel
import torch
import numpy as np
import cv2
def align_images(images, segs):
"""
Align images using SuperGlue for feature matching.
Args:
images: List of input images
segs: List of segmentation images
Returns:
Tuple of (aligned images, aligned segmentation images)
"""
if not images or len(images) < 2:
return images, segs
reference = images[0]
reference_seg = segs[0]
aligned_images = [reference]
aligned_images_seg = [reference_seg]
# Load SuperGlue model and processor
processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")
for i in range(1, len(images)):
current = images[i]
current_seg = segs[i]
# Process image pair
image_pair = [reference, current]
inputs = processor(image_pair, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Get matches
image_sizes = [[(img.shape[0], img.shape[1]) for img in image_pair]]
matches = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
# Extract matching keypoints
match_data = matches[0]
keypoints0 = match_data["keypoints0"].numpy()
keypoints1 = match_data["keypoints1"].numpy()
# Filter matches by confidence
valid_matches = match_data["matching_scores"] > 0.5
if sum(valid_matches) < 4:
print(f"Not enough confident matches for image {i}, keeping original")
aligned_images.append(current)
aligned_images_seg.append(current_seg)
continue
# Get matching points
src_pts = keypoints1[valid_matches].reshape(-1, 1, 2)
dst_pts = keypoints0[valid_matches].reshape(-1, 1, 2)
# Find homography
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
if H is not None:
# Apply homography
h, w = reference.shape[:2]
aligned = cv2.warpPerspective(current, H, (w, h))
aligned_images.append(aligned)
aligned_seg = cv2.warpPerspective(current_seg, H, (w, h))
aligned_images_seg.append(aligned_seg)
else:
print(f"Could not find homography for image {i}, keeping original")
aligned_images.append(current)
aligned_images_seg.append(current_seg)
return aligned_images, aligned_images_seg