Spaces:
Veein
/
Runtime error

File size: 3,700 Bytes
17cd746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/usr/bin/env python
# Copyright (c) Xuangeng Chu ([email protected])
# Modified based on code from Orest Kupyn (University of Oxford).

import torch
import torchvision

def reproject_vertices(flame_model, vgg_results):
    # flame_model = FLAMEModel(n_shape=300, n_exp=100, scale=1.0)
    vertices, _ = flame_model(
        shape_params=vgg_results['shapecode'], 
        expression_params=vgg_results['expcode'], 
        pose_params=vgg_results['posecode'],
        verts_sclae=1.0
    )
    vertices[:, :, 2] += 0.05 # MESH_OFFSET_Z
    vgg_landmarks3d = flame_model._vertices2landmarks(vertices)
    vgg_transform_results = vgg_results['transform']
    rotation_mat = rot_mat_from_6dof(vgg_transform_results['rotation_6d']).type(vertices.dtype)
    translation = vgg_transform_results['translation'][:, None, :]
    scale = torch.clamp(vgg_transform_results['scale'][:, None], 1e-8)
    rot_vertices = vertices.clone()
    rot_vertices = torch.matmul(rotation_mat.unsqueeze(1), rot_vertices.unsqueeze(-1))[..., 0]
    vgg_landmarks3d = torch.matmul(rotation_mat.unsqueeze(1), vgg_landmarks3d.unsqueeze(-1))[..., 0]
    proj_vertices = (rot_vertices * scale) + translation
    vgg_landmarks3d = (vgg_landmarks3d * scale) + translation
    
    trans_padding, trans_scale = vgg_results['normalize']['padding'], vgg_results['normalize']['scale']
    proj_vertices[:, :, 0] -= trans_padding[:, 0, None]
    proj_vertices[:, :, 1] -= trans_padding[:, 1, None]
    proj_vertices = proj_vertices / trans_scale[:, None, None]
    vgg_landmarks3d[:, :, 0] -= trans_padding[:, 0, None]
    vgg_landmarks3d[:, :, 1] -= trans_padding[:, 1, None]
    vgg_landmarks3d = vgg_landmarks3d / trans_scale[:, None, None]
    return proj_vertices.float()[..., :2], vgg_landmarks3d.float()[..., :2]


def rot_mat_from_6dof(v: torch.Tensor) -> torch.Tensor:
    assert v.shape[-1] == 6
    v = v.view(-1, 6)
    vx, vy = v[..., :3].clone(), v[..., 3:].clone()

    b1 = torch.nn.functional.normalize(vx, dim=-1)
    b3 = torch.nn.functional.normalize(torch.cross(b1, vy, dim=-1), dim=-1)
    b2 = -torch.cross(b1, b3, dim=1)
    return torch.stack((b1, b2, b3), dim=-1)


def nms(boxes_xyxy, scores, flame_params,
        confidence_threshold: float = 0.5, iou_threshold: float = 0.5, 
        top_k: int = 1000, keep_top_k: int = 100
    ):
    for pred_bboxes_xyxy, pred_bboxes_conf, pred_flame_params in zip(
            boxes_xyxy.detach().float(),
            scores.detach().float(),
            flame_params.detach().float(),
    ):
        pred_bboxes_conf = pred_bboxes_conf.squeeze(-1)  # [Anchors]
        conf_mask = pred_bboxes_conf >= confidence_threshold

        pred_bboxes_conf = pred_bboxes_conf[conf_mask]
        pred_bboxes_xyxy = pred_bboxes_xyxy[conf_mask]
        pred_flame_params = pred_flame_params[conf_mask]

        # Filter all predictions by self.nms_top_k
        if pred_bboxes_conf.size(0) > top_k:
            topk_candidates = torch.topk(pred_bboxes_conf, k=top_k, largest=True, sorted=True)
            pred_bboxes_conf = pred_bboxes_conf[topk_candidates.indices]
            pred_bboxes_xyxy = pred_bboxes_xyxy[topk_candidates.indices]
            pred_flame_params = pred_flame_params[topk_candidates.indices]

        # NMS
        idx_to_keep = torchvision.ops.boxes.nms(boxes=pred_bboxes_xyxy, scores=pred_bboxes_conf, iou_threshold=iou_threshold)

        final_bboxes = pred_bboxes_xyxy[idx_to_keep][: keep_top_k]  # [Instances, 4]
        final_scores = pred_bboxes_conf[idx_to_keep][: keep_top_k]  # [Instances, 1]
        final_params = pred_flame_params[idx_to_keep][: keep_top_k]  # [Instances, Flame Params]
        return final_bboxes, final_scores, final_params