|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
import torch |
|
from torchvision.utils import draw_bounding_boxes, draw_keypoints |
|
|
|
|
|
connectivity_face = ( |
|
[(i, i + 1) for i in list(range(0, 16))] |
|
+ [(i, i + 1) for i in list(range(17, 21))] |
|
+ [(i, i + 1) for i in list(range(22, 26))] |
|
+ [(i, i + 1) for i in list(range(27, 30))] |
|
+ [(i, i + 1) for i in list(range(31, 35))] |
|
+ [(i, i + 1) for i in list(range(36, 41))] |
|
+ [(36, 41)] |
|
+ [(i, i + 1) for i in list(range(42, 47))] |
|
+ [(42, 47)] |
|
+ [(i, i + 1) for i in list(range(48, 59))] |
|
+ [(48, 59)] |
|
+ [(i, i + 1) for i in list(range(60, 67))] |
|
+ [(60, 67)] |
|
) |
|
|
|
|
|
def plot_landmarks_2d( |
|
img: torch.tensor, |
|
lmks: torch.tensor, |
|
connectivity=None, |
|
colors="white", |
|
unit=1, |
|
input_float=False, |
|
): |
|
if input_float: |
|
img = (img * 255).byte() |
|
|
|
img = draw_keypoints( |
|
img, |
|
lmks, |
|
connectivity=connectivity, |
|
colors=colors, |
|
radius=2 * unit, |
|
width=2 * unit, |
|
) |
|
|
|
if input_float: |
|
img = img.float() / 255 |
|
return img |
|
|
|
|
|
def blend(a, b, w): |
|
return (a * w + b * (1 - w)).byte() |
|
|
|
|
|
if __name__ == "__main__": |
|
from argparse import ArgumentParser |
|
from torch.utils.data import DataLoader |
|
from matplotlib import pyplot as plt |
|
|
|
from vhap.data.nersemble_dataset import NeRSembleDataset |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument("--root_folder", type=str, required=True) |
|
parser.add_argument("--subject", type=str, required=True) |
|
parser.add_argument("--sequence", type=str, required=True) |
|
parser.add_argument("--division", default=None) |
|
parser.add_argument("--subset", default=None) |
|
parser.add_argument("--scale_factor", type=float, default=1.0) |
|
parser.add_argument("--blend_weight", type=float, default=0.6) |
|
args = parser.parse_args() |
|
|
|
dataset = NeRSembleDataset( |
|
root_folder=args.root_folder, |
|
subject=args.subject, |
|
sequence=args.sequence, |
|
division=args.division, |
|
subset=args.subset, |
|
n_downsample_rgb=2, |
|
scale_factor=args.scale_factor, |
|
use_landmark=True, |
|
) |
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) |
|
|
|
for item in dataloader: |
|
unit = int(item["scale_factor"][0] * 3) + 1 |
|
|
|
rgb = item["rgb"][0].permute(2, 0, 1) |
|
vis = rgb |
|
|
|
if "bbox_2d" in item: |
|
bbox = item["bbox_2d"][0][:4] |
|
tmp = draw_bounding_boxes(vis, bbox[None, ...], width=5 * unit) |
|
vis = blend(tmp, vis, args.blend_weight) |
|
|
|
if "lmk2d" in item: |
|
face_landmark = item["lmk2d"][0][:, :2] |
|
tmp = plot_landmarks_2d( |
|
vis, |
|
face_landmark[None, ...], |
|
connectivity=connectivity_face, |
|
colors="white", |
|
unit=unit, |
|
) |
|
vis = blend(tmp, vis, args.blend_weight) |
|
|
|
if "lmk2d_iris" in item: |
|
iris_landmark = item["lmk2d_iris"][0][:, :2] |
|
tmp = plot_landmarks_2d( |
|
vis, |
|
iris_landmark[None, ...], |
|
colors="blue", |
|
unit=unit, |
|
) |
|
vis = blend(tmp, vis, args.blend_weight) |
|
|
|
vis = vis.permute(1, 2, 0).numpy() |
|
plt.imshow(vis) |
|
plt.draw() |
|
while not plt.waitforbuttonpress(timeout=-1): |
|
pass |
|
|