from lib.kits.hsmr_demo import * import gradio as gr from lib.modeling.pipelines import HSMRPipeline class HSMRBackend: ''' Backend class for maintaining HSMR model for inferencing. Some gradio feature is included in this class. ''' def __init__(self, device:str='cpu') -> None: self.max_img_w = 1920 self.max_img_h = 1080 self.device = device self.pipeline = self._build_pipeline(self.device) self.detector = build_detector( batch_size = 1, max_img_size = 512, device = self.device, ) def _build_pipeline(self, device) -> HSMRPipeline: return build_inference_pipeline( model_root = DEFAULT_HSMR_ROOT, device = device, ) def _load_limited_img(self, fn) -> List: img, _ = load_img(fn) if img.shape[0] > self.max_img_h: img = flex_resize_img(img, (self.max_img_h, -1), kp_mod=4) if img.shape[1] > self.max_img_w: img = flex_resize_img(img, (-1, self.max_img_w), kp_mod=4) return [img] def __call__(self, input_path:Union[str, Path], args:Dict): # 1. Initialization. input_type = 'img' if isinstance(input_path, str): input_path = Path(input_path) outputs_root = input_path.parent / 'outputs' outputs_root.mkdir(parents=True, exist_ok=True) # 2. Preprocess. gr.Info(f'[1/3] Pre-processing...') raw_imgs = self._load_limited_img(input_path) detector_outputs = self.detector(raw_imgs) patches, det_meta = imgs_det2patches(raw_imgs, *detector_outputs,args['max_instances']) # N * (256, 256, 3) # 3. Inference. gr.Info(f'[2/3] HSMR inferencing...') pd_params, pd_cam_t = [], [] for bw in bsb(total=len(patches), batch_size=args['rec_bs'], enable_tqdm=True): patches_i = np.concatenate(patches[bw.sid:bw.eid], axis=0) # (N, 256, 256, 3) patches_normalized_i = (patches_i - IMG_MEAN_255) / IMG_STD_255 # (N, 256, 256, 3) patches_normalized_i = patches_normalized_i.transpose(0, 3, 1, 2) # (N, 3, 256, 256) with torch.no_grad(): outputs = self.pipeline(patches_normalized_i) pd_params.append({k: v.detach().cpu().clone() for k, v in outputs['pd_params'].items()}) pd_cam_t.append(outputs['pd_cam_t'].detach().cpu().clone()) pd_params = assemble_dict(pd_params, expand_dim=False) # [{k:[x]}, {k:[y]}] -> {k:[x, y]} pd_cam_t = torch.cat(pd_cam_t, dim=0) # 4. Render. gr.Info(f'[3/3] Rendering results...') m_skin, m_skel = prepare_mesh(self.pipeline, pd_params) results = visualize_img_results(pd_cam_t, raw_imgs, det_meta, m_skin, m_skel) outputs = {} if input_type == 'img': for k, v in results.items(): img_path = str(outputs_root / f'{k}.jpg') outputs[k] = img_path save_img(v, img_path) outputs[k] = img_path return outputs