Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import logging | |
from typing import Dict, Optional, Tuple | |
import numpy as np | |
import torch as th | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision.utils import make_grid | |
from torchvision.transforms.functional import gaussian_blur | |
import visualize.ca_body.nn.layers as la | |
from visualize.ca_body.nn.blocks import ( | |
ConvBlock, | |
ConvDownBlock, | |
UpConvBlockDeep, | |
tile2d, | |
weights_initializer, | |
) | |
from visualize.ca_body.nn.dof_cal import LearnableBlur | |
from visualize.ca_body.utils.geom import ( | |
GeometryModule, | |
compute_view_cos, | |
depth_discontuity_mask, | |
depth2normals, | |
) | |
from visualize.ca_body.nn.shadow import ShadowUNet, PoseToShadow | |
from visualize.ca_body.nn.unet import UNetWB | |
from visualize.ca_body.nn.color_cal import CalV5 | |
from visualize.ca_body.utils.image import linear2displayBatch | |
from visualize.ca_body.utils.lbs import LBSModule | |
from visualize.ca_body.utils.render import RenderLayer | |
from visualize.ca_body.utils.seams import SeamSampler | |
from visualize.ca_body.utils.render import RenderLayer | |
from visualize.ca_body.nn.face import FaceDecoderFrontal | |
logger = logging.getLogger(__name__) | |
class CameraPixelBias(nn.Module): | |
def __init__(self, image_height, image_width, cameras, ds_rate) -> None: | |
super().__init__() | |
self.image_height = image_height | |
self.image_width = image_width | |
self.cameras = cameras | |
self.n_cameras = len(cameras) | |
bias = th.zeros( | |
(self.n_cameras, 1, image_width // ds_rate, image_height // ds_rate), dtype=th.float32 | |
) | |
self.register_parameter("bias", nn.Parameter(bias)) | |
def forward(self, idxs: th.Tensor): | |
bias_up = F.interpolate( | |
self.bias[idxs], size=(self.image_height, self.image_width), mode='bilinear' | |
) | |
return bias_up | |
class AutoEncoder(nn.Module): | |
def __init__( | |
self, | |
encoder, | |
decoder, | |
decoder_view, | |
encoder_face, | |
# hqlp decoder to get the codes | |
decoder_face, | |
shadow_net, | |
upscale_net, | |
assets, | |
pose_to_shadow=None, | |
renderer=None, | |
cal=None, | |
pixel_cal=None, | |
learn_blur: bool = True, | |
): | |
super().__init__() | |
# TODO: should we have a shared LBS here? | |
self.geo_fn = GeometryModule( | |
assets.topology.vi, | |
assets.topology.vt, | |
assets.topology.vti, | |
assets.topology.v2uv, | |
uv_size=1024, | |
impaint=True, | |
) | |
self.lbs_fn = LBSModule( | |
assets.lbs_model_json, | |
assets.lbs_config_dict, | |
assets.lbs_template_verts, | |
assets.lbs_scale, | |
assets.global_scaling, | |
) | |
self.seam_sampler = SeamSampler(assets.seam_data_1024) | |
self.seam_sampler_2k = SeamSampler(assets.seam_data_2048) | |
# joint tex -> body and clothes | |
# TODO: why do we have a joint one in the first place? | |
tex_mean = gaussian_blur(th.as_tensor(assets.tex_mean)[np.newaxis], kernel_size=11) | |
self.register_buffer("tex_mean", F.interpolate(tex_mean, (2048, 2048), mode='bilinear')) | |
# this is shared | |
self.tex_std = assets.tex_var if 'tex_var' in assets else 64.0 | |
face_cond_mask = th.as_tensor(assets.face_cond_mask, dtype=th.float32)[ | |
np.newaxis, np.newaxis | |
] | |
self.register_buffer("face_cond_mask", face_cond_mask) | |
meye_mask = self.geo_fn.to_uv( | |
th.as_tensor(assets.mouth_eyes_mask_geom[np.newaxis, :, np.newaxis]) | |
) | |
meye_mask = F.interpolate(meye_mask, (2048, 2048), mode='bilinear') | |
self.register_buffer("meye_mask", meye_mask) | |
self.decoder = ConvDecoder( | |
geo_fn=self.geo_fn, | |
seam_sampler=self.seam_sampler, | |
**decoder, | |
assets=assets, | |
) | |
# embs for everything but face | |
non_head_mask = 1.0 - assets.face_mask | |
self.encoder = Encoder( | |
geo_fn=self.geo_fn, | |
mask=non_head_mask, | |
**encoder, | |
) | |
self.encoder_face = FaceEncoder( | |
assets=assets, | |
**encoder_face, | |
) | |
# using face decoder to generate better conditioning | |
decoder_face_ckpt_path = None | |
if 'ckpt' in decoder_face: | |
decoder_face_ckpt_path = decoder_face.pop('ckpt') | |
self.decoder_face = FaceDecoderFrontal(assets=assets, **decoder_face) | |
if decoder_face_ckpt_path is not None: | |
self.decoder_face.load_state_dict(th.load(decoder_face_ckpt_path), strict=False) | |
self.decoder_view = UNetViewDecoder( | |
self.geo_fn, | |
seam_sampler=self.seam_sampler, | |
**decoder_view, | |
) | |
self.shadow_net = ShadowUNet( | |
ao_mean=assets.ao_mean, | |
interp_mode="bilinear", | |
biases=False, | |
**shadow_net, | |
) | |
self.pose_to_shadow_enabled = False | |
if pose_to_shadow is not None: | |
self.pose_to_shadow_enabled = True | |
self.pose_to_shadow = PoseToShadow(**pose_to_shadow) | |
self.upscale_net = UpscaleNet( | |
in_channels=6, size=1024, upscale_factor=2, out_channels=3, **upscale_net | |
) | |
self.pixel_cal_enabled = False | |
if pixel_cal is not None: | |
self.pixel_cal_enabled = True | |
self.pixel_cal = CameraPixelBias(**pixel_cal, cameras=assets.camera_ids) | |
self.learn_blur_enabled = False | |
if learn_blur: | |
self.learn_blur_enabled = True | |
self.learn_blur = LearnableBlur(assets.camera_ids) | |
# training-only stuff | |
self.cal_enabled = False | |
if cal is not None: | |
self.cal_enabled = True | |
self.cal = CalV5(**cal, cameras=assets.camera_ids) | |
self.rendering_enabled = False | |
if renderer is not None: | |
self.rendering_enabled = True | |
self.renderer = RenderLayer( | |
h=renderer.image_height, | |
w=renderer.image_width, | |
vt=self.geo_fn.vt, | |
vi=self.geo_fn.vi, | |
vti=self.geo_fn.vti, | |
flip_uvs=False, | |
) | |
def compute_summaries(self, preds, batch): | |
# TODO: switch to common summaries? | |
# return compute_summaries_mesh(preds, batch) | |
rgb = linear2displayBatch(preds['rgb'][:, :3]) | |
rgb_gt = linear2displayBatch(batch['image']) | |
depth = preds['depth'][:, np.newaxis] | |
mask = depth > 0.0 | |
normals = ( | |
255 * (1.0 - depth2normals(depth, batch['focal'], batch['princpt'])) / 2.0 | |
) * mask | |
grid_rgb = make_grid(rgb, nrow=16).permute(1, 2, 0).clip(0, 255).to(th.uint8) | |
grid_rgb_gt = make_grid(rgb_gt, nrow=16).permute(1, 2, 0).clip(0, 255).to(th.uint8) | |
grid_normals = make_grid(normals, nrow=16).permute(1, 2, 0).clip(0, 255).to(th.uint8) | |
progress_image = th.cat([grid_rgb, grid_rgb_gt, grid_normals], dim=0) | |
return { | |
'progress_image': (progress_image, 'png'), | |
} | |
def forward_tex(self, tex_mean_rec, tex_view_rec, shadow_map): | |
x = th.cat([tex_mean_rec, tex_view_rec], dim=1) | |
tex_rec = tex_mean_rec + tex_view_rec | |
tex_rec = self.seam_sampler.impaint(tex_rec) | |
tex_rec = self.seam_sampler.resample(tex_rec) | |
tex_rec = F.interpolate(tex_rec, size=(2048, 2048), mode="bilinear", align_corners=False) | |
tex_rec = tex_rec + self.upscale_net(x) | |
tex_rec = tex_rec * self.tex_std + self.tex_mean | |
shadow_map = self.seam_sampler_2k.impaint(shadow_map) | |
shadow_map = self.seam_sampler_2k.resample(shadow_map) | |
shadow_map = self.seam_sampler_2k.resample(shadow_map) | |
tex_rec = tex_rec * shadow_map | |
tex_rec = self.seam_sampler_2k.impaint(tex_rec) | |
tex_rec = self.seam_sampler_2k.resample(tex_rec) | |
tex_rec = self.seam_sampler_2k.resample(tex_rec) | |
return tex_rec | |
def encode(self, geom: th.Tensor, lbs_motion: th.Tensor, face_embs_hqlp: th.Tensor): | |
with th.no_grad(): | |
verts_unposed = self.lbs_fn.unpose(geom, lbs_motion) | |
verts_unposed_uv = self.geo_fn.to_uv(verts_unposed) | |
# extract face region for geom + tex | |
enc_preds = self.encoder(motion=lbs_motion, verts_unposed=verts_unposed) | |
# TODO: probably need to rename these to `face_embs_mugsy` or smth | |
# TODO: we need the same thing for face? | |
# enc_face_preds = self.encoder_face(verts_unposed_uv) | |
with th.no_grad(): | |
face_dec_preds = self.decoder_face(face_embs_hqlp) | |
enc_face_preds = self.encoder_face(**face_dec_preds) | |
preds = { | |
**enc_preds, | |
**enc_face_preds, | |
'face_dec_preds': face_dec_preds, | |
} | |
return preds | |
def forward( | |
self, | |
# TODO: should we try using this as well for cond? | |
lbs_motion: th.Tensor, | |
campos: th.Tensor, | |
geom: Optional[th.Tensor] = None, | |
ao: Optional[th.Tensor] = None, | |
K: Optional[th.Tensor] = None, | |
Rt: Optional[th.Tensor] = None, | |
image_bg: Optional[th.Tensor] = None, | |
image: Optional[th.Tensor] = None, | |
image_mask: Optional[th.Tensor] = None, | |
embs: Optional[th.Tensor] = None, | |
_index: Optional[Dict[str, th.Tensor]] = None, | |
face_embs: Optional[th.Tensor] = None, | |
embs_conv: Optional[th.Tensor] = None, | |
tex_seg: Optional[th.Tensor] = None, | |
encode=True, | |
iteration: Optional[int] = None, | |
**kwargs, | |
): | |
B = lbs_motion.shape[0] | |
if not th.jit.is_scripting() and encode: | |
# NOTE: these are `face_embs_hqlp` | |
enc_preds = self.encode(geom, lbs_motion, face_embs) | |
embs = enc_preds['embs'] | |
# NOTE: these are `face_embs` in body space | |
face_embs_body = enc_preds['face_embs'] | |
dec_preds = self.decoder( | |
motion=lbs_motion, | |
embs=embs, | |
face_embs=face_embs_body, | |
embs_conv=embs_conv, | |
) | |
geom_rec = self.lbs_fn.pose(dec_preds['geom_delta_rec'], lbs_motion) | |
dec_view_preds = self.decoder_view( | |
geom_rec=geom_rec, | |
tex_mean_rec=dec_preds["tex_mean_rec"], | |
camera_pos=campos, | |
) | |
# TODO: should we train an AO model? | |
if self.training and self.pose_to_shadow_enabled: | |
shadow_preds = self.shadow_net(ao_map=ao) | |
pose_shadow_preds = self.pose_to_shadow(lbs_motion) | |
shadow_preds['pose_shadow_map'] = pose_shadow_preds['shadow_map'] | |
elif self.pose_to_shadow_enabled: | |
shadow_preds = self.pose_to_shadow(lbs_motion) | |
else: | |
shadow_preds = self.shadow_net(ao_map=ao) | |
tex_rec = self.forward_tex( | |
dec_preds["tex_mean_rec"], | |
dec_view_preds["tex_view_rec"], | |
shadow_preds["shadow_map"], | |
) | |
if not th.jit.is_scripting() and self.cal_enabled: | |
tex_rec = self.cal(tex_rec, self.cal.name_to_idx(_index['camera'])) | |
preds = { | |
'geom': geom_rec, | |
'tex_rec': tex_rec, | |
**dec_preds, | |
**shadow_preds, | |
**dec_view_preds, | |
} | |
if not th.jit.is_scripting() and encode: | |
preds.update(**enc_preds) | |
if not th.jit.is_scripting() and self.rendering_enabled: | |
# NOTE: this is a reduced version tested for forward only | |
renders = self.renderer( | |
preds['geom'], | |
tex_rec, | |
K=K, | |
Rt=Rt, | |
) | |
preds.update(rgb=renders['render']) | |
if not th.jit.is_scripting() and self.learn_blur_enabled: | |
preds['rgb'] = self.learn_blur(preds['rgb'], _index['camera']) | |
preds['learn_blur_weights'] = self.learn_blur.reg(_index['camera']) | |
if not th.jit.is_scripting() and self.pixel_cal_enabled: | |
assert self.cal_enabled | |
cam_idxs = self.cal.name_to_idx(_index['camera']) | |
pixel_bias = self.pixel_cal(cam_idxs) | |
preds['rgb'] = preds['rgb'] + pixel_bias | |
return preds | |
class Encoder(nn.Module): | |
"""A joint encoder for tex and geometry.""" | |
def __init__( | |
self, | |
geo_fn, | |
n_embs, | |
noise_std, | |
mask, | |
logvar_scale=0.1, | |
): | |
"""Fixed-width conv encoder.""" | |
super().__init__() | |
self.noise_std = noise_std | |
self.n_embs = n_embs | |
self.geo_fn = geo_fn | |
self.logvar_scale = logvar_scale | |
self.verts_conv = ConvDownBlock(3, 8, 512) | |
mask = th.as_tensor(mask[np.newaxis, np.newaxis], dtype=th.float32) | |
mask = F.interpolate(mask, size=(512, 512), mode='bilinear').to(th.bool) | |
self.register_buffer("mask", mask) | |
self.joint_conv_blocks = nn.Sequential( | |
ConvDownBlock(8, 16, 256), | |
ConvDownBlock(16, 32, 128), | |
ConvDownBlock(32, 32, 64), | |
ConvDownBlock(32, 64, 32), | |
ConvDownBlock(64, 128, 16), | |
ConvDownBlock(128, 128, 8), | |
# ConvDownBlock(128, 128, 4), | |
) | |
# TODO: should we put initializer | |
self.mu = la.LinearWN(4 * 4 * 128, self.n_embs) | |
self.logvar = la.LinearWN(4 * 4 * 128, self.n_embs) | |
self.apply(weights_initializer(0.2)) | |
self.mu.apply(weights_initializer(1.0)) | |
self.logvar.apply(weights_initializer(1.0)) | |
def forward(self, motion, verts_unposed): | |
preds = {} | |
B = motion.shape[0] | |
# converting motion to the unposed | |
verts_cond = ( | |
F.interpolate(self.geo_fn.to_uv(verts_unposed), size=(512, 512), mode='bilinear') | |
* self.mask | |
) | |
verts_cond = self.verts_conv(verts_cond) | |
# tex_cond = F.interpolate(tex_avg, size=(512, 512), mode='bilinear') * self.mask | |
# tex_cond = self.tex_conv(tex_cond) | |
# joint_cond = th.cat([verts_cond, tex_cond], dim=1) | |
joint_cond = verts_cond | |
x = self.joint_conv_blocks(joint_cond) | |
x = x.reshape(B, -1) | |
embs_mu = self.mu(x) | |
embs_logvar = self.logvar_scale * self.logvar(x) | |
# NOTE: the noise is only applied to the input-conditioned values | |
if self.training: | |
noise = th.randn_like(embs_mu) | |
embs = embs_mu + th.exp(embs_logvar) * noise * self.noise_std | |
else: | |
embs = embs_mu.clone() | |
preds.update( | |
embs=embs, | |
embs_mu=embs_mu, | |
embs_logvar=embs_logvar, | |
) | |
return preds | |
class ConvDecoder(nn.Module): | |
"""Multi-region view-independent decoder.""" | |
def __init__( | |
self, | |
geo_fn, | |
uv_size, | |
seam_sampler, | |
init_uv_size, | |
n_pose_dims, | |
n_pose_enc_channels, | |
n_embs, | |
n_embs_enc_channels, | |
n_face_embs, | |
n_init_channels, | |
n_min_channels, | |
assets, | |
): | |
super().__init__() | |
self.geo_fn = geo_fn | |
self.uv_size = uv_size | |
self.init_uv_size = init_uv_size | |
self.n_pose_dims = n_pose_dims | |
self.n_pose_enc_channels = n_pose_enc_channels | |
self.n_embs = n_embs | |
self.n_embs_enc_channels = n_embs_enc_channels | |
self.n_face_embs = n_face_embs | |
self.n_blocks = int(np.log2(self.uv_size // init_uv_size)) | |
self.sizes = [init_uv_size * 2**s for s in range(self.n_blocks + 1)] | |
# TODO: just specify a sequence? | |
self.n_channels = [ | |
max(n_init_channels // 2**b, n_min_channels) for b in range(self.n_blocks + 1) | |
] | |
logger.info(f"ConvDecoder: n_channels = {self.n_channels}") | |
self.local_pose_conv_block = ConvBlock( | |
n_pose_dims, | |
n_pose_enc_channels, | |
init_uv_size, | |
kernel_size=1, | |
padding=0, | |
) | |
self.embs_fc = nn.Sequential( | |
la.LinearWN(n_embs, 4 * 4 * 128), | |
nn.LeakyReLU(0.2, inplace=True), | |
) | |
# TODO: should we switch to the basic version? | |
self.embs_conv_block = nn.Sequential( | |
UpConvBlockDeep(128, 128, 8), | |
UpConvBlockDeep(128, 128, 16), | |
UpConvBlockDeep(128, 64, 32), | |
UpConvBlockDeep(64, n_embs_enc_channels, 64), | |
) | |
self.face_embs_fc = nn.Sequential( | |
la.LinearWN(n_face_embs, 4 * 4 * 32), | |
nn.LeakyReLU(0.2, inplace=True), | |
) | |
self.face_embs_conv_block = nn.Sequential( | |
UpConvBlockDeep(32, 64, 8), | |
UpConvBlockDeep(64, 64, 16), | |
UpConvBlockDeep(64, n_embs_enc_channels, 32), | |
) | |
n_groups = 2 | |
self.joint_conv_block = ConvBlock( | |
n_pose_enc_channels + n_embs_enc_channels, | |
n_init_channels, | |
self.init_uv_size, | |
) | |
self.conv_blocks = nn.ModuleList([]) | |
for b in range(self.n_blocks): | |
self.conv_blocks.append( | |
UpConvBlockDeep( | |
self.n_channels[b] * n_groups, | |
self.n_channels[b + 1] * n_groups, | |
self.sizes[b + 1], | |
groups=n_groups, | |
), | |
) | |
self.verts_conv = la.Conv2dWNUB( | |
in_channels=self.n_channels[-1], | |
out_channels=3, | |
kernel_size=3, | |
height=self.uv_size, | |
width=self.uv_size, | |
padding=1, | |
) | |
self.tex_conv = la.Conv2dWNUB( | |
in_channels=self.n_channels[-1], | |
out_channels=3, | |
kernel_size=3, | |
height=self.uv_size, | |
width=self.uv_size, | |
padding=1, | |
) | |
self.apply(weights_initializer(0.2)) | |
self.verts_conv.apply(weights_initializer(1.0)) | |
self.tex_conv.apply(weights_initializer(1.0)) | |
self.seam_sampler = seam_sampler | |
# NOTE: removing head region from pose completely | |
pose_cond_mask = th.as_tensor( | |
assets.pose_cond_mask[np.newaxis] * (1 - assets.head_cond_mask[np.newaxis, np.newaxis]), | |
dtype=th.int32, | |
) | |
self.register_buffer("pose_cond_mask", pose_cond_mask) | |
face_cond_mask = th.as_tensor(assets.face_cond_mask, dtype=th.float32)[ | |
np.newaxis, np.newaxis | |
] | |
self.register_buffer("face_cond_mask", face_cond_mask) | |
body_cond_mask = th.as_tensor(assets.body_cond_mask, dtype=th.float32)[ | |
np.newaxis, np.newaxis | |
] | |
self.register_buffer("body_cond_mask", body_cond_mask) | |
def forward(self, motion, embs, face_embs, embs_conv: Optional[th.Tensor] = None): | |
# processing pose | |
pose = motion[:, 6:] | |
B = pose.shape[0] | |
non_head_mask = (self.body_cond_mask * (1.0 - self.face_cond_mask)).clip(0.0, 1.0) | |
pose_masked = tile2d(pose, self.init_uv_size) * self.pose_cond_mask | |
pose_conv = self.local_pose_conv_block(pose_masked) * non_head_mask | |
# TODO: decoding properly? | |
if embs_conv is None: | |
embs_conv = self.embs_conv_block(self.embs_fc(embs).reshape(B, 128, 4, 4)) | |
face_conv = self.face_embs_conv_block(self.face_embs_fc(face_embs).reshape(B, 32, 4, 4)) | |
# merging embeddings with spatial masks | |
embs_conv[:, :, 32:, :32] = ( | |
face_conv * self.face_cond_mask[:, :, 32:, :32] | |
+ embs_conv[:, :, 32:, :32] * non_head_mask[:, :, 32:, :32] | |
) | |
joint = th.cat([pose_conv, embs_conv], axis=1) | |
joint = self.joint_conv_block(joint) | |
x = th.cat([joint, joint], axis=1) | |
for b in range(self.n_blocks): | |
x = self.conv_blocks[b](x) | |
# NOTE: here we do resampling at feature level | |
x = self.seam_sampler.impaint(x) | |
x = self.seam_sampler.resample(x) | |
x = self.seam_sampler.resample(x) | |
verts_features, tex_features = th.split(x, self.n_channels[-1], 1) | |
verts_uv_delta_rec = self.verts_conv(verts_features) | |
# TODO: need to get values | |
verts_delta_rec = self.geo_fn.from_uv(verts_uv_delta_rec) | |
tex_mean_rec = self.tex_conv(tex_features) | |
preds = { | |
'geom_delta_rec': verts_delta_rec, | |
'geom_uv_delta_rec': verts_uv_delta_rec, | |
'tex_mean_rec': tex_mean_rec, | |
'embs_conv': embs_conv, | |
'pose_conv': pose_conv, | |
} | |
return preds | |
class FaceEncoder(nn.Module): | |
"""A joint encoder for tex and geometry.""" | |
def __init__( | |
self, | |
noise_std, | |
assets, | |
n_embs=256, | |
uv_size=512, | |
logvar_scale=0.1, | |
n_vert_in=7306 * 3, | |
prefix="face_", | |
): | |
"""Fixed-width conv encoder.""" | |
super().__init__() | |
# TODO: | |
self.noise_std = noise_std | |
self.n_embs = n_embs | |
self.logvar_scale = logvar_scale | |
self.prefix = prefix | |
self.uv_size = uv_size | |
assert self.uv_size == 512 | |
tex_cond_mask = assets.mugsy_face_mask[..., 0] | |
tex_cond_mask = th.as_tensor(tex_cond_mask, dtype=th.float32)[np.newaxis, np.newaxis] | |
tex_cond_mask = F.interpolate( | |
tex_cond_mask, (self.uv_size, self.uv_size), mode="bilinear", align_corners=True | |
) | |
self.register_buffer("tex_cond_mask", tex_cond_mask) | |
self.conv_blocks = nn.Sequential( | |
ConvDownBlock(3, 4, 512), | |
ConvDownBlock(4, 8, 256), | |
ConvDownBlock(8, 16, 128), | |
ConvDownBlock(16, 32, 64), | |
ConvDownBlock(32, 64, 32), | |
ConvDownBlock(64, 128, 16), | |
ConvDownBlock(128, 128, 8), | |
) | |
self.geommod = nn.Sequential(la.LinearWN(n_vert_in, 256), nn.LeakyReLU(0.2, inplace=True)) | |
self.jointmod = nn.Sequential( | |
la.LinearWN(256 + 128 * 4 * 4, 512), nn.LeakyReLU(0.2, inplace=True) | |
) | |
# TODO: should we put initializer | |
self.mu = la.LinearWN(512, self.n_embs) | |
self.logvar = la.LinearWN(512, self.n_embs) | |
self.apply(weights_initializer(0.2)) | |
self.mu.apply(weights_initializer(1.0)) | |
self.logvar.apply(weights_initializer(1.0)) | |
# TODO: compute_losses()? | |
def forward(self, face_geom: th.Tensor, face_tex: th.Tensor, **kwargs): | |
B = face_geom.shape[0] | |
tex_cond = F.interpolate( | |
face_tex, (self.uv_size, self.uv_size), mode="bilinear", align_corners=False | |
) | |
tex_cond = (tex_cond / 255.0 - 0.5) * self.tex_cond_mask | |
x = self.conv_blocks(tex_cond) | |
tex_enc = x.reshape(B, 4 * 4 * 128) | |
geom_enc = self.geommod(face_geom.reshape(B, -1)) | |
x = self.jointmod(th.cat([tex_enc, geom_enc], dim=1)) | |
embs_mu = self.mu(x) | |
embs_logvar = self.logvar_scale * self.logvar(x) | |
# NOTE: the noise is only applied to the input-conditioned values | |
if self.training: | |
noise = th.randn_like(embs_mu) | |
embs = embs_mu + th.exp(embs_logvar) * noise * self.noise_std | |
else: | |
embs = embs_mu.clone() | |
preds = {"embs": embs, "embs_mu": embs_mu, "embs_logvar": embs_logvar, "tex_cond": tex_cond} | |
preds = {f"{self.prefix}{k}": v for k, v in preds.items()} | |
return preds | |
class UNetViewDecoder(nn.Module): | |
def __init__(self, geo_fn, net_uv_size, seam_sampler, n_init_ftrs=8): | |
super().__init__() | |
self.geo_fn = geo_fn | |
self.net_uv_size = net_uv_size | |
self.unet = UNetWB(4, 3, n_init_ftrs=n_init_ftrs, size=net_uv_size) | |
self.register_buffer("faces", self.geo_fn.vi.to(th.int64), persistent=False) | |
def forward(self, geom_rec, tex_mean_rec, camera_pos): | |
with th.no_grad(): | |
view_cos = compute_view_cos(geom_rec, self.faces, camera_pos) | |
view_cos_uv = self.geo_fn.to_uv(view_cos[..., np.newaxis]) | |
cond_view = th.cat([view_cos_uv, tex_mean_rec], dim=1) | |
tex_view = self.unet(cond_view) | |
# TODO: should we try warping here? | |
return {"tex_view_rec": tex_view, "cond_view": cond_view} | |
class UpscaleNet(nn.Module): | |
def __init__(self, in_channels, out_channels, n_ftrs, size=1024, upscale_factor=2): | |
super().__init__() | |
self.conv_block = nn.Sequential( | |
la.Conv2dWNUB(in_channels, n_ftrs, size, size, kernel_size=3, padding=1), | |
nn.LeakyReLU(0.2, inplace=True), | |
) | |
self.out_block = la.Conv2dWNUB( | |
n_ftrs, | |
out_channels * upscale_factor**2, | |
size, | |
size, | |
kernel_size=1, | |
padding=0, | |
) | |
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=upscale_factor) | |
self.apply(weights_initializer(0.2)) | |
self.out_block.apply(weights_initializer(1.0)) | |
def forward(self, x): | |
x = self.conv_block(x) | |
x = self.out_block(x) | |
return self.pixel_shuffle(x) |