Spaces:
Sleeping
Sleeping
File size: 3,137 Bytes
5ac1897 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from lib.kits.hsmr_demo import *
import gradio as gr
from lib.modeling.pipelines import HSMRPipeline
class HSMRBackend:
'''
Backend class for maintaining HSMR model for inferencing.
Some gradio feature is included in this class.
'''
def __init__(self, device:str='cpu') -> None:
self.max_img_w = 1920
self.max_img_h = 1080
self.device = device
self.pipeline = self._build_pipeline(self.device)
self.detector = build_detector(
batch_size = 1,
max_img_size = 512,
device = self.device,
)
def _build_pipeline(self, device) -> HSMRPipeline:
return build_inference_pipeline(
model_root = DEFAULT_HSMR_ROOT,
device = device,
)
def _load_limited_img(self, fn) -> List:
img, _ = load_img(fn)
if img.shape[0] > self.max_img_h:
img = flex_resize_img(img, (self.max_img_h, -1), kp_mod=4)
if img.shape[1] > self.max_img_w:
img = flex_resize_img(img, (-1, self.max_img_w), kp_mod=4)
return [img]
def __call__(self, input_path:Union[str, Path], args:Dict):
# 1. Initialization.
input_type = 'img'
if isinstance(input_path, str): input_path = Path(input_path)
outputs_root = input_path.parent / 'outputs'
outputs_root.mkdir(parents=True, exist_ok=True)
# 2. Preprocess.
gr.Info(f'[1/3] Pre-processing...')
raw_imgs = self._load_limited_img(input_path)
detector_outputs = self.detector(raw_imgs)
patches, det_meta = imgs_det2patches(raw_imgs, *detector_outputs,args['max_instances']) # N * (256, 256, 3)
# 3. Inference.
gr.Info(f'[2/3] HSMR inferencing...')
pd_params, pd_cam_t = [], []
for bw in bsb(total=len(patches), batch_size=args['rec_bs'], enable_tqdm=True):
patches_i = np.concatenate(patches[bw.sid:bw.eid], axis=0) # (N, 256, 256, 3)
patches_normalized_i = (patches_i - IMG_MEAN_255) / IMG_STD_255 # (N, 256, 256, 3)
patches_normalized_i = patches_normalized_i.transpose(0, 3, 1, 2) # (N, 3, 256, 256)
with torch.no_grad():
outputs = self.pipeline(patches_normalized_i)
pd_params.append({k: v.detach().cpu().clone() for k, v in outputs['pd_params'].items()})
pd_cam_t.append(outputs['pd_cam_t'].detach().cpu().clone())
pd_params = assemble_dict(pd_params, expand_dim=False) # [{k:[x]}, {k:[y]}] -> {k:[x, y]}
pd_cam_t = torch.cat(pd_cam_t, dim=0)
# 4. Render.
gr.Info(f'[3/3] Rendering results...')
m_skin, m_skel = prepare_mesh(self.pipeline, pd_params)
results = visualize_img_results(pd_cam_t, raw_imgs, det_meta, m_skin, m_skel)
outputs = {}
if input_type == 'img':
for k, v in results.items():
img_path = str(outputs_root / f'{k}.jpg')
outputs[k] = img_path
save_img(v, img_path)
outputs[k] = img_path
return outputs |