Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import os | |
import boto3 | |
import torch | |
from torch import nn, distributed as dist | |
from torch.nn import functional as F | |
from torch.distributed import barrier | |
from imaginaire.utils.distributed import is_local_master | |
from .clip import build_model | |
from ...utils.io import download_file_from_google_drive | |
def get_image_encoder(aws_credentials=None): | |
if dist.is_initialized() and not is_local_master(): | |
# Make sure only the first process in distributed training downloads the model, and the others use the cache. | |
barrier() | |
# Load the CLIP image encoder. | |
print("Loading CLIP image encoder.") | |
model_path = os.path.join(torch.hub.get_dir(), 'checkpoints', 'ViT-B-32.pt') | |
if not os.path.exists(model_path): | |
if aws_credentials is not None: | |
s3 = boto3.client('s3', **aws_credentials) | |
s3.download_file('lpi-poe', 'model_zoo/ViT-B-32.pt', model_path) | |
else: | |
download_file_from_google_drive("1Ri5APYM34A_IjG4F3Admutsf2oUwDjfW", model_path) | |
model = torch.load(model_path, map_location='cpu') | |
if dist.is_initialized() and is_local_master(): | |
# Make sure only the first process in distributed training downloads the model, and the others use the cache. | |
barrier() | |
encoder = build_model(model).cuda() | |
return ImageEncoder(encoder) | |
class ImageEncoder(nn.Module): | |
def __init__(self, encoder): | |
super().__init__() | |
self.model = encoder | |
self.image_size = self.model.visual.input_resolution | |
self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda") | |
self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda") | |
def forward(self, data, fake_images, align_corners=True): | |
images = 0.5 * (1 + fake_images) | |
images = F.interpolate(images, (self.image_size, self.image_size), mode='bicubic', align_corners=align_corners) | |
images.clamp_(0, 1) | |
images = (images - self.mean[None, :, None, None]) / (self.std[None, :, None, None]) | |
image_code = self.model.encode_image(images) | |
return torch.cat((image_code, data['captions-clip']), dim=1) | |