Spaces:
Running
Running
from transformers import AutoImageProcessor, AutoModel | |
import torch | |
import numpy as np | |
import cv2 | |
def align_images(images, segs): | |
""" | |
Align images using SuperGlue for feature matching. | |
Args: | |
images: List of input images | |
segs: List of segmentation images | |
Returns: | |
Tuple of (aligned images, aligned segmentation images) | |
""" | |
if not images or len(images) < 2: | |
return images, segs | |
reference = images[0] | |
reference_seg = segs[0] | |
aligned_images = [reference] | |
aligned_images_seg = [reference_seg] | |
# Load SuperGlue model and processor | |
processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor") | |
model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor") | |
for i in range(1, len(images)): | |
current = images[i] | |
current_seg = segs[i] | |
# Process image pair | |
image_pair = [reference, current] | |
inputs = processor(image_pair, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Get matches | |
image_sizes = [[(img.shape[0], img.shape[1]) for img in image_pair]] | |
matches = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2) | |
# Extract matching keypoints | |
match_data = matches[0] | |
keypoints0 = match_data["keypoints0"].numpy() | |
keypoints1 = match_data["keypoints1"].numpy() | |
# Filter matches by confidence | |
valid_matches = match_data["matching_scores"] > 0.5 | |
if sum(valid_matches) < 4: | |
print(f"Not enough confident matches for image {i}, keeping original") | |
aligned_images.append(current) | |
aligned_images_seg.append(current_seg) | |
continue | |
# Get matching points | |
src_pts = keypoints1[valid_matches].reshape(-1, 1, 2) | |
dst_pts = keypoints0[valid_matches].reshape(-1, 1, 2) | |
# Find homography | |
H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0) | |
if H is not None: | |
# Apply homography | |
h, w = reference.shape[:2] | |
aligned = cv2.warpPerspective(current, H, (w, h)) | |
aligned_images.append(aligned) | |
aligned_seg = cv2.warpPerspective(current_seg, H, (w, h)) | |
aligned_images_seg.append(aligned_seg) | |
else: | |
print(f"Could not find homography for image {i}, keeping original") | |
aligned_images.append(current) | |
aligned_images_seg.append(current_seg) | |
return aligned_images, aligned_images_seg | |