|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
import numpy as np |
|
import torchvision |
|
|
|
from .utils_vgghead import nms |
|
from .utils_lmks_detector import LmksDetector |
|
|
|
class VGGHeadDetector(torch.nn.Module): |
|
def __init__(self, device, |
|
vggheadmodel_path=None): |
|
super().__init__() |
|
self.image_size = 640 |
|
self._device = device |
|
self.vggheadmodel_path = vggheadmodel_path |
|
self._init_models() |
|
|
|
def _init_models(self,): |
|
|
|
self.model = torch.load(self.vggheadmodel_path, map_location='cpu') |
|
self.model.to(self._device).eval() |
|
|
|
@torch.no_grad() |
|
def forward(self, image_tensor, image_key, conf_threshold=0.5): |
|
if not hasattr(self, 'model'): |
|
self._init_models() |
|
image_tensor = image_tensor.to(self._device).float() |
|
image, padding, scale = self._preprocess(image_tensor) |
|
bbox, scores, flame_params = self.model(image) |
|
bbox, vgg_results = self._postprocess(bbox, scores, flame_params, conf_threshold) |
|
|
|
if bbox is None: |
|
print('VGGHeadDetector: No face detected: {}!'.format(image_key)) |
|
return None, None, None |
|
vgg_results['normalize'] = {'padding': padding, 'scale': scale} |
|
|
|
|
|
bbox = bbox.clip(0, self.image_size) |
|
bbox[[0, 2]] -= padding[0]; bbox[[1, 3]] -= padding[1]; bbox /= scale |
|
bbox = bbox.clip(0, self.image_size / scale) |
|
|
|
return vgg_results, bbox, None |
|
|
|
def _preprocess(self, image): |
|
_, h, w = image.shape |
|
if h > w: |
|
new_h, new_w = self.image_size, int(w * self.image_size / h) |
|
else: |
|
new_h, new_w = int(h * self.image_size / w), self.image_size |
|
scale = self.image_size / max(h, w) |
|
image = torchvision.transforms.functional.resize(image, (new_h, new_w), antialias=True) |
|
pad_w = self.image_size - image.shape[2] |
|
pad_h = self.image_size - image.shape[1] |
|
image = torchvision.transforms.functional.pad(image, (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2), fill=127) |
|
image = image.unsqueeze(0).float() / 255.0 |
|
return image, np.array([pad_w // 2, pad_h // 2]), scale |
|
|
|
def _postprocess(self, bbox, scores, flame_params, conf_threshold): |
|
|
|
bbox, scores, flame_params = nms(bbox, scores, flame_params, confidence_threshold=conf_threshold) |
|
if bbox.shape[0] == 0: |
|
return None, None |
|
max_idx = ((bbox[:, 3] - bbox[:, 1]) * (bbox[:, 2] - bbox[:, 0])).argmax().long() |
|
bbox, flame_params = bbox[max_idx], flame_params[max_idx] |
|
if bbox[0] < 5 and bbox[1] < 5 and bbox[2] > 635 and bbox[3] > 635: |
|
return None, None |
|
|
|
posecode = torch.cat([flame_params.new_zeros(3), flame_params[400:403]]) |
|
vgg_results = { |
|
'rotation_6d': flame_params[403:409], 'translation': flame_params[409:412], 'scale': flame_params[412:], |
|
'shapecode': flame_params[:300], 'expcode': flame_params[300:400], 'posecode': posecode, |
|
} |
|
return bbox, vgg_results |
|
|