Spaces:
Sleeping
Sleeping
from lib.kits.hsmr_demo import * | |
import gradio as gr | |
from lib.modeling.pipelines import HSMRPipeline | |
class HSMRBackend: | |
''' | |
Backend class for maintaining HSMR model for inferencing. | |
Some gradio feature is included in this class. | |
''' | |
def __init__(self, device:str='cpu') -> None: | |
self.max_img_w = 1920 | |
self.max_img_h = 1080 | |
self.device = device | |
self.pipeline = self._build_pipeline(self.device) | |
self.detector = build_detector( | |
batch_size = 1, | |
max_img_size = 512, | |
device = self.device, | |
) | |
def _build_pipeline(self, device) -> HSMRPipeline: | |
return build_inference_pipeline( | |
model_root = DEFAULT_HSMR_ROOT, | |
device = device, | |
) | |
def _load_limited_img(self, fn) -> List: | |
img, _ = load_img(fn) | |
if img.shape[0] > self.max_img_h: | |
img = flex_resize_img(img, (self.max_img_h, -1), kp_mod=4) | |
if img.shape[1] > self.max_img_w: | |
img = flex_resize_img(img, (-1, self.max_img_w), kp_mod=4) | |
return [img] | |
def __call__(self, input_path:Union[str, Path], args:Dict): | |
# 1. Initialization. | |
input_type = 'img' | |
if isinstance(input_path, str): input_path = Path(input_path) | |
outputs_root = input_path.parent / 'outputs' | |
outputs_root.mkdir(parents=True, exist_ok=True) | |
# 2. Preprocess. | |
gr.Info(f'[1/3] Pre-processing...') | |
raw_imgs = self._load_limited_img(input_path) | |
detector_outputs = self.detector(raw_imgs) | |
patches, det_meta = imgs_det2patches(raw_imgs, *detector_outputs,args['max_instances']) # N * (256, 256, 3) | |
# 3. Inference. | |
gr.Info(f'[2/3] HSMR inferencing...') | |
pd_params, pd_cam_t = [], [] | |
for bw in bsb(total=len(patches), batch_size=args['rec_bs'], enable_tqdm=True): | |
patches_i = np.concatenate(patches[bw.sid:bw.eid], axis=0) # (N, 256, 256, 3) | |
patches_normalized_i = (patches_i - IMG_MEAN_255) / IMG_STD_255 # (N, 256, 256, 3) | |
patches_normalized_i = patches_normalized_i.transpose(0, 3, 1, 2) # (N, 3, 256, 256) | |
with torch.no_grad(): | |
outputs = self.pipeline(patches_normalized_i) | |
pd_params.append({k: v.detach().cpu().clone() for k, v in outputs['pd_params'].items()}) | |
pd_cam_t.append(outputs['pd_cam_t'].detach().cpu().clone()) | |
pd_params = assemble_dict(pd_params, expand_dim=False) # [{k:[x]}, {k:[y]}] -> {k:[x, y]} | |
pd_cam_t = torch.cat(pd_cam_t, dim=0) | |
# 4. Render. | |
gr.Info(f'[3/3] Rendering results...') | |
m_skin, m_skel = prepare_mesh(self.pipeline, pd_params) | |
results = visualize_img_results(pd_cam_t, raw_imgs, det_meta, m_skin, m_skel) | |
outputs = {} | |
if input_type == 'img': | |
for k, v in results.items(): | |
img_path = str(outputs_root / f'{k}.jpg') | |
outputs[k] = img_path | |
save_img(v, img_path) | |
outputs[k] = img_path | |
return outputs |