test_virtual / visualize /ca_body /models /mesh_vae_drivable.py
whyun13's picture
Upload folder using huggingface_hub
882f6e2 verified
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import logging
from typing import Dict, Optional, Tuple
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import make_grid
from torchvision.transforms.functional import gaussian_blur
import visualize.ca_body.nn.layers as la
from visualize.ca_body.nn.blocks import (
ConvBlock,
ConvDownBlock,
UpConvBlockDeep,
tile2d,
weights_initializer,
)
from visualize.ca_body.nn.dof_cal import LearnableBlur
from visualize.ca_body.utils.geom import (
GeometryModule,
compute_view_cos,
depth_discontuity_mask,
depth2normals,
)
from visualize.ca_body.nn.shadow import ShadowUNet, PoseToShadow
from visualize.ca_body.nn.unet import UNetWB
from visualize.ca_body.nn.color_cal import CalV5
from visualize.ca_body.utils.image import linear2displayBatch
from visualize.ca_body.utils.lbs import LBSModule
from visualize.ca_body.utils.render import RenderLayer
from visualize.ca_body.utils.seams import SeamSampler
from visualize.ca_body.utils.render import RenderLayer
from visualize.ca_body.nn.face import FaceDecoderFrontal
logger = logging.getLogger(__name__)
class CameraPixelBias(nn.Module):
def __init__(self, image_height, image_width, cameras, ds_rate) -> None:
super().__init__()
self.image_height = image_height
self.image_width = image_width
self.cameras = cameras
self.n_cameras = len(cameras)
bias = th.zeros(
(self.n_cameras, 1, image_width // ds_rate, image_height // ds_rate), dtype=th.float32
)
self.register_parameter("bias", nn.Parameter(bias))
def forward(self, idxs: th.Tensor):
bias_up = F.interpolate(
self.bias[idxs], size=(self.image_height, self.image_width), mode='bilinear'
)
return bias_up
class AutoEncoder(nn.Module):
def __init__(
self,
encoder,
decoder,
decoder_view,
encoder_face,
# hqlp decoder to get the codes
decoder_face,
shadow_net,
upscale_net,
assets,
pose_to_shadow=None,
renderer=None,
cal=None,
pixel_cal=None,
learn_blur: bool = True,
):
super().__init__()
# TODO: should we have a shared LBS here?
self.geo_fn = GeometryModule(
assets.topology.vi,
assets.topology.vt,
assets.topology.vti,
assets.topology.v2uv,
uv_size=1024,
impaint=True,
)
self.lbs_fn = LBSModule(
assets.lbs_model_json,
assets.lbs_config_dict,
assets.lbs_template_verts,
assets.lbs_scale,
assets.global_scaling,
)
self.seam_sampler = SeamSampler(assets.seam_data_1024)
self.seam_sampler_2k = SeamSampler(assets.seam_data_2048)
# joint tex -> body and clothes
# TODO: why do we have a joint one in the first place?
tex_mean = gaussian_blur(th.as_tensor(assets.tex_mean)[np.newaxis], kernel_size=11)
self.register_buffer("tex_mean", F.interpolate(tex_mean, (2048, 2048), mode='bilinear'))
# this is shared
self.tex_std = assets.tex_var if 'tex_var' in assets else 64.0
face_cond_mask = th.as_tensor(assets.face_cond_mask, dtype=th.float32)[
np.newaxis, np.newaxis
]
self.register_buffer("face_cond_mask", face_cond_mask)
meye_mask = self.geo_fn.to_uv(
th.as_tensor(assets.mouth_eyes_mask_geom[np.newaxis, :, np.newaxis])
)
meye_mask = F.interpolate(meye_mask, (2048, 2048), mode='bilinear')
self.register_buffer("meye_mask", meye_mask)
self.decoder = ConvDecoder(
geo_fn=self.geo_fn,
seam_sampler=self.seam_sampler,
**decoder,
assets=assets,
)
# embs for everything but face
non_head_mask = 1.0 - assets.face_mask
self.encoder = Encoder(
geo_fn=self.geo_fn,
mask=non_head_mask,
**encoder,
)
self.encoder_face = FaceEncoder(
assets=assets,
**encoder_face,
)
# using face decoder to generate better conditioning
decoder_face_ckpt_path = None
if 'ckpt' in decoder_face:
decoder_face_ckpt_path = decoder_face.pop('ckpt')
self.decoder_face = FaceDecoderFrontal(assets=assets, **decoder_face)
if decoder_face_ckpt_path is not None:
self.decoder_face.load_state_dict(th.load(decoder_face_ckpt_path), strict=False)
self.decoder_view = UNetViewDecoder(
self.geo_fn,
seam_sampler=self.seam_sampler,
**decoder_view,
)
self.shadow_net = ShadowUNet(
ao_mean=assets.ao_mean,
interp_mode="bilinear",
biases=False,
**shadow_net,
)
self.pose_to_shadow_enabled = False
if pose_to_shadow is not None:
self.pose_to_shadow_enabled = True
self.pose_to_shadow = PoseToShadow(**pose_to_shadow)
self.upscale_net = UpscaleNet(
in_channels=6, size=1024, upscale_factor=2, out_channels=3, **upscale_net
)
self.pixel_cal_enabled = False
if pixel_cal is not None:
self.pixel_cal_enabled = True
self.pixel_cal = CameraPixelBias(**pixel_cal, cameras=assets.camera_ids)
self.learn_blur_enabled = False
if learn_blur:
self.learn_blur_enabled = True
self.learn_blur = LearnableBlur(assets.camera_ids)
# training-only stuff
self.cal_enabled = False
if cal is not None:
self.cal_enabled = True
self.cal = CalV5(**cal, cameras=assets.camera_ids)
self.rendering_enabled = False
if renderer is not None:
self.rendering_enabled = True
self.renderer = RenderLayer(
h=renderer.image_height,
w=renderer.image_width,
vt=self.geo_fn.vt,
vi=self.geo_fn.vi,
vti=self.geo_fn.vti,
flip_uvs=False,
)
@th.jit.unused
def compute_summaries(self, preds, batch):
# TODO: switch to common summaries?
# return compute_summaries_mesh(preds, batch)
rgb = linear2displayBatch(preds['rgb'][:, :3])
rgb_gt = linear2displayBatch(batch['image'])
depth = preds['depth'][:, np.newaxis]
mask = depth > 0.0
normals = (
255 * (1.0 - depth2normals(depth, batch['focal'], batch['princpt'])) / 2.0
) * mask
grid_rgb = make_grid(rgb, nrow=16).permute(1, 2, 0).clip(0, 255).to(th.uint8)
grid_rgb_gt = make_grid(rgb_gt, nrow=16).permute(1, 2, 0).clip(0, 255).to(th.uint8)
grid_normals = make_grid(normals, nrow=16).permute(1, 2, 0).clip(0, 255).to(th.uint8)
progress_image = th.cat([grid_rgb, grid_rgb_gt, grid_normals], dim=0)
return {
'progress_image': (progress_image, 'png'),
}
def forward_tex(self, tex_mean_rec, tex_view_rec, shadow_map):
x = th.cat([tex_mean_rec, tex_view_rec], dim=1)
tex_rec = tex_mean_rec + tex_view_rec
tex_rec = self.seam_sampler.impaint(tex_rec)
tex_rec = self.seam_sampler.resample(tex_rec)
tex_rec = F.interpolate(tex_rec, size=(2048, 2048), mode="bilinear", align_corners=False)
tex_rec = tex_rec + self.upscale_net(x)
tex_rec = tex_rec * self.tex_std + self.tex_mean
shadow_map = self.seam_sampler_2k.impaint(shadow_map)
shadow_map = self.seam_sampler_2k.resample(shadow_map)
shadow_map = self.seam_sampler_2k.resample(shadow_map)
tex_rec = tex_rec * shadow_map
tex_rec = self.seam_sampler_2k.impaint(tex_rec)
tex_rec = self.seam_sampler_2k.resample(tex_rec)
tex_rec = self.seam_sampler_2k.resample(tex_rec)
return tex_rec
def encode(self, geom: th.Tensor, lbs_motion: th.Tensor, face_embs_hqlp: th.Tensor):
with th.no_grad():
verts_unposed = self.lbs_fn.unpose(geom, lbs_motion)
verts_unposed_uv = self.geo_fn.to_uv(verts_unposed)
# extract face region for geom + tex
enc_preds = self.encoder(motion=lbs_motion, verts_unposed=verts_unposed)
# TODO: probably need to rename these to `face_embs_mugsy` or smth
# TODO: we need the same thing for face?
# enc_face_preds = self.encoder_face(verts_unposed_uv)
with th.no_grad():
face_dec_preds = self.decoder_face(face_embs_hqlp)
enc_face_preds = self.encoder_face(**face_dec_preds)
preds = {
**enc_preds,
**enc_face_preds,
'face_dec_preds': face_dec_preds,
}
return preds
def forward(
self,
# TODO: should we try using this as well for cond?
lbs_motion: th.Tensor,
campos: th.Tensor,
geom: Optional[th.Tensor] = None,
ao: Optional[th.Tensor] = None,
K: Optional[th.Tensor] = None,
Rt: Optional[th.Tensor] = None,
image_bg: Optional[th.Tensor] = None,
image: Optional[th.Tensor] = None,
image_mask: Optional[th.Tensor] = None,
embs: Optional[th.Tensor] = None,
_index: Optional[Dict[str, th.Tensor]] = None,
face_embs: Optional[th.Tensor] = None,
embs_conv: Optional[th.Tensor] = None,
tex_seg: Optional[th.Tensor] = None,
encode=True,
iteration: Optional[int] = None,
**kwargs,
):
B = lbs_motion.shape[0]
if not th.jit.is_scripting() and encode:
# NOTE: these are `face_embs_hqlp`
enc_preds = self.encode(geom, lbs_motion, face_embs)
embs = enc_preds['embs']
# NOTE: these are `face_embs` in body space
face_embs_body = enc_preds['face_embs']
dec_preds = self.decoder(
motion=lbs_motion,
embs=embs,
face_embs=face_embs_body,
embs_conv=embs_conv,
)
geom_rec = self.lbs_fn.pose(dec_preds['geom_delta_rec'], lbs_motion)
dec_view_preds = self.decoder_view(
geom_rec=geom_rec,
tex_mean_rec=dec_preds["tex_mean_rec"],
camera_pos=campos,
)
# TODO: should we train an AO model?
if self.training and self.pose_to_shadow_enabled:
shadow_preds = self.shadow_net(ao_map=ao)
pose_shadow_preds = self.pose_to_shadow(lbs_motion)
shadow_preds['pose_shadow_map'] = pose_shadow_preds['shadow_map']
elif self.pose_to_shadow_enabled:
shadow_preds = self.pose_to_shadow(lbs_motion)
else:
shadow_preds = self.shadow_net(ao_map=ao)
tex_rec = self.forward_tex(
dec_preds["tex_mean_rec"],
dec_view_preds["tex_view_rec"],
shadow_preds["shadow_map"],
)
if not th.jit.is_scripting() and self.cal_enabled:
tex_rec = self.cal(tex_rec, self.cal.name_to_idx(_index['camera']))
preds = {
'geom': geom_rec,
'tex_rec': tex_rec,
**dec_preds,
**shadow_preds,
**dec_view_preds,
}
if not th.jit.is_scripting() and encode:
preds.update(**enc_preds)
if not th.jit.is_scripting() and self.rendering_enabled:
# NOTE: this is a reduced version tested for forward only
renders = self.renderer(
preds['geom'],
tex_rec,
K=K,
Rt=Rt,
)
preds.update(rgb=renders['render'])
if not th.jit.is_scripting() and self.learn_blur_enabled:
preds['rgb'] = self.learn_blur(preds['rgb'], _index['camera'])
preds['learn_blur_weights'] = self.learn_blur.reg(_index['camera'])
if not th.jit.is_scripting() and self.pixel_cal_enabled:
assert self.cal_enabled
cam_idxs = self.cal.name_to_idx(_index['camera'])
pixel_bias = self.pixel_cal(cam_idxs)
preds['rgb'] = preds['rgb'] + pixel_bias
return preds
class Encoder(nn.Module):
"""A joint encoder for tex and geometry."""
def __init__(
self,
geo_fn,
n_embs,
noise_std,
mask,
logvar_scale=0.1,
):
"""Fixed-width conv encoder."""
super().__init__()
self.noise_std = noise_std
self.n_embs = n_embs
self.geo_fn = geo_fn
self.logvar_scale = logvar_scale
self.verts_conv = ConvDownBlock(3, 8, 512)
mask = th.as_tensor(mask[np.newaxis, np.newaxis], dtype=th.float32)
mask = F.interpolate(mask, size=(512, 512), mode='bilinear').to(th.bool)
self.register_buffer("mask", mask)
self.joint_conv_blocks = nn.Sequential(
ConvDownBlock(8, 16, 256),
ConvDownBlock(16, 32, 128),
ConvDownBlock(32, 32, 64),
ConvDownBlock(32, 64, 32),
ConvDownBlock(64, 128, 16),
ConvDownBlock(128, 128, 8),
# ConvDownBlock(128, 128, 4),
)
# TODO: should we put initializer
self.mu = la.LinearWN(4 * 4 * 128, self.n_embs)
self.logvar = la.LinearWN(4 * 4 * 128, self.n_embs)
self.apply(weights_initializer(0.2))
self.mu.apply(weights_initializer(1.0))
self.logvar.apply(weights_initializer(1.0))
def forward(self, motion, verts_unposed):
preds = {}
B = motion.shape[0]
# converting motion to the unposed
verts_cond = (
F.interpolate(self.geo_fn.to_uv(verts_unposed), size=(512, 512), mode='bilinear')
* self.mask
)
verts_cond = self.verts_conv(verts_cond)
# tex_cond = F.interpolate(tex_avg, size=(512, 512), mode='bilinear') * self.mask
# tex_cond = self.tex_conv(tex_cond)
# joint_cond = th.cat([verts_cond, tex_cond], dim=1)
joint_cond = verts_cond
x = self.joint_conv_blocks(joint_cond)
x = x.reshape(B, -1)
embs_mu = self.mu(x)
embs_logvar = self.logvar_scale * self.logvar(x)
# NOTE: the noise is only applied to the input-conditioned values
if self.training:
noise = th.randn_like(embs_mu)
embs = embs_mu + th.exp(embs_logvar) * noise * self.noise_std
else:
embs = embs_mu.clone()
preds.update(
embs=embs,
embs_mu=embs_mu,
embs_logvar=embs_logvar,
)
return preds
class ConvDecoder(nn.Module):
"""Multi-region view-independent decoder."""
def __init__(
self,
geo_fn,
uv_size,
seam_sampler,
init_uv_size,
n_pose_dims,
n_pose_enc_channels,
n_embs,
n_embs_enc_channels,
n_face_embs,
n_init_channels,
n_min_channels,
assets,
):
super().__init__()
self.geo_fn = geo_fn
self.uv_size = uv_size
self.init_uv_size = init_uv_size
self.n_pose_dims = n_pose_dims
self.n_pose_enc_channels = n_pose_enc_channels
self.n_embs = n_embs
self.n_embs_enc_channels = n_embs_enc_channels
self.n_face_embs = n_face_embs
self.n_blocks = int(np.log2(self.uv_size // init_uv_size))
self.sizes = [init_uv_size * 2**s for s in range(self.n_blocks + 1)]
# TODO: just specify a sequence?
self.n_channels = [
max(n_init_channels // 2**b, n_min_channels) for b in range(self.n_blocks + 1)
]
logger.info(f"ConvDecoder: n_channels = {self.n_channels}")
self.local_pose_conv_block = ConvBlock(
n_pose_dims,
n_pose_enc_channels,
init_uv_size,
kernel_size=1,
padding=0,
)
self.embs_fc = nn.Sequential(
la.LinearWN(n_embs, 4 * 4 * 128),
nn.LeakyReLU(0.2, inplace=True),
)
# TODO: should we switch to the basic version?
self.embs_conv_block = nn.Sequential(
UpConvBlockDeep(128, 128, 8),
UpConvBlockDeep(128, 128, 16),
UpConvBlockDeep(128, 64, 32),
UpConvBlockDeep(64, n_embs_enc_channels, 64),
)
self.face_embs_fc = nn.Sequential(
la.LinearWN(n_face_embs, 4 * 4 * 32),
nn.LeakyReLU(0.2, inplace=True),
)
self.face_embs_conv_block = nn.Sequential(
UpConvBlockDeep(32, 64, 8),
UpConvBlockDeep(64, 64, 16),
UpConvBlockDeep(64, n_embs_enc_channels, 32),
)
n_groups = 2
self.joint_conv_block = ConvBlock(
n_pose_enc_channels + n_embs_enc_channels,
n_init_channels,
self.init_uv_size,
)
self.conv_blocks = nn.ModuleList([])
for b in range(self.n_blocks):
self.conv_blocks.append(
UpConvBlockDeep(
self.n_channels[b] * n_groups,
self.n_channels[b + 1] * n_groups,
self.sizes[b + 1],
groups=n_groups,
),
)
self.verts_conv = la.Conv2dWNUB(
in_channels=self.n_channels[-1],
out_channels=3,
kernel_size=3,
height=self.uv_size,
width=self.uv_size,
padding=1,
)
self.tex_conv = la.Conv2dWNUB(
in_channels=self.n_channels[-1],
out_channels=3,
kernel_size=3,
height=self.uv_size,
width=self.uv_size,
padding=1,
)
self.apply(weights_initializer(0.2))
self.verts_conv.apply(weights_initializer(1.0))
self.tex_conv.apply(weights_initializer(1.0))
self.seam_sampler = seam_sampler
# NOTE: removing head region from pose completely
pose_cond_mask = th.as_tensor(
assets.pose_cond_mask[np.newaxis] * (1 - assets.head_cond_mask[np.newaxis, np.newaxis]),
dtype=th.int32,
)
self.register_buffer("pose_cond_mask", pose_cond_mask)
face_cond_mask = th.as_tensor(assets.face_cond_mask, dtype=th.float32)[
np.newaxis, np.newaxis
]
self.register_buffer("face_cond_mask", face_cond_mask)
body_cond_mask = th.as_tensor(assets.body_cond_mask, dtype=th.float32)[
np.newaxis, np.newaxis
]
self.register_buffer("body_cond_mask", body_cond_mask)
def forward(self, motion, embs, face_embs, embs_conv: Optional[th.Tensor] = None):
# processing pose
pose = motion[:, 6:]
B = pose.shape[0]
non_head_mask = (self.body_cond_mask * (1.0 - self.face_cond_mask)).clip(0.0, 1.0)
pose_masked = tile2d(pose, self.init_uv_size) * self.pose_cond_mask
pose_conv = self.local_pose_conv_block(pose_masked) * non_head_mask
# TODO: decoding properly?
if embs_conv is None:
embs_conv = self.embs_conv_block(self.embs_fc(embs).reshape(B, 128, 4, 4))
face_conv = self.face_embs_conv_block(self.face_embs_fc(face_embs).reshape(B, 32, 4, 4))
# merging embeddings with spatial masks
embs_conv[:, :, 32:, :32] = (
face_conv * self.face_cond_mask[:, :, 32:, :32]
+ embs_conv[:, :, 32:, :32] * non_head_mask[:, :, 32:, :32]
)
joint = th.cat([pose_conv, embs_conv], axis=1)
joint = self.joint_conv_block(joint)
x = th.cat([joint, joint], axis=1)
for b in range(self.n_blocks):
x = self.conv_blocks[b](x)
# NOTE: here we do resampling at feature level
x = self.seam_sampler.impaint(x)
x = self.seam_sampler.resample(x)
x = self.seam_sampler.resample(x)
verts_features, tex_features = th.split(x, self.n_channels[-1], 1)
verts_uv_delta_rec = self.verts_conv(verts_features)
# TODO: need to get values
verts_delta_rec = self.geo_fn.from_uv(verts_uv_delta_rec)
tex_mean_rec = self.tex_conv(tex_features)
preds = {
'geom_delta_rec': verts_delta_rec,
'geom_uv_delta_rec': verts_uv_delta_rec,
'tex_mean_rec': tex_mean_rec,
'embs_conv': embs_conv,
'pose_conv': pose_conv,
}
return preds
class FaceEncoder(nn.Module):
"""A joint encoder for tex and geometry."""
def __init__(
self,
noise_std,
assets,
n_embs=256,
uv_size=512,
logvar_scale=0.1,
n_vert_in=7306 * 3,
prefix="face_",
):
"""Fixed-width conv encoder."""
super().__init__()
# TODO:
self.noise_std = noise_std
self.n_embs = n_embs
self.logvar_scale = logvar_scale
self.prefix = prefix
self.uv_size = uv_size
assert self.uv_size == 512
tex_cond_mask = assets.mugsy_face_mask[..., 0]
tex_cond_mask = th.as_tensor(tex_cond_mask, dtype=th.float32)[np.newaxis, np.newaxis]
tex_cond_mask = F.interpolate(
tex_cond_mask, (self.uv_size, self.uv_size), mode="bilinear", align_corners=True
)
self.register_buffer("tex_cond_mask", tex_cond_mask)
self.conv_blocks = nn.Sequential(
ConvDownBlock(3, 4, 512),
ConvDownBlock(4, 8, 256),
ConvDownBlock(8, 16, 128),
ConvDownBlock(16, 32, 64),
ConvDownBlock(32, 64, 32),
ConvDownBlock(64, 128, 16),
ConvDownBlock(128, 128, 8),
)
self.geommod = nn.Sequential(la.LinearWN(n_vert_in, 256), nn.LeakyReLU(0.2, inplace=True))
self.jointmod = nn.Sequential(
la.LinearWN(256 + 128 * 4 * 4, 512), nn.LeakyReLU(0.2, inplace=True)
)
# TODO: should we put initializer
self.mu = la.LinearWN(512, self.n_embs)
self.logvar = la.LinearWN(512, self.n_embs)
self.apply(weights_initializer(0.2))
self.mu.apply(weights_initializer(1.0))
self.logvar.apply(weights_initializer(1.0))
# TODO: compute_losses()?
def forward(self, face_geom: th.Tensor, face_tex: th.Tensor, **kwargs):
B = face_geom.shape[0]
tex_cond = F.interpolate(
face_tex, (self.uv_size, self.uv_size), mode="bilinear", align_corners=False
)
tex_cond = (tex_cond / 255.0 - 0.5) * self.tex_cond_mask
x = self.conv_blocks(tex_cond)
tex_enc = x.reshape(B, 4 * 4 * 128)
geom_enc = self.geommod(face_geom.reshape(B, -1))
x = self.jointmod(th.cat([tex_enc, geom_enc], dim=1))
embs_mu = self.mu(x)
embs_logvar = self.logvar_scale * self.logvar(x)
# NOTE: the noise is only applied to the input-conditioned values
if self.training:
noise = th.randn_like(embs_mu)
embs = embs_mu + th.exp(embs_logvar) * noise * self.noise_std
else:
embs = embs_mu.clone()
preds = {"embs": embs, "embs_mu": embs_mu, "embs_logvar": embs_logvar, "tex_cond": tex_cond}
preds = {f"{self.prefix}{k}": v for k, v in preds.items()}
return preds
class UNetViewDecoder(nn.Module):
def __init__(self, geo_fn, net_uv_size, seam_sampler, n_init_ftrs=8):
super().__init__()
self.geo_fn = geo_fn
self.net_uv_size = net_uv_size
self.unet = UNetWB(4, 3, n_init_ftrs=n_init_ftrs, size=net_uv_size)
self.register_buffer("faces", self.geo_fn.vi.to(th.int64), persistent=False)
def forward(self, geom_rec, tex_mean_rec, camera_pos):
with th.no_grad():
view_cos = compute_view_cos(geom_rec, self.faces, camera_pos)
view_cos_uv = self.geo_fn.to_uv(view_cos[..., np.newaxis])
cond_view = th.cat([view_cos_uv, tex_mean_rec], dim=1)
tex_view = self.unet(cond_view)
# TODO: should we try warping here?
return {"tex_view_rec": tex_view, "cond_view": cond_view}
class UpscaleNet(nn.Module):
def __init__(self, in_channels, out_channels, n_ftrs, size=1024, upscale_factor=2):
super().__init__()
self.conv_block = nn.Sequential(
la.Conv2dWNUB(in_channels, n_ftrs, size, size, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
self.out_block = la.Conv2dWNUB(
n_ftrs,
out_channels * upscale_factor**2,
size,
size,
kernel_size=1,
padding=0,
)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=upscale_factor)
self.apply(weights_initializer(0.2))
self.out_block.apply(weights_initializer(1.0))
def forward(self, x):
x = self.conv_block(x)
x = self.out_block(x)
return self.pixel_shuffle(x)