Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import torch | |
from torch import nn | |
from imaginaire.generators.funit import (MLP, ContentEncoder, Decoder, | |
StyleEncoder) | |
class Generator(nn.Module): | |
r"""COCO-FUNIT Generator. | |
""" | |
def __init__(self, gen_cfg, data_cfg): | |
r"""COCO-FUNIT Generator constructor. | |
Args: | |
gen_cfg (obj): Generator definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file. | |
""" | |
super().__init__() | |
self.generator = COCOFUNITTranslator(**vars(gen_cfg)) | |
def forward(self, data): | |
r"""In the FUNIT's forward pass, it generates a content embedding and | |
a style code from the content image, and a style code from the style | |
image. By mixing the content code and the style code from the content | |
image, we reconstruct the input image. By mixing the content code and | |
the style code from the style image, we have a translation output. | |
Args: | |
data (dict): Training data at the current iteration. | |
""" | |
content_a = self.generator.content_encoder(data['images_content']) | |
style_a = self.generator.style_encoder(data['images_content']) | |
style_b = self.generator.style_encoder(data['images_style']) | |
images_trans = self.generator.decode(content_a, style_b) | |
images_recon = self.generator.decode(content_a, style_a) | |
net_G_output = dict(images_trans=images_trans, | |
images_recon=images_recon) | |
return net_G_output | |
def inference(self, data, keep_original_size=True): | |
r"""COCO-FUNIT inference. | |
Args: | |
data (dict): Training data at the current iteration. | |
- images_content (tensor): Content images. | |
- images_style (tensor): Style images. | |
a2b (bool): If ``True``, translates images from domain A to B, | |
otherwise from B to A. | |
keep_original_size (bool): If ``True``, output image is resized | |
to the input content image size. | |
""" | |
content_a = self.generator.content_encoder(data['images_content']) | |
style_b = self.generator.style_encoder(data['images_style']) | |
output_images = self.generator.decode(content_a, style_b) | |
if keep_original_size: | |
height = data['original_h_w'][0][0] | |
width = data['original_h_w'][0][1] | |
# print('( H, W) = ( %d, %d)' % (height, width)) | |
output_images = torch.nn.functional.interpolate( | |
output_images, size=[height, width]) | |
file_names = data['key']['images_content'][0] | |
return output_images, file_names | |
class COCOFUNITTranslator(nn.Module): | |
r"""COCO-FUNIT Generator architecture. | |
Args: | |
num_filters (int): Base filter numbers. | |
num_filters_mlp (int): Base filter number in the MLP module. | |
style_dims (int): Dimension of the style code. | |
usb_dims (int): Dimension of the universal style bias code. | |
num_res_blocks (int): Number of residual blocks at the end of the | |
content encoder. | |
num_mlp_blocks (int): Number of layers in the MLP module. | |
num_downsamples_content (int): Number of times we reduce | |
resolution by 2x2 for the content image. | |
num_downsamples_style (int): Number of times we reduce | |
resolution by 2x2 for the style image. | |
num_image_channels (int): Number of input image channels. | |
weight_norm_type (str): Type of weight normalization. | |
``'none'``, ``'spectral'``, or ``'weight'``. | |
""" | |
def __init__(self, | |
num_filters=64, | |
num_filters_mlp=256, | |
style_dims=64, | |
usb_dims=1024, | |
num_res_blocks=2, | |
num_mlp_blocks=3, | |
num_downsamples_style=4, | |
num_downsamples_content=2, | |
num_image_channels=3, | |
weight_norm_type='', | |
**kwargs): | |
super().__init__() | |
self.style_encoder = StyleEncoder(num_downsamples_style, | |
num_image_channels, | |
num_filters, | |
style_dims, | |
'reflect', | |
'none', | |
weight_norm_type, | |
'relu') | |
self.content_encoder = ContentEncoder(num_downsamples_content, | |
num_res_blocks, | |
num_image_channels, | |
num_filters, | |
'reflect', | |
'instance', | |
weight_norm_type, | |
'relu') | |
self.decoder = Decoder(self.content_encoder.output_dim, | |
num_filters_mlp, | |
num_image_channels, | |
num_downsamples_content, | |
'reflect', | |
weight_norm_type, | |
'relu') | |
self.usb = torch.nn.Parameter(torch.randn(1, usb_dims)) | |
self.mlp = MLP(style_dims, | |
num_filters_mlp, | |
num_filters_mlp, | |
num_mlp_blocks, | |
'none', | |
'relu') | |
num_content_mlp_blocks = 2 | |
num_style_mlp_blocks = 2 | |
self.mlp_content = MLP(self.content_encoder.output_dim, | |
style_dims, | |
num_filters_mlp, | |
num_content_mlp_blocks, | |
'none', | |
'relu') | |
self.mlp_style = MLP(style_dims + usb_dims, | |
style_dims, | |
num_filters_mlp, | |
num_style_mlp_blocks, | |
'none', | |
'relu') | |
def forward(self, images): | |
r"""Reconstruct the input image by combining the computer content and | |
style code. | |
Args: | |
images (tensor): Input image tensor. | |
""" | |
# reconstruct an image | |
content, style = self.encode(images) | |
images_recon = self.decode(content, style) | |
return images_recon | |
def encode(self, images): | |
r"""Encoder images to get their content and style codes. | |
Args: | |
images (tensor): Input image tensor. | |
""" | |
style = self.style_encoder(images) | |
content = self.content_encoder(images) | |
return content, style | |
def decode(self, content, style): | |
r"""Generate images by combining their content and style codes. | |
Args: | |
content (tensor): Content code tensor. | |
style (tensor): Style code tensor. | |
""" | |
content_style_code = content.mean(3).mean(2) | |
content_style_code = self.mlp_content(content_style_code) | |
batch_size = style.size(0) | |
usb = self.usb.repeat(batch_size, 1) | |
style = style.view(batch_size, -1) | |
style_in = self.mlp_style(torch.cat([style, usb], 1)) | |
coco_style = style_in * content_style_code | |
coco_style = self.mlp(coco_style) | |
images = self.decoder(content, coco_style) | |
return images | |