File size: 2,595 Bytes
40aaca9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from transformers import AutoImageProcessor, AutoModel
import torch
import numpy as np
import cv2


def align_images(images, segs):
    """
    Align images using SuperGlue for feature matching.

    Args:
        images: List of input images
        segs: List of segmentation images

    Returns:
        Tuple of (aligned images, aligned segmentation images)
    """
    if not images or len(images) < 2:
        return images, segs

    reference = images[0]
    reference_seg = segs[0]
    aligned_images = [reference]
    aligned_images_seg = [reference_seg]

    # Load SuperGlue model and processor
    processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
    model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")

    for i in range(1, len(images)):
        current = images[i]
        current_seg = segs[i]

        # Process image pair
        image_pair = [reference, current]
        inputs = processor(image_pair, return_tensors="pt")

        with torch.no_grad():
            outputs = model(**inputs)

        # Get matches
        image_sizes = [[(img.shape[0], img.shape[1]) for img in image_pair]]
        matches = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)

        # Extract matching keypoints
        match_data = matches[0]
        keypoints0 = match_data["keypoints0"].numpy()
        keypoints1 = match_data["keypoints1"].numpy()

        # Filter matches by confidence
        valid_matches = match_data["matching_scores"] > 0.5
        if sum(valid_matches) < 4:
            print(f"Not enough confident matches for image {i}, keeping original")
            aligned_images.append(current)
            aligned_images_seg.append(current_seg)
            continue

        # Get matching points
        src_pts = keypoints1[valid_matches].reshape(-1, 1, 2)
        dst_pts = keypoints0[valid_matches].reshape(-1, 1, 2)

        # Find homography
        H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)

        if H is not None:
            # Apply homography
            h, w = reference.shape[:2]
            aligned = cv2.warpPerspective(current, H, (w, h))
            aligned_images.append(aligned)
            aligned_seg = cv2.warpPerspective(current_seg, H, (w, h))
            aligned_images_seg.append(aligned_seg)
        else:
            print(f"Could not find homography for image {i}, keeping original")
            aligned_images.append(current)
            aligned_images_seg.append(current_seg)

    return aligned_images, aligned_images_seg