File size: 3,639 Bytes
fac61f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95


########################################
# python -m train
###########################################


import torch
import logging
import numpy as np
from tqdm import tqdm
from PIL import Image
    
from data import get_dataloader
from model import get_model_and_tokenizer, get_optimizer
import config

logging.basicConfig(level=logging.INFO)

def get_loss(model, input, target, tokenizer):
    with torch.no_grad():
        assert len(input.shape) == 5 # [batch, s, c, w, h]
        cuts = config.number_k_clip_embed
        assert input.shape[0] * input.shape[1] % cuts == 0, 'batch size * `k` preferred embeds must be divisible by cuts'
        input = input.view(cuts//8, -1, 3, target.shape[-2], target.shape[-1])
        full_seq = []
        for b in input:
            input = tokenizer(b)['image_embeds'] # in our case, tokenizer is a clip embedding model
            full_seq.append(input)
        input = torch.stack(full_seq)

        target = tokenizer(target)['image_embeds']

        input = input.view(target.shape[0], -1, target.shape[-1])
        assert len(input.shape) == 3 # [batch, sequence, inner]
    
    with torch.cuda.amp.autocast(enabled=False, ):
        input = input.to(torch.float32)
        latent = torch.randn(input.shape[0], input.shape[-1], device=input.device)
        output = model(latent, input).predicted_image_embedding

    target = target.to(torch.float32)
    mse_loss = torch.nn.functional.mse_loss(target, output).mean()
    
    assert len(target.shape) == 2 and len(output.shape) == 2
    cosine_loss = 1 - torch.nn.functional.cosine_similarity(output, target).mean()
    loss =  mse_loss + .2 * cosine_loss

    logging.info(f'MSE: {mse_loss.item()}, Cosine: {cosine_loss.item()}, Weighted Total: {loss.item()}')
    # TODO wandb

    return loss

def main():
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)

    model, tokenizer = get_model_and_tokenizer(config.model_path, config.device, config.dtype)
    optimizer = get_optimizer(list(model.prior.parameters()), config.lr)
    dataloader = get_dataloader(config.data_path, config.batch_size, config.num_workers, 
                                model.prior_pipe.image_processor)

    for epoch in range(config.epochs):
        for ind, batch in tqdm(enumerate(iter(dataloader))):
            if batch is None:
                continue

            input, target = batch
            input = input.to(config.device)
            target = target.to(config.device)

            if ind % 50 == 0:
                with torch.cuda.amp.autocast(enabled=True, dtype=config.dtype): # NOTE using autocast because our training model is also our val model, so don't want to set to full half precision.
                    examples = ['../generative_recommender/Blue_Tigers_space/1o.png',
 '../generative_recommender/Blue_Tigers_space/2o.png',
 '../generative_recommender/Blue_Tigers_space/3o.png',
 '../generative_recommender/Blue_Tigers_space/4o.png',
 '../generative_recommender/Blue_Tigers_space/5o.png',
 '../generative_recommender/Blue_Tigers_space/6o.png',
 '../generative_recommender/Blue_Tigers_space/7o.png',
 '../generative_recommender/Blue_Tigers_space/8o.png',]
                    model.do_validation([[Image.open('../'+j) for j in examples]])

            loss = get_loss(model, input, target, tokenizer)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if ind % 100 == 0:
                # TODO add loading from path
                model.prior.save_pretrained(f'{config.save_path}/last_epoch_ckpt', from_pt=True) 

if __name__ == '__main__':
    main()