File size: 2,213 Bytes
fac61f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

import torch
import logging
from diffusers import DiffusionPipeline

from prior.pipeline_kandinsky_prior import KandinskyPriorPipeline
from prior.prior_transformer import PriorTransformer


class Zoo(torch.nn.Module):
    def __init__(self, prior, prior_pipe, kandinsky_pipe, ) -> None:
        super().__init__()
        self.prior = prior
        self.prior_pipe = prior_pipe
        self.kandinsky_pipe = kandinsky_pipe
        self.pre_prior_transformer = None 
        # NOTE we may get better perf from freezing our prior 
        #     and only training a transformer adapter?

    def forward(self, latents, preferred_embeds):
        pred = self.prior(latents, preferred_embeds)
        return pred
    
    def do_validation(self, images): # TODO constant val seed
        assert all([len(i) == 8 for i in images]), f'We have must have `k` images, not {len(images)}.'
        image_embeds, negative_image_embeds = self.prior_pipe(images).to_tuple()
        images = self.kandinsky_pipe(
            num_inference_steps=50,
            image_embeds=image_embeds,
            negative_image_embeds=negative_image_embeds,
        ).images
        images[0].save('latest_val.png')
        return images

def get_model_and_tokenizer(path, device, dtype):
    prior = PriorTransformer.from_pretrained("ECLIPSE-Community/ECLIPSE_KandinskyV22_Prior" 
                                             if path is None else path).to(device)
        
    pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior", prior=prior).to(device)
    pipe_prior.image_encoder = pipe_prior.image_encoder.to(device, dtype)
    # Note: don't set the prior to `dtype`` as it may be half precision, 
    #     and we're training with mixed precision
    #     so we need to keep our full-precision weight for trained params
    kandinsky_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder").to(device, dtype)
    model = Zoo(prior, pipe_prior, kandinsky_pipe).to(device)

    return model, model.prior_pipe.image_encoder

def get_optimizer(params, lr):
    logging.info(f'Training: {params}')
    optimizer = torch.optim.AdamW(params, lr=lr)
    return optimizer