File size: 3,658 Bytes
01f1f5c
97cd144
 
 
 
 
 
 
 
01f1f5c
 
97cd144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01f1f5c
97cd144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01f1f5c
 
 
 
 
 
 
 
97cd144
 
 
 
 
 
 
 
 
 
 
 
 
01f1f5c
97cd144
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import logging
import math

import cv2
import gradio as gr
import numpy as np
import onnxruntime as ort
from PIL import Image, ImageOps

logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")

MODEL_PATH = "model.onnx"
IMAGE_SIZE = 480

SESSION = ort.InferenceSession(MODEL_PATH)
INPUT_NAME = SESSION.get_inputs()[0].name


def preprocess(img: Image.Image) -> np.ndarray:
    resized_img = ImageOps.pad(img, (IMAGE_SIZE, IMAGE_SIZE), centering=(0, 0))
    img_chw = np.array(resized_img).transpose(2, 0, 1).astype(np.float32) / 255
    img_chw = (img_chw - 0.5) / 0.5
    return img_chw


def distance(p1, p2):
    return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5


# https://stackoverflow.com/a/1222855
# https://www.microsoft.com/en-us/research/wp-content/uploads/2016/11/Digital-Signal-Processing.pdf
def get_aspect_ratio_zhang(keypoints: np.ndarray, img_width: int, img_height: int):
    keypoints = keypoints[[3, 2, 0, 1]]  # re-arrange keypoint according to Zhang 2006 Figure 6
    keypoints = np.concatenate([keypoints, np.ones((4, 1))], axis=1)  # convert to homogeneous coordinates

    # equation (11) and (12)
    k2 = np.cross(keypoints[0], keypoints[3]).dot(keypoints[2]) / np.cross(keypoints[1], keypoints[3]).dot(keypoints[2])
    k3 = np.cross(keypoints[0], keypoints[3]).dot(keypoints[1]) / np.cross(keypoints[2], keypoints[3]).dot(keypoints[1])

    # equation (14) and (16)
    n2 = k2 * keypoints[1] - keypoints[0]
    n3 = k3 * keypoints[2] - keypoints[0]

    # equation (21)
    u0 = img_width / 2
    v0 = img_height / 2
    f2 = -(n2[0] * n3[0] - (n2[0] * n3[2] + n2[2] + n3[0]) * u0 + n2[2] * n3[2] * u0 * u0) / (n2[2] * n3[2]) + (
        n2[1] * n3[1] - (n2[1] * n3[2] + n2[2] * n3[1]) * v0 + n2[2] * n3[2] * v0 * v0
    )
    f = math.sqrt(f2)

    # equation (20)
    A = np.array([[f, 0, u0], [0, f, v0], [0, 0, 1]])
    A_inv = np.linalg.inv(A)
    mid = A_inv.T.dot(A_inv)
    wh_ratio2 = n2.dot(mid).dot(n2) / n3.dot(mid).dot(n3)

    return math.sqrt(wh_ratio2)


def rectify(img_np: np.ndarray, keypoints: np.ndarray):
    img_height, img_width = img_np.shape[:2]

    h1 = distance(keypoints[0], keypoints[3])
    h2 = distance(keypoints[1], keypoints[2])
    h = (h1 + h2) * 0.5

    # this may fail if two lines are parallel
    try:
        wh_ratio = get_aspect_ratio_zhang(keypoints, img_width, img_height)
        w = h * wh_ratio

    except:
        logging.exception("Failed to estimate aspect ratio from perspective")
        w1 = distance(keypoints[0], keypoints[1])
        w2 = distance(keypoints[3], keypoints[2])
        w = (w1 + w2) * 0.5

    target_kpts = np.array([[1, 1], [w + 1, 1], [w + 1, h + 1], [1, h + 1]], dtype=np.float32)
    transform = cv2.getPerspectiveTransform(keypoints, target_kpts)
    cropped = cv2.warpPerspective(img_np, transform, (round(w) + 2, round(h) + 2), flags=cv2.INTER_CUBIC)
    return cropped


def predict(img: Image.Image):
    img_chw = preprocess(img)

    pred_kpts = SESSION.run(None, {INPUT_NAME: img_chw[None]})[0][0]
    kpts_xy = pred_kpts[:, :2] * max(img.size) / IMAGE_SIZE

    img_np = np.array(img)
    cv2.polylines(
        img_np,
        [kpts_xy.astype(int)],
        True,
        (0, 255, 0),
        thickness=5,
        lineType=cv2.LINE_AA,
    )

    if (pred_kpts[:, 2] >= 0.25).all():
        cropped = rectify(np.array(img), kpts_xy)
    else:
        cropped = None

    return cropped, img_np


gr.Interface(
    predict,
    inputs=[gr.Image(type="pil")],
    outputs=["image", "image"],
    examples=["estonia_id_card.jpg", "german_bundesdruckerei_passport.webp"],
).launch(server_name="0.0.0.0")