Spaces:
Runtime error
Runtime error
File size: 2,376 Bytes
f670afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import os
import boto3
import torch
from torch import nn, distributed as dist
from torch.nn import functional as F
from torch.distributed import barrier
from imaginaire.utils.distributed import is_local_master
from .clip import build_model
from ...utils.io import download_file_from_google_drive
def get_image_encoder(aws_credentials=None):
if dist.is_initialized() and not is_local_master():
# Make sure only the first process in distributed training downloads the model, and the others use the cache.
barrier()
# Load the CLIP image encoder.
print("Loading CLIP image encoder.")
model_path = os.path.join(torch.hub.get_dir(), 'checkpoints', 'ViT-B-32.pt')
if not os.path.exists(model_path):
if aws_credentials is not None:
s3 = boto3.client('s3', **aws_credentials)
s3.download_file('lpi-poe', 'model_zoo/ViT-B-32.pt', model_path)
else:
download_file_from_google_drive("1Ri5APYM34A_IjG4F3Admutsf2oUwDjfW", model_path)
model = torch.load(model_path, map_location='cpu')
if dist.is_initialized() and is_local_master():
# Make sure only the first process in distributed training downloads the model, and the others use the cache.
barrier()
encoder = build_model(model).cuda()
return ImageEncoder(encoder)
class ImageEncoder(nn.Module):
def __init__(self, encoder):
super().__init__()
self.model = encoder
self.image_size = self.model.visual.input_resolution
self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda")
self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda")
@torch.no_grad()
def forward(self, data, fake_images, align_corners=True):
images = 0.5 * (1 + fake_images)
images = F.interpolate(images, (self.image_size, self.image_size), mode='bicubic', align_corners=align_corners)
images.clamp_(0, 1)
images = (images - self.mean[None, :, None, None]) / (self.std[None, :, None, None])
image_code = self.model.encode_image(images)
return torch.cat((image_code, data['captions-clip']), dim=1)
|