File size: 2,376 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import os

import boto3
import torch
from torch import nn, distributed as dist
from torch.nn import functional as F
from torch.distributed import barrier

from imaginaire.utils.distributed import is_local_master
from .clip import build_model
from ...utils.io import download_file_from_google_drive


def get_image_encoder(aws_credentials=None):
    if dist.is_initialized() and not is_local_master():
        # Make sure only the first process in distributed training downloads the model, and the others use the cache.
        barrier()

    # Load the CLIP image encoder.
    print("Loading CLIP image encoder.")
    model_path = os.path.join(torch.hub.get_dir(), 'checkpoints', 'ViT-B-32.pt')
    if not os.path.exists(model_path):
        if aws_credentials is not None:
            s3 = boto3.client('s3', **aws_credentials)
            s3.download_file('lpi-poe', 'model_zoo/ViT-B-32.pt', model_path)
        else:
            download_file_from_google_drive("1Ri5APYM34A_IjG4F3Admutsf2oUwDjfW", model_path)
    model = torch.load(model_path, map_location='cpu')

    if dist.is_initialized() and is_local_master():
        # Make sure only the first process in distributed training downloads the model, and the others use the cache.
        barrier()

    encoder = build_model(model).cuda()
    return ImageEncoder(encoder)


class ImageEncoder(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.model = encoder
        self.image_size = self.model.visual.input_resolution
        self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda")
        self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda")

    @torch.no_grad()
    def forward(self, data, fake_images, align_corners=True):
        images = 0.5 * (1 + fake_images)
        images = F.interpolate(images, (self.image_size, self.image_size), mode='bicubic', align_corners=align_corners)
        images.clamp_(0, 1)
        images = (images - self.mean[None, :, None, None]) / (self.std[None, :, None, None])
        image_code = self.model.encode_image(images)
        return torch.cat((image_code, data['captions-clip']), dim=1)