File size: 3,137 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from lib.kits.hsmr_demo import *

import gradio as gr

from lib.modeling.pipelines import HSMRPipeline

class HSMRBackend:
    '''
    Backend class for maintaining HSMR model for inferencing.
    Some gradio feature is included in this class.
    '''
    def __init__(self, device:str='cpu') -> None:
        self.max_img_w = 1920
        self.max_img_h = 1080
        self.device = device
        self.pipeline = self._build_pipeline(self.device)
        self.detector = build_detector(
                batch_size   = 1,
                max_img_size = 512,
                device       = self.device,
            )


    def _build_pipeline(self, device) -> HSMRPipeline:
        return build_inference_pipeline(
                model_root = DEFAULT_HSMR_ROOT,
                device     = device,
            )


    def _load_limited_img(self, fn) -> List:
        img, _ = load_img(fn)
        if img.shape[0] > self.max_img_h:
            img = flex_resize_img(img, (self.max_img_h, -1), kp_mod=4)
        if img.shape[1] > self.max_img_w:
            img = flex_resize_img(img, (-1, self.max_img_w), kp_mod=4)
        return [img]


    def __call__(self, input_path:Union[str, Path], args:Dict):
        # 1. Initialization.
        input_type = 'img'
        if isinstance(input_path, str): input_path = Path(input_path)
        outputs_root = input_path.parent / 'outputs'
        outputs_root.mkdir(parents=True, exist_ok=True)

        # 2. Preprocess.
        gr.Info(f'[1/3] Pre-processing...')
        raw_imgs = self._load_limited_img(input_path)
        detector_outputs = self.detector(raw_imgs)
        patches, det_meta = imgs_det2patches(raw_imgs, *detector_outputs,args['max_instances'])  # N * (256, 256, 3)

        # 3. Inference.
        gr.Info(f'[2/3] HSMR inferencing...')
        pd_params, pd_cam_t = [], []
        for bw in bsb(total=len(patches), batch_size=args['rec_bs'], enable_tqdm=True):
            patches_i = np.concatenate(patches[bw.sid:bw.eid], axis=0)  # (N, 256, 256, 3)
            patches_normalized_i = (patches_i - IMG_MEAN_255) / IMG_STD_255  # (N, 256, 256, 3)
            patches_normalized_i = patches_normalized_i.transpose(0, 3, 1, 2)  # (N, 3, 256, 256)
            with torch.no_grad():
                outputs = self.pipeline(patches_normalized_i)
            pd_params.append({k: v.detach().cpu().clone() for k, v in outputs['pd_params'].items()})
            pd_cam_t.append(outputs['pd_cam_t'].detach().cpu().clone())

        pd_params = assemble_dict(pd_params, expand_dim=False)  # [{k:[x]}, {k:[y]}] -> {k:[x, y]}
        pd_cam_t = torch.cat(pd_cam_t, dim=0)

        # 4. Render.
        gr.Info(f'[3/3] Rendering results...')
        m_skin, m_skel = prepare_mesh(self.pipeline, pd_params)
        results = visualize_img_results(pd_cam_t, raw_imgs, det_meta, m_skin, m_skel)

        outputs = {}

        if input_type == 'img':
            for k, v in results.items():
                img_path = str(outputs_root / f'{k}.jpg')
                outputs[k] = img_path
                save_img(v, img_path)
                outputs[k] = img_path

        return outputs