working
Browse files- app.py +79 -3
- configs/t2i_256px_clip_dimr.py +118 -0
- configs/t2i_256px_t5_dimr.py +118 -0
- configs/t2i_512px_clip_dimr.py +131 -0
- configs/t2i_512px_t5_dimr.py +131 -0
- configs/t2i_512px_t5_dit.py +92 -0
- configs/t2i_training_demo.py +132 -0
- datasets.py +305 -0
- demo_t2i.py +194 -0
- demo_t2i_arith.py +290 -0
- diffusion/base_solver.py +203 -0
- diffusion/flow_matching.py +702 -0
- libs/__init__.py +1 -0
- libs/autoencoder.py +519 -0
- libs/clip.py +68 -0
- libs/model/axial_rope.py +109 -0
- libs/model/common_layers.py +104 -0
- libs/model/dimr_t2i.py +443 -0
- libs/model/dit_t2i.py +405 -0
- libs/model/flags.py +56 -0
- libs/model/sigmoid/kernel.py +316 -0
- libs/model/sigmoid/module.py +274 -0
- libs/model/trans_autoencoder.py +289 -0
- libs/t5.py +237 -0
- libs/timm.py +114 -0
- requirements.txt +19 -4
- scripts/extract_empty_feature.py +56 -0
- scripts/extract_mscoco_feature.py +83 -0
- scripts/extract_test_prompt_feature.py +72 -0
- scripts/extract_train_feature.py +159 -0
- sde.py +326 -0
- tools/clip_score.py +90 -0
- tools/fid_score.py +268 -0
- tools/inception.py +328 -0
- train_t2i.py +328 -0
- utils.py +274 -0
app.py
CHANGED
@@ -1,8 +1,83 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import numpy as np
|
|
|
|
|
3 |
import random
|
4 |
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from diffusers import DiffusionPipeline
|
7 |
import torch
|
8 |
|
@@ -21,7 +96,7 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
21 |
MAX_IMAGE_SIZE = 1024
|
22 |
|
23 |
|
24 |
-
|
25 |
def infer(
|
26 |
prompt1,
|
27 |
prompt2,
|
@@ -69,7 +144,8 @@ css = """
|
|
69 |
|
70 |
with gr.Blocks(css=css) as demo:
|
71 |
with gr.Column(elem_id="col-container"):
|
72 |
-
gr.Markdown(" #
|
|
|
73 |
|
74 |
with gr.Row():
|
75 |
prompt1 = gr.Text(
|
|
|
1 |
import gradio as gr
|
2 |
+
|
3 |
+
from absl import flags
|
4 |
+
from absl import app
|
5 |
+
from ml_collections import config_flags
|
6 |
+
import os
|
7 |
+
|
8 |
+
import ml_collections
|
9 |
+
import torch
|
10 |
+
from torch import multiprocessing as mp
|
11 |
+
import torch.nn as nn
|
12 |
+
import accelerate
|
13 |
+
import utils
|
14 |
+
import tempfile
|
15 |
+
from absl import logging
|
16 |
+
import builtins
|
17 |
+
import einops
|
18 |
+
import math
|
19 |
import numpy as np
|
20 |
+
import time
|
21 |
+
from PIL import Image
|
22 |
import random
|
23 |
|
24 |
+
from diffusion.flow_matching import FlowMatching, ODEFlowMatchingSolver, ODEEulerFlowMatchingSolver
|
25 |
+
from tools.clip_score import ClipSocre
|
26 |
+
import libs.autoencoder
|
27 |
+
from libs.clip import FrozenCLIPEmbedder
|
28 |
+
from libs.t5 import T5Embedder
|
29 |
+
|
30 |
+
|
31 |
+
def unpreprocess(x):
|
32 |
+
x = 0.5 * (x + 1.)
|
33 |
+
x.clamp_(0., 1.)
|
34 |
+
return x
|
35 |
+
|
36 |
+
def batch_decode(_z, decode, batch_size=10):
|
37 |
+
"""
|
38 |
+
The VAE decoder requires large GPU memory. To run the interpolation model on GPUs with 24 GB or smaller RAM, you can use this code to reduce memory usage for the VAE.
|
39 |
+
It works by splitting the input tensor into smaller chunks.
|
40 |
+
"""
|
41 |
+
num_samples = _z.size(0)
|
42 |
+
decoded_batches = []
|
43 |
+
|
44 |
+
for i in range(0, num_samples, batch_size):
|
45 |
+
batch = _z[i:i + batch_size]
|
46 |
+
decoded_batch = decode(batch)
|
47 |
+
decoded_batches.append(decoded_batch)
|
48 |
+
|
49 |
+
image_unprocessed = torch.cat(decoded_batches, dim=0)
|
50 |
+
return image_unprocessed
|
51 |
+
|
52 |
+
def get_caption(llm, text_model, prompt_dict, batch_size):
|
53 |
+
|
54 |
+
if batch_size == 3:
|
55 |
+
# only addition or only subtraction
|
56 |
+
assert len(prompt_dict) == 2
|
57 |
+
_batch_con = list(prompt_dict.values()) + [' ']
|
58 |
+
elif batch_size == 4:
|
59 |
+
# addition and subtraction
|
60 |
+
assert len(prompt_dict) == 3
|
61 |
+
_batch_con = list(prompt_dict.values()) + [' ']
|
62 |
+
elif batch_size >= 5:
|
63 |
+
# linear interpolation
|
64 |
+
assert len(prompt_dict) == 2
|
65 |
+
_batch_con = [prompt_dict['prompt_1']] + [' '] * (batch_size-2) + [prompt_dict['prompt_2']]
|
66 |
+
|
67 |
+
if llm == "clip":
|
68 |
+
_latent, _latent_and_others = text_model.encode(_batch_con)
|
69 |
+
_con = _latent_and_others['token_embedding'].detach()
|
70 |
+
elif llm == "t5":
|
71 |
+
_latent, _latent_and_others = text_model.get_text_embeddings(_batch_con)
|
72 |
+
_con = (_latent_and_others['token_embedding'] * 10.0).detach()
|
73 |
+
else:
|
74 |
+
raise NotImplementedError
|
75 |
+
_con_mask = _latent_and_others['token_mask'].detach()
|
76 |
+
_batch_token = _latent_and_others['tokens'].detach()
|
77 |
+
_batch_caption = _batch_con
|
78 |
+
return (_con, _con_mask, _batch_token, _batch_caption)
|
79 |
+
|
80 |
+
import spaces #[uncomment to use ZeroGPU]
|
81 |
from diffusers import DiffusionPipeline
|
82 |
import torch
|
83 |
|
|
|
96 |
MAX_IMAGE_SIZE = 1024
|
97 |
|
98 |
|
99 |
+
@spaces.GPU #[uncomment to use ZeroGPU]
|
100 |
def infer(
|
101 |
prompt1,
|
102 |
prompt2,
|
|
|
144 |
|
145 |
with gr.Blocks(css=css) as demo:
|
146 |
with gr.Column(elem_id="col-container"):
|
147 |
+
gr.Markdown(" # CrossFlow")
|
148 |
+
gr.Markdown(" CrossFlow directly transforms text representations into images for text-to-image generation, enabling interpolation in the input text latent space.")
|
149 |
|
150 |
with gr.Row():
|
151 |
prompt1 = gr.Text(
|
configs/t2i_256px_clip_dimr.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ml_collections
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class Args:
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
for key, value in kwargs.items():
|
8 |
+
setattr(self, key, value)
|
9 |
+
|
10 |
+
model = Args(
|
11 |
+
channels = 4,
|
12 |
+
block_grad_to_lowres = False,
|
13 |
+
norm_type = "TDRMSN",
|
14 |
+
use_t2i = True,
|
15 |
+
clip_dim=768,
|
16 |
+
num_clip_token=77,
|
17 |
+
gradient_checking=True,
|
18 |
+
cfg_indicator=0.1,
|
19 |
+
textVAE = Args(
|
20 |
+
num_blocks = 11,
|
21 |
+
hidden_dim = 1024,
|
22 |
+
hidden_token_length = 256,
|
23 |
+
num_attention_heads = 8,
|
24 |
+
dropout_prob = 0.1,
|
25 |
+
),
|
26 |
+
stage_configs = [
|
27 |
+
Args(
|
28 |
+
block_type = "TransformerBlock",
|
29 |
+
dim = 1024, # channel
|
30 |
+
hidden_dim = 2048,
|
31 |
+
num_attention_heads = 16,
|
32 |
+
num_blocks = 65, # depth
|
33 |
+
max_height = 16,
|
34 |
+
max_width = 16,
|
35 |
+
image_input_ratio = 1,
|
36 |
+
input_feature_ratio = 2,
|
37 |
+
final_kernel_size = 3,
|
38 |
+
dropout_prob = 0,
|
39 |
+
),
|
40 |
+
Args(
|
41 |
+
block_type = "ConvNeXtBlock",
|
42 |
+
dim = 512,
|
43 |
+
hidden_dim = 1024,
|
44 |
+
kernel_size = 7,
|
45 |
+
num_blocks = 33,
|
46 |
+
max_height = 32,
|
47 |
+
max_width = 32,
|
48 |
+
image_input_ratio = 1,
|
49 |
+
input_feature_ratio = 1,
|
50 |
+
final_kernel_size = 3,
|
51 |
+
dropout_prob = 0,
|
52 |
+
),
|
53 |
+
],
|
54 |
+
)
|
55 |
+
|
56 |
+
def d(**kwargs):
|
57 |
+
"""Helper of creating a config dict."""
|
58 |
+
return ml_collections.ConfigDict(initial_dictionary=kwargs)
|
59 |
+
|
60 |
+
|
61 |
+
def get_config():
|
62 |
+
config = ml_collections.ConfigDict()
|
63 |
+
|
64 |
+
config.seed = 1234
|
65 |
+
config.z_shape = (4, 32, 32)
|
66 |
+
|
67 |
+
config.autoencoder = d(
|
68 |
+
pretrained_path='assets/stable-diffusion/autoencoder_kl.pth',
|
69 |
+
scale_factor=0.23010
|
70 |
+
)
|
71 |
+
|
72 |
+
config.train = d(
|
73 |
+
n_steps=1000000,
|
74 |
+
batch_size=1024,
|
75 |
+
mode='cond',
|
76 |
+
log_interval=10,
|
77 |
+
eval_interval=5000,
|
78 |
+
save_interval=50000,
|
79 |
+
)
|
80 |
+
|
81 |
+
config.optimizer = d(
|
82 |
+
name='adamw',
|
83 |
+
lr=0.00001,
|
84 |
+
weight_decay=0.03,
|
85 |
+
betas=(0.9, 0.9),
|
86 |
+
)
|
87 |
+
|
88 |
+
config.lr_scheduler = d(
|
89 |
+
name='customized',
|
90 |
+
warmup_steps=5000
|
91 |
+
)
|
92 |
+
|
93 |
+
global model
|
94 |
+
config.nnet = d(
|
95 |
+
name='dimr',
|
96 |
+
model_args=model,
|
97 |
+
)
|
98 |
+
config.loss_coeffs = [1/4, 1]
|
99 |
+
|
100 |
+
config.dataset = d(
|
101 |
+
name='JDB_demo_features',
|
102 |
+
resolution=256,
|
103 |
+
llm='clip',
|
104 |
+
train_path='/data/qihao/dataset/JDB_demo_feature/',
|
105 |
+
val_path='/data/qihao/dataset/coco_val_features/',
|
106 |
+
cfg=False
|
107 |
+
)
|
108 |
+
|
109 |
+
config.sample = d(
|
110 |
+
sample_steps=50,
|
111 |
+
n_samples=30000,
|
112 |
+
mini_batch_size=20,
|
113 |
+
cfg=False,
|
114 |
+
scale=7,
|
115 |
+
path=''
|
116 |
+
)
|
117 |
+
|
118 |
+
return config
|
configs/t2i_256px_t5_dimr.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ml_collections
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class Args:
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
for key, value in kwargs.items():
|
8 |
+
setattr(self, key, value)
|
9 |
+
|
10 |
+
model = Args(
|
11 |
+
channels = 4,
|
12 |
+
block_grad_to_lowres = False,
|
13 |
+
norm_type = "TDRMSN",
|
14 |
+
use_t2i = True,
|
15 |
+
clip_dim=4096,
|
16 |
+
num_clip_token=77,
|
17 |
+
gradient_checking=True,
|
18 |
+
cfg_indicator=0.1,
|
19 |
+
textVAE = Args(
|
20 |
+
num_blocks = 11,
|
21 |
+
hidden_dim = 1024,
|
22 |
+
hidden_token_length = 256,
|
23 |
+
num_attention_heads = 8,
|
24 |
+
dropout_prob = 0.1,
|
25 |
+
),
|
26 |
+
stage_configs = [
|
27 |
+
Args(
|
28 |
+
block_type = "TransformerBlock",
|
29 |
+
dim = 1024, # channel
|
30 |
+
hidden_dim = 2048,
|
31 |
+
num_attention_heads = 16,
|
32 |
+
num_blocks = 65, # depth
|
33 |
+
max_height = 16,
|
34 |
+
max_width = 16,
|
35 |
+
image_input_ratio = 1,
|
36 |
+
input_feature_ratio = 2,
|
37 |
+
final_kernel_size = 3,
|
38 |
+
dropout_prob = 0,
|
39 |
+
),
|
40 |
+
Args(
|
41 |
+
block_type = "ConvNeXtBlock",
|
42 |
+
dim = 512,
|
43 |
+
hidden_dim = 1024,
|
44 |
+
kernel_size = 7,
|
45 |
+
num_blocks = 33,
|
46 |
+
max_height = 32,
|
47 |
+
max_width = 32,
|
48 |
+
image_input_ratio = 1,
|
49 |
+
input_feature_ratio = 1,
|
50 |
+
final_kernel_size = 3,
|
51 |
+
dropout_prob = 0,
|
52 |
+
),
|
53 |
+
],
|
54 |
+
)
|
55 |
+
|
56 |
+
def d(**kwargs):
|
57 |
+
"""Helper of creating a config dict."""
|
58 |
+
return ml_collections.ConfigDict(initial_dictionary=kwargs)
|
59 |
+
|
60 |
+
|
61 |
+
def get_config():
|
62 |
+
config = ml_collections.ConfigDict()
|
63 |
+
|
64 |
+
config.seed = 1234
|
65 |
+
config.z_shape = (4, 32, 32)
|
66 |
+
|
67 |
+
config.autoencoder = d(
|
68 |
+
pretrained_path='assets/stable-diffusion/autoencoder_kl.pth',
|
69 |
+
scale_factor=0.23010
|
70 |
+
)
|
71 |
+
|
72 |
+
config.train = d(
|
73 |
+
n_steps=1000000,
|
74 |
+
batch_size=1024,
|
75 |
+
mode='cond',
|
76 |
+
log_interval=10,
|
77 |
+
eval_interval=5000,
|
78 |
+
save_interval=50000,
|
79 |
+
)
|
80 |
+
|
81 |
+
config.optimizer = d(
|
82 |
+
name='adamw',
|
83 |
+
lr=0.00005,
|
84 |
+
weight_decay=0.03,
|
85 |
+
betas=(0.9, 0.9),
|
86 |
+
)
|
87 |
+
|
88 |
+
config.lr_scheduler = d(
|
89 |
+
name='customized',
|
90 |
+
warmup_steps=5000
|
91 |
+
)
|
92 |
+
|
93 |
+
global model
|
94 |
+
config.nnet = d(
|
95 |
+
name='dimr',
|
96 |
+
model_args=model,
|
97 |
+
)
|
98 |
+
config.loss_coeffs = [1/4, 1]
|
99 |
+
|
100 |
+
config.dataset = d(
|
101 |
+
name='JDB_demo_features',
|
102 |
+
resolution=256,
|
103 |
+
llm='t5',
|
104 |
+
train_path='/data/qihao/dataset/JDB_demo_feature/',
|
105 |
+
val_path='/data/qihao/dataset/coco_val_features/',
|
106 |
+
cfg=False
|
107 |
+
)
|
108 |
+
|
109 |
+
config.sample = d(
|
110 |
+
sample_steps=50,
|
111 |
+
n_samples=30000,
|
112 |
+
mini_batch_size=20,
|
113 |
+
cfg=False,
|
114 |
+
scale=7,
|
115 |
+
path=''
|
116 |
+
)
|
117 |
+
|
118 |
+
return config
|
configs/t2i_512px_clip_dimr.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ml_collections
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class Args:
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
for key, value in kwargs.items():
|
8 |
+
setattr(self, key, value)
|
9 |
+
|
10 |
+
model = Args(
|
11 |
+
channels = 4,
|
12 |
+
block_grad_to_lowres = False,
|
13 |
+
norm_type = "TDRMSN",
|
14 |
+
use_t2i = True,
|
15 |
+
clip_dim=768,
|
16 |
+
num_clip_token=77,
|
17 |
+
gradient_checking=True,
|
18 |
+
cfg_indicator=0.15,
|
19 |
+
textVAE = Args(
|
20 |
+
num_blocks = 11,
|
21 |
+
hidden_dim = 1024,
|
22 |
+
hidden_token_length = 256,
|
23 |
+
num_attention_heads = 8,
|
24 |
+
dropout_prob = 0.1,
|
25 |
+
),
|
26 |
+
stage_configs = [
|
27 |
+
Args(
|
28 |
+
block_type = "TransformerBlock",
|
29 |
+
dim = 1024, # channel
|
30 |
+
hidden_dim = 2048,
|
31 |
+
num_attention_heads = 16,
|
32 |
+
num_blocks = 65, # depth
|
33 |
+
max_height = 16,
|
34 |
+
max_width = 16,
|
35 |
+
image_input_ratio = 1,
|
36 |
+
input_feature_ratio = 4,
|
37 |
+
final_kernel_size = 3,
|
38 |
+
dropout_prob = 0,
|
39 |
+
),
|
40 |
+
Args(
|
41 |
+
block_type = "ConvNeXtBlock",
|
42 |
+
dim = 512,
|
43 |
+
hidden_dim = 1024,
|
44 |
+
kernel_size = 7,
|
45 |
+
num_blocks = 33,
|
46 |
+
max_height = 32,
|
47 |
+
max_width = 32,
|
48 |
+
image_input_ratio = 1,
|
49 |
+
input_feature_ratio = 2,
|
50 |
+
final_kernel_size = 3,
|
51 |
+
dropout_prob = 0,
|
52 |
+
),
|
53 |
+
Args(
|
54 |
+
block_type = "ConvNeXtBlock",
|
55 |
+
dim = 256,
|
56 |
+
hidden_dim = 512,
|
57 |
+
kernel_size = 7,
|
58 |
+
num_blocks = 33,
|
59 |
+
max_height = 64,
|
60 |
+
max_width = 64,
|
61 |
+
image_input_ratio = 1,
|
62 |
+
input_feature_ratio = 1,
|
63 |
+
final_kernel_size = 3,
|
64 |
+
dropout_prob = 0,
|
65 |
+
),
|
66 |
+
],
|
67 |
+
)
|
68 |
+
|
69 |
+
def d(**kwargs):
|
70 |
+
"""Helper of creating a config dict."""
|
71 |
+
return ml_collections.ConfigDict(initial_dictionary=kwargs)
|
72 |
+
|
73 |
+
|
74 |
+
def get_config():
|
75 |
+
config = ml_collections.ConfigDict()
|
76 |
+
|
77 |
+
config.seed = 1234
|
78 |
+
config.z_shape = (4, 64, 64)
|
79 |
+
|
80 |
+
config.autoencoder = d(
|
81 |
+
pretrained_path='assets/stable-diffusion/autoencoder_kl.pth',
|
82 |
+
scale_factor=0.23010
|
83 |
+
)
|
84 |
+
|
85 |
+
config.train = d(
|
86 |
+
n_steps=1000000,
|
87 |
+
batch_size=1024,
|
88 |
+
mode='cond',
|
89 |
+
log_interval=10,
|
90 |
+
eval_interval=5000,
|
91 |
+
save_interval=50000,
|
92 |
+
)
|
93 |
+
|
94 |
+
config.optimizer = d(
|
95 |
+
name='adamw',
|
96 |
+
lr=0.00001,
|
97 |
+
weight_decay=0.03,
|
98 |
+
betas=(0.9, 0.9),
|
99 |
+
)
|
100 |
+
|
101 |
+
config.lr_scheduler = d(
|
102 |
+
name='customized',
|
103 |
+
warmup_steps=5000
|
104 |
+
)
|
105 |
+
|
106 |
+
global model
|
107 |
+
config.nnet = d(
|
108 |
+
name='dimr',
|
109 |
+
model_args=model,
|
110 |
+
)
|
111 |
+
config.loss_coeffs = [1/4, 1/2, 1]
|
112 |
+
|
113 |
+
config.dataset = d(
|
114 |
+
name='JDB_demo_features',
|
115 |
+
resolution=512,
|
116 |
+
llm='clip',
|
117 |
+
train_path='/data/qihao/dataset/JDB_demo_feature/',
|
118 |
+
val_path='/data/qihao/dataset/coco_val_features/',
|
119 |
+
cfg=False
|
120 |
+
)
|
121 |
+
|
122 |
+
config.sample = d(
|
123 |
+
sample_steps=50,
|
124 |
+
n_samples=30000,
|
125 |
+
mini_batch_size=10,
|
126 |
+
cfg=False,
|
127 |
+
scale=7,
|
128 |
+
path=''
|
129 |
+
)
|
130 |
+
|
131 |
+
return config
|
configs/t2i_512px_t5_dimr.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ml_collections
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class Args:
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
for key, value in kwargs.items():
|
8 |
+
setattr(self, key, value)
|
9 |
+
|
10 |
+
model = Args(
|
11 |
+
channels = 4,
|
12 |
+
block_grad_to_lowres = False,
|
13 |
+
norm_type = "TDRMSN",
|
14 |
+
use_t2i = True,
|
15 |
+
clip_dim=4096,
|
16 |
+
num_clip_token=77,
|
17 |
+
gradient_checking=True,
|
18 |
+
cfg_indicator=0.15,
|
19 |
+
textVAE = Args(
|
20 |
+
num_blocks = 11,
|
21 |
+
hidden_dim = 1024,
|
22 |
+
hidden_token_length = 256,
|
23 |
+
num_attention_heads = 8,
|
24 |
+
dropout_prob = 0.1,
|
25 |
+
),
|
26 |
+
stage_configs = [
|
27 |
+
Args(
|
28 |
+
block_type = "TransformerBlock",
|
29 |
+
dim = 1024, # channel
|
30 |
+
hidden_dim = 2048,
|
31 |
+
num_attention_heads = 16,
|
32 |
+
num_blocks = 65, # depth
|
33 |
+
max_height = 16,
|
34 |
+
max_width = 16,
|
35 |
+
image_input_ratio = 1,
|
36 |
+
input_feature_ratio = 4,
|
37 |
+
final_kernel_size = 3,
|
38 |
+
dropout_prob = 0,
|
39 |
+
),
|
40 |
+
Args(
|
41 |
+
block_type = "ConvNeXtBlock",
|
42 |
+
dim = 512,
|
43 |
+
hidden_dim = 1024,
|
44 |
+
kernel_size = 7,
|
45 |
+
num_blocks = 33,
|
46 |
+
max_height = 32,
|
47 |
+
max_width = 32,
|
48 |
+
image_input_ratio = 1,
|
49 |
+
input_feature_ratio = 2,
|
50 |
+
final_kernel_size = 3,
|
51 |
+
dropout_prob = 0,
|
52 |
+
),
|
53 |
+
Args(
|
54 |
+
block_type = "ConvNeXtBlock",
|
55 |
+
dim = 256,
|
56 |
+
hidden_dim = 512,
|
57 |
+
kernel_size = 7,
|
58 |
+
num_blocks = 33,
|
59 |
+
max_height = 64,
|
60 |
+
max_width = 64,
|
61 |
+
image_input_ratio = 1,
|
62 |
+
input_feature_ratio = 1,
|
63 |
+
final_kernel_size = 3,
|
64 |
+
dropout_prob = 0,
|
65 |
+
),
|
66 |
+
],
|
67 |
+
)
|
68 |
+
|
69 |
+
def d(**kwargs):
|
70 |
+
"""Helper of creating a config dict."""
|
71 |
+
return ml_collections.ConfigDict(initial_dictionary=kwargs)
|
72 |
+
|
73 |
+
|
74 |
+
def get_config():
|
75 |
+
config = ml_collections.ConfigDict()
|
76 |
+
|
77 |
+
config.seed = 1234
|
78 |
+
config.z_shape = (4, 64, 64)
|
79 |
+
|
80 |
+
config.autoencoder = d(
|
81 |
+
pretrained_path='assets/stable-diffusion/autoencoder_kl.pth',
|
82 |
+
scale_factor=0.23010
|
83 |
+
)
|
84 |
+
|
85 |
+
config.train = d(
|
86 |
+
n_steps=1000000,
|
87 |
+
batch_size=1024,
|
88 |
+
mode='cond',
|
89 |
+
log_interval=10,
|
90 |
+
eval_interval=5000,
|
91 |
+
save_interval=50000,
|
92 |
+
)
|
93 |
+
|
94 |
+
config.optimizer = d(
|
95 |
+
name='adamw',
|
96 |
+
lr=0.00001,
|
97 |
+
weight_decay=0.03,
|
98 |
+
betas=(0.9, 0.9),
|
99 |
+
)
|
100 |
+
|
101 |
+
config.lr_scheduler = d(
|
102 |
+
name='customized',
|
103 |
+
warmup_steps=5000
|
104 |
+
)
|
105 |
+
|
106 |
+
global model
|
107 |
+
config.nnet = d(
|
108 |
+
name='dimr',
|
109 |
+
model_args=model,
|
110 |
+
)
|
111 |
+
config.loss_coeffs = [1/4, 1/2, 1]
|
112 |
+
|
113 |
+
config.dataset = d(
|
114 |
+
name='JDB_demo_features',
|
115 |
+
resolution=512,
|
116 |
+
llm='t5',
|
117 |
+
train_path='/data/qihao/dataset/JDB_demo_feature/',
|
118 |
+
val_path='/data/qihao/dataset/coco_val_features/',
|
119 |
+
cfg=False
|
120 |
+
)
|
121 |
+
|
122 |
+
config.sample = d(
|
123 |
+
sample_steps=50,
|
124 |
+
n_samples=30000,
|
125 |
+
mini_batch_size=10,
|
126 |
+
cfg=False,
|
127 |
+
scale=7,
|
128 |
+
path=''
|
129 |
+
)
|
130 |
+
|
131 |
+
return config
|
configs/t2i_512px_t5_dit.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ml_collections
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class Args:
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
for key, value in kwargs.items():
|
8 |
+
setattr(self, key, value)
|
9 |
+
|
10 |
+
model = Args(
|
11 |
+
latent_size = 64,
|
12 |
+
learn_sigma = False, # different from DiT, we direct predict noise here
|
13 |
+
channels = 4,
|
14 |
+
block_grad_to_lowres = False,
|
15 |
+
norm_type = "TDRMSN",
|
16 |
+
use_t2i = True,
|
17 |
+
clip_dim=4096,
|
18 |
+
num_clip_token=77,
|
19 |
+
gradient_checking=True, # for larger model
|
20 |
+
cfg_indicator=0.10,
|
21 |
+
textVAE = Args(
|
22 |
+
num_blocks = 11,
|
23 |
+
hidden_dim = 1024,
|
24 |
+
hidden_token_length = 256,
|
25 |
+
num_attention_heads = 8,
|
26 |
+
dropout_prob = 0.1,
|
27 |
+
),
|
28 |
+
)
|
29 |
+
|
30 |
+
def d(**kwargs):
|
31 |
+
"""Helper of creating a config dict."""
|
32 |
+
return ml_collections.ConfigDict(initial_dictionary=kwargs)
|
33 |
+
|
34 |
+
|
35 |
+
def get_config():
|
36 |
+
config = ml_collections.ConfigDict()
|
37 |
+
|
38 |
+
config.seed = 1234
|
39 |
+
config.z_shape = (4, 64, 64)
|
40 |
+
|
41 |
+
config.autoencoder = d(
|
42 |
+
pretrained_path='assets/stable-diffusion/autoencoder_kl.pth',
|
43 |
+
scale_factor=0.23010
|
44 |
+
)
|
45 |
+
|
46 |
+
config.train = d(
|
47 |
+
n_steps=1000000,
|
48 |
+
batch_size=1024,
|
49 |
+
mode='cond',
|
50 |
+
log_interval=10,
|
51 |
+
eval_interval=5000,
|
52 |
+
save_interval=50000,
|
53 |
+
)
|
54 |
+
|
55 |
+
config.optimizer = d(
|
56 |
+
name='adamw',
|
57 |
+
lr=0.00002,
|
58 |
+
weight_decay=0.03,
|
59 |
+
betas=(0.9, 0.9),
|
60 |
+
)
|
61 |
+
|
62 |
+
config.lr_scheduler = d(
|
63 |
+
name='customized',
|
64 |
+
warmup_steps=5000
|
65 |
+
)
|
66 |
+
|
67 |
+
global model
|
68 |
+
config.nnet = d(
|
69 |
+
name='dit',
|
70 |
+
model_args=model,
|
71 |
+
)
|
72 |
+
config.loss_coeffs = []
|
73 |
+
|
74 |
+
config.dataset = d(
|
75 |
+
name='JDB_demo_features',
|
76 |
+
resolution=512,
|
77 |
+
llm='t5',
|
78 |
+
train_path='/data/qihao/dataset/JDB_demo_feature/',
|
79 |
+
val_path='/data/qihao/dataset/coco_val_features/',
|
80 |
+
cfg=False
|
81 |
+
)
|
82 |
+
|
83 |
+
config.sample = d(
|
84 |
+
sample_steps=50,
|
85 |
+
n_samples=30000,
|
86 |
+
mini_batch_size=10,
|
87 |
+
cfg=False,
|
88 |
+
scale=7,
|
89 |
+
path=''
|
90 |
+
)
|
91 |
+
|
92 |
+
return config
|
configs/t2i_training_demo.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ml_collections
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class Args:
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
for key, value in kwargs.items():
|
8 |
+
setattr(self, key, value)
|
9 |
+
|
10 |
+
model = Args(
|
11 |
+
channels = 4,
|
12 |
+
block_grad_to_lowres = False,
|
13 |
+
norm_type = "TDRMSN",
|
14 |
+
use_t2i = True,
|
15 |
+
clip_dim=768, # 768 for CLIP, 4096 for T5-XXL
|
16 |
+
num_clip_token=77,
|
17 |
+
gradient_checking=True,
|
18 |
+
cfg_indicator=0.1,
|
19 |
+
textVAE = Args(
|
20 |
+
num_blocks = 11,
|
21 |
+
hidden_dim = 1024,
|
22 |
+
hidden_token_length = 256,
|
23 |
+
num_attention_heads = 8,
|
24 |
+
dropout_prob = 0.1,
|
25 |
+
),
|
26 |
+
stage_configs = [ # this is just an example
|
27 |
+
Args(
|
28 |
+
block_type = "TransformerBlock",
|
29 |
+
dim = 960,
|
30 |
+
hidden_dim = 1920,
|
31 |
+
num_attention_heads = 16,
|
32 |
+
num_blocks = 29,
|
33 |
+
max_height = 16,
|
34 |
+
max_width = 16,
|
35 |
+
image_input_ratio = 1,
|
36 |
+
input_feature_ratio = 4,
|
37 |
+
final_kernel_size = 3,
|
38 |
+
dropout_prob = 0,
|
39 |
+
),
|
40 |
+
Args(
|
41 |
+
block_type = "ConvNeXtBlock",
|
42 |
+
dim = 480,
|
43 |
+
hidden_dim = 960,
|
44 |
+
kernel_size = 7,
|
45 |
+
num_blocks = 15,
|
46 |
+
max_height = 32,
|
47 |
+
max_width = 32,
|
48 |
+
image_input_ratio = 1,
|
49 |
+
input_feature_ratio = 2,
|
50 |
+
final_kernel_size = 3,
|
51 |
+
dropout_prob = 0,
|
52 |
+
),
|
53 |
+
Args(
|
54 |
+
block_type = "ConvNeXtBlock",
|
55 |
+
dim = 240,
|
56 |
+
hidden_dim = 480,
|
57 |
+
kernel_size = 7,
|
58 |
+
num_blocks = 15,
|
59 |
+
max_height = 64,
|
60 |
+
max_width = 64,
|
61 |
+
image_input_ratio = 1,
|
62 |
+
input_feature_ratio = 1,
|
63 |
+
final_kernel_size = 3,
|
64 |
+
dropout_prob = 0,
|
65 |
+
),
|
66 |
+
],
|
67 |
+
)
|
68 |
+
|
69 |
+
def d(**kwargs):
|
70 |
+
"""Helper of creating a config dict."""
|
71 |
+
return ml_collections.ConfigDict(initial_dictionary=kwargs)
|
72 |
+
|
73 |
+
|
74 |
+
def get_config():
|
75 |
+
config = ml_collections.ConfigDict()
|
76 |
+
|
77 |
+
config.seed = 1234 # random seed
|
78 |
+
config.z_shape = (4, 64, 64) # image latent size
|
79 |
+
|
80 |
+
config.autoencoder = d(
|
81 |
+
pretrained_path='assets/stable-diffusion/autoencoder_kl.pth', # path of pretrained VAE CKPT from LDM
|
82 |
+
scale_factor=0.23010
|
83 |
+
)
|
84 |
+
|
85 |
+
config.train = d(
|
86 |
+
n_steps=1000000, # total training iterations
|
87 |
+
batch_size=4, # overall batch size across ALL gpus, where batch_size_per_gpu == batch_size / number_of_gpus
|
88 |
+
mode='cond',
|
89 |
+
log_interval=10,
|
90 |
+
eval_interval=10, # iteration interval for visual testing on the specified prompt
|
91 |
+
save_interval=100, # iteration interval for saving checkpoints and testing FID
|
92 |
+
n_samples_eval=5, # number of samples duing visual testing. This depends on your GPU memory and can be any integer between 1 and 15 (as we provide only 15 prompts).
|
93 |
+
)
|
94 |
+
|
95 |
+
config.optimizer = d(
|
96 |
+
name='adamw',
|
97 |
+
lr=0.00001, # learning rate
|
98 |
+
weight_decay=0.03,
|
99 |
+
betas=(0.9, 0.9),
|
100 |
+
)
|
101 |
+
|
102 |
+
config.lr_scheduler = d(
|
103 |
+
name='customized',
|
104 |
+
warmup_steps=5000 # warmup steps
|
105 |
+
)
|
106 |
+
|
107 |
+
global model
|
108 |
+
config.nnet = d(
|
109 |
+
name='dimr',
|
110 |
+
model_args=model,
|
111 |
+
)
|
112 |
+
config.loss_coeffs = [1/4, 1/2, 1] # weight on loss, only needed for DiMR. Here, loss = 1/4 * loss_block1 + 1/2 * loss_block2 + 1 * loss_block3
|
113 |
+
|
114 |
+
config.dataset = d(
|
115 |
+
name='JDB_demo_features', # dataset name
|
116 |
+
resolution=512, # dataset resolution
|
117 |
+
llm='clip', # language model to generate language embedding
|
118 |
+
train_path='/data/qihao/dataset/JDB_demo_feature/', # training set path
|
119 |
+
val_path='/data/qihao/dataset/coco_val_features/', # val set path
|
120 |
+
cfg=False
|
121 |
+
)
|
122 |
+
|
123 |
+
config.sample = d(
|
124 |
+
sample_steps=50, # sample steps duing inference/testing
|
125 |
+
n_samples=30000, # number of samples for testing (during training, we sample 10K images, which is hardcoded in the training script)
|
126 |
+
mini_batch_size=10, # batch size for testing (i.e., the number of images generated per GPU)
|
127 |
+
cfg=False,
|
128 |
+
scale=7, # cfg scale
|
129 |
+
path=''
|
130 |
+
)
|
131 |
+
|
132 |
+
return config
|
datasets.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from torchvision import datasets
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from scipy.signal import convolve2d
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import math
|
8 |
+
import random
|
9 |
+
from PIL import Image
|
10 |
+
import os
|
11 |
+
import glob
|
12 |
+
import einops
|
13 |
+
import torchvision.transforms.functional as F
|
14 |
+
import time
|
15 |
+
from tqdm import tqdm
|
16 |
+
import json
|
17 |
+
import pickle
|
18 |
+
import io
|
19 |
+
import cv2
|
20 |
+
|
21 |
+
import libs.clip
|
22 |
+
import bisect
|
23 |
+
|
24 |
+
|
25 |
+
class UnlabeledDataset(Dataset):
|
26 |
+
def __init__(self, dataset):
|
27 |
+
self.dataset = dataset
|
28 |
+
|
29 |
+
def __len__(self):
|
30 |
+
return len(self.dataset)
|
31 |
+
|
32 |
+
def __getitem__(self, item):
|
33 |
+
data = tuple(self.dataset[item][:-1]) # remove label
|
34 |
+
if len(data) == 1:
|
35 |
+
data = data[0]
|
36 |
+
return data
|
37 |
+
|
38 |
+
|
39 |
+
class LabeledDataset(Dataset):
|
40 |
+
def __init__(self, dataset, labels):
|
41 |
+
self.dataset = dataset
|
42 |
+
self.labels = labels
|
43 |
+
|
44 |
+
def __len__(self):
|
45 |
+
return len(self.dataset)
|
46 |
+
|
47 |
+
def __getitem__(self, item):
|
48 |
+
return self.dataset[item], self.labels[item]
|
49 |
+
|
50 |
+
|
51 |
+
class DatasetFactory(object):
|
52 |
+
|
53 |
+
def __init__(self):
|
54 |
+
self.train = None
|
55 |
+
self.test = None
|
56 |
+
|
57 |
+
def get_split(self, split, labeled=False):
|
58 |
+
if split == "train":
|
59 |
+
dataset = self.train
|
60 |
+
elif split == "test":
|
61 |
+
dataset = self.test
|
62 |
+
else:
|
63 |
+
raise ValueError
|
64 |
+
|
65 |
+
if self.has_label:
|
66 |
+
return dataset if labeled else UnlabeledDataset(dataset)
|
67 |
+
else:
|
68 |
+
assert not labeled
|
69 |
+
return dataset
|
70 |
+
|
71 |
+
def unpreprocess(self, v): # to B C H W and [0, 1]
|
72 |
+
v = 0.5 * (v + 1.)
|
73 |
+
v.clamp_(0., 1.)
|
74 |
+
return v
|
75 |
+
|
76 |
+
@property
|
77 |
+
def has_label(self):
|
78 |
+
return True
|
79 |
+
|
80 |
+
@property
|
81 |
+
def data_shape(self):
|
82 |
+
raise NotImplementedError
|
83 |
+
|
84 |
+
@property
|
85 |
+
def data_dim(self):
|
86 |
+
return int(np.prod(self.data_shape))
|
87 |
+
|
88 |
+
@property
|
89 |
+
def fid_stat(self):
|
90 |
+
return None
|
91 |
+
|
92 |
+
def sample_label(self, n_samples, device):
|
93 |
+
raise NotImplementedError
|
94 |
+
|
95 |
+
def label_prob(self, k):
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
|
99 |
+
def center_crop_arr(pil_image, image_size):
|
100 |
+
# We are not on a new enough PIL to support the `reducing_gap`
|
101 |
+
# argument, which uses BOX downsampling at powers of two first.
|
102 |
+
# Thus, we do it by hand to improve downsample quality.
|
103 |
+
while min(*pil_image.size) >= 2 * image_size:
|
104 |
+
pil_image = pil_image.resize(
|
105 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
106 |
+
)
|
107 |
+
|
108 |
+
scale = image_size / min(*pil_image.size)
|
109 |
+
pil_image = pil_image.resize(
|
110 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
111 |
+
)
|
112 |
+
|
113 |
+
arr = np.array(pil_image)
|
114 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
115 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
116 |
+
return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
|
117 |
+
|
118 |
+
|
119 |
+
# MS COCO
|
120 |
+
|
121 |
+
|
122 |
+
def center_crop(width, height, img):
|
123 |
+
resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos']
|
124 |
+
crop = np.min(img.shape[:2])
|
125 |
+
img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
|
126 |
+
(img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2]
|
127 |
+
try:
|
128 |
+
img = Image.fromarray(img, 'RGB')
|
129 |
+
except:
|
130 |
+
img = Image.fromarray(img)
|
131 |
+
img = img.resize((width, height), resample)
|
132 |
+
|
133 |
+
return np.array(img).astype(np.uint8)
|
134 |
+
|
135 |
+
|
136 |
+
class MSCOCODatabase(Dataset):
|
137 |
+
def __init__(self, root, annFile, size=None):
|
138 |
+
from pycocotools.coco import COCO
|
139 |
+
self.root = root
|
140 |
+
self.height = self.width = size
|
141 |
+
|
142 |
+
self.coco = COCO(annFile)
|
143 |
+
self.keys = list(sorted(self.coco.imgs.keys()))
|
144 |
+
|
145 |
+
def _load_image(self, key: int):
|
146 |
+
path = self.coco.loadImgs(key)[0]["file_name"]
|
147 |
+
return Image.open(os.path.join(self.root, path)).convert("RGB")
|
148 |
+
|
149 |
+
def _load_target(self, key: int):
|
150 |
+
return self.coco.loadAnns(self.coco.getAnnIds(key))
|
151 |
+
|
152 |
+
def __len__(self):
|
153 |
+
return len(self.keys)
|
154 |
+
|
155 |
+
def __getitem__(self, index):
|
156 |
+
key = self.keys[index]
|
157 |
+
image = self._load_image(key)
|
158 |
+
image = np.array(image).astype(np.uint8)
|
159 |
+
image = center_crop(self.width, self.height, image).astype(np.float32)
|
160 |
+
image = (image / 127.5 - 1.0).astype(np.float32)
|
161 |
+
image = einops.rearrange(image, 'h w c -> c h w')
|
162 |
+
|
163 |
+
anns = self._load_target(key)
|
164 |
+
target = []
|
165 |
+
for ann in anns:
|
166 |
+
target.append(ann['caption'])
|
167 |
+
|
168 |
+
return image, target
|
169 |
+
|
170 |
+
|
171 |
+
def get_feature_dir_info(root):
|
172 |
+
files = glob.glob(os.path.join(root, '*.npy'))
|
173 |
+
files_caption = glob.glob(os.path.join(root, '*_*.npy'))
|
174 |
+
num_data = len(files) - len(files_caption)
|
175 |
+
n_captions = {k: 0 for k in range(num_data)}
|
176 |
+
for f in files_caption:
|
177 |
+
name = os.path.split(f)[-1]
|
178 |
+
k1, k2 = os.path.splitext(name)[0].split('_')
|
179 |
+
n_captions[int(k1)] += 1
|
180 |
+
return num_data, n_captions
|
181 |
+
|
182 |
+
|
183 |
+
class MSCOCOFeatureDataset(Dataset):
|
184 |
+
# the image features are got through sample
|
185 |
+
def __init__(self, root, need_squeeze=False, full_feature=False, fix_test_order=False):
|
186 |
+
self.root = root
|
187 |
+
self.num_data, self.n_captions = get_feature_dir_info(root)
|
188 |
+
self.need_squeeze = need_squeeze
|
189 |
+
self.full_feature = full_feature
|
190 |
+
self.fix_test_order = fix_test_order
|
191 |
+
|
192 |
+
def __len__(self):
|
193 |
+
return self.num_data
|
194 |
+
|
195 |
+
def __getitem__(self, index):
|
196 |
+
if self.full_feature:
|
197 |
+
z = np.load(os.path.join(self.root, f'{index}.npy'))
|
198 |
+
|
199 |
+
if self.fix_test_order:
|
200 |
+
k = self.n_captions[index] - 1
|
201 |
+
else:
|
202 |
+
k = random.randint(0, self.n_captions[index] - 1)
|
203 |
+
|
204 |
+
test_item = np.load(os.path.join(self.root, f'{index}_{k}.npy'), allow_pickle=True).item()
|
205 |
+
token_embedding = test_item['token_embedding']
|
206 |
+
token_mask = test_item['token_mask']
|
207 |
+
token = test_item['token']
|
208 |
+
caption = test_item['promt']
|
209 |
+
return z, token_embedding, token_mask, token, caption
|
210 |
+
else:
|
211 |
+
z = np.load(os.path.join(self.root, f'{index}.npy'))
|
212 |
+
k = random.randint(0, self.n_captions[index] - 1)
|
213 |
+
c = np.load(os.path.join(self.root, f'{index}_{k}.npy'))
|
214 |
+
if self.need_squeeze:
|
215 |
+
return z, c.squeeze()
|
216 |
+
else:
|
217 |
+
return z, c
|
218 |
+
|
219 |
+
|
220 |
+
class JDBFeatureDataset(Dataset):
|
221 |
+
def __init__(self, root, resolution, llm):
|
222 |
+
super().__init__()
|
223 |
+
json_path = os.path.join(root,'img_text_pair.jsonl')
|
224 |
+
self.img_root = os.path.join(root,'imgs')
|
225 |
+
self.feature_root = os.path.join(root,'features')
|
226 |
+
self.resolution = resolution
|
227 |
+
self.llm = llm
|
228 |
+
self.file_list = []
|
229 |
+
with open(json_path, 'r', encoding='utf-8') as file:
|
230 |
+
for line in file:
|
231 |
+
self.file_list.append(json.loads(line)['img_path'])
|
232 |
+
|
233 |
+
def __len__(self):
|
234 |
+
return len(self.file_list)
|
235 |
+
|
236 |
+
def __getitem__(self, idx):
|
237 |
+
data_item = self.file_list[idx]
|
238 |
+
feature_path = os.path.join(self.feature_root, data_item.split('/')[-1].replace('.jpg','.npy'))
|
239 |
+
img_path = os.path.join(self.img_root, data_item)
|
240 |
+
|
241 |
+
train_item = np.load(feature_path, allow_pickle=True).item()
|
242 |
+
pil_image = Image.open(img_path)
|
243 |
+
pil_image.load()
|
244 |
+
pil_image = pil_image.convert("RGB")
|
245 |
+
|
246 |
+
|
247 |
+
z = train_item[f'image_latent_{self.resolution}']
|
248 |
+
token_embedding = train_item[f'token_embedding_{self.llm}']
|
249 |
+
token_mask = train_item[f'token_mask_{self.llm}']
|
250 |
+
token = train_item[f'token_{self.llm}']
|
251 |
+
caption = train_item['batch_caption']
|
252 |
+
|
253 |
+
img = center_crop_arr(pil_image, image_size=self.resolution)
|
254 |
+
img = (img / 127.5 - 1.0).astype(np.float32)
|
255 |
+
img = einops.rearrange(img, 'h w c -> c h w')
|
256 |
+
|
257 |
+
# return z, token_embedding, token_mask, token, caption, 0, img, 0, 0
|
258 |
+
return z, token_embedding, token_mask, token, caption, img
|
259 |
+
|
260 |
+
|
261 |
+
class JDBFullFeatures(DatasetFactory): # the moments calculated by Stable Diffusion image encoder & the contexts calculated by clip
|
262 |
+
def __init__(self, train_path, val_path, resolution, llm, cfg=False, p_uncond=None, fix_test_order=False):
|
263 |
+
super().__init__()
|
264 |
+
print('Prepare dataset...')
|
265 |
+
self.resolution = resolution
|
266 |
+
|
267 |
+
self.train = JDBFeatureDataset(train_path, resolution=resolution, llm=llm)
|
268 |
+
self.test = MSCOCOFeatureDataset(os.path.join(val_path, 'val'), full_feature=True, fix_test_order=fix_test_order)
|
269 |
+
assert len(self.test) == 40504
|
270 |
+
|
271 |
+
print('Prepare dataset ok')
|
272 |
+
|
273 |
+
self.empty_context = np.load(os.path.join(val_path, 'empty_context.npy'), allow_pickle=True).item()
|
274 |
+
|
275 |
+
assert not cfg
|
276 |
+
|
277 |
+
# text embedding extracted by clip
|
278 |
+
self.prompts, self.token_embedding, self.token_mask, self.token = [], [], [], []
|
279 |
+
for f in sorted(os.listdir(os.path.join(val_path, 'run_vis')), key=lambda x: int(x.split('.')[0])):
|
280 |
+
vis_item = np.load(os.path.join(val_path, 'run_vis', f), allow_pickle=True).item()
|
281 |
+
self.prompts.append(vis_item['promt'])
|
282 |
+
self.token_embedding.append(vis_item['token_embedding'])
|
283 |
+
self.token_mask.append(vis_item['token_mask'])
|
284 |
+
self.token.append(vis_item['token'])
|
285 |
+
self.token_embedding = np.array(self.token_embedding)
|
286 |
+
self.token_mask = np.array(self.token_mask)
|
287 |
+
self.token = np.array(self.token)
|
288 |
+
|
289 |
+
@property
|
290 |
+
def data_shape(self):
|
291 |
+
if self.resolution==512:
|
292 |
+
return 4, 64, 64
|
293 |
+
else:
|
294 |
+
return 4, 32, 32
|
295 |
+
|
296 |
+
@property
|
297 |
+
def fid_stat(self):
|
298 |
+
return f'assets/fid_stats/fid_stats_mscoco256_val.npz'
|
299 |
+
|
300 |
+
|
301 |
+
def get_dataset(name, **kwargs):
|
302 |
+
if name == 'JDB_demo_features':
|
303 |
+
return JDBFullFeatures(**kwargs)
|
304 |
+
else:
|
305 |
+
raise NotImplementedError(name)
|
demo_t2i.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file is used for T2I generation, it also compute the clip similarity between the generated images and the input prompt
|
3 |
+
"""
|
4 |
+
from absl import flags
|
5 |
+
from absl import app
|
6 |
+
from ml_collections import config_flags
|
7 |
+
import os
|
8 |
+
|
9 |
+
import ml_collections
|
10 |
+
import torch
|
11 |
+
from torch import multiprocessing as mp
|
12 |
+
import torch.nn as nn
|
13 |
+
import accelerate
|
14 |
+
import utils
|
15 |
+
import tempfile
|
16 |
+
from absl import logging
|
17 |
+
import builtins
|
18 |
+
import einops
|
19 |
+
import math
|
20 |
+
import numpy as np
|
21 |
+
import time
|
22 |
+
from PIL import Image
|
23 |
+
|
24 |
+
from diffusion.flow_matching import FlowMatching, ODEFlowMatchingSolver, ODEEulerFlowMatchingSolver
|
25 |
+
from tools.clip_score import ClipSocre
|
26 |
+
import libs.autoencoder
|
27 |
+
from libs.clip import FrozenCLIPEmbedder
|
28 |
+
from libs.t5 import T5Embedder
|
29 |
+
|
30 |
+
|
31 |
+
def unpreprocess(x):
|
32 |
+
x = 0.5 * (x + 1.)
|
33 |
+
x.clamp_(0., 1.)
|
34 |
+
return x
|
35 |
+
|
36 |
+
def get_caption(llm, text_model, _batch_prompt):
|
37 |
+
_batch_con = _batch_prompt
|
38 |
+
if llm == "clip":
|
39 |
+
_latent, _latent_and_others = text_model.encode(_batch_con)
|
40 |
+
_con = _latent_and_others['token_embedding'].detach()
|
41 |
+
elif llm == "t5":
|
42 |
+
_latent, _latent_and_others = text_model.get_text_embeddings(_batch_con)
|
43 |
+
_con = (_latent_and_others['token_embedding'] * 10.0).detach()
|
44 |
+
else:
|
45 |
+
raise NotImplementedError
|
46 |
+
_con_mask = _latent_and_others['token_mask'].detach()
|
47 |
+
_batch_token = _latent_and_others['tokens'].detach()
|
48 |
+
_batch_caption = _batch_con
|
49 |
+
return (_con, _con_mask, _batch_token, _batch_caption)
|
50 |
+
|
51 |
+
|
52 |
+
def evaluate(config):
|
53 |
+
|
54 |
+
if config.get('benchmark', False):
|
55 |
+
torch.backends.cudnn.benchmark = True
|
56 |
+
torch.backends.cudnn.deterministic = False
|
57 |
+
|
58 |
+
mp.set_start_method('spawn')
|
59 |
+
accelerator = accelerate.Accelerator()
|
60 |
+
device = accelerator.device
|
61 |
+
accelerate.utils.set_seed(config.seed, device_specific=True)
|
62 |
+
logging.info(f'Process {accelerator.process_index} using device: {device}')
|
63 |
+
|
64 |
+
config.mixed_precision = accelerator.mixed_precision
|
65 |
+
config = ml_collections.FrozenConfigDict(config)
|
66 |
+
if accelerator.is_main_process:
|
67 |
+
utils.set_logger(log_level='info', fname=config.output_path)
|
68 |
+
else:
|
69 |
+
utils.set_logger(log_level='error')
|
70 |
+
builtins.print = lambda *args: None
|
71 |
+
|
72 |
+
nnet = utils.get_nnet(**config.nnet)
|
73 |
+
nnet = accelerator.prepare(nnet)
|
74 |
+
logging.info(f'load nnet from {config.nnet_path}')
|
75 |
+
accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
|
76 |
+
nnet.eval()
|
77 |
+
|
78 |
+
##
|
79 |
+
|
80 |
+
if config.nnet.model_args.clip_dim == 4096:
|
81 |
+
llm = "t5"
|
82 |
+
t5 = T5Embedder(device=device)
|
83 |
+
elif config.nnet.model_args.clip_dim == 768:
|
84 |
+
llm = "clip"
|
85 |
+
clip = FrozenCLIPEmbedder()
|
86 |
+
clip.eval()
|
87 |
+
clip.to(device)
|
88 |
+
else:
|
89 |
+
raise NotImplementedError
|
90 |
+
|
91 |
+
if llm == "clip":
|
92 |
+
context_generator = get_caption(llm, clip, _batch_prompt=[config.prompt]*config.sample.mini_batch_size)
|
93 |
+
elif llm == "t5":
|
94 |
+
context_generator = get_caption(llm, t5, _batch_prompt=[config.prompt]*config.sample.mini_batch_size)
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
##
|
99 |
+
|
100 |
+
autoencoder = libs.autoencoder.get_model(**config.autoencoder)
|
101 |
+
autoencoder.to(device)
|
102 |
+
|
103 |
+
@torch.cuda.amp.autocast()
|
104 |
+
def encode(_batch):
|
105 |
+
return autoencoder.encode(_batch)
|
106 |
+
|
107 |
+
@torch.cuda.amp.autocast()
|
108 |
+
def decode(_batch):
|
109 |
+
return autoencoder.decode(_batch)
|
110 |
+
|
111 |
+
bdv_nnet = None # We don't use Autoguidance
|
112 |
+
ClipSocre_model = ClipSocre(device=device) # we also return clip score
|
113 |
+
|
114 |
+
#######
|
115 |
+
logging.info(config.sample)
|
116 |
+
logging.info(f'sample: n_samples={config.sample.n_samples}, mode=t2i, mixed_precision={config.mixed_precision}')
|
117 |
+
|
118 |
+
|
119 |
+
def ode_fm_solver_sample(nnet_ema, _n_samples, _sample_steps, bdv_nnet=bdv_nnet, context=None, caption=None, testbatch_img_blurred=None, two_stage_generation=-1, token=None, token_mask=None, return_clipScore=False, ClipSocre_model=None):
|
120 |
+
with torch.no_grad():
|
121 |
+
del testbatch_img_blurred
|
122 |
+
|
123 |
+
_z_gaussian = torch.randn(_n_samples, *config.z_shape, device=device)
|
124 |
+
|
125 |
+
if 'dimr' in config.nnet.name or 'dit' in config.nnet.name:
|
126 |
+
_z_x0, _mu, _log_var = nnet_ema(context, text_encoder = True, shape = _z_gaussian.shape, mask=token_mask)
|
127 |
+
_z_init = _z_x0.reshape(_z_gaussian.shape)
|
128 |
+
else:
|
129 |
+
raise NotImplementedError
|
130 |
+
|
131 |
+
assert config.sample.scale > 1
|
132 |
+
if config.cfg != -1:
|
133 |
+
_cfg = config.cfg
|
134 |
+
else:
|
135 |
+
_cfg = config.sample.scale
|
136 |
+
|
137 |
+
has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator")
|
138 |
+
|
139 |
+
_sample_steps = config.sample.sample_steps
|
140 |
+
|
141 |
+
ode_solver = ODEEulerFlowMatchingSolver(nnet_ema, bdv_model_fn=bdv_nnet, step_size_type="step_in_dsigma", guidance_scale=_cfg)
|
142 |
+
_z, _ = ode_solver.sample(x_T=_z_init, batch_size=_n_samples, sample_steps=_sample_steps, unconditional_guidance_scale=_cfg, has_null_indicator=has_null_indicator)
|
143 |
+
|
144 |
+
image_unprocessed = decode(_z)
|
145 |
+
clip_score = ClipSocre_model.calculate_clip_score(caption, image_unprocessed)
|
146 |
+
|
147 |
+
return image_unprocessed, clip_score
|
148 |
+
|
149 |
+
|
150 |
+
def sample_fn(_n_samples, return_caption=False, return_clipScore=False, ClipSocre_model=None, config=None):
|
151 |
+
_context, _token_mask, _token, _caption = context_generator
|
152 |
+
assert _context.size(0) == _n_samples
|
153 |
+
assert return_clipScore
|
154 |
+
assert not return_caption
|
155 |
+
return ode_fm_solver_sample(nnet, _n_samples, config.sample.sample_steps, bdv_nnet=bdv_nnet, context=_context, token=_token, token_mask=_token_mask, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, caption=_caption)
|
156 |
+
|
157 |
+
|
158 |
+
with tempfile.TemporaryDirectory() as temp_path:
|
159 |
+
path = config.img_save_path or config.sample.path or temp_path
|
160 |
+
if accelerator.is_main_process:
|
161 |
+
os.makedirs(path, exist_ok=True)
|
162 |
+
logging.info(f'Samples are saved in {path}')
|
163 |
+
|
164 |
+
clip_score_list = utils.sample2dir_wCLIP(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, unpreprocess, return_clipScore=True, ClipSocre_model=ClipSocre_model, config=config)
|
165 |
+
if clip_score_list is not None:
|
166 |
+
_clip_score_list = torch.cat(clip_score_list)
|
167 |
+
if accelerator.is_main_process:
|
168 |
+
logging.info(f'nnet_path={config.nnet_path}, clip_score{len(_clip_score_list)}={_clip_score_list.mean().item()}')
|
169 |
+
|
170 |
+
|
171 |
+
FLAGS = flags.FLAGS
|
172 |
+
config_flags.DEFINE_config_file(
|
173 |
+
"config", None, "Training configuration.", lock_config=False)
|
174 |
+
|
175 |
+
flags.mark_flags_as_required(["config"])
|
176 |
+
flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
|
177 |
+
flags.DEFINE_string("prompt", None, "The prompt used for generation.")
|
178 |
+
flags.DEFINE_string("output_path", None, "The path to output log.")
|
179 |
+
flags.DEFINE_float("cfg", -1, 'cfg scale, will use the scale defined in the config file is not assigned')
|
180 |
+
flags.DEFINE_string("img_save_path", None, "The path to image log.")
|
181 |
+
|
182 |
+
|
183 |
+
def main(argv):
|
184 |
+
config = FLAGS.config
|
185 |
+
config.nnet_path = FLAGS.nnet_path
|
186 |
+
config.prompt = FLAGS.prompt
|
187 |
+
config.output_path = FLAGS.output_path
|
188 |
+
config.img_save_path = FLAGS.img_save_path
|
189 |
+
config.cfg = FLAGS.cfg
|
190 |
+
evaluate(config)
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == "__main__":
|
194 |
+
app.run(main)
|
demo_t2i_arith.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file is used for T2I generation, it also compute the clip similarity between the generated images and the input prompt
|
3 |
+
"""
|
4 |
+
from absl import flags
|
5 |
+
from absl import app
|
6 |
+
from ml_collections import config_flags
|
7 |
+
import os
|
8 |
+
|
9 |
+
import ml_collections
|
10 |
+
import torch
|
11 |
+
from torch import multiprocessing as mp
|
12 |
+
import torch.nn as nn
|
13 |
+
import accelerate
|
14 |
+
import utils
|
15 |
+
import tempfile
|
16 |
+
from absl import logging
|
17 |
+
import builtins
|
18 |
+
import einops
|
19 |
+
import math
|
20 |
+
import numpy as np
|
21 |
+
import time
|
22 |
+
from PIL import Image
|
23 |
+
|
24 |
+
from diffusion.flow_matching import FlowMatching, ODEFlowMatchingSolver, ODEEulerFlowMatchingSolver
|
25 |
+
from tools.clip_score import ClipSocre
|
26 |
+
import libs.autoencoder
|
27 |
+
from libs.clip import FrozenCLIPEmbedder
|
28 |
+
from libs.t5 import T5Embedder
|
29 |
+
|
30 |
+
|
31 |
+
def unpreprocess(x):
|
32 |
+
x = 0.5 * (x + 1.)
|
33 |
+
x.clamp_(0., 1.)
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
def batch_decode(_z, decode, batch_size=10):
|
38 |
+
"""
|
39 |
+
The VAE decoder requires large GPU memory. To run the interpolation model on GPUs with 24 GB or smaller RAM, you can use this code to reduce memory usage for the VAE.
|
40 |
+
It works by splitting the input tensor into smaller chunks.
|
41 |
+
"""
|
42 |
+
num_samples = _z.size(0)
|
43 |
+
decoded_batches = []
|
44 |
+
|
45 |
+
for i in range(0, num_samples, batch_size):
|
46 |
+
batch = _z[i:i + batch_size]
|
47 |
+
decoded_batch = decode(batch)
|
48 |
+
decoded_batches.append(decoded_batch)
|
49 |
+
|
50 |
+
image_unprocessed = torch.cat(decoded_batches, dim=0)
|
51 |
+
return image_unprocessed
|
52 |
+
|
53 |
+
def get_caption(llm, text_model, prompt_dict, batch_size):
|
54 |
+
|
55 |
+
if batch_size == 3:
|
56 |
+
# only addition or only subtraction
|
57 |
+
assert len(prompt_dict) == 2
|
58 |
+
_batch_con = list(prompt_dict.values()) + [' ']
|
59 |
+
elif batch_size == 4:
|
60 |
+
# addition and subtraction
|
61 |
+
assert len(prompt_dict) == 3
|
62 |
+
_batch_con = list(prompt_dict.values()) + [' ']
|
63 |
+
elif batch_size >= 5:
|
64 |
+
# linear interpolation
|
65 |
+
assert len(prompt_dict) == 2
|
66 |
+
_batch_con = [prompt_dict['prompt_1']] + [' '] * (batch_size-2) + [prompt_dict['prompt_2']]
|
67 |
+
|
68 |
+
if llm == "clip":
|
69 |
+
_latent, _latent_and_others = text_model.encode(_batch_con)
|
70 |
+
_con = _latent_and_others['token_embedding'].detach()
|
71 |
+
elif llm == "t5":
|
72 |
+
_latent, _latent_and_others = text_model.get_text_embeddings(_batch_con)
|
73 |
+
_con = (_latent_and_others['token_embedding'] * 10.0).detach()
|
74 |
+
else:
|
75 |
+
raise NotImplementedError
|
76 |
+
_con_mask = _latent_and_others['token_mask'].detach()
|
77 |
+
_batch_token = _latent_and_others['tokens'].detach()
|
78 |
+
_batch_caption = _batch_con
|
79 |
+
return (_con, _con_mask, _batch_token, _batch_caption)
|
80 |
+
|
81 |
+
|
82 |
+
def evaluate(config):
|
83 |
+
|
84 |
+
if config.get('benchmark', False):
|
85 |
+
torch.backends.cudnn.benchmark = True
|
86 |
+
torch.backends.cudnn.deterministic = False
|
87 |
+
|
88 |
+
mp.set_start_method('spawn')
|
89 |
+
accelerator = accelerate.Accelerator()
|
90 |
+
device = accelerator.device
|
91 |
+
accelerate.utils.set_seed(config.seed, device_specific=True)
|
92 |
+
logging.info(f'Process {accelerator.process_index} using device: {device}')
|
93 |
+
|
94 |
+
config.mixed_precision = accelerator.mixed_precision
|
95 |
+
config = ml_collections.FrozenConfigDict(config)
|
96 |
+
if accelerator.is_main_process:
|
97 |
+
utils.set_logger(log_level='info', fname=config.output_path)
|
98 |
+
else:
|
99 |
+
utils.set_logger(log_level='error')
|
100 |
+
builtins.print = lambda *args: None
|
101 |
+
|
102 |
+
nnet = utils.get_nnet(**config.nnet)
|
103 |
+
nnet = accelerator.prepare(nnet)
|
104 |
+
logging.info(f'load nnet from {config.nnet_path}')
|
105 |
+
accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
|
106 |
+
nnet.eval()
|
107 |
+
|
108 |
+
##
|
109 |
+
|
110 |
+
if config.nnet.model_args.clip_dim == 4096:
|
111 |
+
llm = "t5"
|
112 |
+
t5 = T5Embedder(device=device)
|
113 |
+
elif config.nnet.model_args.clip_dim == 768:
|
114 |
+
llm = "clip"
|
115 |
+
clip = FrozenCLIPEmbedder()
|
116 |
+
clip.eval()
|
117 |
+
clip.to(device)
|
118 |
+
else:
|
119 |
+
raise NotImplementedError
|
120 |
+
|
121 |
+
|
122 |
+
config = ml_collections.ConfigDict(config)
|
123 |
+
|
124 |
+
if config.test_type == 'interpolation':
|
125 |
+
prompt_dict = {'prompt_1':config.prompt_1, 'prompt_2':config.prompt_2}
|
126 |
+
for key in prompt_dict.keys():
|
127 |
+
assert prompt_dict[key] is not None
|
128 |
+
config.sample.mini_batch_size = config.num_of_interpolation
|
129 |
+
assert config.sample.mini_batch_size >= 5, "for linear interpolation, please sample at least five image"
|
130 |
+
elif config.test_type == 'arithmetic':
|
131 |
+
prompt_dict = {'prompt_ori':config.prompt_ori, 'prompt_a':config.prompt_a, 'prompt_s':config.prompt_s}
|
132 |
+
keys_to_remove = [key for key, value in prompt_dict.items() if value is None]
|
133 |
+
for key in keys_to_remove:
|
134 |
+
del prompt_dict[key]
|
135 |
+
counter = len(prompt_dict)
|
136 |
+
assert prompt_dict['prompt_ori'] is not None
|
137 |
+
assert counter == 2 or counter == 3
|
138 |
+
config.sample.mini_batch_size = counter + 1
|
139 |
+
else:
|
140 |
+
raise NotImplementedError
|
141 |
+
|
142 |
+
config = ml_collections.FrozenConfigDict(config)
|
143 |
+
|
144 |
+
if llm == "clip":
|
145 |
+
context_generator = get_caption(llm, clip, prompt_dict=prompt_dict, batch_size=config.sample.mini_batch_size)
|
146 |
+
elif llm == "t5":
|
147 |
+
context_generator = get_caption(llm, t5, prompt_dict=prompt_dict, batch_size=config.sample.mini_batch_size)
|
148 |
+
else:
|
149 |
+
raise NotImplementedError
|
150 |
+
|
151 |
+
##
|
152 |
+
|
153 |
+
autoencoder = libs.autoencoder.get_model(**config.autoencoder)
|
154 |
+
autoencoder.to(device)
|
155 |
+
|
156 |
+
@torch.cuda.amp.autocast()
|
157 |
+
def encode(_batch):
|
158 |
+
return autoencoder.encode(_batch)
|
159 |
+
|
160 |
+
@torch.cuda.amp.autocast()
|
161 |
+
def decode(_batch):
|
162 |
+
return autoencoder.decode(_batch)
|
163 |
+
|
164 |
+
bdv_nnet = None # We don't use Autoguidance
|
165 |
+
ClipSocre_model = ClipSocre(device=device) # we also return clip score
|
166 |
+
|
167 |
+
#######
|
168 |
+
logging.info(config.sample)
|
169 |
+
logging.info(f'sample: n_samples={config.sample.n_samples}, mode=t2i, mixed_precision={config.mixed_precision}')
|
170 |
+
|
171 |
+
|
172 |
+
def ode_fm_solver_sample(nnet_ema, _n_samples, _sample_steps, bdv_nnet=bdv_nnet, context=None, caption=None, testbatch_img_blurred=None, two_stage_generation=-1, token=None, token_mask=None, return_clipScore=False, ClipSocre_model=None):
|
173 |
+
with torch.no_grad():
|
174 |
+
del testbatch_img_blurred
|
175 |
+
|
176 |
+
_z_gaussian = torch.randn(_n_samples, *config.z_shape, device=device)
|
177 |
+
|
178 |
+
if 'dimr' in config.nnet.name or 'dit' in config.nnet.name:
|
179 |
+
_z_x0, _mu, _log_var = nnet_ema(context, text_encoder = True, shape = _z_gaussian.shape, mask=token_mask)
|
180 |
+
_z_init = _z_x0.reshape(_z_gaussian.shape)
|
181 |
+
else:
|
182 |
+
raise NotImplementedError
|
183 |
+
|
184 |
+
if len(_z_init) == 3:
|
185 |
+
if config.prompt_a is not None:
|
186 |
+
assert config.prompt_s is None
|
187 |
+
_z_x0_temp = _z_x0[0] + _z_x0[1]
|
188 |
+
elif config.prompt_s is not None:
|
189 |
+
assert config.prompt_a is None
|
190 |
+
_z_x0_temp = _z_x0[0] - _z_x0[1]
|
191 |
+
else:
|
192 |
+
raise NotImplementedError
|
193 |
+
mean = _z_x0_temp.mean()
|
194 |
+
std = _z_x0_temp.std()
|
195 |
+
_z_x0[2] = (_z_x0_temp - mean) / std
|
196 |
+
elif len(_z_init) == 4:
|
197 |
+
_z_x0_temp = _z_x0[0] + _z_x0[1] - _z_x0[2]
|
198 |
+
mean = _z_x0_temp.mean()
|
199 |
+
std = _z_x0_temp.std()
|
200 |
+
_z_x0[3] = (_z_x0_temp - mean) / std
|
201 |
+
elif len(_z_init) >= 5:
|
202 |
+
tensor_a = _z_init[0]
|
203 |
+
tensor_b = _z_init[-1]
|
204 |
+
num_interpolations = len(_z_init) - 2
|
205 |
+
interpolations = [tensor_a + (tensor_b - tensor_a) * (i / (num_interpolations + 1)) for i in range(1, num_interpolations + 1)]
|
206 |
+
_z_init = torch.stack([tensor_a] + interpolations + [tensor_b], dim=0)
|
207 |
+
|
208 |
+
assert config.sample.scale > 1
|
209 |
+
if config.cfg != -1:
|
210 |
+
_cfg = config.cfg
|
211 |
+
else:
|
212 |
+
_cfg = config.sample.scale
|
213 |
+
|
214 |
+
has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator")
|
215 |
+
|
216 |
+
_sample_steps = config.sample.sample_steps
|
217 |
+
|
218 |
+
ode_solver = ODEEulerFlowMatchingSolver(nnet_ema, bdv_model_fn=bdv_nnet, step_size_type="step_in_dsigma", guidance_scale=_cfg)
|
219 |
+
_z, _ = ode_solver.sample(x_T=_z_init, batch_size=_n_samples, sample_steps=_sample_steps, unconditional_guidance_scale=_cfg, has_null_indicator=has_null_indicator)
|
220 |
+
|
221 |
+
if config.save_gpu_memory:
|
222 |
+
image_unprocessed = batch_decode(_z, decode)
|
223 |
+
else:
|
224 |
+
image_unprocessed = decode(_z)
|
225 |
+
clip_score = ClipSocre_model.calculate_clip_score(caption, image_unprocessed)
|
226 |
+
|
227 |
+
return image_unprocessed, clip_score
|
228 |
+
|
229 |
+
|
230 |
+
def sample_fn(_n_samples, return_caption=False, return_clipScore=False, ClipSocre_model=None, config=None):
|
231 |
+
_context, _token_mask, _token, _caption = context_generator
|
232 |
+
assert return_clipScore
|
233 |
+
assert not return_caption
|
234 |
+
return ode_fm_solver_sample(nnet, _n_samples, config.sample.sample_steps, bdv_nnet=bdv_nnet, context=_context, token=_token, token_mask=_token_mask, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, caption=_caption)
|
235 |
+
|
236 |
+
|
237 |
+
with tempfile.TemporaryDirectory() as temp_path:
|
238 |
+
path = config.img_save_path or config.sample.path or temp_path
|
239 |
+
if accelerator.is_main_process:
|
240 |
+
os.makedirs(path, exist_ok=True)
|
241 |
+
logging.info(f'Samples are saved in {path}')
|
242 |
+
|
243 |
+
clip_score_list = utils.sample2dir_wCLIP(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, unpreprocess, return_clipScore=True, ClipSocre_model=ClipSocre_model, config=config)
|
244 |
+
if clip_score_list is not None:
|
245 |
+
_clip_score_list = torch.cat(clip_score_list)
|
246 |
+
if accelerator.is_main_process:
|
247 |
+
logging.info(f'nnet_path={config.nnet_path}, clip_score{len(_clip_score_list)}={_clip_score_list.mean().item()}')
|
248 |
+
|
249 |
+
|
250 |
+
FLAGS = flags.FLAGS
|
251 |
+
config_flags.DEFINE_config_file(
|
252 |
+
"config", None, "Training configuration.", lock_config=False)
|
253 |
+
|
254 |
+
flags.mark_flags_as_required(["config"])
|
255 |
+
flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
|
256 |
+
flags.DEFINE_string("output_path", None, "The path to output log.")
|
257 |
+
flags.DEFINE_float("cfg", -1, 'cfg scale, will use the scale defined in the config file is not assigned')
|
258 |
+
flags.DEFINE_string("img_save_path", None, "The path to image log.")
|
259 |
+
|
260 |
+
flags.DEFINE_string("test_type", None, "The prompt used for generation.")
|
261 |
+
|
262 |
+
flags.DEFINE_string("prompt_1", None, "The prompt used for linear interpolation.")
|
263 |
+
flags.DEFINE_string("prompt_2", None, "The prompt used for linear interpolation.")
|
264 |
+
flags.DEFINE_integer("num_of_interpolation", -1, 'number of image being samples for linear interpolation')
|
265 |
+
flags.DEFINE_boolean('save_gpu_memory', False, 'To save VRAM')
|
266 |
+
|
267 |
+
flags.DEFINE_string("prompt_ori", None, "The prompt used for arithmetic operations.")
|
268 |
+
flags.DEFINE_string("prompt_a", None, "The prompt used for arithmetic operations (addition).")
|
269 |
+
flags.DEFINE_string("prompt_s", None, "The prompt used for arithmetic operations (subtraction).")
|
270 |
+
|
271 |
+
|
272 |
+
def main(argv):
|
273 |
+
config = FLAGS.config
|
274 |
+
config.nnet_path = FLAGS.nnet_path
|
275 |
+
config.output_path = FLAGS.output_path
|
276 |
+
config.img_save_path = FLAGS.img_save_path
|
277 |
+
config.cfg = FLAGS.cfg
|
278 |
+
config.test_type = FLAGS.test_type
|
279 |
+
config.prompt_1 = FLAGS.prompt_1
|
280 |
+
config.prompt_2 = FLAGS.prompt_2
|
281 |
+
config.num_of_interpolation = FLAGS.num_of_interpolation
|
282 |
+
config.save_gpu_memory = FLAGS.save_gpu_memory
|
283 |
+
config.prompt_ori = FLAGS.prompt_ori
|
284 |
+
config.prompt_a = FLAGS.prompt_a
|
285 |
+
config.prompt_s = FLAGS.prompt_s
|
286 |
+
evaluate(config)
|
287 |
+
|
288 |
+
|
289 |
+
if __name__ == "__main__":
|
290 |
+
app.run(main)
|
diffusion/base_solver.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains the solver base class, including the cfg indicator
|
3 |
+
"""
|
4 |
+
|
5 |
+
import enum
|
6 |
+
import logging
|
7 |
+
from collections import defaultdict
|
8 |
+
from typing import Callable, Dict, List, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
import random
|
14 |
+
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
_default_cfg_processor = {"caption": lambda x, T, t: x}
|
19 |
+
|
20 |
+
|
21 |
+
class ConditionTypes(enum.Enum):
|
22 |
+
IMAGE_EMBED: str = "image_conditioning" # not implemented yet
|
23 |
+
TEXT_EMBED: str = "caption"
|
24 |
+
HINT_EMBED: str = "hint" # not implemented yet
|
25 |
+
|
26 |
+
class Solver:
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
model_fn,
|
30 |
+
bdv_model_fn=None,
|
31 |
+
schedule="linear",
|
32 |
+
conditioning_types: List[str] = ["caption"],
|
33 |
+
guidance_scale: Union[float, Dict[ConditionTypes, float]] = 1.0,
|
34 |
+
cfg_processor: Callable = _default_cfg_processor,
|
35 |
+
**kwargs,
|
36 |
+
):
|
37 |
+
self.model = model_fn
|
38 |
+
self.bdv_model = bdv_model_fn
|
39 |
+
self.schedule = schedule
|
40 |
+
# This list (conditioning_types) is important to decide which conditioning variable is given the priority
|
41 |
+
# For multi_cfg with 2 variables c,i, the cfg equation is
|
42 |
+
# output = e(null,null) + scale_c * (e(i,c) - e(i,null)) + scale_i * (e(i,null) - e(null,null))
|
43 |
+
# Note that the marginalization can be changed slightly to obtain a different equation
|
44 |
+
# output = e(null,null) + scale_i * (e(c,i) - e(c,null)) + scale_c * (e(c,null) - e(null,null))
|
45 |
+
# The order of the conditioning variables in the list decides which of the two equations above are used
|
46 |
+
# If the list is ["image", "caption"] then the first equation is used and
|
47 |
+
# if the list is ["caption", "image"] then the second is used
|
48 |
+
self.condition_types = [ConditionTypes(el) for el in conditioning_types]
|
49 |
+
|
50 |
+
self.unconditional_guidance_scale = guidance_scale
|
51 |
+
if isinstance(guidance_scale, dict):
|
52 |
+
self.unconditional_guidance_scale = {
|
53 |
+
ConditionTypes(k): v for k, v in guidance_scale.items()
|
54 |
+
}
|
55 |
+
else:
|
56 |
+
# If a single float is provided, we assume it is for text conditioning
|
57 |
+
self.unconditional_guidance_scale = {
|
58 |
+
ConditionTypes.TEXT_EMBED: guidance_scale
|
59 |
+
}
|
60 |
+
assert all(
|
61 |
+
[
|
62 |
+
el in self.unconditional_guidance_scale.keys()
|
63 |
+
for el in self.condition_types
|
64 |
+
]
|
65 |
+
)
|
66 |
+
self.cfg_processor = cfg_processor
|
67 |
+
if self.cfg_processor is None:
|
68 |
+
self.cfg_processor = _default_cfg_processor
|
69 |
+
if isinstance(self.cfg_processor, dict):
|
70 |
+
assert all(callable(v) for k, v in self.cfg_processor.items())
|
71 |
+
self.cfg_processor = {
|
72 |
+
ConditionTypes(k): v for k, v in self.cfg_processor.items()
|
73 |
+
}
|
74 |
+
else:
|
75 |
+
assert callable(self.cfg_processor)
|
76 |
+
self.cfg_processor = {ConditionTypes.TEXT_EMBED: cfg_processor}
|
77 |
+
|
78 |
+
if self.cfg_processor is not None:
|
79 |
+
assert all([el in self.cfg_processor.keys() for el in self.condition_types])
|
80 |
+
self.inf_steps_completed = 0
|
81 |
+
|
82 |
+
@property
|
83 |
+
def device(self):
|
84 |
+
return self.model.device
|
85 |
+
|
86 |
+
def register_buffer(self, name, attr):
|
87 |
+
if isinstance(attr, torch.Tensor):
|
88 |
+
attr = attr.to(self.device)
|
89 |
+
setattr(self, name, attr)
|
90 |
+
|
91 |
+
def _check_the_conditioning(self, conditioning, batch_size):
|
92 |
+
# Checks if batch sizes match
|
93 |
+
if conditioning is not None:
|
94 |
+
if isinstance(conditioning, dict):
|
95 |
+
ctmp = conditioning[list(conditioning.keys())[0]]
|
96 |
+
while isinstance(ctmp, list):
|
97 |
+
ctmp = ctmp[0]
|
98 |
+
if isinstance(ctmp, dict):
|
99 |
+
if isinstance(ctmp["c"], list):
|
100 |
+
cbs = ctmp["c"][0].shape[0]
|
101 |
+
else:
|
102 |
+
cbs = ctmp["c"].shape[0]
|
103 |
+
else:
|
104 |
+
cbs = ctmp.shape[0]
|
105 |
+
if cbs != batch_size:
|
106 |
+
logger.info(
|
107 |
+
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
108 |
+
)
|
109 |
+
|
110 |
+
elif isinstance(conditioning, list):
|
111 |
+
for ctmp in conditioning:
|
112 |
+
if ctmp.shape[0] != batch_size:
|
113 |
+
logger.info(
|
114 |
+
f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}"
|
115 |
+
)
|
116 |
+
|
117 |
+
else:
|
118 |
+
if conditioning.shape[0] != batch_size:
|
119 |
+
logger.info(
|
120 |
+
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
121 |
+
)
|
122 |
+
|
123 |
+
def sample(
|
124 |
+
self,
|
125 |
+
sample_steps,
|
126 |
+
batch_size,
|
127 |
+
sampling_method,
|
128 |
+
unconditional_guidance_scale,
|
129 |
+
has_null_indicator,
|
130 |
+
shape=None, # no longer use it
|
131 |
+
callback=None,
|
132 |
+
normals_sequence=None,
|
133 |
+
img_callback=None,
|
134 |
+
quantize_x0=False,
|
135 |
+
eta=0.0,
|
136 |
+
mask=None,
|
137 |
+
x0=None,
|
138 |
+
temperature=1.0,
|
139 |
+
noise_dropout=0.0,
|
140 |
+
verbose=True,
|
141 |
+
x_T=None,
|
142 |
+
log_every_t=100,
|
143 |
+
dynamic_threshold=None,
|
144 |
+
ucg_schedule=None,
|
145 |
+
t_schedule=None, # Default value is set below
|
146 |
+
skip_type=None, # Deprecated, kept for backward compatibility. Use `t_schedule` instead.
|
147 |
+
start_timestep=None,
|
148 |
+
num_timesteps=None,
|
149 |
+
do_make_schedule=True,
|
150 |
+
**kwargs,
|
151 |
+
):
|
152 |
+
self.num_inf_timesteps = sample_steps
|
153 |
+
assert skip_type is None
|
154 |
+
|
155 |
+
t_schedule = t_schedule or "time_uniform"
|
156 |
+
|
157 |
+
if self.unconditional_guidance_scale is None:
|
158 |
+
self.unconditional_guidance_scale = unconditional_guidance_scale
|
159 |
+
|
160 |
+
assert isinstance(sampling_method, Callable)
|
161 |
+
samples, intermediates = sampling_method(
|
162 |
+
x_T=x_T,
|
163 |
+
# Hardcoded in PLMS file
|
164 |
+
ddim_use_original_steps=False,
|
165 |
+
callback=callback,
|
166 |
+
num_timesteps=num_timesteps,
|
167 |
+
quantize_denoised=quantize_x0,
|
168 |
+
mask=mask,
|
169 |
+
x0=x0,
|
170 |
+
img_callback=img_callback,
|
171 |
+
log_every_t=log_every_t,
|
172 |
+
temperature=temperature,
|
173 |
+
noise_dropout=noise_dropout,
|
174 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
175 |
+
has_null_indicator=has_null_indicator,
|
176 |
+
dynamic_threshold=dynamic_threshold,
|
177 |
+
verbose=verbose,
|
178 |
+
ucg_schedule=ucg_schedule,
|
179 |
+
start_timestep=start_timestep,
|
180 |
+
)
|
181 |
+
return samples, intermediates
|
182 |
+
|
183 |
+
@torch.no_grad()
|
184 |
+
def get_model_output_dimr(
|
185 |
+
self,
|
186 |
+
x,
|
187 |
+
t_continuous,
|
188 |
+
unconditional_guidance_scale,
|
189 |
+
has_null_indicator,
|
190 |
+
):
|
191 |
+
|
192 |
+
log_snr = 4 - t_continuous * 8 # inversed
|
193 |
+
|
194 |
+
if has_null_indicator:
|
195 |
+
_cond = self.model(x, t=t_continuous, log_snr=log_snr, null_indicator=torch.tensor([False] * x.shape[0]).to(x.device))[-1]
|
196 |
+
_uncond = self.model(x, t=t_continuous, log_snr=log_snr, null_indicator=torch.tensor([True] * x.shape[0]).to(x.device))[-1]
|
197 |
+
|
198 |
+
assert unconditional_guidance_scale > 1
|
199 |
+
return _uncond + unconditional_guidance_scale * (_cond - _uncond)
|
200 |
+
else:
|
201 |
+
_cond = self.model(x, log_snr=log_snr)[-1]
|
202 |
+
return _cond
|
203 |
+
|
diffusion/flow_matching.py
ADDED
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import logging
|
3 |
+
from typing import Callable, Dict, Optional, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
import torchdiffeq
|
9 |
+
import random
|
10 |
+
|
11 |
+
from sde import multi_scale_targets
|
12 |
+
from diffusion.base_solver import Solver
|
13 |
+
import numpy as np
|
14 |
+
from torchvision import transforms
|
15 |
+
|
16 |
+
|
17 |
+
def check_zip(*args):
|
18 |
+
args = [list(arg) for arg in args]
|
19 |
+
length = len(args[0])
|
20 |
+
for arg in args:
|
21 |
+
assert len(arg) == length
|
22 |
+
return zip(*args)
|
23 |
+
|
24 |
+
|
25 |
+
def kl_divergence(source, target):
|
26 |
+
q_raw = source.view(-1)
|
27 |
+
p_raw = target.view(-1)
|
28 |
+
|
29 |
+
p = F.softmax(p_raw, dim=0)
|
30 |
+
q = F.softmax(q_raw, dim=0)
|
31 |
+
|
32 |
+
|
33 |
+
q_log = torch.log(q)
|
34 |
+
kl_div_1 = F.kl_div(q_log, p, reduction='sum')
|
35 |
+
|
36 |
+
return kl_div_1
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
class TimeStepSampler:
|
41 |
+
"""
|
42 |
+
Abstract class to sample timesteps for flow matching.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def sample_time(self, x_start):
|
46 |
+
# In flow matching, time is in range [0, 1] and 1 indicates the original image; 0 is pure noise
|
47 |
+
# this convention is *REVERSE* of diffusion
|
48 |
+
raise NotImplementedError
|
49 |
+
|
50 |
+
class ClipLoss(nn.Module):
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
local_loss=False,
|
55 |
+
gather_with_grad=False,
|
56 |
+
cache_labels=False,
|
57 |
+
rank=0,
|
58 |
+
world_size=1,
|
59 |
+
use_horovod=False,
|
60 |
+
):
|
61 |
+
super().__init__()
|
62 |
+
self.local_loss = local_loss
|
63 |
+
self.gather_with_grad = gather_with_grad
|
64 |
+
self.cache_labels = cache_labels
|
65 |
+
self.rank = rank
|
66 |
+
self.world_size = world_size
|
67 |
+
self.use_horovod = use_horovod
|
68 |
+
|
69 |
+
# cache state
|
70 |
+
self.prev_num_logits = 0
|
71 |
+
self.labels = {}
|
72 |
+
|
73 |
+
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
|
74 |
+
# calculated ground-truth and cache if enabled
|
75 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
76 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
77 |
+
if self.world_size > 1 and self.local_loss:
|
78 |
+
labels = labels + num_logits * self.rank
|
79 |
+
if self.cache_labels:
|
80 |
+
self.labels[device] = labels
|
81 |
+
self.prev_num_logits = num_logits
|
82 |
+
else:
|
83 |
+
labels = self.labels[device]
|
84 |
+
return labels
|
85 |
+
|
86 |
+
def get_logits(self, image_features, text_features, logit_scale):
|
87 |
+
if self.world_size > 1:
|
88 |
+
all_image_features, all_text_features = gather_features(
|
89 |
+
image_features, text_features,
|
90 |
+
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
91 |
+
|
92 |
+
if self.local_loss:
|
93 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
94 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
95 |
+
else:
|
96 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
97 |
+
logits_per_text = logits_per_image.T
|
98 |
+
else:
|
99 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
100 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
101 |
+
|
102 |
+
return logits_per_image, logits_per_text
|
103 |
+
|
104 |
+
def forward(self, image_features, text_features, logit_scale, output_dict=False):
|
105 |
+
device = image_features.device
|
106 |
+
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
|
107 |
+
|
108 |
+
labels = self.get_ground_truth(device, logits_per_image.shape[0])
|
109 |
+
|
110 |
+
total_loss = (
|
111 |
+
F.cross_entropy(logits_per_image, labels) +
|
112 |
+
F.cross_entropy(logits_per_text, labels)
|
113 |
+
) / 2
|
114 |
+
|
115 |
+
return {"contrastive_loss": total_loss} if output_dict else total_loss
|
116 |
+
|
117 |
+
|
118 |
+
class SigLipLoss(nn.Module):
|
119 |
+
""" Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343
|
120 |
+
|
121 |
+
@article{zhai2023sigmoid,
|
122 |
+
title={Sigmoid loss for language image pre-training},
|
123 |
+
author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
|
124 |
+
journal={arXiv preprint arXiv:2303.15343},
|
125 |
+
year={2023}
|
126 |
+
}
|
127 |
+
"""
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
cache_labels=False,
|
131 |
+
rank=0,
|
132 |
+
world_size=1,
|
133 |
+
bidir=True,
|
134 |
+
use_horovod=False,
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
self.cache_labels = cache_labels
|
138 |
+
self.rank = rank
|
139 |
+
self.world_size = world_size
|
140 |
+
assert not use_horovod # FIXME need to look at hvd ops for ring transfers
|
141 |
+
self.use_horovod = use_horovod
|
142 |
+
self.bidir = bidir
|
143 |
+
|
144 |
+
# cache state FIXME cache not currently used, worthwhile?
|
145 |
+
self.prev_num_logits = 0
|
146 |
+
self.labels = {}
|
147 |
+
|
148 |
+
def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:
|
149 |
+
labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)
|
150 |
+
if not negative_only:
|
151 |
+
labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels
|
152 |
+
return labels
|
153 |
+
|
154 |
+
def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
|
155 |
+
logits = logit_scale * image_features @ text_features.T
|
156 |
+
if logit_bias is not None:
|
157 |
+
logits += logit_bias
|
158 |
+
return logits
|
159 |
+
|
160 |
+
def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
|
161 |
+
logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
|
162 |
+
labels = self.get_ground_truth(
|
163 |
+
image_features.device,
|
164 |
+
image_features.dtype,
|
165 |
+
image_features.shape[0],
|
166 |
+
negative_only=negative_only,
|
167 |
+
)
|
168 |
+
loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
|
169 |
+
return loss
|
170 |
+
|
171 |
+
def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False):
|
172 |
+
loss = self._loss(image_features, text_features, logit_scale, logit_bias)
|
173 |
+
|
174 |
+
if self.world_size > 1:
|
175 |
+
# exchange text features w/ neighbour world_size - 1 times
|
176 |
+
right_rank = (self.rank + 1) % self.world_size
|
177 |
+
left_rank = (self.rank - 1 + self.world_size) % self.world_size
|
178 |
+
if self.bidir:
|
179 |
+
text_features_to_right = text_features_to_left = text_features
|
180 |
+
num_bidir, remainder = divmod(self.world_size - 1, 2)
|
181 |
+
for i in range(num_bidir):
|
182 |
+
text_features_recv = neighbour_exchange_bidir_with_grad(
|
183 |
+
left_rank,
|
184 |
+
right_rank,
|
185 |
+
text_features_to_left,
|
186 |
+
text_features_to_right,
|
187 |
+
)
|
188 |
+
|
189 |
+
for f in text_features_recv:
|
190 |
+
loss += self._loss(
|
191 |
+
image_features,
|
192 |
+
f,
|
193 |
+
logit_scale,
|
194 |
+
logit_bias,
|
195 |
+
negative_only=True,
|
196 |
+
)
|
197 |
+
text_features_to_left, text_features_to_right = text_features_recv
|
198 |
+
|
199 |
+
if remainder:
|
200 |
+
text_features_recv = neighbour_exchange_with_grad(
|
201 |
+
left_rank, right_rank, text_features_to_right)
|
202 |
+
|
203 |
+
loss += self._loss(
|
204 |
+
image_features,
|
205 |
+
text_features_recv,
|
206 |
+
logit_scale,
|
207 |
+
logit_bias,
|
208 |
+
negative_only=True,
|
209 |
+
)
|
210 |
+
else:
|
211 |
+
text_features_to_right = text_features
|
212 |
+
for i in range(self.world_size - 1):
|
213 |
+
text_features_from_left = neighbour_exchange_with_grad(
|
214 |
+
left_rank, right_rank, text_features_to_right)
|
215 |
+
|
216 |
+
loss += self._loss(
|
217 |
+
image_features,
|
218 |
+
text_features_from_left,
|
219 |
+
logit_scale,
|
220 |
+
logit_bias,
|
221 |
+
negative_only=True,
|
222 |
+
)
|
223 |
+
text_features_to_right = text_features_from_left
|
224 |
+
|
225 |
+
return {"contrastive_loss": loss} if output_dict else loss
|
226 |
+
|
227 |
+
|
228 |
+
class ResolutionScaledTimeStepSampler(TimeStepSampler):
|
229 |
+
def __init__(self, scale: float, base_time_step_sampler: TimeStepSampler):
|
230 |
+
self.scale = scale
|
231 |
+
self.base_time_step_sampler = base_time_step_sampler
|
232 |
+
|
233 |
+
@torch.no_grad()
|
234 |
+
def sample_time(self, x_start):
|
235 |
+
base_time = self.base_time_step_sampler.sample_time(x_start)
|
236 |
+
# based on eq (23) of https://arxiv.org/abs/2403.03206
|
237 |
+
scaled_time = (base_time * self.scale) / (1 + (self.scale - 1) * base_time)
|
238 |
+
return scaled_time
|
239 |
+
|
240 |
+
|
241 |
+
class LogitNormalSampler(TimeStepSampler):
|
242 |
+
def __init__(self, normal_mean: float = 0, normal_std: float = 1):
|
243 |
+
# follows https://arxiv.org/pdf/2403.03206.pdf
|
244 |
+
# sample from a normal distribution
|
245 |
+
# pass the output through standard logistic function, i.e., sigmoid
|
246 |
+
self.normal_mean = float(normal_mean)
|
247 |
+
self.normal_std = float(normal_std)
|
248 |
+
|
249 |
+
@torch.no_grad()
|
250 |
+
def sample_time(self, x_start):
|
251 |
+
x_normal = torch.normal(
|
252 |
+
mean=self.normal_mean,
|
253 |
+
std=self.normal_std,
|
254 |
+
size=(x_start.shape[0],),
|
255 |
+
device=x_start.device,
|
256 |
+
)
|
257 |
+
x_logistic = torch.nn.functional.sigmoid(x_normal)
|
258 |
+
return x_logistic
|
259 |
+
|
260 |
+
|
261 |
+
class UniformTimeSampler(TimeStepSampler):
|
262 |
+
@torch.no_grad()
|
263 |
+
def sample_time(self, x_start):
|
264 |
+
# [0, 1] and 1 indicates the original image; 0 is pure noise
|
265 |
+
return torch.rand(x_start.shape[0], device=x_start.device)
|
266 |
+
|
267 |
+
|
268 |
+
class FlowMatching(nn.Module):
|
269 |
+
def __init__(
|
270 |
+
self,
|
271 |
+
sigma_min: float = 1e-5,
|
272 |
+
sigma_max: float = 1.0,
|
273 |
+
timescale: float = 1.0,
|
274 |
+
**kwargs,
|
275 |
+
):
|
276 |
+
# LatentDiffusion/DDPM will create too many class variables we do not need
|
277 |
+
super().__init__(**kwargs)
|
278 |
+
self.time_step_sampler = LogitNormalSampler()
|
279 |
+
self.sigma_min = sigma_min
|
280 |
+
self.sigma_max = sigma_max
|
281 |
+
self.timescale = timescale
|
282 |
+
|
283 |
+
self.clip_loss = ClipLoss()
|
284 |
+
# self.SigLipLoss = SigLipLoss()
|
285 |
+
|
286 |
+
self.resizer = transforms.Resize(256) # for clip
|
287 |
+
|
288 |
+
def sample_noise(self, x_start):
|
289 |
+
# simple IID noise
|
290 |
+
return torch.randn_like(x_start, device=x_start.device) * self.sigma_max
|
291 |
+
|
292 |
+
def mos(self, err, start_dim=1, con_mask=None): # mean of square
|
293 |
+
if con_mask is not None:
|
294 |
+
return (err.pow(2).mean(dim=-1) * con_mask).sum(dim=-1) / con_mask.sum(dim=-1)
|
295 |
+
else:
|
296 |
+
return err.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
|
297 |
+
|
298 |
+
|
299 |
+
def Xentropy(self, pred, tar, con_mask=None):
|
300 |
+
if con_mask is not None:
|
301 |
+
return (nn.functional.cross_entropy(pred, tar, reduction='none') * con_mask).sum(dim=-1) / con_mask.sum(dim=-1)
|
302 |
+
else:
|
303 |
+
return nn.functional.cross_entropy(pred, tar, reduction='none').mean(dim=-1)
|
304 |
+
|
305 |
+
def l2_reg(self, pred, lam = 0.0001):
|
306 |
+
return lam * torch.norm(pred, p=2, dim=(1, 2, 3)) ** 2
|
307 |
+
|
308 |
+
# model forward and prediction
|
309 |
+
def forward(
|
310 |
+
self,
|
311 |
+
x,
|
312 |
+
nnet,
|
313 |
+
loss_coeffs,
|
314 |
+
cond,
|
315 |
+
con_mask,
|
316 |
+
nnet_style,
|
317 |
+
training_step,
|
318 |
+
cond_ori=None, # not using
|
319 |
+
con_mask_ori=None, # not using
|
320 |
+
batch_img_clip=None, # not using
|
321 |
+
model_config=None,
|
322 |
+
all_config=None,
|
323 |
+
text_token=None,
|
324 |
+
return_raw_loss=False,
|
325 |
+
additional_embeddings=None,
|
326 |
+
timesteps: Optional[Tuple[int, int]] = None,
|
327 |
+
*args,
|
328 |
+
**kwargs,
|
329 |
+
):
|
330 |
+
assert timesteps is None, "timesteps must be None"
|
331 |
+
|
332 |
+
timesteps = self.time_step_sampler.sample_time(x)
|
333 |
+
|
334 |
+
if nnet_style == 'dimr':
|
335 |
+
if hasattr(model_config, "standard_diffusion") and model_config.standard_diffusion:
|
336 |
+
standard_diffusion=True
|
337 |
+
else:
|
338 |
+
standard_diffusion=False
|
339 |
+
return self.p_losses_textVAE(
|
340 |
+
x, cond, con_mask, timesteps, nnet, batch_img_clip=batch_img_clip, cond_ori=cond_ori, con_mask_ori=con_mask_ori, text_token=text_token, loss_coeffs=loss_coeffs, return_raw_loss=return_raw_loss, nnet_style=nnet_style, standard_diffusion=standard_diffusion, all_config=all_config, training_step=training_step, *args, **kwargs
|
341 |
+
)
|
342 |
+
elif nnet_style == 'dit':
|
343 |
+
if hasattr(model_config, "standard_diffusion") and model_config.standard_diffusion:
|
344 |
+
standard_diffusion=True
|
345 |
+
raise NotImplementedError("need update")
|
346 |
+
else:
|
347 |
+
standard_diffusion=False
|
348 |
+
return self.p_losses_textVAE_dit(
|
349 |
+
x, cond, con_mask, timesteps, nnet, batch_img_clip=batch_img_clip, cond_ori=cond_ori, con_mask_ori=con_mask_ori, text_token=text_token, loss_coeffs=loss_coeffs, return_raw_loss=return_raw_loss, nnet_style=nnet_style, standard_diffusion=standard_diffusion, all_config=all_config, training_step=training_step, *args, **kwargs
|
350 |
+
)
|
351 |
+
else:
|
352 |
+
raise NotImplementedError
|
353 |
+
|
354 |
+
|
355 |
+
|
356 |
+
def p_losses_textVAE(
|
357 |
+
self,
|
358 |
+
x_start,
|
359 |
+
cond,
|
360 |
+
con_mask,
|
361 |
+
t,
|
362 |
+
nnet,
|
363 |
+
loss_coeffs,
|
364 |
+
training_step,
|
365 |
+
text_token=None,
|
366 |
+
nnet_style=None,
|
367 |
+
all_config=None,
|
368 |
+
batch_img_clip=None,
|
369 |
+
cond_ori=None, # not using
|
370 |
+
con_mask_ori=None, # not using
|
371 |
+
return_raw_loss=False,
|
372 |
+
additional_embeddings=None,
|
373 |
+
standard_diffusion=False,
|
374 |
+
noise=None,
|
375 |
+
):
|
376 |
+
"""
|
377 |
+
CrossFlow training for DiMR
|
378 |
+
"""
|
379 |
+
|
380 |
+
assert noise is None
|
381 |
+
|
382 |
+
x0, mu, log_var = nnet(cond, text_encoder = True, shape = x_start.shape, mask = con_mask)
|
383 |
+
|
384 |
+
############ loss for Text VE
|
385 |
+
if batch_img_clip.shape[-1] == 512:
|
386 |
+
recon_gt = self.resizer(batch_img_clip)
|
387 |
+
else:
|
388 |
+
recon_gt = batch_img_clip
|
389 |
+
recon_gt_clip, logit_scale = nnet(recon_gt, image_clip = True)
|
390 |
+
image_features = recon_gt_clip / recon_gt_clip.norm(dim=-1, keepdim=True)
|
391 |
+
text_features = x0 / x0.norm(dim=-1, keepdim=True)
|
392 |
+
recons_loss = self.clip_loss(image_features, text_features, logit_scale)
|
393 |
+
|
394 |
+
# kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1)
|
395 |
+
kld_loss = -0.5 * torch.sum(1 + log_var - (0.3 * mu) ** 6 - log_var.exp(), dim = 1) # slightly different KL loss function: mu -> 0 [(0.3*mu) ** 6] and var -> 1
|
396 |
+
kld_loss_weight = 1e-2 # 0.0005
|
397 |
+
|
398 |
+
loss_mlp = recons_loss + kld_loss * kld_loss_weight
|
399 |
+
|
400 |
+
|
401 |
+
############ loss for FM
|
402 |
+
noise = x0.reshape(x_start.shape)
|
403 |
+
|
404 |
+
if hasattr(all_config.nnet.model_args, "cfg_indicator"):
|
405 |
+
null_indicator = torch.from_numpy(np.array([random.random() < all_config.nnet.model_args.cfg_indicator for _ in range(x_start.shape[0])])).to(x_start.device)
|
406 |
+
if null_indicator.sum()<=1:
|
407 |
+
null_indicator[null_indicator==True] = False
|
408 |
+
assert null_indicator.sum() == 0
|
409 |
+
pass
|
410 |
+
else:
|
411 |
+
target_null = x_start[null_indicator]
|
412 |
+
target_null = torch.cat((target_null[1:], target_null[:1]))
|
413 |
+
x_start[null_indicator] = target_null
|
414 |
+
else:
|
415 |
+
null_indicator = None
|
416 |
+
|
417 |
+
|
418 |
+
x_noisy = self.psi(t, x=noise, x1=x_start)
|
419 |
+
target_velocity = self.Dt_psi(t, x=noise, x1=x_start)
|
420 |
+
log_snr = 4 - t * 8 # compute from timestep : inversed
|
421 |
+
|
422 |
+
prediction = nnet(x_noisy, log_snr = log_snr, null_indicator=null_indicator)
|
423 |
+
|
424 |
+
target = multi_scale_targets(target_velocity, levels = len(prediction), scale_correction = True)
|
425 |
+
|
426 |
+
loss_diff = 0
|
427 |
+
for pred, coeff in check_zip(prediction, loss_coeffs):
|
428 |
+
loss_diff = loss_diff + coeff * self.mos(pred - target[pred.shape[-1]])
|
429 |
+
|
430 |
+
###########
|
431 |
+
|
432 |
+
loss = loss_diff + loss_mlp
|
433 |
+
|
434 |
+
return loss, {'loss_diff': loss_diff, 'clip_loss': recons_loss, 'kld_loss': kld_loss, 'kld_loss_weight': torch.tensor(kld_loss_weight, device=kld_loss.device), 'clip_logit_scale': logit_scale}
|
435 |
+
|
436 |
+
|
437 |
+
def p_losses_textVAE_dit(
|
438 |
+
self,
|
439 |
+
x_start,
|
440 |
+
cond,
|
441 |
+
con_mask,
|
442 |
+
t,
|
443 |
+
nnet,
|
444 |
+
loss_coeffs,
|
445 |
+
training_step,
|
446 |
+
text_token=None,
|
447 |
+
nnet_style=None,
|
448 |
+
all_config=None,
|
449 |
+
batch_img_clip=None,
|
450 |
+
cond_ori=None, # not using
|
451 |
+
con_mask_ori=None, # not using
|
452 |
+
return_raw_loss=False,
|
453 |
+
additional_embeddings=None,
|
454 |
+
standard_diffusion=False,
|
455 |
+
noise=None,
|
456 |
+
):
|
457 |
+
"""
|
458 |
+
CrossFLow training for DiT
|
459 |
+
"""
|
460 |
+
|
461 |
+
assert noise is None
|
462 |
+
|
463 |
+
x0, mu, log_var = nnet(cond, text_encoder = True, shape = x_start.shape, mask = con_mask)
|
464 |
+
|
465 |
+
############ loss for Text VE
|
466 |
+
if batch_img_clip.shape[-1] == 512:
|
467 |
+
recon_gt = self.resizer(batch_img_clip)
|
468 |
+
else:
|
469 |
+
recon_gt = batch_img_clip
|
470 |
+
recon_gt_clip, logit_scale = nnet(recon_gt, image_clip = True)
|
471 |
+
image_features = recon_gt_clip / recon_gt_clip.norm(dim=-1, keepdim=True)
|
472 |
+
text_features = x0 / x0.norm(dim=-1, keepdim=True)
|
473 |
+
recons_loss = self.clip_loss(image_features, text_features, logit_scale)
|
474 |
+
|
475 |
+
# kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1)
|
476 |
+
kld_loss = -0.5 * torch.sum(1 + log_var - (0.3 * mu) ** 6 - log_var.exp(), dim = 1)
|
477 |
+
kld_loss_weight = 1e-2 # 0.0005
|
478 |
+
|
479 |
+
loss_mlp = recons_loss + kld_loss * kld_loss_weight
|
480 |
+
|
481 |
+
############ loss for FM
|
482 |
+
noise = x0.reshape(x_start.shape)
|
483 |
+
|
484 |
+
if hasattr(all_config.nnet.model_args, "cfg_indicator"):
|
485 |
+
null_indicator = torch.from_numpy(np.array([random.random() < all_config.nnet.model_args.cfg_indicator for _ in range(x_start.shape[0])])).to(x_start.device)
|
486 |
+
if null_indicator.sum()<=1:
|
487 |
+
null_indicator[null_indicator==True] = False
|
488 |
+
assert null_indicator.sum() == 0
|
489 |
+
pass
|
490 |
+
else:
|
491 |
+
target_null = x_start[null_indicator]
|
492 |
+
target_null = torch.cat((target_null[1:], target_null[:1]))
|
493 |
+
x_start[null_indicator] = target_null
|
494 |
+
else:
|
495 |
+
null_indicator = None
|
496 |
+
|
497 |
+
x_noisy = self.psi(t, x=noise, x1=x_start)
|
498 |
+
target_velocity = self.Dt_psi(t, x=noise, x1=x_start)
|
499 |
+
|
500 |
+
prediction = nnet(x_noisy, t = t, null_indicator = null_indicator)[0]
|
501 |
+
|
502 |
+
loss_diff = self.mos(prediction - target_velocity)
|
503 |
+
|
504 |
+
###########
|
505 |
+
|
506 |
+
loss = loss_diff + loss_mlp
|
507 |
+
|
508 |
+
return loss, {'loss_diff': loss_diff, 'clip_loss': recons_loss, 'kld_loss': kld_loss, 'kld_loss_weight': torch.tensor(kld_loss_weight, device=kld_loss.device), 'clip_logit_scale': logit_scale}
|
509 |
+
|
510 |
+
|
511 |
+
## flow matching specific functions
|
512 |
+
def psi(self, t, x, x1):
|
513 |
+
assert (
|
514 |
+
t.shape[0] == x.shape[0]
|
515 |
+
), f"Batch size of t and x does not agree {t.shape[0]} vs. {x.shape[0]}"
|
516 |
+
assert (
|
517 |
+
t.shape[0] == x1.shape[0]
|
518 |
+
), f"Batch size of t and x1 does not agree {t.shape[0]} vs. {x1.shape[0]}"
|
519 |
+
assert t.ndim == 1
|
520 |
+
t = self.expand_t(t, x)
|
521 |
+
return (t * (self.sigma_min / self.sigma_max - 1) + 1) * x + t * x1
|
522 |
+
|
523 |
+
def Dt_psi(self, t: torch.Tensor, x: torch.Tensor, x1: torch.Tensor):
|
524 |
+
assert x.shape[0] == x1.shape[0]
|
525 |
+
return (self.sigma_min / self.sigma_max - 1) * x + x1
|
526 |
+
|
527 |
+
def expand_t(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
528 |
+
t_expanded = t
|
529 |
+
while t_expanded.ndim < x.ndim:
|
530 |
+
t_expanded = t_expanded.unsqueeze(-1)
|
531 |
+
return t_expanded.expand_as(x)
|
532 |
+
|
533 |
+
|
534 |
+
|
535 |
+
|
536 |
+
class ODEEulerFlowMatchingSolver(Solver):
|
537 |
+
"""
|
538 |
+
ODE Solver for Flow matching that uses an Euler discretization
|
539 |
+
Supports number of time steps at inference
|
540 |
+
"""
|
541 |
+
|
542 |
+
def __init__(self, *args, **kwargs):
|
543 |
+
super().__init__(*args, **kwargs)
|
544 |
+
self.step_size_type = kwargs.get("step_size_type", "step_in_dsigma")
|
545 |
+
assert self.step_size_type in ["step_in_dsigma", "step_in_dt"]
|
546 |
+
self.sample_timescale = 1.0 - 1e-5
|
547 |
+
|
548 |
+
@torch.no_grad()
|
549 |
+
def sample_euler(
|
550 |
+
self,
|
551 |
+
x_T,
|
552 |
+
unconditional_guidance_scale,
|
553 |
+
has_null_indicator,
|
554 |
+
t=[0, 1.0],
|
555 |
+
**kwargs,
|
556 |
+
):
|
557 |
+
"""
|
558 |
+
Euler solver for flow matching.
|
559 |
+
Based on https://github.com/VinAIResearch/LFM/blob/main/sampler/karras_sample.py
|
560 |
+
"""
|
561 |
+
t = torch.tensor(t)
|
562 |
+
t = t * self.sample_timescale
|
563 |
+
sigma_min = 1e-5
|
564 |
+
sigma_max = 1.0
|
565 |
+
sigma_steps = torch.linspace(
|
566 |
+
sigma_min, sigma_max, self.num_time_steps + 1, device=x_T.device
|
567 |
+
)
|
568 |
+
discrete_time_steps_for_step = torch.linspace(
|
569 |
+
t[0], t[1], self.num_time_steps + 1, device=x_T.device
|
570 |
+
)
|
571 |
+
discrete_time_steps_to_eval_model_at = torch.linspace(
|
572 |
+
t[0], t[1], self.num_time_steps, device=x_T.device
|
573 |
+
)
|
574 |
+
|
575 |
+
print("num_time_steps : " + str(self.num_time_steps))
|
576 |
+
|
577 |
+
for i in range(self.num_time_steps):
|
578 |
+
t_i = discrete_time_steps_to_eval_model_at[i]
|
579 |
+
velocity = self.get_model_output_dimr(
|
580 |
+
x_T,
|
581 |
+
has_null_indicator = has_null_indicator,
|
582 |
+
t_continuous = t_i.repeat(x_T.shape[0]),
|
583 |
+
unconditional_guidance_scale = unconditional_guidance_scale,
|
584 |
+
)
|
585 |
+
if self.step_size_type == "step_in_dsigma":
|
586 |
+
step_size = sigma_steps[i + 1] - sigma_steps[i]
|
587 |
+
elif self.step_size_type == "step_in_dt":
|
588 |
+
step_size = (
|
589 |
+
discrete_time_steps_for_step[i + 1]
|
590 |
+
- discrete_time_steps_for_step[i]
|
591 |
+
)
|
592 |
+
x_T = x_T + velocity * step_size
|
593 |
+
|
594 |
+
intermediates = None
|
595 |
+
return x_T, intermediates
|
596 |
+
|
597 |
+
@torch.no_grad()
|
598 |
+
def sample(
|
599 |
+
self,
|
600 |
+
*args,
|
601 |
+
**kwargs,
|
602 |
+
):
|
603 |
+
assert kwargs.get("ucg_schedule", None) is None
|
604 |
+
assert kwargs.get("skip_type", None) is None
|
605 |
+
assert kwargs.get("dynamic_threshold", None) is None
|
606 |
+
assert kwargs.get("x0", None) is None
|
607 |
+
assert kwargs.get("x_T") is not None
|
608 |
+
assert kwargs.get("score_corrector", None) is None
|
609 |
+
assert kwargs.get("normals_sequence", None) is None
|
610 |
+
assert kwargs.get("callback", None) is None
|
611 |
+
assert kwargs.get("quantize_x0", False) is False
|
612 |
+
assert kwargs.get("eta", 0.0) == 0.0
|
613 |
+
assert kwargs.get("mask", None) is None
|
614 |
+
assert kwargs.get("noise_dropout", 0.0) == 0.0
|
615 |
+
|
616 |
+
self.num_time_steps = kwargs.get("sample_steps")
|
617 |
+
self.x_T_uncon = kwargs.get("x_T_uncon")
|
618 |
+
|
619 |
+
samples, intermediates = super().sample(
|
620 |
+
*args,
|
621 |
+
sampling_method=self.sample_euler,
|
622 |
+
do_make_schedule=False,
|
623 |
+
**kwargs,
|
624 |
+
)
|
625 |
+
return samples, intermediates
|
626 |
+
|
627 |
+
|
628 |
+
class ODEFlowMatchingSolver(Solver):
|
629 |
+
"""
|
630 |
+
ODE Solver for Flow matching that uses `dopri5`
|
631 |
+
Does not support number of time steps based control
|
632 |
+
"""
|
633 |
+
|
634 |
+
def __init__(self, *args, **kwargs):
|
635 |
+
super().__init__(*args, **kwargs)
|
636 |
+
self.sample_timescale = 1.0 - 1e-5
|
637 |
+
|
638 |
+
# sampling for inference
|
639 |
+
@torch.no_grad()
|
640 |
+
def sample_transport(
|
641 |
+
self,
|
642 |
+
x_T,
|
643 |
+
unconditional_guidance_scale,
|
644 |
+
has_null_indicator,
|
645 |
+
t=[0, 1.0],
|
646 |
+
ode_opts={},
|
647 |
+
**kwargs,
|
648 |
+
):
|
649 |
+
num_evals = 0
|
650 |
+
t = torch.tensor(t, device=x_T.device)
|
651 |
+
if "options" not in ode_opts:
|
652 |
+
ode_opts["options"] = {}
|
653 |
+
ode_opts["options"]["step_t"] = [self.sample_timescale + 1e-6]
|
654 |
+
|
655 |
+
def ode_func(t, x_T):
|
656 |
+
nonlocal num_evals
|
657 |
+
num_evals += 1
|
658 |
+
model_output = self.get_model_output_dimr(
|
659 |
+
x_T,
|
660 |
+
has_null_indicator = has_null_indicator,
|
661 |
+
t_continuous = t.repeat(x_T.shape[0]),
|
662 |
+
unconditional_guidance_scale = unconditional_guidance_scale,
|
663 |
+
)
|
664 |
+
return model_output
|
665 |
+
|
666 |
+
z = torchdiffeq.odeint(
|
667 |
+
ode_func,
|
668 |
+
x_T,
|
669 |
+
t * self.sample_timescale,
|
670 |
+
**{"atol": 1e-5, "rtol": 1e-5, "method": "dopri5", **ode_opts},
|
671 |
+
)
|
672 |
+
# first dimension of z contains solutions to different timepoints
|
673 |
+
# we only need the last one (corresponding to t=1, i.e., image)
|
674 |
+
z = z[-1]
|
675 |
+
intermediates = None
|
676 |
+
return z, intermediates
|
677 |
+
|
678 |
+
@torch.no_grad()
|
679 |
+
def sample(
|
680 |
+
self,
|
681 |
+
*args,
|
682 |
+
**kwargs,
|
683 |
+
):
|
684 |
+
assert kwargs.get("ucg_schedule", None) is None
|
685 |
+
assert kwargs.get("skip_type", None) is None
|
686 |
+
assert kwargs.get("dynamic_threshold", None) is None
|
687 |
+
assert kwargs.get("x0", None) is None
|
688 |
+
assert kwargs.get("x_T") is not None
|
689 |
+
assert kwargs.get("score_corrector", None) is None
|
690 |
+
assert kwargs.get("normals_sequence", None) is None
|
691 |
+
assert kwargs.get("callback", None) is None
|
692 |
+
assert kwargs.get("quantize_x0", False) is False
|
693 |
+
assert kwargs.get("eta", 0.0) == 0.0
|
694 |
+
assert kwargs.get("mask", None) is None
|
695 |
+
assert kwargs.get("noise_dropout", 0.0) == 0.0
|
696 |
+
samples, intermediates = super().sample(
|
697 |
+
*args,
|
698 |
+
sampling_method=self.sample_transport,
|
699 |
+
do_make_schedule=False,
|
700 |
+
**kwargs,
|
701 |
+
)
|
702 |
+
return samples, intermediates
|
libs/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# codes from third party
|
libs/autoencoder.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
|
7 |
+
class LinearAttention(nn.Module):
|
8 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
9 |
+
super().__init__()
|
10 |
+
self.heads = heads
|
11 |
+
hidden_dim = dim_head * heads
|
12 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
13 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
b, c, h, w = x.shape
|
17 |
+
qkv = self.to_qkv(x)
|
18 |
+
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
19 |
+
k = k.softmax(dim=-1)
|
20 |
+
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
21 |
+
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
22 |
+
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
23 |
+
return self.to_out(out)
|
24 |
+
|
25 |
+
|
26 |
+
def nonlinearity(x):
|
27 |
+
# swish
|
28 |
+
return x*torch.sigmoid(x)
|
29 |
+
|
30 |
+
|
31 |
+
def Normalize(in_channels, num_groups=32):
|
32 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
33 |
+
|
34 |
+
|
35 |
+
class Upsample(nn.Module):
|
36 |
+
def __init__(self, in_channels, with_conv):
|
37 |
+
super().__init__()
|
38 |
+
self.with_conv = with_conv
|
39 |
+
if self.with_conv:
|
40 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
41 |
+
in_channels,
|
42 |
+
kernel_size=3,
|
43 |
+
stride=1,
|
44 |
+
padding=1)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
48 |
+
if self.with_conv:
|
49 |
+
x = self.conv(x)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class Downsample(nn.Module):
|
54 |
+
def __init__(self, in_channels, with_conv):
|
55 |
+
super().__init__()
|
56 |
+
self.with_conv = with_conv
|
57 |
+
if self.with_conv:
|
58 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
59 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
60 |
+
in_channels,
|
61 |
+
kernel_size=3,
|
62 |
+
stride=2,
|
63 |
+
padding=0)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
if self.with_conv:
|
67 |
+
pad = (0,1,0,1)
|
68 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
69 |
+
x = self.conv(x)
|
70 |
+
else:
|
71 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
class ResnetBlock(nn.Module):
|
76 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
77 |
+
dropout, temb_channels=512):
|
78 |
+
super().__init__()
|
79 |
+
self.in_channels = in_channels
|
80 |
+
out_channels = in_channels if out_channels is None else out_channels
|
81 |
+
self.out_channels = out_channels
|
82 |
+
self.use_conv_shortcut = conv_shortcut
|
83 |
+
|
84 |
+
self.norm1 = Normalize(in_channels)
|
85 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
86 |
+
out_channels,
|
87 |
+
kernel_size=3,
|
88 |
+
stride=1,
|
89 |
+
padding=1)
|
90 |
+
if temb_channels > 0:
|
91 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
92 |
+
out_channels)
|
93 |
+
self.norm2 = Normalize(out_channels)
|
94 |
+
self.dropout = torch.nn.Dropout(dropout)
|
95 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
96 |
+
out_channels,
|
97 |
+
kernel_size=3,
|
98 |
+
stride=1,
|
99 |
+
padding=1)
|
100 |
+
if self.in_channels != self.out_channels:
|
101 |
+
if self.use_conv_shortcut:
|
102 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
103 |
+
out_channels,
|
104 |
+
kernel_size=3,
|
105 |
+
stride=1,
|
106 |
+
padding=1)
|
107 |
+
else:
|
108 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
109 |
+
out_channels,
|
110 |
+
kernel_size=1,
|
111 |
+
stride=1,
|
112 |
+
padding=0)
|
113 |
+
|
114 |
+
def forward(self, x, temb):
|
115 |
+
h = x
|
116 |
+
h = self.norm1(h)
|
117 |
+
h = nonlinearity(h)
|
118 |
+
h = self.conv1(h)
|
119 |
+
|
120 |
+
if temb is not None:
|
121 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
122 |
+
|
123 |
+
h = self.norm2(h)
|
124 |
+
h = nonlinearity(h)
|
125 |
+
h = self.dropout(h)
|
126 |
+
h = self.conv2(h)
|
127 |
+
|
128 |
+
if self.in_channels != self.out_channels:
|
129 |
+
if self.use_conv_shortcut:
|
130 |
+
x = self.conv_shortcut(x)
|
131 |
+
else:
|
132 |
+
x = self.nin_shortcut(x)
|
133 |
+
|
134 |
+
return x+h
|
135 |
+
|
136 |
+
|
137 |
+
class LinAttnBlock(LinearAttention):
|
138 |
+
"""to match AttnBlock usage"""
|
139 |
+
def __init__(self, in_channels):
|
140 |
+
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
141 |
+
|
142 |
+
|
143 |
+
class AttnBlock(nn.Module):
|
144 |
+
def __init__(self, in_channels):
|
145 |
+
super().__init__()
|
146 |
+
self.in_channels = in_channels
|
147 |
+
|
148 |
+
self.norm = Normalize(in_channels)
|
149 |
+
self.q = torch.nn.Conv2d(in_channels,
|
150 |
+
in_channels,
|
151 |
+
kernel_size=1,
|
152 |
+
stride=1,
|
153 |
+
padding=0)
|
154 |
+
self.k = torch.nn.Conv2d(in_channels,
|
155 |
+
in_channels,
|
156 |
+
kernel_size=1,
|
157 |
+
stride=1,
|
158 |
+
padding=0)
|
159 |
+
self.v = torch.nn.Conv2d(in_channels,
|
160 |
+
in_channels,
|
161 |
+
kernel_size=1,
|
162 |
+
stride=1,
|
163 |
+
padding=0)
|
164 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
165 |
+
in_channels,
|
166 |
+
kernel_size=1,
|
167 |
+
stride=1,
|
168 |
+
padding=0)
|
169 |
+
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
h_ = x
|
173 |
+
h_ = self.norm(h_)
|
174 |
+
q = self.q(h_)
|
175 |
+
k = self.k(h_)
|
176 |
+
v = self.v(h_)
|
177 |
+
|
178 |
+
# compute attention
|
179 |
+
b,c,h,w = q.shape
|
180 |
+
q = q.reshape(b,c,h*w)
|
181 |
+
q = q.permute(0,2,1) # b,hw,c
|
182 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
183 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
184 |
+
w_ = w_ * (int(c)**(-0.5))
|
185 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
186 |
+
|
187 |
+
# attend to values
|
188 |
+
v = v.reshape(b,c,h*w)
|
189 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
190 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
191 |
+
h_ = h_.reshape(b,c,h,w)
|
192 |
+
|
193 |
+
h_ = self.proj_out(h_)
|
194 |
+
|
195 |
+
return x+h_
|
196 |
+
|
197 |
+
|
198 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
199 |
+
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
200 |
+
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
201 |
+
if attn_type == "vanilla":
|
202 |
+
return AttnBlock(in_channels)
|
203 |
+
elif attn_type == "none":
|
204 |
+
return nn.Identity(in_channels)
|
205 |
+
else:
|
206 |
+
return LinAttnBlock(in_channels)
|
207 |
+
|
208 |
+
|
209 |
+
class Encoder(nn.Module):
|
210 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
211 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
212 |
+
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
213 |
+
**ignore_kwargs):
|
214 |
+
super().__init__()
|
215 |
+
if use_linear_attn: attn_type = "linear"
|
216 |
+
self.ch = ch
|
217 |
+
self.temb_ch = 0
|
218 |
+
self.num_resolutions = len(ch_mult)
|
219 |
+
self.num_res_blocks = num_res_blocks
|
220 |
+
self.resolution = resolution
|
221 |
+
self.in_channels = in_channels
|
222 |
+
|
223 |
+
# downsampling
|
224 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
225 |
+
self.ch,
|
226 |
+
kernel_size=3,
|
227 |
+
stride=1,
|
228 |
+
padding=1)
|
229 |
+
|
230 |
+
curr_res = resolution
|
231 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
232 |
+
self.in_ch_mult = in_ch_mult
|
233 |
+
self.down = nn.ModuleList()
|
234 |
+
for i_level in range(self.num_resolutions):
|
235 |
+
block = nn.ModuleList()
|
236 |
+
attn = nn.ModuleList()
|
237 |
+
block_in = ch*in_ch_mult[i_level]
|
238 |
+
block_out = ch*ch_mult[i_level]
|
239 |
+
for i_block in range(self.num_res_blocks):
|
240 |
+
block.append(ResnetBlock(in_channels=block_in,
|
241 |
+
out_channels=block_out,
|
242 |
+
temb_channels=self.temb_ch,
|
243 |
+
dropout=dropout))
|
244 |
+
block_in = block_out
|
245 |
+
if curr_res in attn_resolutions:
|
246 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
247 |
+
down = nn.Module()
|
248 |
+
down.block = block
|
249 |
+
down.attn = attn
|
250 |
+
if i_level != self.num_resolutions-1:
|
251 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
252 |
+
curr_res = curr_res // 2
|
253 |
+
self.down.append(down)
|
254 |
+
|
255 |
+
# middle
|
256 |
+
self.mid = nn.Module()
|
257 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
258 |
+
out_channels=block_in,
|
259 |
+
temb_channels=self.temb_ch,
|
260 |
+
dropout=dropout)
|
261 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
262 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
263 |
+
out_channels=block_in,
|
264 |
+
temb_channels=self.temb_ch,
|
265 |
+
dropout=dropout)
|
266 |
+
|
267 |
+
# end
|
268 |
+
self.norm_out = Normalize(block_in)
|
269 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
270 |
+
2*z_channels if double_z else z_channels,
|
271 |
+
kernel_size=3,
|
272 |
+
stride=1,
|
273 |
+
padding=1)
|
274 |
+
|
275 |
+
def forward(self, x):
|
276 |
+
# timestep embedding
|
277 |
+
temb = None
|
278 |
+
|
279 |
+
# downsampling
|
280 |
+
hs = [self.conv_in(x)]
|
281 |
+
for i_level in range(self.num_resolutions):
|
282 |
+
for i_block in range(self.num_res_blocks):
|
283 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
284 |
+
if len(self.down[i_level].attn) > 0:
|
285 |
+
h = self.down[i_level].attn[i_block](h)
|
286 |
+
hs.append(h)
|
287 |
+
if i_level != self.num_resolutions-1:
|
288 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
289 |
+
|
290 |
+
# middle
|
291 |
+
h = hs[-1]
|
292 |
+
h = self.mid.block_1(h, temb)
|
293 |
+
h = self.mid.attn_1(h)
|
294 |
+
h = self.mid.block_2(h, temb)
|
295 |
+
|
296 |
+
# end
|
297 |
+
h = self.norm_out(h)
|
298 |
+
h = nonlinearity(h)
|
299 |
+
h = self.conv_out(h)
|
300 |
+
return h
|
301 |
+
|
302 |
+
|
303 |
+
class Decoder(nn.Module):
|
304 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
305 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
306 |
+
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
307 |
+
attn_type="vanilla", **ignorekwargs):
|
308 |
+
super().__init__()
|
309 |
+
if use_linear_attn: attn_type = "linear"
|
310 |
+
self.ch = ch
|
311 |
+
self.temb_ch = 0
|
312 |
+
self.num_resolutions = len(ch_mult)
|
313 |
+
self.num_res_blocks = num_res_blocks
|
314 |
+
self.resolution = resolution
|
315 |
+
self.in_channels = in_channels
|
316 |
+
self.give_pre_end = give_pre_end
|
317 |
+
self.tanh_out = tanh_out
|
318 |
+
|
319 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
320 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
321 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
322 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
323 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
324 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
325 |
+
self.z_shape, np.prod(self.z_shape)))
|
326 |
+
|
327 |
+
# z to block_in
|
328 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
329 |
+
block_in,
|
330 |
+
kernel_size=3,
|
331 |
+
stride=1,
|
332 |
+
padding=1)
|
333 |
+
|
334 |
+
# middle
|
335 |
+
self.mid = nn.Module()
|
336 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
337 |
+
out_channels=block_in,
|
338 |
+
temb_channels=self.temb_ch,
|
339 |
+
dropout=dropout)
|
340 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
341 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
342 |
+
out_channels=block_in,
|
343 |
+
temb_channels=self.temb_ch,
|
344 |
+
dropout=dropout)
|
345 |
+
|
346 |
+
# upsampling
|
347 |
+
self.up = nn.ModuleList()
|
348 |
+
for i_level in reversed(range(self.num_resolutions)):
|
349 |
+
block = nn.ModuleList()
|
350 |
+
attn = nn.ModuleList()
|
351 |
+
block_out = ch*ch_mult[i_level]
|
352 |
+
for i_block in range(self.num_res_blocks+1):
|
353 |
+
block.append(ResnetBlock(in_channels=block_in,
|
354 |
+
out_channels=block_out,
|
355 |
+
temb_channels=self.temb_ch,
|
356 |
+
dropout=dropout))
|
357 |
+
block_in = block_out
|
358 |
+
if curr_res in attn_resolutions:
|
359 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
360 |
+
up = nn.Module()
|
361 |
+
up.block = block
|
362 |
+
up.attn = attn
|
363 |
+
if i_level != 0:
|
364 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
365 |
+
curr_res = curr_res * 2
|
366 |
+
self.up.insert(0, up) # prepend to get consistent order
|
367 |
+
|
368 |
+
# end
|
369 |
+
self.norm_out = Normalize(block_in)
|
370 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
371 |
+
out_ch,
|
372 |
+
kernel_size=3,
|
373 |
+
stride=1,
|
374 |
+
padding=1)
|
375 |
+
|
376 |
+
def forward(self, z):
|
377 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
378 |
+
self.last_z_shape = z.shape
|
379 |
+
|
380 |
+
# timestep embedding
|
381 |
+
temb = None
|
382 |
+
|
383 |
+
# z to block_in
|
384 |
+
h = self.conv_in(z)
|
385 |
+
|
386 |
+
# middle
|
387 |
+
h = self.mid.block_1(h, temb)
|
388 |
+
h = self.mid.attn_1(h)
|
389 |
+
h = self.mid.block_2(h, temb)
|
390 |
+
|
391 |
+
# upsampling
|
392 |
+
for i_level in reversed(range(self.num_resolutions)):
|
393 |
+
for i_block in range(self.num_res_blocks+1):
|
394 |
+
h = self.up[i_level].block[i_block](h, temb)
|
395 |
+
if len(self.up[i_level].attn) > 0:
|
396 |
+
h = self.up[i_level].attn[i_block](h)
|
397 |
+
if i_level != 0:
|
398 |
+
h = self.up[i_level].upsample(h)
|
399 |
+
|
400 |
+
# end
|
401 |
+
if self.give_pre_end:
|
402 |
+
return h
|
403 |
+
|
404 |
+
h = self.norm_out(h)
|
405 |
+
h = nonlinearity(h)
|
406 |
+
h = self.conv_out(h)
|
407 |
+
if self.tanh_out:
|
408 |
+
h = torch.tanh(h)
|
409 |
+
return h
|
410 |
+
|
411 |
+
|
412 |
+
class FrozenAutoencoderKL(nn.Module):
|
413 |
+
def __init__(self, ddconfig, embed_dim, pretrained_path, scale_factor=0.18215):
|
414 |
+
super().__init__()
|
415 |
+
print(f'Create autoencoder with scale_factor={scale_factor}')
|
416 |
+
self.encoder = Encoder(**ddconfig)
|
417 |
+
self.decoder = Decoder(**ddconfig)
|
418 |
+
assert ddconfig["double_z"]
|
419 |
+
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
420 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
421 |
+
self.embed_dim = embed_dim
|
422 |
+
self.scale_factor = scale_factor
|
423 |
+
m, u = self.load_state_dict(torch.load(pretrained_path, map_location='cpu'))
|
424 |
+
assert len(m) == 0 and len(u) == 0
|
425 |
+
self.eval()
|
426 |
+
self.requires_grad_(False)
|
427 |
+
|
428 |
+
def encode_moments(self, x):
|
429 |
+
h = self.encoder(x)
|
430 |
+
moments = self.quant_conv(h)
|
431 |
+
return moments
|
432 |
+
|
433 |
+
def sample(self, moments):
|
434 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
435 |
+
logvar = torch.clamp(logvar, -30.0, 20.0)
|
436 |
+
std = torch.exp(0.5 * logvar)
|
437 |
+
z = mean + std * torch.randn_like(mean)
|
438 |
+
z = self.scale_factor * z
|
439 |
+
return z
|
440 |
+
|
441 |
+
def encode(self, x):
|
442 |
+
moments = self.encode_moments(x)
|
443 |
+
z = self.sample(moments)
|
444 |
+
return z
|
445 |
+
|
446 |
+
def decode(self, z):
|
447 |
+
z = (1. / self.scale_factor) * z
|
448 |
+
z = self.post_quant_conv(z)
|
449 |
+
dec = self.decoder(z)
|
450 |
+
return dec
|
451 |
+
|
452 |
+
def forward(self, inputs, fn):
|
453 |
+
if fn == 'encode_moments':
|
454 |
+
return self.encode_moments(inputs)
|
455 |
+
elif fn == 'encode':
|
456 |
+
return self.encode(inputs)
|
457 |
+
elif fn == 'decode':
|
458 |
+
return self.decode(inputs)
|
459 |
+
else:
|
460 |
+
raise NotImplementedError
|
461 |
+
|
462 |
+
|
463 |
+
def get_model(pretrained_path, scale_factor=0.18215):
|
464 |
+
ddconfig = dict(
|
465 |
+
double_z=True,
|
466 |
+
z_channels=4,
|
467 |
+
resolution=256,
|
468 |
+
in_channels=3,
|
469 |
+
out_ch=3,
|
470 |
+
ch=128,
|
471 |
+
ch_mult=[1, 2, 4, 4],
|
472 |
+
num_res_blocks=2,
|
473 |
+
attn_resolutions=[],
|
474 |
+
dropout=0.0
|
475 |
+
)
|
476 |
+
return FrozenAutoencoderKL(ddconfig, 4, pretrained_path, scale_factor)
|
477 |
+
|
478 |
+
|
479 |
+
def main():
|
480 |
+
import torchvision.transforms as transforms
|
481 |
+
from torchvision.utils import save_image
|
482 |
+
import os
|
483 |
+
from PIL import Image
|
484 |
+
|
485 |
+
model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
|
486 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
487 |
+
model = model.to(device)
|
488 |
+
|
489 |
+
scale_factor = 0.18215
|
490 |
+
T = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor()])
|
491 |
+
path = 'imgs'
|
492 |
+
fnames = os.listdir(path)
|
493 |
+
for fname in fnames:
|
494 |
+
p = os.path.join(path, fname)
|
495 |
+
img = Image.open(p)
|
496 |
+
img = T(img)
|
497 |
+
img = img * 2. - 1
|
498 |
+
img = img[None, ...]
|
499 |
+
img = img.to(device)
|
500 |
+
|
501 |
+
# with torch.cuda.amp.autocast():
|
502 |
+
# moments = model.encode_moments(img)
|
503 |
+
# mean, logvar = torch.chunk(moments, 2, dim=1)
|
504 |
+
# logvar = torch.clamp(logvar, -30.0, 20.0)
|
505 |
+
# std = torch.exp(0.5 * logvar)
|
506 |
+
# zs = [(mean + std * torch.randn_like(mean)) * scale_factor for _ in range(4)]
|
507 |
+
# recons = [model.decode(z) for z in zs]
|
508 |
+
|
509 |
+
with torch.cuda.amp.autocast():
|
510 |
+
print('test encode & decode')
|
511 |
+
recons = [model.decode(model.encode(img)) for _ in range(4)]
|
512 |
+
|
513 |
+
out = torch.cat([img, *recons], dim=0)
|
514 |
+
out = (out + 1) * 0.5
|
515 |
+
save_image(out, f'recons_{fname}')
|
516 |
+
|
517 |
+
|
518 |
+
if __name__ == "__main__":
|
519 |
+
main()
|
libs/clip.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
3 |
+
import time
|
4 |
+
|
5 |
+
|
6 |
+
class AbstractEncoder(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
def encode(self, *args, **kwargs):
|
11 |
+
raise NotImplementedError
|
12 |
+
|
13 |
+
|
14 |
+
class FrozenCLIPEmbedder(AbstractEncoder):
|
15 |
+
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
16 |
+
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
|
17 |
+
super().__init__()
|
18 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
19 |
+
self.transformer = CLIPTextModel.from_pretrained(version)
|
20 |
+
self.device = device
|
21 |
+
self.max_length = max_length
|
22 |
+
self.freeze()
|
23 |
+
|
24 |
+
def freeze(self):
|
25 |
+
self.transformer = self.transformer.eval()
|
26 |
+
for param in self.parameters():
|
27 |
+
param.requires_grad = False
|
28 |
+
|
29 |
+
def forward(self, text):
|
30 |
+
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
31 |
+
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
32 |
+
tokens = batch_encoding["input_ids"].to(self.device)
|
33 |
+
outputs = self.transformer(input_ids=tokens)
|
34 |
+
|
35 |
+
z = outputs.last_hidden_state
|
36 |
+
return z, {'token_embedding': outputs.last_hidden_state, 'pooler_output': outputs.pooler_output, 'token_mask': batch_encoding['attention_mask'].to(self.device), 'tokens': batch_encoding["input_ids"].to(self.device)}
|
37 |
+
|
38 |
+
def encode_from_token(self, tokens):
|
39 |
+
tokens = tokens.to(self.device)
|
40 |
+
outputs = self.transformer(input_ids=tokens)
|
41 |
+
|
42 |
+
z = outputs.last_hidden_state
|
43 |
+
return z
|
44 |
+
|
45 |
+
def encode(self, text):
|
46 |
+
return self(text)
|
47 |
+
|
48 |
+
|
49 |
+
class FrozenCLIPTokenizer(AbstractEncoder):
|
50 |
+
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
51 |
+
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
|
52 |
+
super().__init__()
|
53 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
54 |
+
self.max_length = max_length
|
55 |
+
self.freeze()
|
56 |
+
|
57 |
+
def freeze(self):
|
58 |
+
for param in self.parameters():
|
59 |
+
param.requires_grad = False
|
60 |
+
|
61 |
+
def forward(self, text):
|
62 |
+
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
63 |
+
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
64 |
+
tokens = batch_encoding["input_ids"]
|
65 |
+
return tokens
|
66 |
+
|
67 |
+
def encode(self, text):
|
68 |
+
return self(text)
|
libs/model/axial_rope.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch._dynamo
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from . import flags
|
8 |
+
|
9 |
+
if flags.get_use_compile():
|
10 |
+
torch._dynamo.config.suppress_errors = True
|
11 |
+
|
12 |
+
|
13 |
+
def rotate_half(x):
|
14 |
+
x1, x2 = x[..., 0::2], x[..., 1::2]
|
15 |
+
x = torch.stack((-x2, x1), dim=-1)
|
16 |
+
*shape, d, r = x.shape
|
17 |
+
return x.view(*shape, d * r)
|
18 |
+
|
19 |
+
|
20 |
+
@flags.compile_wrap
|
21 |
+
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0):
|
22 |
+
freqs = freqs.to(t)
|
23 |
+
rot_dim = freqs.shape[-1]
|
24 |
+
end_index = start_index + rot_dim
|
25 |
+
assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
26 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
27 |
+
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
28 |
+
return torch.cat((t_left, t, t_right), dim=-1)
|
29 |
+
|
30 |
+
|
31 |
+
def centers(start, stop, num, dtype=None, device=None):
|
32 |
+
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
|
33 |
+
return (edges[:-1] + edges[1:]) / 2
|
34 |
+
|
35 |
+
|
36 |
+
def make_grid(h_pos, w_pos):
|
37 |
+
grid = torch.stack(torch.meshgrid(h_pos, w_pos, indexing='ij'), dim=-1)
|
38 |
+
h, w, d = grid.shape
|
39 |
+
return grid.view(h * w, d)
|
40 |
+
|
41 |
+
|
42 |
+
def bounding_box(h, w, pixel_aspect_ratio=1.0):
|
43 |
+
# Adjusted dimensions
|
44 |
+
w_adj = w
|
45 |
+
h_adj = h * pixel_aspect_ratio
|
46 |
+
|
47 |
+
# Adjusted aspect ratio
|
48 |
+
ar_adj = w_adj / h_adj
|
49 |
+
|
50 |
+
# Determine bounding box based on the adjusted aspect ratio
|
51 |
+
y_min, y_max, x_min, x_max = -1.0, 1.0, -1.0, 1.0
|
52 |
+
if ar_adj > 1:
|
53 |
+
y_min, y_max = -1 / ar_adj, 1 / ar_adj
|
54 |
+
elif ar_adj < 1:
|
55 |
+
x_min, x_max = -ar_adj, ar_adj
|
56 |
+
|
57 |
+
return y_min, y_max, x_min, x_max
|
58 |
+
|
59 |
+
|
60 |
+
def make_axial_pos(h, w, pixel_aspect_ratio=1.0, align_corners=False, dtype=None, device=None):
|
61 |
+
y_min, y_max, x_min, x_max = bounding_box(h, w, pixel_aspect_ratio)
|
62 |
+
if align_corners:
|
63 |
+
h_pos = torch.linspace(y_min, y_max, h, dtype=dtype, device=device)
|
64 |
+
w_pos = torch.linspace(x_min, x_max, w, dtype=dtype, device=device)
|
65 |
+
else:
|
66 |
+
h_pos = centers(y_min, y_max, h, dtype=dtype, device=device)
|
67 |
+
w_pos = centers(x_min, x_max, w, dtype=dtype, device=device)
|
68 |
+
return make_grid(h_pos, w_pos)
|
69 |
+
|
70 |
+
|
71 |
+
def freqs_pixel(max_freq=10.0):
|
72 |
+
def init(shape):
|
73 |
+
freqs = torch.linspace(1.0, max_freq / 2, shape[-1]) * math.pi
|
74 |
+
return freqs.log().expand(shape)
|
75 |
+
return init
|
76 |
+
|
77 |
+
|
78 |
+
def freqs_pixel_log(max_freq=10.0):
|
79 |
+
def init(shape):
|
80 |
+
log_min = math.log(math.pi)
|
81 |
+
log_max = math.log(max_freq * math.pi / 2)
|
82 |
+
return torch.linspace(log_min, log_max, shape[-1]).expand(shape)
|
83 |
+
return init
|
84 |
+
|
85 |
+
|
86 |
+
class AxialRoPE(nn.Module):
|
87 |
+
def __init__(self, dim, n_heads, start_index=0, freqs_init=freqs_pixel_log(max_freq=10.0)):
|
88 |
+
super().__init__()
|
89 |
+
self.n_heads = n_heads
|
90 |
+
self.start_index = start_index
|
91 |
+
log_freqs = freqs_init((n_heads, dim // 4))
|
92 |
+
self.freqs_h = nn.Parameter(log_freqs.clone())
|
93 |
+
self.freqs_w = nn.Parameter(log_freqs.clone())
|
94 |
+
|
95 |
+
def extra_repr(self):
|
96 |
+
dim = (self.freqs_h.shape[-1] + self.freqs_w.shape[-1]) * 2
|
97 |
+
return f"dim={dim}, n_heads={self.n_heads}, start_index={self.start_index}"
|
98 |
+
|
99 |
+
def get_freqs(self, pos):
|
100 |
+
if pos.shape[-1] != 2:
|
101 |
+
raise ValueError("input shape must be (..., 2)")
|
102 |
+
freqs_h = pos[..., None, None, 0] * self.freqs_h.exp()
|
103 |
+
freqs_w = pos[..., None, None, 1] * self.freqs_w.exp()
|
104 |
+
freqs = torch.cat((freqs_h, freqs_w), dim=-1).repeat_interleave(2, dim=-1)
|
105 |
+
return freqs.transpose(-2, -3)
|
106 |
+
|
107 |
+
def forward(self, x, pos):
|
108 |
+
freqs = self.get_freqs(pos)
|
109 |
+
return apply_rotary_emb(freqs, x, self.start_index)
|
libs/model/common_layers.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
from timm.models.layers import trunc_normal_
|
6 |
+
|
7 |
+
class Linear(nn.Linear):
|
8 |
+
def __init__(self, *args, **kwargs):
|
9 |
+
super().__init__(*args, **kwargs)
|
10 |
+
trunc_normal_(self.weight, mean = 0, std = 0.02)
|
11 |
+
if self.bias is not None:
|
12 |
+
nn.init.zeros_(self.bias)
|
13 |
+
|
14 |
+
class LayerNorm(nn.LayerNorm):
|
15 |
+
def __init__(self, *args, **kwargs):
|
16 |
+
super().__init__(*args, **kwargs)
|
17 |
+
trunc_normal_(self.weight, mean = 0, std = 0.02)
|
18 |
+
if self.bias is not None:
|
19 |
+
nn.init.zeros_(self.bias)
|
20 |
+
|
21 |
+
class Conv2d(nn.Conv2d):
|
22 |
+
def __init__(self, *args, **kwargs):
|
23 |
+
super().__init__(*args, **kwargs)
|
24 |
+
trunc_normal_(self.weight, mean = 0, std = 0.02)
|
25 |
+
if self.bias is not None:
|
26 |
+
nn.init.zeros_(self.bias)
|
27 |
+
|
28 |
+
class Embedding(nn.Embedding):
|
29 |
+
def __init__(self, *args, **kwargs):
|
30 |
+
super().__init__(*args, **kwargs)
|
31 |
+
trunc_normal_(self.weight, mean = 0, std = 0.02)
|
32 |
+
|
33 |
+
class ImageNorm(nn.Module):
|
34 |
+
def forward(self, x):
|
35 |
+
assert x.dim() == 4
|
36 |
+
eps = 1e-05
|
37 |
+
x = x / (x.var(dim = (1, 2, 3), keepdim = True) + eps).sqrt()
|
38 |
+
return x
|
39 |
+
|
40 |
+
class Flatten(nn.Module):
|
41 |
+
def forward(self, x):
|
42 |
+
B, H, W, C = x.shape
|
43 |
+
x = x.reshape(B, H * W, C)
|
44 |
+
return x
|
45 |
+
|
46 |
+
class ChannelLast(nn.Module):
|
47 |
+
def forward(self, x):
|
48 |
+
assert x.dim() == 4
|
49 |
+
x = x.permute(0, 2, 3, 1) # [B, H, W, C]
|
50 |
+
return x
|
51 |
+
|
52 |
+
class ChannelFirst(nn.Module):
|
53 |
+
def forward(self, x):
|
54 |
+
assert x.dim() == 4
|
55 |
+
x = x.permute(0, 3, 1, 2) # [B, C, H, W]
|
56 |
+
return x
|
57 |
+
|
58 |
+
class OddUpInterpolate(nn.Module):
|
59 |
+
def __init__(self, ratio):
|
60 |
+
super().__init__()
|
61 |
+
self.ratio = ratio
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
if self.ratio == 1:
|
65 |
+
return x
|
66 |
+
assert x.dim() == 4
|
67 |
+
B, C, H, W = x.shape
|
68 |
+
x = F.interpolate(x, size = ((H - 1) * self.ratio + 1, (W - 1) * self.ratio + 1), mode = "bilinear", align_corners = True)
|
69 |
+
return x
|
70 |
+
|
71 |
+
def __repr__(self):
|
72 |
+
return f"UpInterpolate(ratio={self.ratio})"
|
73 |
+
|
74 |
+
class OddDownInterpolate(nn.Module):
|
75 |
+
def __init__(self, ratio):
|
76 |
+
super().__init__()
|
77 |
+
self.ratio = ratio
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
if self.ratio == 1:
|
81 |
+
return x
|
82 |
+
assert x.dim() == 4
|
83 |
+
B, C, H, W = x.shape
|
84 |
+
x = F.interpolate(x, size = ((H - 1) // self.ratio + 1, (W - 1) // self.ratio + 1), mode = "area")
|
85 |
+
return x
|
86 |
+
|
87 |
+
def __repr__(self):
|
88 |
+
return f"DownInterpolate(ratio={self.ratio})"
|
89 |
+
|
90 |
+
class EvenDownInterpolate(nn.Module):
|
91 |
+
def __init__(self, ratio):
|
92 |
+
super().__init__()
|
93 |
+
self.ratio = ratio
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
if self.ratio == 1:
|
97 |
+
return x
|
98 |
+
assert len(x.shape) == 4
|
99 |
+
B, C, H, W = x.shape
|
100 |
+
x = F.interpolate(x, size = (H // self.ratio, W // self.ratio), mode = "area")
|
101 |
+
return x
|
102 |
+
|
103 |
+
def __repr__(self):
|
104 |
+
return f"DownInterpolate(ratio={self.ratio})"
|
libs/model/dimr_t2i.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from re import A
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
import math
|
6 |
+
import einops
|
7 |
+
import torch.utils.checkpoint
|
8 |
+
from functools import partial
|
9 |
+
import open_clip
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import timm
|
15 |
+
from timm.models.layers import trunc_normal_, Mlp
|
16 |
+
from .sigmoid.module import LayerNorm, RMSNorm, AdaRMSNorm, TDRMSNorm, QKNorm, TimeDependentParameter
|
17 |
+
from .common_layers import Linear, EvenDownInterpolate, ChannelFirst, ChannelLast, Embedding
|
18 |
+
from .axial_rope import AxialRoPE, make_axial_pos
|
19 |
+
from .trans_autoencoder import TransEncoder, Adaptor
|
20 |
+
|
21 |
+
def check_zip(*args):
|
22 |
+
args = [list(arg) for arg in args]
|
23 |
+
length = len(args[0])
|
24 |
+
for arg in args:
|
25 |
+
assert len(arg) == length
|
26 |
+
return zip(*args)
|
27 |
+
|
28 |
+
class PixelShuffleUpsample(nn.Module):
|
29 |
+
def __init__(self, dim_in, dim_out, ratio = 2):
|
30 |
+
super().__init__()
|
31 |
+
self.ratio = ratio
|
32 |
+
self.kernel = Linear(dim_in, dim_out * self.ratio * self.ratio)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
x = self.kernel(x)
|
36 |
+
B, H, W, C = x.shape
|
37 |
+
x = x.reshape(B, H, W, self.ratio, self.ratio, C // self.ratio // self.ratio)
|
38 |
+
x = x.transpose(2, 3)
|
39 |
+
x = x.reshape(B, H * self.ratio, W * self.ratio, C // self.ratio // self.ratio)
|
40 |
+
return x
|
41 |
+
|
42 |
+
class PositionEmbeddings(nn.Module):
|
43 |
+
def __init__(self, max_height, max_width, dim):
|
44 |
+
super().__init__()
|
45 |
+
self.max_height = max_height
|
46 |
+
self.max_width = max_width
|
47 |
+
self.position_embeddings = Embedding(self.max_height * self.max_width, dim)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
B, H, W, C = x.shape
|
51 |
+
height_idxes = torch.arange(H, device = x.device)[:, None].repeat(1, W)
|
52 |
+
width_idxes = torch.arange(W, device = x.device)[None, :].repeat(H, 1)
|
53 |
+
idxes = height_idxes * self.max_width + width_idxes
|
54 |
+
x = x + self.position_embeddings(idxes[None])
|
55 |
+
return x
|
56 |
+
|
57 |
+
class TextPositionEmbeddings(nn.Module):
|
58 |
+
def __init__(self, num_embeddings, embedding_dim):
|
59 |
+
super().__init__()
|
60 |
+
self.embedding = Embedding(num_embeddings, embedding_dim)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
batch_size, num_embeddings, embedding_dim = x.shape
|
64 |
+
# positions = torch.arange(height * width, device=x.device).reshape(1, height, width)
|
65 |
+
positions = torch.arange(num_embeddings, device=x.device).unsqueeze(0).expand(batch_size, num_embeddings)
|
66 |
+
x = x + self.embedding(positions)
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
class MLPBlock(nn.Module):
|
71 |
+
def __init__(self, config):
|
72 |
+
super().__init__()
|
73 |
+
if config.norm_type == 'LN':
|
74 |
+
self.norm_type = 'LN'
|
75 |
+
self.norm = LayerNorm(config.dim)
|
76 |
+
elif config.norm_type == 'RMSN':
|
77 |
+
self.norm_type = 'RMSN'
|
78 |
+
self.norm = RMSNorm(config.dim)
|
79 |
+
elif config.norm_type == 'TDRMSN':
|
80 |
+
self.norm_type = 'TDRMSN'
|
81 |
+
self.norm = TDRMSNorm(config.dim)
|
82 |
+
elif config.norm_type == 'ADARMSN':
|
83 |
+
self.norm_type = 'ADARMSN'
|
84 |
+
self.norm = AdaRMSNorm(config.dim, config.dim)
|
85 |
+
self.act = nn.GELU()
|
86 |
+
self.w0 = Linear(config.dim, config.hidden_dim)
|
87 |
+
self.w1 = Linear(config.dim, config.hidden_dim)
|
88 |
+
self.w2 = Linear(config.hidden_dim, config.dim)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
if self.norm_type == 'LN' or self.norm_type == 'RMSN' or self.norm_type == 'TDRMSN':
|
92 |
+
x = self.norm(x)
|
93 |
+
elif self.norm_type == 'ADARMSN':
|
94 |
+
condition = x[:,0]
|
95 |
+
x = self.norm(x, condition)
|
96 |
+
x = self.act(self.w0(x)) * self.w1(x)
|
97 |
+
x = self.w2(x)
|
98 |
+
return x
|
99 |
+
|
100 |
+
class SelfAttention(nn.Module):
|
101 |
+
def __init__(self, config):
|
102 |
+
super().__init__()
|
103 |
+
assert config.dim % config.num_attention_heads == 0
|
104 |
+
|
105 |
+
self.num_heads = config.num_attention_heads
|
106 |
+
self.head_dim = config.dim // config.num_attention_heads
|
107 |
+
|
108 |
+
if hasattr(config, "self_att_prompt") and config.self_att_prompt:
|
109 |
+
self.condition_key_value = Linear(config.clip_dim, 2 * config.dim, bias = False)
|
110 |
+
|
111 |
+
if config.norm_type == 'LN':
|
112 |
+
self.norm_type = 'LN'
|
113 |
+
self.norm = LayerNorm(config.dim)
|
114 |
+
elif config.norm_type == 'RMSN':
|
115 |
+
self.norm_type = 'RMSN'
|
116 |
+
self.norm = RMSNorm(config.dim)
|
117 |
+
elif config.norm_type == 'TDRMSN':
|
118 |
+
self.norm_type = 'TDRMSN'
|
119 |
+
self.norm = TDRMSNorm(config.dim)
|
120 |
+
elif config.norm_type == 'ADARMSN':
|
121 |
+
self.norm_type = 'ADARMSN'
|
122 |
+
self.norm = AdaRMSNorm(config.dim, config.dim)
|
123 |
+
|
124 |
+
self.pe_type = config.pe_type
|
125 |
+
if config.pe_type == 'Axial_RoPE':
|
126 |
+
self.pos_emb = AxialRoPE(self.head_dim, self.num_heads)
|
127 |
+
self.qk_norm = QKNorm(self.num_heads)
|
128 |
+
|
129 |
+
self.query_key_value = Linear(config.dim, 3 * config.dim, bias = False)
|
130 |
+
self.dense = Linear(config.dim, config.dim)
|
131 |
+
|
132 |
+
def forward(self, x, condition_embeds, condition_masks, pos=None):
|
133 |
+
B, N, C = x.shape
|
134 |
+
|
135 |
+
if self.norm_type == 'LN' or self.norm_type == 'RMSN' or self.norm_type == 'TDRMSN':
|
136 |
+
qkv = self.query_key_value(self.norm(x))
|
137 |
+
elif self.norm_type == 'ADARMSN':
|
138 |
+
condition = x[:,0]
|
139 |
+
qkv = self.query_key_value(self.norm(x, condition))
|
140 |
+
q, k, v = qkv.reshape(B, N, 3 * self.num_heads, self.head_dim).permute(0, 2, 1, 3).float().chunk(3, dim = 1)
|
141 |
+
|
142 |
+
if self.pe_type == 'Axial_RoPE':
|
143 |
+
q = self.pos_emb(self.qk_norm(q), pos)
|
144 |
+
k = self.pos_emb(self.qk_norm(k), pos)
|
145 |
+
|
146 |
+
if condition_embeds is not None:
|
147 |
+
_, L, D = condition_embeds.shape
|
148 |
+
kcvc = self.condition_key_value(condition_embeds)
|
149 |
+
kc, vc = kcvc.reshape(B, L, 2 * self.num_heads, self.head_dim).permute(0, 2, 1, 3).float().chunk(2, dim = 1)
|
150 |
+
k = torch.cat([k, kc], dim = 2)
|
151 |
+
v = torch.cat([v, vc], dim = 2)
|
152 |
+
mask = torch.cat([torch.ones(B, N, dtype = torch.bool, device = condition_masks.device), condition_masks], dim = -1)
|
153 |
+
mask = mask[:, None, None, :]
|
154 |
+
else:
|
155 |
+
mask = None
|
156 |
+
|
157 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask = mask)
|
158 |
+
x = self.dense(x.permute(0, 2, 1, 3).reshape(B, N, C))
|
159 |
+
|
160 |
+
return x
|
161 |
+
|
162 |
+
class TransformerBlock(nn.Module):
|
163 |
+
def __init__(self, config):
|
164 |
+
super().__init__()
|
165 |
+
self.block1 = SelfAttention(config)
|
166 |
+
self.block2 = MLPBlock(config)
|
167 |
+
self.dropout = nn.Dropout(config.dropout_prob)
|
168 |
+
self.gradient_checking = config.gradient_checking
|
169 |
+
|
170 |
+
def forward(self, x, condition_embeds, condition_masks, pos):
|
171 |
+
if self.gradient_checking:
|
172 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, condition_embeds, condition_masks, pos)
|
173 |
+
else:
|
174 |
+
return self._forward(x, condition_embeds, condition_masks, pos)
|
175 |
+
|
176 |
+
def _forward(self, x, condition_embeds, condition_masks, pos):
|
177 |
+
x = x + self.dropout(self.block1(x, condition_embeds, condition_masks, pos))
|
178 |
+
x = x + self.dropout(self.block2(x))
|
179 |
+
return x
|
180 |
+
|
181 |
+
class ConvNeXtBlock(nn.Module):
|
182 |
+
def __init__(self, config):
|
183 |
+
super().__init__()
|
184 |
+
self.block1 = nn.Sequential(
|
185 |
+
ChannelFirst(),
|
186 |
+
nn.Conv2d(config.dim, config.dim, kernel_size = config.kernel_size, padding = config.kernel_size // 2, stride = 1, groups = config.dim),
|
187 |
+
ChannelLast()
|
188 |
+
)
|
189 |
+
self.block2 = MLPBlock(config)
|
190 |
+
self.dropout = nn.Dropout(config.dropout_prob)
|
191 |
+
self.gradient_checking = config.gradient_checking
|
192 |
+
|
193 |
+
def forward(self, x, condition_embeds, condition_masks, pos):
|
194 |
+
if self.gradient_checking:
|
195 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x)
|
196 |
+
else:
|
197 |
+
return self._forward(x)
|
198 |
+
|
199 |
+
def _forward(self, x):
|
200 |
+
x = x + self.dropout(self.block1(x))
|
201 |
+
x = x + self.dropout(self.block2(x))
|
202 |
+
return x
|
203 |
+
|
204 |
+
|
205 |
+
class Stage(nn.Module):
|
206 |
+
def __init__(self, channels, config, lowres_dim = None, lowres_height = None):
|
207 |
+
super().__init__()
|
208 |
+
if config.block_type == "TransformerBlock":
|
209 |
+
self.encoder_cls = TransformerBlock
|
210 |
+
elif config.block_type == "ConvNeXtBlock":
|
211 |
+
self.encoder_cls = ConvNeXtBlock
|
212 |
+
else:
|
213 |
+
raise Exception()
|
214 |
+
|
215 |
+
self.pe_type = config.pe_type
|
216 |
+
|
217 |
+
self.input_layer = nn.Sequential(
|
218 |
+
EvenDownInterpolate(config.image_input_ratio),
|
219 |
+
nn.Conv2d(channels, config.dim, kernel_size = config.input_feature_ratio, stride = config.input_feature_ratio),
|
220 |
+
ChannelLast(),
|
221 |
+
PositionEmbeddings(config.max_height, config.max_width, config.dim)
|
222 |
+
)
|
223 |
+
|
224 |
+
|
225 |
+
if lowres_dim is not None:
|
226 |
+
ratio = config.max_height // lowres_height
|
227 |
+
self.upsample = nn.Sequential(
|
228 |
+
LayerNorm(lowres_dim),
|
229 |
+
PixelShuffleUpsample(lowres_dim, config.dim, ratio = ratio),
|
230 |
+
LayerNorm(config.dim),
|
231 |
+
)
|
232 |
+
|
233 |
+
self.blocks = nn.ModuleList([self.encoder_cls(config) for _ in range(config.num_blocks // 2 * 2 + 1)])
|
234 |
+
self.skip_denses = nn.ModuleList([Linear(config.dim * 2, config.dim) for _ in range(config.num_blocks // 2)])
|
235 |
+
|
236 |
+
self.output_layer = nn.Sequential(
|
237 |
+
LayerNorm(config.dim),
|
238 |
+
ChannelFirst(),
|
239 |
+
nn.Conv2d(config.dim, channels, kernel_size = config.final_kernel_size, padding = config.final_kernel_size // 2),
|
240 |
+
)
|
241 |
+
|
242 |
+
self.tensor_true = torch.nn.Parameter(torch.tensor([-1.0])) if self.encoder_cls is TransformerBlock else None
|
243 |
+
self.tensor_false = torch.nn.Parameter(torch.tensor([1.0])) if self.encoder_cls is TransformerBlock else None
|
244 |
+
|
245 |
+
|
246 |
+
|
247 |
+
|
248 |
+
def forward(self, images, lowres_skips = None, condition_context = None, condition_embeds = None, condition_masks = None, null_indicator=None):
|
249 |
+
if self.pe_type == 'Axial_RoPE' and self.encoder_cls is TransformerBlock:
|
250 |
+
x = self.input_layer(images)
|
251 |
+
_, H, W, _ = x.shape
|
252 |
+
pos = make_axial_pos(H, W)
|
253 |
+
else:
|
254 |
+
x = self.input_layer(images)
|
255 |
+
pos = None
|
256 |
+
|
257 |
+
if lowres_skips is not None:
|
258 |
+
x = x + self.upsample(lowres_skips)
|
259 |
+
|
260 |
+
if self.encoder_cls is TransformerBlock:
|
261 |
+
B, H, W, C = x.shape
|
262 |
+
x = x.reshape(B, H * W, C)
|
263 |
+
|
264 |
+
if null_indicator is not None:
|
265 |
+
indicator_tensor = torch.where(null_indicator, self.tensor_true, self.tensor_false)
|
266 |
+
indicator_tensor = indicator_tensor.view(B, 1, 1).expand(-1, -1, C)
|
267 |
+
|
268 |
+
x = torch.cat([indicator_tensor, x], dim = 1)
|
269 |
+
|
270 |
+
external_skips = [x]
|
271 |
+
|
272 |
+
num_blocks = len(self.blocks)
|
273 |
+
in_blocks = self.blocks[:(num_blocks // 2)]
|
274 |
+
mid_block = self.blocks[(num_blocks // 2)]
|
275 |
+
out_blocks = self.blocks[(num_blocks // 2 + 1):]
|
276 |
+
|
277 |
+
skips = []
|
278 |
+
for block in in_blocks:
|
279 |
+
x = block(x, condition_embeds, condition_masks, pos=pos)
|
280 |
+
external_skips.append(x)
|
281 |
+
skips.append(x)
|
282 |
+
|
283 |
+
x = mid_block(x, condition_embeds, condition_masks, pos=pos)
|
284 |
+
external_skips.append(x)
|
285 |
+
|
286 |
+
for dense, block in check_zip(self.skip_denses, out_blocks):
|
287 |
+
x = dense(torch.cat([x, skips.pop()], dim = -1))
|
288 |
+
x = block(x, condition_embeds, condition_masks, pos=pos)
|
289 |
+
external_skips.append(x)
|
290 |
+
|
291 |
+
if self.encoder_cls is TransformerBlock:
|
292 |
+
|
293 |
+
if null_indicator is not None:
|
294 |
+
x = x[:, 1:, :]
|
295 |
+
external_skips = [skip[:, 1:, :] for skip in external_skips]
|
296 |
+
|
297 |
+
x = x.reshape(B, H, W, C)
|
298 |
+
external_skips = [skip.reshape(B, H, W, C) for skip in external_skips]
|
299 |
+
|
300 |
+
output = self.output_layer(x)
|
301 |
+
|
302 |
+
return output, external_skips
|
303 |
+
|
304 |
+
|
305 |
+
class MRModel(nn.Module):
|
306 |
+
def __init__(self, config):
|
307 |
+
super().__init__()
|
308 |
+
self.channels = config.channels
|
309 |
+
self.block_grad_to_lowres = config.block_grad_to_lowres
|
310 |
+
|
311 |
+
for stage_config in config.stage_configs:
|
312 |
+
if hasattr(config, "use_t2i"):
|
313 |
+
stage_config.use_t2i = config.use_t2i
|
314 |
+
if hasattr(config, "clip_dim"):
|
315 |
+
stage_config.clip_dim = config.clip_dim
|
316 |
+
if hasattr(config, "num_clip_token"):
|
317 |
+
stage_config.num_clip_token = config.num_clip_token
|
318 |
+
if hasattr(config, "gradient_checking"):
|
319 |
+
stage_config.gradient_checking = config.gradient_checking
|
320 |
+
if hasattr(config, "pe_type"):
|
321 |
+
stage_config.pe_type = config.pe_type
|
322 |
+
else:
|
323 |
+
stage_config.pe_type = 'APE'
|
324 |
+
if hasattr(config, "norm_type"):
|
325 |
+
stage_config.norm_type = config.norm_type
|
326 |
+
else:
|
327 |
+
stage_config.norm_type = 'LN'
|
328 |
+
|
329 |
+
|
330 |
+
#### diffusion model
|
331 |
+
if hasattr(config, "not_training_diff") and config.not_training_diff:
|
332 |
+
self.has_diff = False
|
333 |
+
else:
|
334 |
+
self.has_diff = True
|
335 |
+
|
336 |
+
lowres_dims = [None] + [stage_config.dim * (stage_config.num_blocks // 2 * 2 + 2) for stage_config in config.stage_configs[:-1]]
|
337 |
+
lowres_heights = [None] + [stage_config.max_height for stage_config in config.stage_configs[:-1]]
|
338 |
+
self.stages = nn.ModuleList([
|
339 |
+
Stage(self.channels, stage_config, lowres_dim = lowres_dim, lowres_height=lowres_height)
|
340 |
+
for stage_config, lowres_dim, lowres_height in check_zip(config.stage_configs, lowres_dims, lowres_heights)]
|
341 |
+
)
|
342 |
+
|
343 |
+
|
344 |
+
#### Text VE
|
345 |
+
if hasattr(config.textVAE, "num_down_sample_block"):
|
346 |
+
down_sample_block = config.textVAE.num_down_sample_block
|
347 |
+
else:
|
348 |
+
down_sample_block = 3
|
349 |
+
|
350 |
+
self.context_encoder = TransEncoder(d_model=config.clip_dim, N=config.textVAE.num_blocks, num_token=config.num_clip_token,
|
351 |
+
head_num=config.textVAE.num_attention_heads, d_ff=config.textVAE.hidden_dim,
|
352 |
+
latten_size=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width * 2,
|
353 |
+
down_sample_block=down_sample_block, dropout=config.textVAE.dropout_prob, last_norm=False)
|
354 |
+
|
355 |
+
|
356 |
+
|
357 |
+
#### image encoder to train VE
|
358 |
+
self.open_clip, _, self.open_clip_preprocess = open_clip.create_model_and_transforms('ViT-L-16-SigLIP-256', pretrained=None)
|
359 |
+
if config.stage_configs[-1].max_width==32:
|
360 |
+
# for 256px generation
|
361 |
+
self.open_clip_output = Mlp(in_features=1024,
|
362 |
+
hidden_features=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width,
|
363 |
+
out_features=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width,
|
364 |
+
norm_layer=nn.LayerNorm,
|
365 |
+
)
|
366 |
+
else:
|
367 |
+
# for 512px generation
|
368 |
+
self.open_clip_output = Adaptor(input_dim=1024,
|
369 |
+
tar_dim=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width
|
370 |
+
)
|
371 |
+
del self.open_clip.text
|
372 |
+
del self.open_clip.logit_bias
|
373 |
+
|
374 |
+
|
375 |
+
def _forward(self, images, log_snr, condition_context = None, condition_text_embeds = None, condition_text_masks = None, condition_drop_prob = None, null_indicator=None):
|
376 |
+
if self.has_diff:
|
377 |
+
TimeDependentParameter.seed_time(self, log_snr)
|
378 |
+
|
379 |
+
assert condition_context is None
|
380 |
+
assert condition_text_embeds is None
|
381 |
+
|
382 |
+
if condition_text_embeds is not None:
|
383 |
+
condition_embeds = self.text_conditioning(condition_text_embeds)
|
384 |
+
condition_masks = condition_text_masks
|
385 |
+
else:
|
386 |
+
condition_embeds = None
|
387 |
+
condition_masks = None
|
388 |
+
|
389 |
+
outputs = []
|
390 |
+
lowres_skips = None
|
391 |
+
for stage in self.stages:
|
392 |
+
output, lowres_skips = stage(images, lowres_skips = lowres_skips, condition_context = condition_context, condition_embeds = condition_embeds, condition_masks = condition_masks, null_indicator=null_indicator)
|
393 |
+
outputs.append(output)
|
394 |
+
lowres_skips = torch.cat(lowres_skips, dim = -1)
|
395 |
+
if self.block_grad_to_lowres:
|
396 |
+
lowres_skips = lowres_skips.detach()
|
397 |
+
|
398 |
+
return outputs
|
399 |
+
|
400 |
+
else:
|
401 |
+
return [images]
|
402 |
+
|
403 |
+
|
404 |
+
def _reparameterize(self, mu, logvar):
|
405 |
+
std = torch.exp(0.5 * logvar)
|
406 |
+
eps = torch.randn_like(std)
|
407 |
+
return eps * std + mu
|
408 |
+
|
409 |
+
def _text_encoder(self, condition_context, tar_shape, mask):
|
410 |
+
|
411 |
+
output = self.context_encoder(condition_context, mask)
|
412 |
+
mu, log_var = torch.chunk(output, 2, dim=-1)
|
413 |
+
|
414 |
+
z = self._reparameterize(mu, log_var)
|
415 |
+
|
416 |
+
return [z, mu, log_var]
|
417 |
+
|
418 |
+
def _text_decoder(self, condition_enbedding, tar_shape):
|
419 |
+
|
420 |
+
context_token = self.context_decoder(condition_enbedding)
|
421 |
+
|
422 |
+
return context_token
|
423 |
+
|
424 |
+
def _img_clip(self, image_input):
|
425 |
+
|
426 |
+
image_latent = self.open_clip.encode_image(image_input)
|
427 |
+
image_latent = self.open_clip_output(image_latent)
|
428 |
+
|
429 |
+
return image_latent, self.open_clip.logit_scale
|
430 |
+
|
431 |
+
|
432 |
+
|
433 |
+
def forward(self, x, t = None, log_snr = None, text_encoder=False, text_decoder=False, image_clip=False, shape=None, mask=None, null_indicator=None):
|
434 |
+
if text_encoder:
|
435 |
+
return self._text_encoder(condition_context = x, tar_shape=shape, mask=mask)
|
436 |
+
elif text_decoder:
|
437 |
+
return self._text_decoder(condition_enbedding = x, tar_shape=shape) # mask is not needed for decoder
|
438 |
+
elif image_clip:
|
439 |
+
return self._img_clip(image_input = x)
|
440 |
+
else:
|
441 |
+
assert log_snr.dtype == torch.float32
|
442 |
+
return self._forward(images = x, log_snr = log_snr, null_indicator=null_indicator)
|
443 |
+
|
libs/model/dit_t2i.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
|
2 |
+
# --------------------------------------------------------
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
import math
|
8 |
+
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
9 |
+
|
10 |
+
import open_clip
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
|
13 |
+
from .trans_autoencoder import TransEncoder, Adaptor
|
14 |
+
|
15 |
+
|
16 |
+
def modulate(x, shift, scale):
|
17 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
18 |
+
|
19 |
+
|
20 |
+
#################################################################################
|
21 |
+
# Embedding Layers for Timesteps and Class Labels #
|
22 |
+
#################################################################################
|
23 |
+
|
24 |
+
class TimestepEmbedder(nn.Module):
|
25 |
+
"""
|
26 |
+
Embeds scalar timesteps into vector representations.
|
27 |
+
"""
|
28 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
29 |
+
super().__init__()
|
30 |
+
self.mlp = nn.Sequential(
|
31 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
32 |
+
nn.SiLU(),
|
33 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
34 |
+
)
|
35 |
+
self.frequency_embedding_size = frequency_embedding_size
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def timestep_embedding(t, dim, max_period=10000):
|
39 |
+
"""
|
40 |
+
Create sinusoidal timestep embeddings.
|
41 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
42 |
+
These may be fractional.
|
43 |
+
:param dim: the dimension of the output.
|
44 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
45 |
+
:return: an (N, D) Tensor of positional embeddings.
|
46 |
+
"""
|
47 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
48 |
+
half = dim // 2
|
49 |
+
freqs = torch.exp(
|
50 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
51 |
+
).to(device=t.device)
|
52 |
+
args = t[:, None].float() * freqs[None]
|
53 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
54 |
+
if dim % 2:
|
55 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
56 |
+
return embedding
|
57 |
+
|
58 |
+
def forward(self, t):
|
59 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
60 |
+
t_emb = self.mlp(t_freq)
|
61 |
+
return t_emb
|
62 |
+
|
63 |
+
|
64 |
+
class LabelEmbedder(nn.Module):
|
65 |
+
"""
|
66 |
+
CrossFlow: update it for CFG with indicator
|
67 |
+
"""
|
68 |
+
def __init__(self, num_classes, hidden_size):
|
69 |
+
super().__init__()
|
70 |
+
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
71 |
+
|
72 |
+
def forward(self, labels):
|
73 |
+
embeddings = self.embedding_table(labels.int())
|
74 |
+
return embeddings
|
75 |
+
|
76 |
+
|
77 |
+
#################################################################################
|
78 |
+
# Core DiT Model #
|
79 |
+
#################################################################################
|
80 |
+
|
81 |
+
class DiTBlock(nn.Module):
|
82 |
+
"""
|
83 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
84 |
+
"""
|
85 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
86 |
+
super().__init__()
|
87 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
88 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
89 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
90 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
91 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
92 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
93 |
+
self.adaLN_modulation = nn.Sequential(
|
94 |
+
nn.SiLU(),
|
95 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
96 |
+
)
|
97 |
+
|
98 |
+
def forward(self, x, c):
|
99 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, c)
|
100 |
+
# return self._forward(x, c)
|
101 |
+
|
102 |
+
def _forward(self, x, c):
|
103 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
104 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
105 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
class FinalLayer(nn.Module):
|
110 |
+
"""
|
111 |
+
The final layer of DiT.
|
112 |
+
"""
|
113 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
114 |
+
super().__init__()
|
115 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
116 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
117 |
+
self.adaLN_modulation = nn.Sequential(
|
118 |
+
nn.SiLU(),
|
119 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
120 |
+
)
|
121 |
+
|
122 |
+
def forward(self, x, c):
|
123 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
124 |
+
x = modulate(self.norm_final(x), shift, scale)
|
125 |
+
x = self.linear(x)
|
126 |
+
return x
|
127 |
+
|
128 |
+
|
129 |
+
class DiT(nn.Module):
|
130 |
+
"""
|
131 |
+
Diffusion model with a Transformer backbone.
|
132 |
+
"""
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
config,
|
136 |
+
patch_size=2,
|
137 |
+
hidden_size=1152,
|
138 |
+
depth=28,
|
139 |
+
num_heads=16,
|
140 |
+
mlp_ratio=4.0,
|
141 |
+
num_classes=2, # for cfg indicator
|
142 |
+
):
|
143 |
+
super().__init__()
|
144 |
+
self.input_size = config.latent_size
|
145 |
+
self.learn_sigma = config.learn_sigma
|
146 |
+
self.in_channels = config.channels
|
147 |
+
self.out_channels = self.in_channels * 2 if self.learn_sigma else self.in_channels
|
148 |
+
self.patch_size = patch_size
|
149 |
+
self.num_heads = num_heads
|
150 |
+
|
151 |
+
self.x_embedder = PatchEmbed(self.input_size, patch_size, self.in_channels, hidden_size, bias=True)
|
152 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
153 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size)
|
154 |
+
num_patches = self.x_embedder.num_patches
|
155 |
+
# Will use fixed sin-cos embedding:
|
156 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
157 |
+
|
158 |
+
self.blocks = nn.ModuleList([
|
159 |
+
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
160 |
+
])
|
161 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
162 |
+
self.initialize_weights()
|
163 |
+
|
164 |
+
######### CrossFlow related
|
165 |
+
if hasattr(config.textVAE, "num_down_sample_block"):
|
166 |
+
down_sample_block = config.textVAE.num_down_sample_block
|
167 |
+
else:
|
168 |
+
down_sample_block = 3
|
169 |
+
self.context_encoder = TransEncoder(d_model=config.clip_dim, N=config.textVAE.num_blocks, num_token=config.num_clip_token,
|
170 |
+
head_num=config.textVAE.num_attention_heads, d_ff=config.textVAE.hidden_dim,
|
171 |
+
latten_size=config.channels * config.latent_size * config.latent_size * 2,
|
172 |
+
down_sample_block=down_sample_block, dropout=config.textVAE.dropout_prob, last_norm=False)
|
173 |
+
|
174 |
+
|
175 |
+
self.open_clip, _, self.open_clip_preprocess = open_clip.create_model_and_transforms('ViT-L-16-SigLIP-256', pretrained=None)
|
176 |
+
self.open_clip_output = Adaptor(input_dim=1024,
|
177 |
+
tar_dim=config.channels * config.latent_size * config.latent_size
|
178 |
+
)
|
179 |
+
del self.open_clip.text
|
180 |
+
del self.open_clip.logit_bias
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
def initialize_weights(self):
|
185 |
+
# Initialize transformer layers:
|
186 |
+
def _basic_init(module):
|
187 |
+
if isinstance(module, nn.Linear):
|
188 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
189 |
+
if module.bias is not None:
|
190 |
+
nn.init.constant_(module.bias, 0)
|
191 |
+
self.apply(_basic_init)
|
192 |
+
|
193 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
194 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
195 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
196 |
+
|
197 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
198 |
+
w = self.x_embedder.proj.weight.data
|
199 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
200 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
201 |
+
|
202 |
+
# Initialize label embedding table:
|
203 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
204 |
+
|
205 |
+
# Initialize timestep embedding MLP:
|
206 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
207 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
208 |
+
|
209 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
210 |
+
for block in self.blocks:
|
211 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
212 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
213 |
+
|
214 |
+
# Zero-out output layers:
|
215 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
216 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
217 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
218 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
219 |
+
|
220 |
+
def unpatchify(self, x):
|
221 |
+
"""
|
222 |
+
x: (N, T, patch_size**2 * C)
|
223 |
+
imgs: (N, H, W, C)
|
224 |
+
"""
|
225 |
+
c = self.out_channels
|
226 |
+
p = self.x_embedder.patch_size[0]
|
227 |
+
h = w = int(x.shape[1] ** 0.5)
|
228 |
+
assert h * w == x.shape[1]
|
229 |
+
|
230 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
231 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
232 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
233 |
+
return imgs
|
234 |
+
|
235 |
+
def _forward(self, x, t, null_indicator):
|
236 |
+
"""
|
237 |
+
Forward pass of DiT.
|
238 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
239 |
+
t: (N,) tensor of diffusion timesteps
|
240 |
+
"""
|
241 |
+
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
242 |
+
t = self.t_embedder(t) # (N, D)
|
243 |
+
y = self.y_embedder(null_indicator) # (N, D)
|
244 |
+
c = t + y # (N, D)
|
245 |
+
for block in self.blocks:
|
246 |
+
x = block(x, c) # (N, T, D)
|
247 |
+
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
248 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
249 |
+
return [x]
|
250 |
+
|
251 |
+
def _forward_with_cfg(self, x, t, cfg_scale):
|
252 |
+
"""
|
253 |
+
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
254 |
+
"""
|
255 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
256 |
+
half = x[: len(x) // 2]
|
257 |
+
combined = torch.cat([half, half], dim=0)
|
258 |
+
model_out = self.forward(combined, t)
|
259 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
260 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
261 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
262 |
+
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
263 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
264 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
265 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
266 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
267 |
+
return torch.cat([eps, rest], dim=1)
|
268 |
+
|
269 |
+
def _reparameterize(self, mu, logvar):
|
270 |
+
std = torch.exp(0.5 * logvar)
|
271 |
+
eps = torch.randn_like(std)
|
272 |
+
return eps * std + mu
|
273 |
+
|
274 |
+
def _text_encoder(self, condition_context, tar_shape, mask):
|
275 |
+
|
276 |
+
output = self.context_encoder(condition_context, mask)
|
277 |
+
mu, log_var = torch.chunk(output, 2, dim=-1)
|
278 |
+
z = self._reparameterize(mu, log_var)
|
279 |
+
|
280 |
+
return [z, mu, log_var]
|
281 |
+
|
282 |
+
def _img_clip(self, image_input):
|
283 |
+
|
284 |
+
image_latent = self.open_clip.encode_image(image_input)
|
285 |
+
image_latent = self.open_clip_output(image_latent)
|
286 |
+
|
287 |
+
return image_latent, self.open_clip.logit_scale
|
288 |
+
|
289 |
+
def forward(self, x, t = None, log_snr = None, text_encoder=False, text_decoder=False, image_clip=False, shape=None, mask=None, null_indicator=None):
|
290 |
+
if text_encoder:
|
291 |
+
return self._text_encoder(condition_context = x, tar_shape=shape, mask=mask)
|
292 |
+
elif text_decoder:
|
293 |
+
raise NotImplementedError
|
294 |
+
return self._text_decoder(condition_enbedding = x, tar_shape=shape) # mask is not needed for decoder
|
295 |
+
elif image_clip:
|
296 |
+
return self._img_clip(image_input = x)
|
297 |
+
else:
|
298 |
+
return self._forward(x = x, t = t, null_indicator=null_indicator)
|
299 |
+
|
300 |
+
|
301 |
+
#################################################################################
|
302 |
+
# Sine/Cosine Positional Embedding Functions #
|
303 |
+
#################################################################################
|
304 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
305 |
+
|
306 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
307 |
+
"""
|
308 |
+
grid_size: int of the grid height and width
|
309 |
+
return:
|
310 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
311 |
+
"""
|
312 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
313 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
314 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
315 |
+
grid = np.stack(grid, axis=0)
|
316 |
+
|
317 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
318 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
319 |
+
if cls_token and extra_tokens > 0:
|
320 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
321 |
+
return pos_embed
|
322 |
+
|
323 |
+
|
324 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
325 |
+
assert embed_dim % 2 == 0
|
326 |
+
|
327 |
+
# use half of dimensions to encode grid_h
|
328 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
329 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
330 |
+
|
331 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
332 |
+
return emb
|
333 |
+
|
334 |
+
|
335 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
336 |
+
"""
|
337 |
+
embed_dim: output dimension for each position
|
338 |
+
pos: a list of positions to be encoded: size (M,)
|
339 |
+
out: (M, D)
|
340 |
+
"""
|
341 |
+
assert embed_dim % 2 == 0
|
342 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
343 |
+
omega /= embed_dim / 2.
|
344 |
+
omega = 1. / 10000**omega # (D/2,)
|
345 |
+
|
346 |
+
pos = pos.reshape(-1) # (M,)
|
347 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
348 |
+
|
349 |
+
emb_sin = np.sin(out) # (M, D/2)
|
350 |
+
emb_cos = np.cos(out) # (M, D/2)
|
351 |
+
|
352 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
353 |
+
return emb
|
354 |
+
|
355 |
+
|
356 |
+
#################################################################################
|
357 |
+
# DiT Configs #
|
358 |
+
#################################################################################
|
359 |
+
|
360 |
+
def DiT_H_2(config, **kwargs):
|
361 |
+
return DiT(config=config, depth=36, hidden_size=1280, patch_size=2, num_heads=20, **kwargs)
|
362 |
+
|
363 |
+
def DiT_XL_2(config, **kwargs):
|
364 |
+
return DiT(config=config, depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
365 |
+
|
366 |
+
def DiT_XL_4(config, **kwargs):
|
367 |
+
return DiT(config=config, depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
|
368 |
+
|
369 |
+
def DiT_XL_8(config, **kwargs):
|
370 |
+
return DiT(config=config, depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
371 |
+
|
372 |
+
def DiT_L_2(config, **kwargs):
|
373 |
+
return DiT(config=config, depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
374 |
+
|
375 |
+
def DiT_L_4(config, **kwargs):
|
376 |
+
return DiT(config=config, depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
|
377 |
+
|
378 |
+
def DiT_L_8(config, **kwargs):
|
379 |
+
return DiT(config=config, depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
|
380 |
+
|
381 |
+
def DiT_B_2(config, **kwargs):
|
382 |
+
return DiT(config=config, depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
383 |
+
|
384 |
+
def DiT_B_4(config, **kwargs):
|
385 |
+
return DiT(config=config, depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
386 |
+
|
387 |
+
def DiT_B_8(config, **kwargs):
|
388 |
+
return DiT(config=config, depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
389 |
+
|
390 |
+
def DiT_S_2(config, **kwargs):
|
391 |
+
return DiT(config=config, depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
392 |
+
|
393 |
+
def DiT_S_4(config, **kwargs):
|
394 |
+
return DiT(config=config, depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
|
395 |
+
|
396 |
+
def DiT_S_8(config, **kwargs):
|
397 |
+
return DiT(config=config, depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
398 |
+
|
399 |
+
|
400 |
+
DiT_models = {
|
401 |
+
'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
|
402 |
+
'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
|
403 |
+
'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
|
404 |
+
'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
|
405 |
+
}
|
libs/model/flags.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import contextmanager
|
2 |
+
from functools import update_wrapper
|
3 |
+
import os
|
4 |
+
import threading
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def get_use_compile():
|
10 |
+
return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1"
|
11 |
+
|
12 |
+
|
13 |
+
def get_use_flash_attention_2():
|
14 |
+
return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1"
|
15 |
+
|
16 |
+
|
17 |
+
state = threading.local()
|
18 |
+
state.checkpointing = False
|
19 |
+
|
20 |
+
|
21 |
+
@contextmanager
|
22 |
+
def checkpointing(enable=True):
|
23 |
+
try:
|
24 |
+
old_checkpointing, state.checkpointing = state.checkpointing, enable
|
25 |
+
yield
|
26 |
+
finally:
|
27 |
+
state.checkpointing = old_checkpointing
|
28 |
+
|
29 |
+
|
30 |
+
def get_checkpointing():
|
31 |
+
return getattr(state, "checkpointing", False)
|
32 |
+
|
33 |
+
|
34 |
+
class compile_wrap:
|
35 |
+
def __init__(self, function, *args, **kwargs):
|
36 |
+
self.function = function
|
37 |
+
self.args = args
|
38 |
+
self.kwargs = kwargs
|
39 |
+
self._compiled_function = None
|
40 |
+
update_wrapper(self, function)
|
41 |
+
|
42 |
+
@property
|
43 |
+
def compiled_function(self):
|
44 |
+
if self._compiled_function is not None:
|
45 |
+
return self._compiled_function
|
46 |
+
if get_use_compile():
|
47 |
+
try:
|
48 |
+
self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs)
|
49 |
+
except RuntimeError:
|
50 |
+
self._compiled_function = self.function
|
51 |
+
else:
|
52 |
+
self._compiled_function = self.function
|
53 |
+
return self._compiled_function
|
54 |
+
|
55 |
+
def __call__(self, *args, **kwargs):
|
56 |
+
return self.compiled_function(*args, **kwargs)
|
libs/model/sigmoid/kernel.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.utils.cpp_extension import load
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import random
|
8 |
+
import math
|
9 |
+
from torch.utils.checkpoint import checkpoint
|
10 |
+
from torch.autograd import Function
|
11 |
+
from functools import partial
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
# curr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extension")
|
15 |
+
# src_files = ['tdp.cu', 'torch_extension.cpp']
|
16 |
+
# src_files = [os.path.join(curr_path, file) for file in src_files]
|
17 |
+
# tdp = load('tdp', src_files, verbose = True)
|
18 |
+
|
19 |
+
# import tdp
|
20 |
+
|
21 |
+
def exported_tdp(param0, param1, weight, bias, times, custom = True):
|
22 |
+
original_shape = param0.shape
|
23 |
+
param0 = param0.reshape(-1)
|
24 |
+
param1 = param1.reshape(-1)
|
25 |
+
weight = weight.reshape(-1)
|
26 |
+
bias = bias.reshape(-1)
|
27 |
+
if custom and param0.shape[0] % 2 == 0:
|
28 |
+
result = TDP.apply(param0, param1, weight, bias, times)
|
29 |
+
else:
|
30 |
+
warnings.warn(f'Using slower tdp_torch implementation for a tensor with shape {param0.shape}')
|
31 |
+
result = tdp_torch(param0, param1, weight, bias, times)
|
32 |
+
result = result.reshape(*([times.shape[0]] + [d for d in original_shape]))
|
33 |
+
return result
|
34 |
+
|
35 |
+
class TDP(Function):
|
36 |
+
@staticmethod
|
37 |
+
def forward(ctx, param0, param1, weight, bias, times):
|
38 |
+
assert param0.shape[0] % 2 == 0
|
39 |
+
param0 = param0.contiguous()
|
40 |
+
param1 = param1.contiguous()
|
41 |
+
weight = weight.contiguous()
|
42 |
+
bias = bias.contiguous()
|
43 |
+
times = times.contiguous()
|
44 |
+
assert param0.shape[0] == param1.shape[0] and param0.shape[0] == weight.shape[0] and param0.shape[0] == bias.shape[0]
|
45 |
+
assert param0.dim() == 1 and param1.dim() == 1 and weight.dim() == 1 and bias.dim() == 1 and times.dim() == 1
|
46 |
+
ctx.save_for_backward(param0, param1, weight, bias, times)
|
47 |
+
return tdp_cuda(param0, param1, weight, bias, times)
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def backward(ctx, g_result):
|
51 |
+
g_result = g_result.contiguous()
|
52 |
+
param0, param1, weight, bias, times = ctx.saved_tensors
|
53 |
+
g_param0, g_param1, g_weight, g_bias = backward_tdp_cuda(param0, param1, weight, bias, times, g_result)
|
54 |
+
return g_param0, g_param1, g_weight, g_bias, None
|
55 |
+
|
56 |
+
def backward_tdp_torch(param0, param1, weight, bias, times, g_result):
|
57 |
+
param0 = param0[None]
|
58 |
+
param1 = param1[None]
|
59 |
+
weight = weight[None]
|
60 |
+
bias = bias[None]
|
61 |
+
|
62 |
+
a = times[:, None] * weight + bias
|
63 |
+
s = torch.sigmoid(a)
|
64 |
+
g_param0 = (s * g_result).sum(0)
|
65 |
+
g_param1 = ((1 - s) * g_result).sum(0)
|
66 |
+
g_s = (param0 - param1) * g_result
|
67 |
+
g_a = g_s * s * (1 - s)
|
68 |
+
g_weight = (g_a * times[:, None]).sum(0)
|
69 |
+
g_bias = g_a.sum(0)
|
70 |
+
|
71 |
+
return g_param0, g_param1, g_weight, g_bias
|
72 |
+
|
73 |
+
def backward_tdp_cuda(param0, param1, weight, bias, times, g_result):
|
74 |
+
g_param0 = torch.empty_like(param0)
|
75 |
+
g_param1 = torch.empty_like(param0)
|
76 |
+
g_weight = torch.empty_like(param0)
|
77 |
+
g_bias = torch.empty_like(param0)
|
78 |
+
if param0.dtype == torch.half:
|
79 |
+
tdp.backward_tdp_fp16(param0, param1, weight, bias, times, g_result, g_param0, g_param1, g_weight, g_bias)
|
80 |
+
elif param0.dtype == torch.float:
|
81 |
+
tdp.backward_tdp_fp32(param0, param1, weight, bias, times, g_result, g_param0, g_param1, g_weight, g_bias)
|
82 |
+
else:
|
83 |
+
raise NotImplementedError
|
84 |
+
return g_param0, g_param1, g_weight, g_bias
|
85 |
+
|
86 |
+
def tdp_torch(param0, param1, weight, bias, times):
|
87 |
+
a = torch.addcmul(bias[None], times[:, None], weight[None])
|
88 |
+
s = torch.sigmoid(a)
|
89 |
+
result = torch.addcmul(param1[None], s, param0[None] - param1[None])
|
90 |
+
return result
|
91 |
+
|
92 |
+
def tdp_cuda(param0, param1, weight, bias, times):
|
93 |
+
result = torch.empty(times.shape[0], param0.shape[0], dtype = param0.dtype, device = param0.device)
|
94 |
+
if param0.dtype == torch.half:
|
95 |
+
tdp.tdp_fp16(param0, param1, weight, bias, times, result)
|
96 |
+
elif param0.dtype == torch.float:
|
97 |
+
tdp.tdp_fp32(param0, param1, weight, bias, times, result)
|
98 |
+
else:
|
99 |
+
raise NotImplementedError
|
100 |
+
return result
|
101 |
+
|
102 |
+
def corrcoef(x, y):
|
103 |
+
return torch.corrcoef(torch.stack([x.reshape(-1).float(), y.reshape(-1).float()], dim = 0))[0, 1]
|
104 |
+
|
105 |
+
def tdp_cuda_unit_test():
|
106 |
+
print("***** tdp_cuda_unit_test *****")
|
107 |
+
|
108 |
+
batch_size = random.randrange(1, 128)
|
109 |
+
num_params = random.randrange(1, 1000000) * 2
|
110 |
+
print("batch_size", batch_size, "num_params", num_params)
|
111 |
+
|
112 |
+
param0 = torch.randn(num_params).cuda()
|
113 |
+
param1 = torch.randn(num_params).cuda()
|
114 |
+
weight = torch.randn(num_params).cuda()
|
115 |
+
bias = torch.randn(num_params).cuda()
|
116 |
+
times = torch.rand(batch_size).cuda()
|
117 |
+
|
118 |
+
ref = tdp_torch(param0, param1, weight, bias, times)
|
119 |
+
|
120 |
+
out = tdp_cuda(param0, param1, weight, bias, times)
|
121 |
+
print(corrcoef(ref, out), (ref - out).abs().max())
|
122 |
+
|
123 |
+
out = tdp_cuda(param0.half(), param1.half(), weight.half(), bias.half(), times.half()).float()
|
124 |
+
print(corrcoef(ref, out), (ref - out).abs().max())
|
125 |
+
|
126 |
+
def backward_tdp_cuda_unit_test():
|
127 |
+
print("***** backward_tdp_cuda_unit_test *****")
|
128 |
+
|
129 |
+
batch_size = random.randrange(1, 128)
|
130 |
+
num_params = random.randrange(1, 100000) * 2
|
131 |
+
print("batch_size", batch_size, "num_params", num_params)
|
132 |
+
|
133 |
+
param0 = torch.randn(num_params).cuda()
|
134 |
+
param1 = torch.randn(num_params).cuda()
|
135 |
+
weight = torch.randn(num_params).cuda()
|
136 |
+
bias = torch.randn(num_params).cuda()
|
137 |
+
times = torch.rand(batch_size).cuda()
|
138 |
+
g_result = torch.randn(batch_size, num_params).cuda()
|
139 |
+
|
140 |
+
refs = backward_tdp_torch(param0, param1, weight, bias, times, g_result)
|
141 |
+
|
142 |
+
outs = backward_tdp_cuda(param0, param1, weight, bias, times, g_result)
|
143 |
+
for r, o in zip(refs, outs):
|
144 |
+
print(corrcoef(r, o), (r - o).abs().max())
|
145 |
+
|
146 |
+
outs = backward_tdp_cuda(param0.half(), param1.half(), weight.half(), bias.half(), times.half(), g_result.half())
|
147 |
+
for r, o in zip(refs, outs):
|
148 |
+
print(corrcoef(r, o), (r - o).abs().max())
|
149 |
+
|
150 |
+
def autograd_unit_test():
|
151 |
+
print("***** autograd_unit_test *****")
|
152 |
+
batch_size = random.randrange(1, 128)
|
153 |
+
num_params = random.randrange(1, 100000) * 2
|
154 |
+
print("batch_size", batch_size, "num_params", num_params)
|
155 |
+
|
156 |
+
def get_outputs(fn):
|
157 |
+
torch.manual_seed(1)
|
158 |
+
param0 = torch.randn(num_params, requires_grad = True).cuda()
|
159 |
+
param1 = torch.randn(num_params, requires_grad = True).cuda()
|
160 |
+
weight = torch.randn(num_params, requires_grad = True).cuda()
|
161 |
+
bias = torch.randn(num_params, requires_grad = True).cuda()
|
162 |
+
times = torch.rand(batch_size).cuda()
|
163 |
+
|
164 |
+
out = fn(param0, param1, weight, bias, times)
|
165 |
+
loss = ((out - 1.5) ** 2).mean()
|
166 |
+
|
167 |
+
param0.retain_grad()
|
168 |
+
param1.retain_grad()
|
169 |
+
weight.retain_grad()
|
170 |
+
bias.retain_grad()
|
171 |
+
|
172 |
+
loss.backward()
|
173 |
+
g_param0 = param0.grad
|
174 |
+
g_param1 = param1.grad
|
175 |
+
g_weight = weight.grad
|
176 |
+
g_bias = bias.grad
|
177 |
+
|
178 |
+
return out, g_param0, g_param1, g_weight, g_bias
|
179 |
+
|
180 |
+
refs = get_outputs(tdp_torch)
|
181 |
+
outs = get_outputs(TDP.apply)
|
182 |
+
for r, o in zip(refs, outs):
|
183 |
+
print(corrcoef(r, o), (r - o).abs().max())
|
184 |
+
|
185 |
+
def exported_tdp_unit_test():
|
186 |
+
print("***** exported_tdp_unit_test *****")
|
187 |
+
batch_size = random.randrange(1, 128)
|
188 |
+
num_params = random.randrange(1, 100000) * 2
|
189 |
+
print("batch_size", batch_size, "num_params", num_params)
|
190 |
+
|
191 |
+
def get_outputs(fn):
|
192 |
+
torch.manual_seed(1)
|
193 |
+
param0 = torch.randn(num_params, requires_grad = True).cuda()
|
194 |
+
param1 = torch.randn(num_params, requires_grad = True).cuda()
|
195 |
+
weight = torch.randn(num_params, requires_grad = True).cuda()
|
196 |
+
bias = torch.randn(num_params, requires_grad = True).cuda()
|
197 |
+
times = torch.rand(batch_size).cuda()
|
198 |
+
|
199 |
+
out = fn(param0, param1, weight, bias, times)
|
200 |
+
loss = ((out - 1.5) ** 2).mean()
|
201 |
+
|
202 |
+
param0.retain_grad()
|
203 |
+
param1.retain_grad()
|
204 |
+
weight.retain_grad()
|
205 |
+
bias.retain_grad()
|
206 |
+
|
207 |
+
loss.backward()
|
208 |
+
g_param0 = param0.grad
|
209 |
+
g_param1 = param1.grad
|
210 |
+
g_weight = weight.grad
|
211 |
+
g_bias = bias.grad
|
212 |
+
|
213 |
+
return out, g_param0, g_param1, g_weight, g_bias
|
214 |
+
|
215 |
+
refs = get_outputs(partial(exported_tdp, custom = False))
|
216 |
+
outs = get_outputs(partial(exported_tdp, custom = True))
|
217 |
+
for r, o in zip(refs, outs):
|
218 |
+
print(corrcoef(r, o), (r - o).abs().max())
|
219 |
+
|
220 |
+
def tdp_cuda_profile():
|
221 |
+
print("***** tdp_cuda_profile *****")
|
222 |
+
def profiler(fn, args):
|
223 |
+
for _ in range(10):
|
224 |
+
fn(*args)
|
225 |
+
torch.cuda.synchronize()
|
226 |
+
t0 = time.time()
|
227 |
+
for _ in range(100):
|
228 |
+
fn(*args)
|
229 |
+
torch.cuda.synchronize()
|
230 |
+
t1 = time.time()
|
231 |
+
return t1 - t0
|
232 |
+
|
233 |
+
batch_size = 16
|
234 |
+
num_params = 1024 * 1024
|
235 |
+
print("batch_size", batch_size, "num_params", num_params)
|
236 |
+
|
237 |
+
param0 = torch.randn(num_params).cuda()
|
238 |
+
param1 = torch.randn(num_params).cuda()
|
239 |
+
weight = torch.randn(num_params).cuda()
|
240 |
+
bias = torch.randn(num_params).cuda()
|
241 |
+
times = torch.rand(batch_size).cuda()
|
242 |
+
|
243 |
+
print("ref", profiler(tdp_torch, (param0, param1, weight, bias, times)))
|
244 |
+
print("cuda", profiler(tdp_cuda, (param0, param1, weight, bias, times)))
|
245 |
+
|
246 |
+
print("ref", profiler(tdp_torch, (param0.half(), param1.half(), weight.half(), bias.half(), times.half())))
|
247 |
+
print("cuda", profiler(tdp_cuda, (param0.half(), param1.half(), weight.half(), bias.half(), times.half())))
|
248 |
+
|
249 |
+
def backward_tdp_cuda_profile():
|
250 |
+
print("***** backward_tdp_cuda_profile *****")
|
251 |
+
def profiler(fn, args):
|
252 |
+
for _ in range(10):
|
253 |
+
fn(*args)
|
254 |
+
torch.cuda.synchronize()
|
255 |
+
t0 = time.time()
|
256 |
+
for _ in range(100):
|
257 |
+
fn(*args)
|
258 |
+
torch.cuda.synchronize()
|
259 |
+
t1 = time.time()
|
260 |
+
return t1 - t0
|
261 |
+
|
262 |
+
batch_size = 16
|
263 |
+
num_params = 1024 * 1024
|
264 |
+
print("batch_size", batch_size, "num_params", num_params)
|
265 |
+
|
266 |
+
param0 = torch.randn(num_params).cuda()
|
267 |
+
param1 = torch.randn(num_params).cuda()
|
268 |
+
weight = torch.randn(num_params).cuda()
|
269 |
+
bias = torch.randn(num_params).cuda()
|
270 |
+
times = torch.rand(batch_size).cuda()
|
271 |
+
g_result = torch.randn(batch_size, num_params).cuda()
|
272 |
+
|
273 |
+
|
274 |
+
print("ref", profiler(backward_tdp_torch, (param0, param1, weight, bias, times, g_result)))
|
275 |
+
print("cuda", profiler(backward_tdp_cuda, (param0, param1, weight, bias, times, g_result)))
|
276 |
+
|
277 |
+
print("ref", profiler(backward_tdp_torch, (param0.half(), param1.half(), weight.half(), bias.half(), times.half(), g_result.half())))
|
278 |
+
print("cuda", profiler(backward_tdp_cuda, (param0.half(), param1.half(), weight.half(), bias.half(), times.half(), g_result.half())))
|
279 |
+
|
280 |
+
def autogad_profile():
|
281 |
+
print("***** autogad_profile *****")
|
282 |
+
def profiler(fn, args):
|
283 |
+
for _ in range(10):
|
284 |
+
fn(*args).mean().backward()
|
285 |
+
torch.cuda.synchronize()
|
286 |
+
t0 = time.time()
|
287 |
+
for _ in range(100):
|
288 |
+
fn(*args).mean().backward()
|
289 |
+
torch.cuda.synchronize()
|
290 |
+
t1 = time.time()
|
291 |
+
return t1 - t0
|
292 |
+
|
293 |
+
batch_size = 16
|
294 |
+
num_params = 1024 * 1024
|
295 |
+
print("batch_size", batch_size, "num_params", num_params)
|
296 |
+
|
297 |
+
param0 = nn.Parameter(torch.randn(num_params)).cuda()
|
298 |
+
param1 = nn.Parameter(torch.randn(num_params)).cuda()
|
299 |
+
weight = nn.Parameter(torch.randn(num_params)).cuda()
|
300 |
+
bias = nn.Parameter(torch.randn(num_params)).cuda()
|
301 |
+
times = torch.rand(batch_size).cuda()
|
302 |
+
|
303 |
+
print("ref", profiler(tdp_torch, (param0, param1, weight, bias, times)))
|
304 |
+
print("cuda", profiler(TDP.apply, (param0, param1, weight, bias, times)))
|
305 |
+
|
306 |
+
print("ref", profiler(tdp_torch, (param0.half(), param1.half(), weight.half(), bias.half(), times.half())))
|
307 |
+
print("cuda", profiler(TDP.apply, (param0.half(), param1.half(), weight.half(), bias.half(), times.half())))
|
308 |
+
|
309 |
+
if __name__ == "__main__":
|
310 |
+
tdp_cuda_unit_test()
|
311 |
+
backward_tdp_cuda_unit_test()
|
312 |
+
autograd_unit_test()
|
313 |
+
exported_tdp_unit_test()
|
314 |
+
tdp_cuda_profile()
|
315 |
+
backward_tdp_cuda_profile()
|
316 |
+
autogad_profile()
|
libs/model/sigmoid/module.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import reduce
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
from .kernel import exported_tdp
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from functools import partial
|
9 |
+
from timm.models.layers import trunc_normal_
|
10 |
+
|
11 |
+
class TimeDependentParameter(nn.Module):
|
12 |
+
def __init__(self, shape, init_fn):
|
13 |
+
super().__init__()
|
14 |
+
self.shape = shape
|
15 |
+
|
16 |
+
w = torch.empty(*shape)
|
17 |
+
init_fn(w)
|
18 |
+
|
19 |
+
self.param0 = nn.Parameter(w.clone().detach())
|
20 |
+
self.param1 = nn.Parameter(w.clone().detach())
|
21 |
+
|
22 |
+
self.nodecay_weight = nn.Parameter(torch.zeros(*shape))
|
23 |
+
self.nodecay_bias = nn.Parameter(torch.zeros(*shape))
|
24 |
+
self.curr_weight = None
|
25 |
+
|
26 |
+
def forward(self):
|
27 |
+
weight = self.curr_weight
|
28 |
+
# self.curr_weight = None
|
29 |
+
return weight
|
30 |
+
|
31 |
+
def __repr__(self):
|
32 |
+
return f"TimeDependentParameter(shape={self.shape})"
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def seed_time(model, log_snr):
|
36 |
+
assert log_snr.dim() == 1
|
37 |
+
if torch.all(log_snr == log_snr[0]):
|
38 |
+
log_snr = log_snr[0][None]
|
39 |
+
time_condition = log_snr / 4.0
|
40 |
+
|
41 |
+
tdp_list = [module for module in model.modules() if isinstance(module, TimeDependentParameter)]
|
42 |
+
for tdp in tdp_list:
|
43 |
+
tdp.curr_weight = exported_tdp(tdp.param0, tdp.param1, tdp.nodecay_weight + 1, tdp.nodecay_bias, time_condition, custom = False)
|
44 |
+
|
45 |
+
class LayerNorm(nn.Module):
|
46 |
+
def __init__(self, dim, num_groups = 1, eps = 1e-05):
|
47 |
+
super().__init__()
|
48 |
+
self.eps = eps
|
49 |
+
self.dim = dim
|
50 |
+
self.num_groups = num_groups
|
51 |
+
self.weight = TimeDependentParameter((dim, ), nn.init.ones_)
|
52 |
+
self.bias = TimeDependentParameter((dim, ), nn.init.zeros_)
|
53 |
+
|
54 |
+
def _forward(self, x):
|
55 |
+
weight, bias = self.weight(), self.bias()
|
56 |
+
assert weight.shape[0] == bias.shape[0]
|
57 |
+
|
58 |
+
assert x.shape[-1] == self.dim
|
59 |
+
|
60 |
+
if weight.shape[0] == 1:
|
61 |
+
x = F.layer_norm(x, (self.dim, ), weight = weight[0], bias = bias[0], eps = self.eps)
|
62 |
+
else:
|
63 |
+
assert x.shape[0] == weight.shape[0]
|
64 |
+
x = F.layer_norm(x, (self.dim, ), eps = self.eps)
|
65 |
+
x = torch.addcmul(bias[:, None, :], weight[:, None, :], x)
|
66 |
+
|
67 |
+
return x
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
original_shape = x.shape
|
71 |
+
batch_size = x.shape[0]
|
72 |
+
assert self.dim == x.shape[-1]
|
73 |
+
|
74 |
+
x = x.reshape(batch_size, -1, self.dim)
|
75 |
+
x = self._forward(x)
|
76 |
+
x = x.reshape(*original_shape)
|
77 |
+
|
78 |
+
return x
|
79 |
+
|
80 |
+
class Linear(nn.Module):
|
81 |
+
def __init__(self, din, dout, bias = True, weight_init_fn = partial(trunc_normal_, std = 0.02)):
|
82 |
+
super().__init__()
|
83 |
+
self.din = din
|
84 |
+
self.dout = dout
|
85 |
+
self.weight = TimeDependentParameter((din, dout), weight_init_fn)
|
86 |
+
if bias:
|
87 |
+
self.bias = TimeDependentParameter((dout, ), nn.init.zeros_)
|
88 |
+
else:
|
89 |
+
self.bias = None
|
90 |
+
|
91 |
+
def _forward(self, x):
|
92 |
+
weight = self.weight()
|
93 |
+
bias = self.bias() if self.bias is not None else None
|
94 |
+
|
95 |
+
# if weight.shape[0] == 1:
|
96 |
+
# B, L, D = x.shape
|
97 |
+
# if bias is not None:
|
98 |
+
# assert weight.shape[0] == bias.shape[0]
|
99 |
+
# x = torch.addmm(bias, x.reshape(B * L, D), weight[0])
|
100 |
+
# else:
|
101 |
+
# x = torch.matmul(x.reshape(B * L, D), weight[0])
|
102 |
+
# x = x.reshape(B, L, -1)
|
103 |
+
# else:
|
104 |
+
if bias is not None:
|
105 |
+
x = torch.baddbmm(bias[:, None, :], x, weight)
|
106 |
+
else:
|
107 |
+
x = torch.bmm(x, weight)
|
108 |
+
|
109 |
+
return x
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
original_shape = x.shape
|
113 |
+
batch_size = x.shape[0]
|
114 |
+
|
115 |
+
x = x.reshape(batch_size, -1, self.din)
|
116 |
+
x = self._forward(x)
|
117 |
+
x = x.reshape(*(list(original_shape[:-1]) + [self.dout]))
|
118 |
+
|
119 |
+
return x
|
120 |
+
|
121 |
+
class RMSNorm(nn.Module):
|
122 |
+
def __init__(self, d, p=-1., eps=1e-8, bias=False):
|
123 |
+
"""
|
124 |
+
Root Mean Square Layer Normalization
|
125 |
+
:param d: model size
|
126 |
+
:param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
|
127 |
+
:param eps: epsilon value, default 1e-8
|
128 |
+
:param bias: whether use bias term for RMSNorm, disabled by
|
129 |
+
default because RMSNorm doesn't enforce re-centering invariance.
|
130 |
+
"""
|
131 |
+
super(RMSNorm, self).__init__()
|
132 |
+
|
133 |
+
self.eps = eps
|
134 |
+
self.d = d
|
135 |
+
self.p = p
|
136 |
+
self.bias = bias
|
137 |
+
|
138 |
+
self.scale = nn.Parameter(torch.ones(d))
|
139 |
+
self.register_parameter("scale", self.scale)
|
140 |
+
|
141 |
+
if self.bias:
|
142 |
+
self.offset = nn.Parameter(torch.zeros(d))
|
143 |
+
self.register_parameter("offset", self.offset)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
if self.p < 0. or self.p > 1.:
|
147 |
+
norm_x = x.norm(2, dim=-1, keepdim=True)
|
148 |
+
d_x = self.d
|
149 |
+
else:
|
150 |
+
partial_size = int(self.d * self.p)
|
151 |
+
partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)
|
152 |
+
|
153 |
+
norm_x = partial_x.norm(2, dim=-1, keepdim=True)
|
154 |
+
d_x = partial_size
|
155 |
+
|
156 |
+
rms_x = norm_x * d_x ** (-1. / 2)
|
157 |
+
x_normed = x / (rms_x + self.eps)
|
158 |
+
|
159 |
+
if self.bias:
|
160 |
+
return self.scale * x_normed + self.offset
|
161 |
+
|
162 |
+
return self.scale * x_normed
|
163 |
+
|
164 |
+
|
165 |
+
class TDRMSNorm(nn.Module):
|
166 |
+
def __init__(self, d, p=-1., eps=1e-8, bias=False):
|
167 |
+
"""
|
168 |
+
Root Mean Square Layer Normalization
|
169 |
+
:param d: model size
|
170 |
+
:param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
|
171 |
+
:param eps: epsilon value, default 1e-8
|
172 |
+
:param bias: whether use bias term for RMSNorm, disabled by
|
173 |
+
default because RMSNorm doesn't enforce re-centering invariance.
|
174 |
+
"""
|
175 |
+
super(TDRMSNorm, self).__init__()
|
176 |
+
|
177 |
+
self.eps = eps
|
178 |
+
self.d = d
|
179 |
+
self.p = p
|
180 |
+
self.bias = bias
|
181 |
+
|
182 |
+
# self.scale = nn.Parameter(torch.ones(d))
|
183 |
+
self.scale = TimeDependentParameter((d, ), nn.init.ones_)
|
184 |
+
# self.register_parameter("scale", self.scale)
|
185 |
+
|
186 |
+
if self.bias:
|
187 |
+
# self.offset = nn.Parameter(torch.zeros(d))
|
188 |
+
self.offset = TimeDependentParameter((d, ), nn.init.zeros_)
|
189 |
+
# self.register_parameter("offset", self.offset)
|
190 |
+
|
191 |
+
def forward(self, x):
|
192 |
+
if self.p < 0. or self.p > 1.:
|
193 |
+
norm_x = x.norm(2, dim=-1, keepdim=True)
|
194 |
+
d_x = self.d
|
195 |
+
else:
|
196 |
+
partial_size = int(self.d * self.p)
|
197 |
+
partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)
|
198 |
+
|
199 |
+
norm_x = partial_x.norm(2, dim=-1, keepdim=True)
|
200 |
+
d_x = partial_size
|
201 |
+
|
202 |
+
rms_x = norm_x * d_x ** (-1. / 2)
|
203 |
+
x_normed = x / (rms_x + self.eps)
|
204 |
+
|
205 |
+
_scale = self.scale()
|
206 |
+
|
207 |
+
if self.bias:
|
208 |
+
# return self.scale * x_normed + self.offset
|
209 |
+
_offset = self.offset()
|
210 |
+
if _scale.shape[0] == 1:
|
211 |
+
return _scale[0] * x_normed + _offset[0]
|
212 |
+
elif x_normed.dim() == 3:
|
213 |
+
return torch.addcmul(_offset[:, None, :], _scale[:, None, :], x_normed)
|
214 |
+
elif x_normed.dim() == 4:
|
215 |
+
return torch.addcmul(_offset[:, None, None, :], _scale[:, None, None, :], x_normed)
|
216 |
+
else:
|
217 |
+
raise NotImplementedError
|
218 |
+
|
219 |
+
# return self.scale * x_normed
|
220 |
+
if _scale.shape[0] == 1:
|
221 |
+
return _scale[0] * x_normed
|
222 |
+
elif x_normed.dim() == 3:
|
223 |
+
return _scale[:, None, :] * x_normed
|
224 |
+
elif x_normed.dim() == 4:
|
225 |
+
return _scale[:, None, None, :] * x_normed
|
226 |
+
else:
|
227 |
+
raise NotImplementedError
|
228 |
+
|
229 |
+
|
230 |
+
def zero_init(layer):
|
231 |
+
nn.init.zeros_(layer.weight)
|
232 |
+
if layer.bias is not None:
|
233 |
+
nn.init.zeros_(layer.bias)
|
234 |
+
return layer
|
235 |
+
|
236 |
+
def rms_norm(x, scale, eps):
|
237 |
+
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
238 |
+
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
|
239 |
+
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
|
240 |
+
return x * scale.to(x.dtype)
|
241 |
+
|
242 |
+
class AdaRMSNorm(nn.Module):
|
243 |
+
def __init__(self, features, cond_features, eps=1e-6):
|
244 |
+
super().__init__()
|
245 |
+
self.eps = eps
|
246 |
+
self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
|
247 |
+
|
248 |
+
def extra_repr(self):
|
249 |
+
return f"eps={self.eps},"
|
250 |
+
|
251 |
+
def forward(self, x, cond):
|
252 |
+
return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
|
253 |
+
|
254 |
+
class QKNorm(nn.Module):
|
255 |
+
def __init__(self, n_heads, eps=1e-6, max_scale=100.0):
|
256 |
+
super().__init__()
|
257 |
+
self.eps = eps
|
258 |
+
self.max_scale = math.log(max_scale)
|
259 |
+
self.scale = nn.Parameter(torch.full((n_heads,), math.log(10.0)))
|
260 |
+
self.proj_()
|
261 |
+
|
262 |
+
def extra_repr(self):
|
263 |
+
return f"n_heads={self.scale.shape[0]}, eps={self.eps}"
|
264 |
+
|
265 |
+
@torch.no_grad()
|
266 |
+
def proj_(self):
|
267 |
+
"""Modify the scale in-place so it doesn't get "stuck" with zero gradient if it's clamped
|
268 |
+
to the max value."""
|
269 |
+
self.scale.clamp_(max=self.max_scale)
|
270 |
+
|
271 |
+
def forward(self, x):
|
272 |
+
self.proj_()
|
273 |
+
scale = torch.exp(0.5 * self.scale - 0.25 * math.log(x.shape[-1]))
|
274 |
+
return rms_norm(x, scale[:, None, None], self.eps)
|
libs/model/trans_autoencoder.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Transformer-based varitional encoder model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import math
|
9 |
+
import copy
|
10 |
+
|
11 |
+
|
12 |
+
def clones(module, N):
|
13 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
14 |
+
|
15 |
+
|
16 |
+
def build_mask(base_mask):
|
17 |
+
assert len(base_mask.shape) == 2
|
18 |
+
batch_size, seq_len = base_mask.shape[0], base_mask.shape[-1]
|
19 |
+
|
20 |
+
# create subsequent token mask
|
21 |
+
sub_mask = torch.tril(torch.ones([seq_len, seq_len],
|
22 |
+
dtype=torch.uint8)).type_as(base_mask)
|
23 |
+
sub_mask = sub_mask.unsqueeze(0).expand(batch_size, -1, -1)
|
24 |
+
base_mask = base_mask.unsqueeze(1).expand(-1, seq_len, -1)
|
25 |
+
return sub_mask & base_mask
|
26 |
+
|
27 |
+
|
28 |
+
class Adaptor(nn.Module):
|
29 |
+
def __init__(self, input_dim, tar_dim):
|
30 |
+
super(Adaptor, self).__init__()
|
31 |
+
|
32 |
+
if tar_dim == 32768:
|
33 |
+
output_channel = 8
|
34 |
+
elif tar_dim == 16384:
|
35 |
+
output_channel = 4
|
36 |
+
else:
|
37 |
+
raise NotImplementedError("only support 512px, 256px does not need this")
|
38 |
+
|
39 |
+
self.tar_dim = tar_dim
|
40 |
+
|
41 |
+
self.fc1 = nn.Linear(input_dim, 4096)
|
42 |
+
self.ln_fc1 = nn.LayerNorm(4096)
|
43 |
+
self.fc2 = nn.Linear(4096, 4096)
|
44 |
+
self.ln_fc2 = nn.LayerNorm(4096)
|
45 |
+
|
46 |
+
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
|
47 |
+
self.ln_conv1 = nn.LayerNorm([32, 64, 64])
|
48 |
+
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
|
49 |
+
self.ln_conv2 = nn.LayerNorm([64, 64, 64])
|
50 |
+
self.conv3 = nn.Conv2d(in_channels=64, out_channels=output_channel, kernel_size=3, padding=1)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x = torch.relu(self.ln_fc1(self.fc1(x)))
|
54 |
+
x = torch.relu(self.ln_fc2(self.fc2(x)))
|
55 |
+
|
56 |
+
x = x.view(-1, 1, 64, 64)
|
57 |
+
|
58 |
+
x = torch.relu(self.ln_conv1(self.conv1(x)))
|
59 |
+
x = torch.relu(self.ln_conv2(self.conv2(x)))
|
60 |
+
|
61 |
+
x = self.conv3(x)
|
62 |
+
x = x.view(-1, self.tar_dim)
|
63 |
+
|
64 |
+
return x
|
65 |
+
|
66 |
+
|
67 |
+
class Compressor(nn.Module):
|
68 |
+
def __init__(self, input_dim=4096, tar_dim=2048):
|
69 |
+
super(Compressor, self).__init__()
|
70 |
+
|
71 |
+
self.fc1 = nn.Linear(input_dim, tar_dim)
|
72 |
+
self.ln_fc1 = nn.LayerNorm(tar_dim)
|
73 |
+
self.fc2 = nn.Linear(tar_dim, tar_dim)
|
74 |
+
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
x = torch.relu(self.ln_fc1(self.fc1(x)))
|
78 |
+
x = self.fc2(x)
|
79 |
+
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
class TransEncoder(nn.Module):
|
84 |
+
def __init__(self, d_model, N, num_token, head_num, d_ff, latten_size, down_sample_block=3, dropout=0.1, last_norm=True):
|
85 |
+
super(TransEncoder, self).__init__()
|
86 |
+
self.N = N
|
87 |
+
if d_model==4096:
|
88 |
+
# for T5-XXL, first use MLP to compress into 1024
|
89 |
+
self.compressor = Compressor(input_dim=d_model, tar_dim=1024)
|
90 |
+
d_model = 1024
|
91 |
+
else:
|
92 |
+
self.compressor = None
|
93 |
+
|
94 |
+
self.layers = clones(EncoderLayer(MultiHeadAttentioin(d_model, head_num, dropout=dropout),
|
95 |
+
FeedForward(d_model, d_ff, dropout=dropout),
|
96 |
+
LayerNorm(d_model),
|
97 |
+
LayerNorm(d_model)), N)
|
98 |
+
|
99 |
+
self.reduction_layers = nn.ModuleList()
|
100 |
+
for _ in range(down_sample_block):
|
101 |
+
self.reduction_layers.append(
|
102 |
+
EncoderReductionLayer(MultiHeadAttentioin(d_model, head_num, dropout=dropout),
|
103 |
+
FeedForward(d_model, d_ff, dropout=dropout),
|
104 |
+
nn.Linear(d_model, d_model // 2),
|
105 |
+
LayerNorm(d_model),
|
106 |
+
LayerNorm(d_model)))
|
107 |
+
d_model = d_model // 2
|
108 |
+
|
109 |
+
if latten_size == 8192 or latten_size == 4096:
|
110 |
+
self.arc = 0
|
111 |
+
self.linear = nn.Linear(d_model*num_token, latten_size)
|
112 |
+
self.norm = LayerNorm(latten_size) if last_norm else None
|
113 |
+
else:
|
114 |
+
self.arc = 1
|
115 |
+
self.adaptor = Adaptor(d_model*num_token, latten_size)
|
116 |
+
|
117 |
+
|
118 |
+
def forward(self, x, mask):
|
119 |
+
mask = mask.unsqueeze(1)
|
120 |
+
|
121 |
+
if self.compressor is not None:
|
122 |
+
x = self.compressor(x)
|
123 |
+
|
124 |
+
for i, layer in enumerate(self.layers):
|
125 |
+
x = layer(x, mask)
|
126 |
+
|
127 |
+
for i, layer in enumerate(self.reduction_layers):
|
128 |
+
x = layer(x, mask)
|
129 |
+
|
130 |
+
if self.arc == 0:
|
131 |
+
x = self.linear(x.view(x.shape[0],-1))
|
132 |
+
x = self.norm(x) if self.norm else x
|
133 |
+
else:
|
134 |
+
x = self.adaptor(x.view(x.shape[0],-1))
|
135 |
+
|
136 |
+
return x
|
137 |
+
|
138 |
+
|
139 |
+
class EncoderLayer(nn.Module):
|
140 |
+
def __init__(self, attn, feed_forward, norm1, norm2, dropout=0.1):
|
141 |
+
super(EncoderLayer, self).__init__()
|
142 |
+
self.attn = attn
|
143 |
+
self.feed_forward = feed_forward
|
144 |
+
self.norm1, self.norm2 = norm1, norm2
|
145 |
+
|
146 |
+
self.dropout1 = nn.Dropout(dropout)
|
147 |
+
self.dropout2 = nn.Dropout(dropout)
|
148 |
+
|
149 |
+
def forward(self, x, mask):
|
150 |
+
# multihead attn & norm
|
151 |
+
a = self.attn(x, x, x, mask)
|
152 |
+
t = self.norm1(x + self.dropout1(a))
|
153 |
+
|
154 |
+
# feed forward & norm
|
155 |
+
z = self.feed_forward(t) # linear(dropout(act(linear(x)))))
|
156 |
+
y = self.norm2(t + self.dropout2(z))
|
157 |
+
|
158 |
+
return y
|
159 |
+
|
160 |
+
|
161 |
+
class EncoderReductionLayer(nn.Module):
|
162 |
+
def __init__(self, attn, feed_forward, reduction, norm1, norm2, dropout=0.1):
|
163 |
+
super(EncoderReductionLayer, self).__init__()
|
164 |
+
self.attn = attn
|
165 |
+
self.feed_forward = feed_forward
|
166 |
+
self.reduction = reduction
|
167 |
+
self.norm1, self.norm2 = norm1, norm2
|
168 |
+
|
169 |
+
self.dropout1 = nn.Dropout(dropout)
|
170 |
+
self.dropout2 = nn.Dropout(dropout)
|
171 |
+
|
172 |
+
def forward(self, x, mask):
|
173 |
+
# multihead attn & norm
|
174 |
+
a = self.attn(x, x, x, mask)
|
175 |
+
t = self.norm1(x + self.dropout1(a))
|
176 |
+
|
177 |
+
# feed forward & norm
|
178 |
+
z = self.feed_forward(t) # linear(dropout(act(linear(x)))))
|
179 |
+
y = self.norm2(t + self.dropout2(z))
|
180 |
+
|
181 |
+
# reduction
|
182 |
+
# y = self.reduction(y).view(x.shape[0], -1, x.shape[-1])
|
183 |
+
y = self.reduction(y)
|
184 |
+
|
185 |
+
return y
|
186 |
+
|
187 |
+
|
188 |
+
class MultiHeadAttentioin(nn.Module):
|
189 |
+
def __init__(self, d_model, head_num, dropout=0.1, d_v=None):
|
190 |
+
super(MultiHeadAttentioin, self).__init__()
|
191 |
+
assert d_model % head_num == 0, "d_model must be divisible by head_num"
|
192 |
+
|
193 |
+
self.d_model = d_model
|
194 |
+
self.head_num = head_num
|
195 |
+
self.d_k = d_model // head_num
|
196 |
+
self.d_v = self.d_k if d_v is None else d_v
|
197 |
+
|
198 |
+
# d_model = d_k * head_num
|
199 |
+
self.W_Q = nn.Linear(d_model, head_num * self.d_k)
|
200 |
+
self.W_K = nn.Linear(d_model, head_num * self.d_k)
|
201 |
+
self.W_V = nn.Linear(d_model, head_num * self.d_v)
|
202 |
+
self.W_O = nn.Linear(d_model, d_model)
|
203 |
+
|
204 |
+
self.dropout = nn.Dropout(dropout)
|
205 |
+
|
206 |
+
def scaled_dp_attn(self, query, key, value, mask=None):
|
207 |
+
assert self.d_k == query.shape[-1]
|
208 |
+
|
209 |
+
# scores: [batch_size, head_num, seq_len, seq_len]
|
210 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
|
211 |
+
|
212 |
+
# if torch.isinf(scores).any():
|
213 |
+
# # to avoid leaking
|
214 |
+
# scores = torch.where(scores == float('-inf'), torch.tensor(-65504.0), scores)
|
215 |
+
# scores = torch.where(scores == float('inf'), torch.tensor(65504.0), scores)
|
216 |
+
|
217 |
+
if mask is not None:
|
218 |
+
assert mask.ndim == 3, "Mask shape {} doesn't seem right...".format(mask.shape)
|
219 |
+
mask = mask.unsqueeze(1)
|
220 |
+
try:
|
221 |
+
if scores.dtype == torch.float32:
|
222 |
+
scores = scores.masked_fill(mask == 0, -1e9)
|
223 |
+
else:
|
224 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
225 |
+
except RuntimeError:
|
226 |
+
print("- scores device: {}".format(scores.device))
|
227 |
+
print("- mask device: {}".format(mask.device))
|
228 |
+
|
229 |
+
# attn: [batch_size, head_num, seq_len, seq_len]
|
230 |
+
attn = F.softmax(scores, dim=-1)
|
231 |
+
attn = self.dropout(attn)
|
232 |
+
return torch.matmul(attn, value), attn
|
233 |
+
|
234 |
+
def forward(self, q, k, v, mask):
|
235 |
+
batch_size = q.shape[0]
|
236 |
+
|
237 |
+
query = self.W_Q(q).view(batch_size, -1, self.head_num, self.d_k).transpose(1, 2)
|
238 |
+
key = self.W_K(k).view(batch_size, -1, self.head_num, self.d_k).transpose(1, 2)
|
239 |
+
value = self.W_V(v).view(batch_size, -1, self.head_num, self.d_k).transpose(1, 2)
|
240 |
+
|
241 |
+
heads, attn = self.scaled_dp_attn(query, key, value, mask)
|
242 |
+
heads = heads.transpose(1, 2).contiguous().view(batch_size, -1,
|
243 |
+
self.head_num * self.d_k)
|
244 |
+
assert heads.shape[-1] == self.d_model and heads.shape[0] == batch_size
|
245 |
+
|
246 |
+
y = self.W_O(heads)
|
247 |
+
|
248 |
+
assert y.shape == q.shape
|
249 |
+
return y
|
250 |
+
|
251 |
+
|
252 |
+
class LayerNorm(nn.Module):
|
253 |
+
def __init__(self, layer_size, eps=1e-5):
|
254 |
+
super(LayerNorm, self).__init__()
|
255 |
+
self.g = nn.Parameter(torch.ones(layer_size))
|
256 |
+
self.b = nn.Parameter(torch.zeros(layer_size))
|
257 |
+
self.eps = eps
|
258 |
+
|
259 |
+
def forward(self, x):
|
260 |
+
mean = x.mean(-1, keepdim=True)
|
261 |
+
std = x.std(-1, keepdim=True)
|
262 |
+
x = (x - mean) / (std + self.eps)
|
263 |
+
return self.g * x + self.b
|
264 |
+
|
265 |
+
|
266 |
+
class FeedForward(nn.Module):
|
267 |
+
def __init__(self, d_model, d_ff, dropout=0.1, act='relu', d_output=None):
|
268 |
+
super(FeedForward, self).__init__()
|
269 |
+
self.d_model = d_model
|
270 |
+
self.d_ff = d_ff
|
271 |
+
d_output = d_model if d_output is None else d_output
|
272 |
+
|
273 |
+
self.ffn_1 = nn.Linear(d_model, d_ff)
|
274 |
+
self.ffn_2 = nn.Linear(d_ff, d_output)
|
275 |
+
|
276 |
+
if act == 'relu':
|
277 |
+
self.act = nn.ReLU()
|
278 |
+
elif act == 'rrelu':
|
279 |
+
self.act = nn.RReLU()
|
280 |
+
else:
|
281 |
+
raise NotImplementedError
|
282 |
+
|
283 |
+
self.dropout = nn.Dropout(dropout)
|
284 |
+
|
285 |
+
def forward(self, x):
|
286 |
+
y = self.ffn_2(self.dropout(self.act(self.ffn_1(x))))
|
287 |
+
return y
|
288 |
+
|
289 |
+
|
libs/t5.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains code for t5 model.
|
3 |
+
|
4 |
+
Reference:
|
5 |
+
https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py
|
6 |
+
"""
|
7 |
+
|
8 |
+
# -*- coding: utf-8 -*-
|
9 |
+
import os
|
10 |
+
import re
|
11 |
+
import html
|
12 |
+
import urllib.parse as ul
|
13 |
+
|
14 |
+
import ftfy
|
15 |
+
import torch
|
16 |
+
from bs4 import BeautifulSoup
|
17 |
+
from transformers import T5EncoderModel, AutoTokenizer
|
18 |
+
from huggingface_hub import hf_hub_download
|
19 |
+
|
20 |
+
|
21 |
+
class T5Embedder:
|
22 |
+
|
23 |
+
available_models = ['t5-v1_1-xxl']
|
24 |
+
bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
|
25 |
+
|
26 |
+
def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=None, hf_token=None, use_text_preprocessing=True,
|
27 |
+
t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None):
|
28 |
+
self.device = torch.device(device)
|
29 |
+
self.torch_dtype = torch_dtype or torch.bfloat16
|
30 |
+
if t5_model_kwargs is None:
|
31 |
+
t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
|
32 |
+
if use_offload_folder is not None:
|
33 |
+
t5_model_kwargs['offload_folder'] = use_offload_folder
|
34 |
+
t5_model_kwargs['device_map'] = {
|
35 |
+
'shared': self.device,
|
36 |
+
'encoder.embed_tokens': self.device,
|
37 |
+
'encoder.block.0': self.device,
|
38 |
+
'encoder.block.1': self.device,
|
39 |
+
'encoder.block.2': self.device,
|
40 |
+
'encoder.block.3': self.device,
|
41 |
+
'encoder.block.4': self.device,
|
42 |
+
'encoder.block.5': self.device,
|
43 |
+
'encoder.block.6': self.device,
|
44 |
+
'encoder.block.7': self.device,
|
45 |
+
'encoder.block.8': self.device,
|
46 |
+
'encoder.block.9': self.device,
|
47 |
+
'encoder.block.10': self.device,
|
48 |
+
'encoder.block.11': self.device,
|
49 |
+
'encoder.block.12': 'disk',
|
50 |
+
'encoder.block.13': 'disk',
|
51 |
+
'encoder.block.14': 'disk',
|
52 |
+
'encoder.block.15': 'disk',
|
53 |
+
'encoder.block.16': 'disk',
|
54 |
+
'encoder.block.17': 'disk',
|
55 |
+
'encoder.block.18': 'disk',
|
56 |
+
'encoder.block.19': 'disk',
|
57 |
+
'encoder.block.20': 'disk',
|
58 |
+
'encoder.block.21': 'disk',
|
59 |
+
'encoder.block.22': 'disk',
|
60 |
+
'encoder.block.23': 'disk',
|
61 |
+
'encoder.final_layer_norm': 'disk',
|
62 |
+
'encoder.dropout': 'disk',
|
63 |
+
}
|
64 |
+
else:
|
65 |
+
t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}
|
66 |
+
|
67 |
+
self.use_text_preprocessing = use_text_preprocessing
|
68 |
+
self.hf_token = hf_token
|
69 |
+
self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
|
70 |
+
self.dir_or_name = dir_or_name
|
71 |
+
|
72 |
+
tokenizer_path, path = dir_or_name, dir_or_name
|
73 |
+
if dir_or_name in self.available_models:
|
74 |
+
cache_dir = os.path.join(self.cache_dir, dir_or_name)
|
75 |
+
for filename in [
|
76 |
+
'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
|
77 |
+
'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
|
78 |
+
]:
|
79 |
+
hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
|
80 |
+
force_filename=filename, token=self.hf_token)
|
81 |
+
tokenizer_path, path = cache_dir, cache_dir
|
82 |
+
else:
|
83 |
+
cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
|
84 |
+
for filename in [
|
85 |
+
'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
|
86 |
+
]:
|
87 |
+
hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
|
88 |
+
force_filename=filename, token=self.hf_token)
|
89 |
+
tokenizer_path = cache_dir
|
90 |
+
|
91 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
92 |
+
self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
|
93 |
+
|
94 |
+
def get_text_embeddings(self, texts):
|
95 |
+
texts = [self.text_preprocessing(text) for text in texts]
|
96 |
+
|
97 |
+
text_tokens_and_mask = self.tokenizer(
|
98 |
+
texts,
|
99 |
+
max_length=77,
|
100 |
+
padding='max_length',
|
101 |
+
truncation=True,
|
102 |
+
return_attention_mask=True,
|
103 |
+
add_special_tokens=True,
|
104 |
+
return_tensors='pt'
|
105 |
+
)
|
106 |
+
text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']
|
107 |
+
text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
|
108 |
+
|
109 |
+
with torch.no_grad():
|
110 |
+
text_encoder_embs = self.model(
|
111 |
+
input_ids=text_tokens_and_mask['input_ids'].to(self.device),
|
112 |
+
attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
|
113 |
+
)['last_hidden_state'].detach()
|
114 |
+
|
115 |
+
return text_encoder_embs, {'token_embedding': text_encoder_embs, 'token_mask': text_tokens_and_mask['attention_mask'].to(self.device), 'tokens': text_tokens_and_mask['input_ids'].to(self.device)}
|
116 |
+
|
117 |
+
def text_preprocessing(self, text):
|
118 |
+
if self.use_text_preprocessing:
|
119 |
+
# The exact text cleaning as was in the training stage:
|
120 |
+
text = self.clean_caption(text)
|
121 |
+
text = self.clean_caption(text)
|
122 |
+
return text
|
123 |
+
else:
|
124 |
+
return text.lower().strip()
|
125 |
+
|
126 |
+
@staticmethod
|
127 |
+
def basic_clean(text):
|
128 |
+
text = ftfy.fix_text(text)
|
129 |
+
text = html.unescape(html.unescape(text))
|
130 |
+
return text.strip()
|
131 |
+
|
132 |
+
def clean_caption(self, caption):
|
133 |
+
caption = str(caption)
|
134 |
+
caption = ul.unquote_plus(caption)
|
135 |
+
caption = caption.strip().lower()
|
136 |
+
caption = re.sub('<person>', 'person', caption)
|
137 |
+
# urls:
|
138 |
+
caption = re.sub(
|
139 |
+
r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
|
140 |
+
'', caption) # regex for urls
|
141 |
+
caption = re.sub(
|
142 |
+
r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
|
143 |
+
'', caption) # regex for urls
|
144 |
+
# html:
|
145 |
+
caption = BeautifulSoup(caption, features='html.parser').text
|
146 |
+
|
147 |
+
# @<nickname>
|
148 |
+
caption = re.sub(r'@[\w\d]+\b', '', caption)
|
149 |
+
|
150 |
+
# 31C0—31EF CJK Strokes
|
151 |
+
# 31F0—31FF Katakana Phonetic Extensions
|
152 |
+
# 3200—32FF Enclosed CJK Letters and Months
|
153 |
+
# 3300—33FF CJK Compatibility
|
154 |
+
# 3400—4DBF CJK Unified Ideographs Extension A
|
155 |
+
# 4DC0—4DFF Yijing Hexagram Symbols
|
156 |
+
# 4E00—9FFF CJK Unified Ideographs
|
157 |
+
caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
|
158 |
+
caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
|
159 |
+
caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
|
160 |
+
caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
|
161 |
+
caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
|
162 |
+
caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
|
163 |
+
caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
|
164 |
+
#######################################################
|
165 |
+
|
166 |
+
# все виды тире / all types of dash --> "-"
|
167 |
+
caption = re.sub(
|
168 |
+
r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa
|
169 |
+
'-', caption)
|
170 |
+
|
171 |
+
# кавычки к одному стандарту
|
172 |
+
caption = re.sub(r'[`´«»“”¨]', '"', caption)
|
173 |
+
caption = re.sub(r'[‘’]', "'", caption)
|
174 |
+
|
175 |
+
# "
|
176 |
+
caption = re.sub(r'"?', '', caption)
|
177 |
+
# &
|
178 |
+
caption = re.sub(r'&', '', caption)
|
179 |
+
|
180 |
+
# ip adresses:
|
181 |
+
caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)
|
182 |
+
|
183 |
+
# article ids:
|
184 |
+
caption = re.sub(r'\d:\d\d\s+$', '', caption)
|
185 |
+
|
186 |
+
# \n
|
187 |
+
caption = re.sub(r'\\n', ' ', caption)
|
188 |
+
|
189 |
+
# "#123"
|
190 |
+
caption = re.sub(r'#\d{1,3}\b', '', caption)
|
191 |
+
# "#12345.."
|
192 |
+
caption = re.sub(r'#\d{5,}\b', '', caption)
|
193 |
+
# "123456.."
|
194 |
+
caption = re.sub(r'\b\d{6,}\b', '', caption)
|
195 |
+
# filenames:
|
196 |
+
caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)
|
197 |
+
|
198 |
+
#
|
199 |
+
caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT"""
|
200 |
+
caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT"""
|
201 |
+
|
202 |
+
caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
203 |
+
caption = re.sub(r'\s+\.\s+', r' ', caption) # " . "
|
204 |
+
|
205 |
+
# this-is-my-cute-cat / this_is_my_cute_cat
|
206 |
+
regex2 = re.compile(r'(?:\-|\_)')
|
207 |
+
if len(re.findall(regex2, caption)) > 3:
|
208 |
+
caption = re.sub(regex2, ' ', caption)
|
209 |
+
|
210 |
+
caption = self.basic_clean(caption)
|
211 |
+
|
212 |
+
caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640
|
213 |
+
caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc
|
214 |
+
caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231
|
215 |
+
|
216 |
+
caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
|
217 |
+
caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
|
218 |
+
caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
|
219 |
+
caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
|
220 |
+
caption = re.sub(r'\bpage\s+\d+\b', '', caption)
|
221 |
+
|
222 |
+
caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a...
|
223 |
+
|
224 |
+
caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
|
225 |
+
|
226 |
+
caption = re.sub(r'\b\s+\:\s+', r': ', caption)
|
227 |
+
caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
|
228 |
+
caption = re.sub(r'\s+', ' ', caption)
|
229 |
+
|
230 |
+
caption.strip()
|
231 |
+
|
232 |
+
caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
|
233 |
+
caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
|
234 |
+
caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
|
235 |
+
caption = re.sub(r'^\.\S+$', '', caption)
|
236 |
+
|
237 |
+
return caption.strip()
|
libs/timm.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
code from timm 0.3.2
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import math
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
|
10 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
11 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
12 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
13 |
+
def norm_cdf(x):
|
14 |
+
# Computes standard normal cumulative distribution function
|
15 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
16 |
+
|
17 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
18 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
19 |
+
"The distribution of values may be incorrect.",
|
20 |
+
stacklevel=2)
|
21 |
+
|
22 |
+
with torch.no_grad():
|
23 |
+
# Values are generated by using a truncated uniform distribution and
|
24 |
+
# then using the inverse CDF for the normal distribution.
|
25 |
+
# Get upper and lower cdf values
|
26 |
+
l = norm_cdf((a - mean) / std)
|
27 |
+
u = norm_cdf((b - mean) / std)
|
28 |
+
|
29 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
30 |
+
# [2l-1, 2u-1].
|
31 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
32 |
+
|
33 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
34 |
+
# standard normal
|
35 |
+
tensor.erfinv_()
|
36 |
+
|
37 |
+
# Transform to proper mean, std
|
38 |
+
tensor.mul_(std * math.sqrt(2.))
|
39 |
+
tensor.add_(mean)
|
40 |
+
|
41 |
+
# Clamp to ensure it's in the proper range
|
42 |
+
tensor.clamp_(min=a, max=b)
|
43 |
+
return tensor
|
44 |
+
|
45 |
+
|
46 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
47 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
48 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
49 |
+
normal distribution. The values are effectively drawn from the
|
50 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
51 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
52 |
+
the bounds. The method used for generating the random values works
|
53 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
54 |
+
Args:
|
55 |
+
tensor: an n-dimensional `torch.Tensor`
|
56 |
+
mean: the mean of the normal distribution
|
57 |
+
std: the standard deviation of the normal distribution
|
58 |
+
a: the minimum cutoff value
|
59 |
+
b: the maximum cutoff value
|
60 |
+
Examples:
|
61 |
+
>>> w = torch.empty(3, 5)
|
62 |
+
>>> nn.init.trunc_normal_(w)
|
63 |
+
"""
|
64 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
65 |
+
|
66 |
+
|
67 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
68 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
69 |
+
|
70 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
71 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
72 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
73 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
74 |
+
'survival rate' as the argument.
|
75 |
+
|
76 |
+
"""
|
77 |
+
if drop_prob == 0. or not training:
|
78 |
+
return x
|
79 |
+
keep_prob = 1 - drop_prob
|
80 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
81 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
82 |
+
random_tensor.floor_() # binarize
|
83 |
+
output = x.div(keep_prob) * random_tensor
|
84 |
+
return output
|
85 |
+
|
86 |
+
|
87 |
+
class DropPath(nn.Module):
|
88 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
89 |
+
"""
|
90 |
+
def __init__(self, drop_prob=None):
|
91 |
+
super(DropPath, self).__init__()
|
92 |
+
self.drop_prob = drop_prob
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
return drop_path(x, self.drop_prob, self.training)
|
96 |
+
|
97 |
+
|
98 |
+
class Mlp(nn.Module):
|
99 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
100 |
+
super().__init__()
|
101 |
+
out_features = out_features or in_features
|
102 |
+
hidden_features = hidden_features or in_features
|
103 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
104 |
+
self.act = act_layer()
|
105 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
106 |
+
self.drop = nn.Dropout(drop)
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
x = self.fc1(x)
|
110 |
+
x = self.act(x)
|
111 |
+
x = self.drop(x)
|
112 |
+
x = self.fc2(x)
|
113 |
+
x = self.drop(x)
|
114 |
+
return x
|
requirements.txt
CHANGED
@@ -1,6 +1,21 @@
|
|
1 |
-
accelerate
|
2 |
diffusers
|
3 |
-
invisible_watermark
|
4 |
torch
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
diffusers
|
|
|
2 |
torch
|
3 |
+
xformers
|
4 |
+
openai-clip
|
5 |
+
scikit-learn
|
6 |
+
opencv-python
|
7 |
+
torchdiffeq
|
8 |
+
beautifulsoup4
|
9 |
+
open_clip_torch
|
10 |
+
scikit-image
|
11 |
+
cython
|
12 |
+
matplotlib
|
13 |
+
accelerate==0.12.0
|
14 |
+
absl-py
|
15 |
+
ml_collections
|
16 |
+
einops
|
17 |
+
wandb
|
18 |
+
ftfy==6.1.1
|
19 |
+
transformers==4.23.1
|
20 |
+
timm
|
21 |
+
tensorboard
|
scripts/extract_empty_feature.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file is used to extract feature of the empty prompt.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
from libs.clip import FrozenCLIPEmbedder
|
13 |
+
from libs.t5 import T5Embedder
|
14 |
+
|
15 |
+
|
16 |
+
def main():
|
17 |
+
prompts = [
|
18 |
+
'',
|
19 |
+
]
|
20 |
+
|
21 |
+
device = 'cuda'
|
22 |
+
llm = 'clip'
|
23 |
+
|
24 |
+
if llm=='clip':
|
25 |
+
clip = FrozenCLIPEmbedder()
|
26 |
+
clip.eval()
|
27 |
+
clip.to(device)
|
28 |
+
elif llm=='t5':
|
29 |
+
t5 = T5Embedder(device=device)
|
30 |
+
else:
|
31 |
+
raise NotImplementedError
|
32 |
+
|
33 |
+
save_dir = f'./'
|
34 |
+
|
35 |
+
if llm=='clip':
|
36 |
+
latent, latent_and_others = clip.encode(prompts)
|
37 |
+
token_embedding = latent_and_others['token_embedding']
|
38 |
+
token_mask = latent_and_others['token_mask']
|
39 |
+
token = latent_and_others['tokens']
|
40 |
+
elif llm=='t5':
|
41 |
+
latent, latent_and_others = t5.get_text_embeddings(prompts)
|
42 |
+
token_embedding = latent_and_others['token_embedding'].to(torch.float32) * 10.0
|
43 |
+
token_mask = latent_and_others['token_mask']
|
44 |
+
token = latent_and_others['tokens']
|
45 |
+
|
46 |
+
for i in range(len(prompts)):
|
47 |
+
data = {'token_embedding': token_embedding[i].detach().cpu().numpy(),
|
48 |
+
'token_mask': token_mask[i].detach().cpu().numpy(),
|
49 |
+
'token': token[i].detach().cpu().numpy(),
|
50 |
+
'batch_caption': prompts[i]}
|
51 |
+
np.save(os.path.join(save_dir, f'empty_context.npy'), data)
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
main()
|
scripts/extract_mscoco_feature.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file is used to extract feature of the coco val set (to test zero-shot FID).
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
from datasets import MSCOCODatabase
|
13 |
+
import argparse
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
import libs.autoencoder
|
17 |
+
from libs.clip import FrozenCLIPEmbedder
|
18 |
+
from libs.t5 import T5Embedder
|
19 |
+
|
20 |
+
|
21 |
+
def main(resolution=256):
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument('--split', default='val')
|
24 |
+
args = parser.parse_args()
|
25 |
+
print(args)
|
26 |
+
|
27 |
+
if args.split == "val":
|
28 |
+
datas = MSCOCODatabase(root='/data/qihao/dataset/coco2014/val2014',
|
29 |
+
annFile='/data/qihao/dataset/coco2014/annotations/captions_val2014.json',
|
30 |
+
size=resolution)
|
31 |
+
save_dir = f'val'
|
32 |
+
else:
|
33 |
+
raise NotImplementedError
|
34 |
+
|
35 |
+
device = "cuda"
|
36 |
+
os.makedirs(save_dir, exist_ok=True)
|
37 |
+
|
38 |
+
autoencoder = libs.autoencoder.get_model('../assets/stable-diffusion/autoencoder_kl.pth')
|
39 |
+
autoencoder.to(device)
|
40 |
+
|
41 |
+
llm = 'clip'
|
42 |
+
|
43 |
+
if llm=='clip':
|
44 |
+
clip = FrozenCLIPEmbedder()
|
45 |
+
clip.eval()
|
46 |
+
clip.to(device)
|
47 |
+
elif llm=='t5':
|
48 |
+
t5 = T5Embedder(device=device)
|
49 |
+
else:
|
50 |
+
raise NotImplementedError
|
51 |
+
|
52 |
+
with torch.no_grad():
|
53 |
+
for idx, data in tqdm(enumerate(datas)):
|
54 |
+
x, captions = data
|
55 |
+
|
56 |
+
if len(x.shape) == 3:
|
57 |
+
x = x[None, ...]
|
58 |
+
x = torch.tensor(x, device=device)
|
59 |
+
moments = autoencoder(x, fn='encode_moments').squeeze(0)
|
60 |
+
moments = moments.detach().cpu().numpy()
|
61 |
+
np.save(os.path.join(save_dir, f'{idx}.npy'), moments)
|
62 |
+
|
63 |
+
if llm=='clip':
|
64 |
+
latent, latent_and_others = clip.encode(captions)
|
65 |
+
token_embedding = latent_and_others['token_embedding']
|
66 |
+
token_mask = latent_and_others['token_mask']
|
67 |
+
token = latent_and_others['tokens']
|
68 |
+
elif llm=='t5':
|
69 |
+
latent, latent_and_others = t5.get_text_embeddings(captions)
|
70 |
+
token_embedding = latent_and_others['token_embedding'].to(torch.float32) * 10.0
|
71 |
+
token_mask = latent_and_others['token_mask']
|
72 |
+
token = latent_and_others['tokens']
|
73 |
+
|
74 |
+
for i in range(len(captions)):
|
75 |
+
data = {'promt': captions[i],
|
76 |
+
'token_embedding': token_embedding[i].detach().cpu().numpy(),
|
77 |
+
'token_mask': token_mask[i].detach().cpu().numpy(),
|
78 |
+
'token': token[i].detach().cpu().numpy()}
|
79 |
+
np.save(os.path.join(save_dir, f'{idx}_{i}.npy'), data)
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == '__main__':
|
83 |
+
main()
|
scripts/extract_test_prompt_feature.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file is used to extract feature for visulization during training
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
import libs.autoencoder
|
15 |
+
from libs.clip import FrozenCLIPEmbedder
|
16 |
+
from libs.t5 import T5Embedder
|
17 |
+
|
18 |
+
|
19 |
+
def main():
|
20 |
+
prompts = [
|
21 |
+
'A road with traffic lights, street lights and cars.',
|
22 |
+
'A bus driving in a city area with traffic signs.',
|
23 |
+
'A bus pulls over to the curb close to an intersection.',
|
24 |
+
'A group of people are walking and one is holding an umbrella.',
|
25 |
+
'A baseball player taking a swing at an incoming ball.',
|
26 |
+
'A dog next to a white cat with black-tipped ears.',
|
27 |
+
'A tiger standing on a rooftop while singing and jamming on an electric guitar under a spotlight. anime illustration.',
|
28 |
+
'A bird wearing headphones and speaking into a high-end microphone in a recording studio.',
|
29 |
+
'A bus made of cardboard.',
|
30 |
+
'A tower in the mountains.',
|
31 |
+
'Two cups of coffee, one with latte art of a cat. The other has latter art of a bird.',
|
32 |
+
'Oil painting of a robot made of sushi, holding chopsticks.',
|
33 |
+
'Portrait of a dog wearing a hat and holding a flag that has a yin-yang symbol on it.',
|
34 |
+
'A teddy bear wearing a motorcycle helmet and cape is standing in front of Loch Awe with Kilchurn Castle behind him. dslr photo.',
|
35 |
+
'A man standing on the moon',
|
36 |
+
]
|
37 |
+
save_dir = f'run_vis'
|
38 |
+
os.makedirs(save_dir, exist_ok=True)
|
39 |
+
|
40 |
+
device = 'cuda'
|
41 |
+
llm = 'clip'
|
42 |
+
|
43 |
+
if llm=='clip':
|
44 |
+
clip = FrozenCLIPEmbedder()
|
45 |
+
clip.eval()
|
46 |
+
clip.to(device)
|
47 |
+
elif llm=='t5':
|
48 |
+
t5 = T5Embedder(device=device)
|
49 |
+
else:
|
50 |
+
raise NotImplementedError
|
51 |
+
|
52 |
+
if llm=='clip':
|
53 |
+
latent, latent_and_others = clip.encode(prompts)
|
54 |
+
token_embedding = latent_and_others['token_embedding']
|
55 |
+
token_mask = latent_and_others['token_mask']
|
56 |
+
token = latent_and_others['tokens']
|
57 |
+
elif llm=='t5':
|
58 |
+
latent, latent_and_others = t5.get_text_embeddings(prompts)
|
59 |
+
token_embedding = latent_and_others['token_embedding'].to(torch.float32) * 10.0
|
60 |
+
token_mask = latent_and_others['token_mask']
|
61 |
+
token = latent_and_others['tokens']
|
62 |
+
|
63 |
+
for i in range(len(prompts)):
|
64 |
+
data = {'promt': prompts[i],
|
65 |
+
'token_embedding': token_embedding[i].detach().cpu().numpy(),
|
66 |
+
'token_mask': token_mask[i].detach().cpu().numpy(),
|
67 |
+
'token': token[i].detach().cpu().numpy()}
|
68 |
+
np.save(os.path.join(save_dir, f'{i}.npy'), data)
|
69 |
+
|
70 |
+
|
71 |
+
if __name__ == '__main__':
|
72 |
+
main()
|
scripts/extract_train_feature.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file is used to extract feature of the demo training data.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
import sys
|
8 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import os
|
13 |
+
import numpy as np
|
14 |
+
from tqdm import tqdm
|
15 |
+
from PIL import Image
|
16 |
+
import io
|
17 |
+
import einops
|
18 |
+
import random
|
19 |
+
import json
|
20 |
+
import libs.autoencoder
|
21 |
+
from libs.clip import FrozenCLIPEmbedder
|
22 |
+
from libs.t5 import T5Embedder
|
23 |
+
|
24 |
+
|
25 |
+
def recreate_folder(folder_path):
|
26 |
+
if os.path.exists(folder_path):
|
27 |
+
shutil.rmtree(folder_path)
|
28 |
+
os.makedirs(folder_path)
|
29 |
+
|
30 |
+
def center_crop_arr(pil_image, image_size):
|
31 |
+
while min(*pil_image.size) >= 2 * image_size:
|
32 |
+
pil_image = pil_image.resize(
|
33 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
34 |
+
)
|
35 |
+
|
36 |
+
scale = image_size / min(*pil_image.size)
|
37 |
+
pil_image = pil_image.resize(
|
38 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
39 |
+
)
|
40 |
+
|
41 |
+
arr = np.array(pil_image)
|
42 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
43 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
44 |
+
return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
|
45 |
+
|
46 |
+
|
47 |
+
def main(bz = 16):
|
48 |
+
|
49 |
+
json_path = '/path/to/JourneyDB_demo/img_text_pair.jsonl'
|
50 |
+
root_path = '/path/to/JourneyDB_demo/imgs'
|
51 |
+
|
52 |
+
dicts_list = []
|
53 |
+
with open(json_path, 'r', encoding='utf-8') as file:
|
54 |
+
for line in file:
|
55 |
+
dicts_list.append(json.loads(line))
|
56 |
+
|
57 |
+
save_dir = f'feature'
|
58 |
+
device = "cuda"
|
59 |
+
recreate_folder(save_dir)
|
60 |
+
|
61 |
+
autoencoder = libs.autoencoder.get_model('../assets/stable-diffusion/autoencoder_kl.pth')
|
62 |
+
autoencoder.to(device)
|
63 |
+
|
64 |
+
# CLIP model:
|
65 |
+
clip = FrozenCLIPEmbedder()
|
66 |
+
clip.eval()
|
67 |
+
clip.to(device)
|
68 |
+
|
69 |
+
# T5 model:
|
70 |
+
t5 = T5Embedder(device=device)
|
71 |
+
|
72 |
+
idx = 0
|
73 |
+
batch_img_256 = []
|
74 |
+
batch_img_512 = []
|
75 |
+
batch_caption = []
|
76 |
+
batch_name = []
|
77 |
+
for i, sample in enumerate(tqdm(dicts_list)):
|
78 |
+
try:
|
79 |
+
pil_image = Image.open(os.path.join(root_path,sample['img_path']))
|
80 |
+
caption = sample['prompt']
|
81 |
+
img_name = sample['img_path'].replace('.jpg','')
|
82 |
+
|
83 |
+
pil_image.load()
|
84 |
+
pil_image = pil_image.convert("RGB")
|
85 |
+
except:
|
86 |
+
with open("failed_file.txt", 'a+') as file:
|
87 |
+
file.write(sample['img_path'] + "\n")
|
88 |
+
continue
|
89 |
+
|
90 |
+
image_256 = center_crop_arr(pil_image, image_size=256)
|
91 |
+
image_512 = center_crop_arr(pil_image, image_size=512)
|
92 |
+
|
93 |
+
# if True:
|
94 |
+
# image_id = random.randint(0,20)
|
95 |
+
# Image.fromarray(image_256.astype(np.uint8)).save(f"temp_img_{image_id}_256.jpg")
|
96 |
+
# Image.fromarray(image_512.astype(np.uint8)).save(f"temp_img_{image_id}_512.jpg")
|
97 |
+
|
98 |
+
image_256 = (image_256 / 127.5 - 1.0).astype(np.float32)
|
99 |
+
image_256 = einops.rearrange(image_256, 'h w c -> c h w')
|
100 |
+
batch_img_256.append(image_256)
|
101 |
+
|
102 |
+
image_512 = (image_512 / 127.5 - 1.0).astype(np.float32)
|
103 |
+
image_512 = einops.rearrange(image_512, 'h w c -> c h w')
|
104 |
+
batch_img_512.append(image_512)
|
105 |
+
|
106 |
+
batch_caption.append(caption)
|
107 |
+
batch_name.append(img_name)
|
108 |
+
|
109 |
+
if len(batch_name) == bz or i == len(dicts_list) - 1:
|
110 |
+
batch_img_256 = torch.tensor(np.stack(batch_img_256)).to(device)
|
111 |
+
moments_256 = autoencoder(batch_img_256, fn='encode_moments').squeeze(0)
|
112 |
+
moments_256 = moments_256.detach().cpu().numpy()
|
113 |
+
|
114 |
+
batch_img_512 = torch.tensor(np.stack(batch_img_512)).to(device)
|
115 |
+
moments_512 = autoencoder(batch_img_512, fn='encode_moments').squeeze(0)
|
116 |
+
moments_512 = moments_512.detach().cpu().numpy()
|
117 |
+
|
118 |
+
_latent_clip, latent_and_others_clip = clip.encode(batch_caption)
|
119 |
+
token_embedding_clip = latent_and_others_clip['token_embedding'].detach().cpu().numpy()
|
120 |
+
token_mask_clip = latent_and_others_clip['token_mask'].detach().cpu().numpy()
|
121 |
+
token_clip = latent_and_others_clip['tokens'].detach().cpu().numpy()
|
122 |
+
|
123 |
+
_latent_t5, latent_and_others_t5 = t5.get_text_embeddings(batch_caption)
|
124 |
+
token_embedding_t5 = (latent_and_others_t5['token_embedding'].to(torch.float32) * 10.0).detach().cpu().numpy()
|
125 |
+
token_mask_t5 = latent_and_others_t5['token_mask'].detach().cpu().numpy()
|
126 |
+
token_t5 = latent_and_others_t5['tokens'].detach().cpu().numpy()
|
127 |
+
|
128 |
+
for mt_256, mt_512, te_c, te_t, tm_c, tm_t, tk_c, tk_t, bc, bn in zip(moments_256, moments_512, token_embedding_clip, token_embedding_t5, token_mask_clip, token_mask_t5, token_clip, token_t5, batch_caption, batch_name):
|
129 |
+
assert mt_256.shape == (8,32,32)
|
130 |
+
assert mt_512.shape == (8,64,64)
|
131 |
+
assert te_c.shape == (77, 768)
|
132 |
+
assert te_t.shape == (77, 4096)
|
133 |
+
tar_path_name = os.path.join(save_dir, f'{bn}.npy')
|
134 |
+
if os.path.exists(tar_path_name):
|
135 |
+
os.remove(tar_path_name)
|
136 |
+
data = {'image_latent_256': mt_256,
|
137 |
+
'image_latent_512': mt_512,
|
138 |
+
'token_embedding_clip': te_c,
|
139 |
+
'token_embedding_t5': te_t,
|
140 |
+
'token_mask_clip': tm_c,
|
141 |
+
'token_mask_t5': tm_t,
|
142 |
+
'token_clip': tk_c,
|
143 |
+
'token_t5': tk_t,
|
144 |
+
'batch_caption': bc}
|
145 |
+
try:
|
146 |
+
np.save(tar_path_name, data)
|
147 |
+
idx += 1
|
148 |
+
except:
|
149 |
+
pass
|
150 |
+
|
151 |
+
batch_img_256 = []
|
152 |
+
batch_img_512 = []
|
153 |
+
batch_caption = []
|
154 |
+
batch_name = []
|
155 |
+
|
156 |
+
print(f'save {idx} files')
|
157 |
+
|
158 |
+
if __name__ == '__main__':
|
159 |
+
main()
|
sde.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from absl import logging
|
4 |
+
import numpy as np
|
5 |
+
import math
|
6 |
+
from tqdm import tqdm
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
def check_zip(*args):
|
11 |
+
args = [list(arg) for arg in args]
|
12 |
+
length = len(args[0])
|
13 |
+
for arg in args:
|
14 |
+
assert len(arg) == length
|
15 |
+
return zip(*args)
|
16 |
+
|
17 |
+
def get_sde(name, **kwargs):
|
18 |
+
if name == 'vpsde':
|
19 |
+
return VPSDE(**kwargs)
|
20 |
+
elif name == 'vpsde_cosine':
|
21 |
+
return VPSDECosine(**kwargs)
|
22 |
+
else:
|
23 |
+
raise NotImplementedError
|
24 |
+
|
25 |
+
|
26 |
+
def stp(s, ts: torch.Tensor): # scalar tensor product
|
27 |
+
if isinstance(s, np.ndarray):
|
28 |
+
s = torch.from_numpy(s).type_as(ts)
|
29 |
+
extra_dims = (1,) * (ts.dim() - 1)
|
30 |
+
return s.view(-1, *extra_dims) * ts
|
31 |
+
|
32 |
+
|
33 |
+
def mos(a, start_dim=1): # mean of square
|
34 |
+
return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
|
35 |
+
|
36 |
+
|
37 |
+
def duplicate(tensor, *size):
|
38 |
+
return tensor.unsqueeze(dim=0).expand(*size, *tensor.shape)
|
39 |
+
|
40 |
+
|
41 |
+
class SDE(object):
|
42 |
+
r"""
|
43 |
+
dx = f(x, t)dt + g(t) dw with 0 <= t <= 1
|
44 |
+
f(x, t) is the drift
|
45 |
+
g(t) is the diffusion
|
46 |
+
"""
|
47 |
+
def drift(self, x, t):
|
48 |
+
raise NotImplementedError
|
49 |
+
|
50 |
+
def diffusion(self, t):
|
51 |
+
raise NotImplementedError
|
52 |
+
|
53 |
+
def cum_beta(self, t): # the variance of xt|x0
|
54 |
+
raise NotImplementedError
|
55 |
+
|
56 |
+
def cum_alpha(self, t):
|
57 |
+
raise NotImplementedError
|
58 |
+
|
59 |
+
def snr(self, t): # signal noise ratio
|
60 |
+
raise NotImplementedError
|
61 |
+
|
62 |
+
def nsr(self, t): # noise signal ratio
|
63 |
+
raise NotImplementedError
|
64 |
+
|
65 |
+
def marginal_prob(self, x0, t): # the mean and std of q(xt|x0)
|
66 |
+
alpha = self.cum_alpha(t)
|
67 |
+
beta = self.cum_beta(t)
|
68 |
+
mean = stp(alpha ** 0.5, x0) # E[xt|x0]
|
69 |
+
std = beta ** 0.5 # Cov[xt|x0] ** 0.5
|
70 |
+
return mean, std
|
71 |
+
|
72 |
+
def sample(self, x0, t_init=0): # sample from q(xn|x0), where n is uniform
|
73 |
+
t = torch.rand(x0.shape[0], device=x0.device) * (1. - t_init) + t_init
|
74 |
+
mean, std = self.marginal_prob(x0, t)
|
75 |
+
eps = torch.randn_like(x0)
|
76 |
+
xt = mean + stp(std, eps)
|
77 |
+
return t, eps, xt
|
78 |
+
|
79 |
+
|
80 |
+
class VPSDE(SDE):
|
81 |
+
def __init__(self, beta_min=0.1, beta_max=20):
|
82 |
+
# 0 <= t <= 1
|
83 |
+
self.beta_0 = beta_min
|
84 |
+
self.beta_1 = beta_max
|
85 |
+
|
86 |
+
def drift(self, x, t):
|
87 |
+
return -0.5 * stp(self.squared_diffusion(t), x)
|
88 |
+
|
89 |
+
def diffusion(self, t):
|
90 |
+
return self.squared_diffusion(t) ** 0.5
|
91 |
+
|
92 |
+
def squared_diffusion(self, t): # beta(t)
|
93 |
+
return self.beta_0 + t * (self.beta_1 - self.beta_0)
|
94 |
+
|
95 |
+
def squared_diffusion_integral(self, s, t): # \int_s^t beta(tau) d tau
|
96 |
+
return self.beta_0 * (t - s) + (self.beta_1 - self.beta_0) * (t ** 2 - s ** 2) * 0.5
|
97 |
+
|
98 |
+
def skip_beta(self, s, t): # beta_{t|s}, Cov[xt|xs]=beta_{t|s} I
|
99 |
+
return 1. - self.skip_alpha(s, t)
|
100 |
+
|
101 |
+
def skip_alpha(self, s, t): # alpha_{t|s}, E[xt|xs]=alpha_{t|s}**0.5 xs
|
102 |
+
x = -self.squared_diffusion_integral(s, t)
|
103 |
+
return x.exp()
|
104 |
+
|
105 |
+
def cum_beta(self, t):
|
106 |
+
return self.skip_beta(0, t)
|
107 |
+
|
108 |
+
def cum_alpha(self, t):
|
109 |
+
return self.skip_alpha(0, t)
|
110 |
+
|
111 |
+
def nsr(self, t):
|
112 |
+
nsr = self.squared_diffusion_integral(0, t).expm1()
|
113 |
+
nsr = nsr.clamp(max = 1e6, min = 1e-12)
|
114 |
+
return nsr
|
115 |
+
|
116 |
+
def snr(self, t):
|
117 |
+
snr = 1. / self.nsr(t)
|
118 |
+
snr = snr.clamp(max = 1e6, min = 1e-12)
|
119 |
+
return snr
|
120 |
+
|
121 |
+
def __str__(self):
|
122 |
+
return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}'
|
123 |
+
|
124 |
+
def __repr__(self):
|
125 |
+
return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}'
|
126 |
+
|
127 |
+
|
128 |
+
class VPSDECosine(SDE):
|
129 |
+
r"""
|
130 |
+
dx = f(x, t)dt + g(t) dw with 0 <= t <= 1
|
131 |
+
f(x, t) is the drift
|
132 |
+
g(t) is the diffusion
|
133 |
+
"""
|
134 |
+
def __init__(self, s=0.008):
|
135 |
+
self.s = s
|
136 |
+
self.F = lambda t: torch.cos((t + s) / (1 + s) * math.pi / 2) ** 2
|
137 |
+
self.F0 = math.cos(s / (1 + s) * math.pi / 2) ** 2
|
138 |
+
|
139 |
+
def drift(self, x, t):
|
140 |
+
ft = - torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi / 2
|
141 |
+
return stp(ft, x)
|
142 |
+
|
143 |
+
def diffusion(self, t):
|
144 |
+
return (torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi) ** 0.5
|
145 |
+
|
146 |
+
def cum_beta(self, t): # the variance of xt|x0
|
147 |
+
return 1 - self.cum_alpha(t)
|
148 |
+
|
149 |
+
def cum_alpha(self, t):
|
150 |
+
return self.F(t) / self.F0
|
151 |
+
|
152 |
+
def snr(self, t): # signal noise ratio
|
153 |
+
Ft = self.F(t)
|
154 |
+
snr = Ft / (self.F0 - Ft)
|
155 |
+
snr = snr.clamp(max = 1e6, min = 1e-12)
|
156 |
+
return snr
|
157 |
+
|
158 |
+
def nsr(self, t): # noise signal ratio
|
159 |
+
Ft = self.F(t)
|
160 |
+
nsr = self.F0 / Ft - 1
|
161 |
+
nsr = nsr.clamp(max = 1e6, min = 1e-12)
|
162 |
+
return nsr
|
163 |
+
|
164 |
+
def __str__(self):
|
165 |
+
return 'vpsde_cosine'
|
166 |
+
|
167 |
+
def __repr__(self):
|
168 |
+
return 'vpsde_cosine'
|
169 |
+
|
170 |
+
|
171 |
+
class ScoreModel(object):
|
172 |
+
r"""
|
173 |
+
The forward process is q(x_[0,T])
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(self, nnet: nn.Module, loss_coeffs:list, sde: SDE, using_cfg: bool = False, T=1):
|
177 |
+
assert T == 1
|
178 |
+
self.nnet = nnet
|
179 |
+
self.loss_coeffs = loss_coeffs
|
180 |
+
self.sde = sde
|
181 |
+
self.T = T
|
182 |
+
self.using_cfg = using_cfg
|
183 |
+
print(f'ScoreModel with loss_coeffs={loss_coeffs}, sde={sde}, T={T}')
|
184 |
+
|
185 |
+
def predict(self, xt, t, **kwargs):
|
186 |
+
if not isinstance(t, torch.Tensor):
|
187 |
+
t = torch.tensor(t)
|
188 |
+
t = t.to(xt.device)
|
189 |
+
if t.dim() == 0:
|
190 |
+
t = duplicate(t, xt.size(0))
|
191 |
+
log_snr = self.sde.snr(t).log()
|
192 |
+
|
193 |
+
return self.nnet(xt, t = t * 999, log_snr = log_snr, **kwargs) # follow SDE
|
194 |
+
# return self.nnet(xt, t = t, log_snr = log_snr, **kwargs) # follow SDE
|
195 |
+
|
196 |
+
def noise_pred(self, xt, t, sampling = True, **kwargs):
|
197 |
+
if sampling:
|
198 |
+
if self.using_cfg:
|
199 |
+
return self.predict(xt, t, **kwargs)
|
200 |
+
else:
|
201 |
+
return self.predict(xt, t, **kwargs)[-1]
|
202 |
+
else:
|
203 |
+
return self.predict(xt, t, **kwargs)
|
204 |
+
|
205 |
+
def score(self, xt, t, **kwargs):
|
206 |
+
cum_beta = self.sde.cum_beta(t)
|
207 |
+
noise_pred = self.noise_pred(xt, t, sampling = True, **kwargs)
|
208 |
+
return stp(-cum_beta.rsqrt(), noise_pred)
|
209 |
+
|
210 |
+
|
211 |
+
class ReverseSDE(object):
|
212 |
+
r"""
|
213 |
+
dx = [f(x, t) - g(t)^2 s(x, t)] dt + g(t) dw
|
214 |
+
"""
|
215 |
+
def __init__(self, score_model):
|
216 |
+
self.sde = score_model.sde # the forward sde
|
217 |
+
self.score_model = score_model
|
218 |
+
|
219 |
+
def drift(self, x, t, **kwargs):
|
220 |
+
drift = self.sde.drift(x, t) # f(x, t)
|
221 |
+
diffusion = self.sde.diffusion(t) # g(t)
|
222 |
+
score = self.score_model.score(x, t, **kwargs)
|
223 |
+
return drift - stp(diffusion ** 2, score)
|
224 |
+
|
225 |
+
def diffusion(self, t):
|
226 |
+
return self.sde.diffusion(t)
|
227 |
+
|
228 |
+
|
229 |
+
class ODE(object):
|
230 |
+
r"""
|
231 |
+
dx = [f(x, t) - g(t)^2 s(x, t)] dt
|
232 |
+
"""
|
233 |
+
|
234 |
+
def __init__(self, score_model):
|
235 |
+
self.sde = score_model.sde # the forward sde
|
236 |
+
self.score_model = score_model
|
237 |
+
|
238 |
+
def drift(self, x, t, **kwargs):
|
239 |
+
drift = self.sde.drift(x, t) # f(x, t)
|
240 |
+
diffusion = self.sde.diffusion(t) # g(t)
|
241 |
+
score = self.score_model.score(x, t, **kwargs)
|
242 |
+
return drift - 0.5 * stp(diffusion ** 2, score)
|
243 |
+
|
244 |
+
def diffusion(self, t):
|
245 |
+
return 0
|
246 |
+
|
247 |
+
|
248 |
+
def dct2str(dct):
|
249 |
+
return str({k: f'{v:.6g}' for k, v in dct.items()})
|
250 |
+
|
251 |
+
|
252 |
+
@ torch.no_grad()
|
253 |
+
def euler_maruyama(rsde, x_init, sample_steps, eps=1e-3, T=1, trace=None, verbose=False, **kwargs):
|
254 |
+
r"""
|
255 |
+
The Euler Maruyama sampler for reverse SDE / ODE
|
256 |
+
See `Score-Based Generative Modeling through Stochastic Differential Equations`
|
257 |
+
"""
|
258 |
+
assert isinstance(rsde, ReverseSDE) or isinstance(rsde, ODE)
|
259 |
+
print(f"euler_maruyama with sample_steps={sample_steps}")
|
260 |
+
timesteps = np.append(0., np.linspace(eps, T, sample_steps))
|
261 |
+
timesteps = torch.tensor(timesteps).to(x_init)
|
262 |
+
x = x_init
|
263 |
+
if trace is not None:
|
264 |
+
trace.append(x)
|
265 |
+
for s, t in tqdm(list(zip(timesteps, timesteps[1:]))[::-1], disable=not verbose, desc='euler_maruyama'):
|
266 |
+
drift = rsde.drift(x, t, **kwargs)
|
267 |
+
diffusion = rsde.diffusion(t)
|
268 |
+
dt = s - t
|
269 |
+
mean = x + drift * dt
|
270 |
+
sigma = diffusion * (-dt).sqrt()
|
271 |
+
x = mean + stp(sigma, torch.randn_like(x)) if s != 0 else mean
|
272 |
+
if trace is not None:
|
273 |
+
trace.append(x)
|
274 |
+
statistics = dict(s=s, t=t, sigma=sigma.item())
|
275 |
+
logging.debug(dct2str(statistics))
|
276 |
+
return x
|
277 |
+
|
278 |
+
|
279 |
+
def LSimple(score_model: ScoreModel, x0, **kwargs):
|
280 |
+
t, noise, xt = score_model.sde.sample(x0)
|
281 |
+
prediction = score_model.noise_pred(xt, t, sampling = False, **kwargs)
|
282 |
+
target = multi_scale_targets(noise, levels = len(prediction), scale_correction = True)
|
283 |
+
loss = 0
|
284 |
+
for pred, coeff in check_zip(prediction, score_model.loss_coeffs):
|
285 |
+
loss = loss + coeff * mos(pred - target[pred.shape[-1]])
|
286 |
+
return loss
|
287 |
+
|
288 |
+
|
289 |
+
def odd_multi_scale_targets(target, levels, scale_correction):
|
290 |
+
B, C, H, W = target.shape
|
291 |
+
targets = {}
|
292 |
+
for l in range(levels):
|
293 |
+
ratio = int(2 ** l)
|
294 |
+
if ratio == 1:
|
295 |
+
targets[target.shape[-1]] = target
|
296 |
+
continue
|
297 |
+
assert (H - 1) % ratio == 0 and (W - 1) % ratio == 0
|
298 |
+
KS = ratio + 1
|
299 |
+
scale = KS if scale_correction else KS ** 2
|
300 |
+
kernel = torch.ones(C, 1, KS, KS, device = target.device) / scale
|
301 |
+
downsampled = F.conv2d(target, kernel, stride = ratio, padding = KS // 2, groups = C)
|
302 |
+
targets[downsampled.shape[-1]] = downsampled
|
303 |
+
return targets
|
304 |
+
|
305 |
+
def even_multi_scale_targets(target, levels, scale_correction):
|
306 |
+
B, C, H, W = target.shape
|
307 |
+
targets = {}
|
308 |
+
for l in range(levels):
|
309 |
+
ratio = int(2 ** l)
|
310 |
+
if ratio == 1:
|
311 |
+
targets[target.shape[-1]] = target
|
312 |
+
continue
|
313 |
+
assert H % ratio == 0 and W % ratio == 0
|
314 |
+
KS = ratio
|
315 |
+
scale = KS if scale_correction else KS ** 2
|
316 |
+
kernel = torch.ones(C, 1, KS, KS, device = target.device) / scale
|
317 |
+
downsampled = F.conv2d(target, kernel, stride = ratio, groups = C)
|
318 |
+
targets[downsampled.shape[-1]] = downsampled
|
319 |
+
return targets
|
320 |
+
|
321 |
+
def multi_scale_targets(target, levels, scale_correction):
|
322 |
+
B, C, H, W = target.shape
|
323 |
+
if H % 2 == 0:
|
324 |
+
return even_multi_scale_targets(target, levels, scale_correction)
|
325 |
+
else:
|
326 |
+
return odd_multi_scale_targets(target, levels, scale_correction)
|
tools/clip_score.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file computes the clip score given image and text pair
|
3 |
+
"""
|
4 |
+
import clip
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from sklearn.preprocessing import normalize
|
8 |
+
from torchvision.transforms import Compose, Normalize, Resize
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
class ClipSocre:
|
13 |
+
def __init__(self,device='cuda', prefix='A photo depicts', weight=1.0): # weight=2.5
|
14 |
+
self.device = device
|
15 |
+
|
16 |
+
self.model, _ = clip.load("ViT-B/32", device=device, jit=False)
|
17 |
+
self.model.eval()
|
18 |
+
|
19 |
+
self.transform = Compose([
|
20 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
21 |
+
])
|
22 |
+
|
23 |
+
self.prefix = prefix
|
24 |
+
if self.prefix[-1] != ' ':
|
25 |
+
self.prefix += ' '
|
26 |
+
|
27 |
+
self.w = weight
|
28 |
+
|
29 |
+
def extract_all_images(self, images):
|
30 |
+
images_input = self.transform(images)
|
31 |
+
if self.device == 'cuda':
|
32 |
+
images_input = images_input.to(torch.float16)
|
33 |
+
image_feature = self.model.encode_image(images_input)
|
34 |
+
return image_feature
|
35 |
+
|
36 |
+
def extract_all_texts(self, texts,need_prefix):
|
37 |
+
if need_prefix:
|
38 |
+
c_data = clip.tokenize(self.prefix + texts, truncate=True).to(self.device)
|
39 |
+
else:
|
40 |
+
c_data = clip.tokenize(texts, truncate=True).to(self.device)
|
41 |
+
text_feature = self.model.encode_text(c_data)
|
42 |
+
return text_feature
|
43 |
+
|
44 |
+
def get_clip_score(self, img, text, need_prefix=False):
|
45 |
+
|
46 |
+
img_f = self.extract_all_images(img)
|
47 |
+
text_f = self.extract_all_texts(text,need_prefix)
|
48 |
+
images = img_f / torch.sqrt(torch.sum(img_f**2, axis=1, keepdims=True))
|
49 |
+
candidates = text_f / torch.sqrt(torch.sum(text_f**2, axis=1, keepdims=True))
|
50 |
+
|
51 |
+
clip_per = self.w * torch.clip(torch.sum(images * candidates, axis=1), 0, None)
|
52 |
+
|
53 |
+
return clip_per
|
54 |
+
|
55 |
+
def get_text_clip_score(self, text_1, text_2, need_prefix=False):
|
56 |
+
text_1_f = self.extract_all_texts(text_1,need_prefix)
|
57 |
+
text_2_f = self.extract_all_texts(text_2,need_prefix)
|
58 |
+
|
59 |
+
candidates_1 = text_1_f / torch.sqrt(torch.sum(text_1_f**2, axis=1, keepdims=True))
|
60 |
+
candidates_2 = text_2_f / torch.sqrt(torch.sum(text_2_f**2, axis=1, keepdims=True))
|
61 |
+
|
62 |
+
per = self.w * torch.clip(torch.sum(candidates_1 * candidates_2, axis=1), 0, None)
|
63 |
+
|
64 |
+
|
65 |
+
results = 'ClipS : ' + str(format(per.item(),'.4f'))
|
66 |
+
|
67 |
+
print(results)
|
68 |
+
|
69 |
+
return per.sum()
|
70 |
+
|
71 |
+
def get_img_clip_score(self, img_1, img_2, weight = 1):
|
72 |
+
|
73 |
+
img_f_1 = self.extract_all_images(img_1)
|
74 |
+
img_f_2 = self.extract_all_images(img_2)
|
75 |
+
|
76 |
+
images_1 = img_f_1 / torch.sqrt(torch.sum(img_f_1**2, axis=1, keepdims=True))
|
77 |
+
images_2 = img_f_2 / torch.sqrt(torch.sum(img_f_2**2, axis=1, keepdims=True))
|
78 |
+
|
79 |
+
# per = self.w * torch.clip(torch.sum(images_1 * images_2, axis=1), 0, None)
|
80 |
+
per = weight * torch.clip(torch.sum(images_1 * images_2, axis=1), 0, None)
|
81 |
+
|
82 |
+
|
83 |
+
return per.sum()
|
84 |
+
|
85 |
+
|
86 |
+
def calculate_clip_score(self, caption_list, image_unprocessed):
|
87 |
+
image_unprocessed = 0.5 * (image_unprocessed + 1.)
|
88 |
+
image_unprocessed.clamp_(0., 1.)
|
89 |
+
img_resize = Resize((224))(image_unprocessed)
|
90 |
+
return self.get_clip_score(img_resize,caption_list)
|
tools/fid_score.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
|
2 |
+
|
3 |
+
The FID metric calculates the distance between two distributions of images.
|
4 |
+
Typically, we have summary statistics (mean & covariance matrix) of one
|
5 |
+
of these distributions, while the 2nd distribution is given by a GAN.
|
6 |
+
|
7 |
+
When run as a stand-alone program, it compares the distribution of
|
8 |
+
images that are stored as PNG/JPEG at a specified location with a
|
9 |
+
distribution given by summary statistics (in pickle format).
|
10 |
+
|
11 |
+
The FID is calculated by assuming that X_1 and X_2 are the activations of
|
12 |
+
the pool_3 layer of the inception net for generated samples and real world
|
13 |
+
samples respectively.
|
14 |
+
|
15 |
+
See --help to see further details.
|
16 |
+
|
17 |
+
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
|
18 |
+
of Tensorflow
|
19 |
+
|
20 |
+
Copyright 2018 Institute of Bioinformatics, JKU Linz
|
21 |
+
|
22 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
23 |
+
you may not use this file except in compliance with the License.
|
24 |
+
You may obtain a copy of the License at
|
25 |
+
|
26 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
27 |
+
|
28 |
+
Unless required by applicable law or agreed to in writing, software
|
29 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
30 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
31 |
+
See the License for the specific language governing permissions and
|
32 |
+
limitations under the License.
|
33 |
+
"""
|
34 |
+
import os
|
35 |
+
import pathlib
|
36 |
+
|
37 |
+
import numpy as np
|
38 |
+
import torch
|
39 |
+
import torchvision.transforms as TF
|
40 |
+
from PIL import Image
|
41 |
+
from scipy import linalg
|
42 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
43 |
+
from torchvision import transforms
|
44 |
+
|
45 |
+
try:
|
46 |
+
from tqdm import tqdm
|
47 |
+
except ImportError:
|
48 |
+
# If tqdm is not available, provide a mock version of it
|
49 |
+
def tqdm(x):
|
50 |
+
return x
|
51 |
+
|
52 |
+
from .inception import InceptionV3
|
53 |
+
|
54 |
+
|
55 |
+
IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
|
56 |
+
'tif', 'tiff', 'webp'}
|
57 |
+
|
58 |
+
|
59 |
+
class ImagePathDataset(torch.utils.data.Dataset):
|
60 |
+
def __init__(self, files, transforms=None):
|
61 |
+
self.files = files
|
62 |
+
self.transforms = transforms
|
63 |
+
|
64 |
+
def __len__(self):
|
65 |
+
return len(self.files)
|
66 |
+
|
67 |
+
def __getitem__(self, i):
|
68 |
+
path = self.files[i]
|
69 |
+
img = Image.open(path).convert('RGB')
|
70 |
+
|
71 |
+
if img.size == (512,512):
|
72 |
+
img = img.resize((256, 256))
|
73 |
+
|
74 |
+
if self.transforms is not None:
|
75 |
+
img = self.transforms(img)
|
76 |
+
return img
|
77 |
+
|
78 |
+
|
79 |
+
def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8):
|
80 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
81 |
+
|
82 |
+
Params:
|
83 |
+
-- files : List of image files paths
|
84 |
+
-- model : Instance of inception model
|
85 |
+
-- batch_size : Batch size of images for the model to process at once.
|
86 |
+
Make sure that the number of samples is a multiple of
|
87 |
+
the batch size, otherwise some samples are ignored. This
|
88 |
+
behavior is retained to match the original FID score
|
89 |
+
implementation.
|
90 |
+
-- dims : Dimensionality of features returned by Inception
|
91 |
+
-- device : Device to run calculations
|
92 |
+
-- num_workers : Number of parallel dataloader workers
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
96 |
+
activations of the given tensor when feeding inception with the
|
97 |
+
query tensor.
|
98 |
+
"""
|
99 |
+
model.eval()
|
100 |
+
|
101 |
+
if batch_size > len(files):
|
102 |
+
print(('Warning: batch size is bigger than the data size. '
|
103 |
+
'Setting batch size to data size'))
|
104 |
+
batch_size = len(files)
|
105 |
+
|
106 |
+
dataset = ImagePathDataset(files, transforms=TF.ToTensor())
|
107 |
+
dataloader = torch.utils.data.DataLoader(dataset,
|
108 |
+
batch_size=batch_size,
|
109 |
+
shuffle=False,
|
110 |
+
drop_last=False,
|
111 |
+
num_workers=num_workers)
|
112 |
+
|
113 |
+
pred_arr = np.empty((len(files), dims))
|
114 |
+
|
115 |
+
start_idx = 0
|
116 |
+
|
117 |
+
# resizer = transforms.Resize(256) # for clip
|
118 |
+
|
119 |
+
for batch in tqdm(dataloader):
|
120 |
+
batch = batch.to(device)
|
121 |
+
|
122 |
+
with torch.no_grad():
|
123 |
+
|
124 |
+
pred = model(batch)[0]
|
125 |
+
|
126 |
+
# If model output is not scalar, apply global spatial average pooling.
|
127 |
+
# This happens if you choose a dimensionality not equal 2048.
|
128 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
129 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
130 |
+
|
131 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
132 |
+
|
133 |
+
pred_arr[start_idx:start_idx + pred.shape[0]] = pred
|
134 |
+
|
135 |
+
start_idx = start_idx + pred.shape[0]
|
136 |
+
|
137 |
+
return pred_arr
|
138 |
+
|
139 |
+
|
140 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
141 |
+
"""Numpy implementation of the Frechet Distance.
|
142 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
143 |
+
and X_2 ~ N(mu_2, C_2) is
|
144 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
145 |
+
|
146 |
+
Stable version by Dougal J. Sutherland.
|
147 |
+
|
148 |
+
Params:
|
149 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
150 |
+
inception net (like returned by the function 'get_predictions')
|
151 |
+
for generated samples.
|
152 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
153 |
+
representative data set.
|
154 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
155 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
156 |
+
representative data set.
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
-- : The Frechet Distance.
|
160 |
+
"""
|
161 |
+
|
162 |
+
mu1 = np.atleast_1d(mu1)
|
163 |
+
mu2 = np.atleast_1d(mu2)
|
164 |
+
|
165 |
+
sigma1 = np.atleast_2d(sigma1)
|
166 |
+
sigma2 = np.atleast_2d(sigma2)
|
167 |
+
|
168 |
+
assert mu1.shape == mu2.shape, \
|
169 |
+
'Training and test mean vectors have different lengths'
|
170 |
+
assert sigma1.shape == sigma2.shape, \
|
171 |
+
'Training and test covariances have different dimensions'
|
172 |
+
|
173 |
+
diff = mu1 - mu2
|
174 |
+
|
175 |
+
# Product might be almost singular
|
176 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
177 |
+
if not np.isfinite(covmean).all():
|
178 |
+
msg = ('fid calculation produces singular product; '
|
179 |
+
'adding %s to diagonal of cov estimates') % eps
|
180 |
+
print(msg)
|
181 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
182 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
183 |
+
|
184 |
+
# Numerical error might give slight imaginary component
|
185 |
+
if np.iscomplexobj(covmean):
|
186 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
187 |
+
m = np.max(np.abs(covmean.imag))
|
188 |
+
raise ValueError('Imaginary component {}'.format(m))
|
189 |
+
covmean = covmean.real
|
190 |
+
|
191 |
+
tr_covmean = np.trace(covmean)
|
192 |
+
|
193 |
+
return (diff.dot(diff) + np.trace(sigma1)
|
194 |
+
+ np.trace(sigma2) - 2 * tr_covmean)
|
195 |
+
|
196 |
+
|
197 |
+
def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
|
198 |
+
device='cpu', num_workers=8):
|
199 |
+
"""Calculation of the statistics used by the FID.
|
200 |
+
Params:
|
201 |
+
-- files : List of image files paths
|
202 |
+
-- model : Instance of inception model
|
203 |
+
-- batch_size : The images numpy array is split into batches with
|
204 |
+
batch size batch_size. A reasonable batch size
|
205 |
+
depends on the hardware.
|
206 |
+
-- dims : Dimensionality of features returned by Inception
|
207 |
+
-- device : Device to run calculations
|
208 |
+
-- num_workers : Number of parallel dataloader workers
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
212 |
+
the inception model.
|
213 |
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
214 |
+
the inception model.
|
215 |
+
"""
|
216 |
+
act = get_activations(files, model, batch_size, dims, device, num_workers)
|
217 |
+
mu = np.mean(act, axis=0)
|
218 |
+
sigma = np.cov(act, rowvar=False)
|
219 |
+
return mu, sigma
|
220 |
+
|
221 |
+
|
222 |
+
def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8):
|
223 |
+
if path.endswith('.npz'):
|
224 |
+
with np.load(path) as f:
|
225 |
+
m, s = f['mu'][:], f['sigma'][:]
|
226 |
+
else:
|
227 |
+
path = pathlib.Path(path)
|
228 |
+
files = sorted([file for ext in IMAGE_EXTENSIONS
|
229 |
+
for file in path.glob('*.{}'.format(ext))])
|
230 |
+
m, s = calculate_activation_statistics(files, model, batch_size,
|
231 |
+
dims, device, num_workers)
|
232 |
+
|
233 |
+
return m, s
|
234 |
+
|
235 |
+
|
236 |
+
def save_statistics_of_path(path, out_path, device=None, batch_size=50, dims=2048, num_workers=8):
|
237 |
+
if device is None:
|
238 |
+
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
|
239 |
+
else:
|
240 |
+
device = torch.device(device)
|
241 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
242 |
+
model = InceptionV3([block_idx]).to(device)
|
243 |
+
m1, s1 = compute_statistics_of_path(path, model, batch_size, dims, device, num_workers)
|
244 |
+
np.savez(out_path, mu=m1, sigma=s1)
|
245 |
+
|
246 |
+
|
247 |
+
def calculate_fid_given_paths(paths, device=None, batch_size=50, dims=2048, num_workers=8):
|
248 |
+
"""Calculates the FID of two paths"""
|
249 |
+
if device is None:
|
250 |
+
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
|
251 |
+
else:
|
252 |
+
device = torch.device(device)
|
253 |
+
|
254 |
+
for p in paths:
|
255 |
+
if not os.path.exists(p):
|
256 |
+
raise RuntimeError('Invalid path: %s' % p)
|
257 |
+
|
258 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
259 |
+
|
260 |
+
model = InceptionV3([block_idx]).to(device)
|
261 |
+
|
262 |
+
m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
|
263 |
+
dims, device, num_workers)
|
264 |
+
m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
|
265 |
+
dims, device, num_workers)
|
266 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
267 |
+
|
268 |
+
return fid_value
|
tools/inception.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
try:
|
7 |
+
from torchvision.models.utils import load_state_dict_from_url
|
8 |
+
except ImportError:
|
9 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
10 |
+
|
11 |
+
# Inception weights ported to Pytorch from
|
12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
|
14 |
+
|
15 |
+
|
16 |
+
class InceptionV3(nn.Module):
|
17 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
18 |
+
|
19 |
+
# Index of default block of inception to return,
|
20 |
+
# corresponds to output of final average pooling
|
21 |
+
DEFAULT_BLOCK_INDEX = 3
|
22 |
+
|
23 |
+
# Maps feature dimensionality to their output blocks indices
|
24 |
+
BLOCK_INDEX_BY_DIM = {
|
25 |
+
64: 0, # First max pooling features
|
26 |
+
192: 1, # Second max pooling featurs
|
27 |
+
768: 2, # Pre-aux classifier features
|
28 |
+
2048: 3 # Final average pooling features
|
29 |
+
}
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
output_blocks=(DEFAULT_BLOCK_INDEX,),
|
33 |
+
resize_input=True,
|
34 |
+
normalize_input=True,
|
35 |
+
requires_grad=False,
|
36 |
+
use_fid_inception=True):
|
37 |
+
"""Build pretrained InceptionV3
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
output_blocks : list of int
|
42 |
+
Indices of blocks to return features of. Possible values are:
|
43 |
+
- 0: corresponds to output of first max pooling
|
44 |
+
- 1: corresponds to output of second max pooling
|
45 |
+
- 2: corresponds to output which is fed to aux classifier
|
46 |
+
- 3: corresponds to output of final average pooling
|
47 |
+
resize_input : bool
|
48 |
+
If true, bilinearly resizes input to width and height 299 before
|
49 |
+
feeding input to model. As the network without fully connected
|
50 |
+
layers is fully convolutional, it should be able to handle inputs
|
51 |
+
of arbitrary size, so resizing might not be strictly needed
|
52 |
+
normalize_input : bool
|
53 |
+
If true, scales the input from range (0, 1) to the range the
|
54 |
+
pretrained Inception network expects, namely (-1, 1)
|
55 |
+
requires_grad : bool
|
56 |
+
If true, parameters of the model require gradients. Possibly useful
|
57 |
+
for finetuning the network
|
58 |
+
use_fid_inception : bool
|
59 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
60 |
+
FID implementation. If false, uses the pretrained Inception model
|
61 |
+
available in torchvision. The FID Inception model has different
|
62 |
+
weights and a slightly different structure from torchvision's
|
63 |
+
Inception model. If you want to compute FID scores, you are
|
64 |
+
strongly advised to set this parameter to true to get comparable
|
65 |
+
results.
|
66 |
+
"""
|
67 |
+
super(InceptionV3, self).__init__()
|
68 |
+
|
69 |
+
self.resize_input = resize_input
|
70 |
+
self.normalize_input = normalize_input
|
71 |
+
self.output_blocks = sorted(output_blocks)
|
72 |
+
self.last_needed_block = max(output_blocks)
|
73 |
+
|
74 |
+
assert self.last_needed_block <= 3, \
|
75 |
+
'Last possible output block index is 3'
|
76 |
+
|
77 |
+
self.blocks = nn.ModuleList()
|
78 |
+
|
79 |
+
if use_fid_inception:
|
80 |
+
inception = fid_inception_v3()
|
81 |
+
else:
|
82 |
+
inception = _inception_v3(pretrained=True)
|
83 |
+
|
84 |
+
# Block 0: input to maxpool1
|
85 |
+
block0 = [
|
86 |
+
inception.Conv2d_1a_3x3,
|
87 |
+
inception.Conv2d_2a_3x3,
|
88 |
+
inception.Conv2d_2b_3x3,
|
89 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
90 |
+
]
|
91 |
+
self.blocks.append(nn.Sequential(*block0))
|
92 |
+
|
93 |
+
# Block 1: maxpool1 to maxpool2
|
94 |
+
if self.last_needed_block >= 1:
|
95 |
+
block1 = [
|
96 |
+
inception.Conv2d_3b_1x1,
|
97 |
+
inception.Conv2d_4a_3x3,
|
98 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
99 |
+
]
|
100 |
+
self.blocks.append(nn.Sequential(*block1))
|
101 |
+
|
102 |
+
# Block 2: maxpool2 to aux classifier
|
103 |
+
if self.last_needed_block >= 2:
|
104 |
+
block2 = [
|
105 |
+
inception.Mixed_5b,
|
106 |
+
inception.Mixed_5c,
|
107 |
+
inception.Mixed_5d,
|
108 |
+
inception.Mixed_6a,
|
109 |
+
inception.Mixed_6b,
|
110 |
+
inception.Mixed_6c,
|
111 |
+
inception.Mixed_6d,
|
112 |
+
inception.Mixed_6e,
|
113 |
+
]
|
114 |
+
self.blocks.append(nn.Sequential(*block2))
|
115 |
+
|
116 |
+
# Block 3: aux classifier to final avgpool
|
117 |
+
if self.last_needed_block >= 3:
|
118 |
+
block3 = [
|
119 |
+
inception.Mixed_7a,
|
120 |
+
inception.Mixed_7b,
|
121 |
+
inception.Mixed_7c,
|
122 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
123 |
+
]
|
124 |
+
self.blocks.append(nn.Sequential(*block3))
|
125 |
+
|
126 |
+
for param in self.parameters():
|
127 |
+
param.requires_grad = requires_grad
|
128 |
+
|
129 |
+
def forward(self, inp):
|
130 |
+
"""Get Inception feature maps
|
131 |
+
|
132 |
+
Parameters
|
133 |
+
----------
|
134 |
+
inp : torch.autograd.Variable
|
135 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
136 |
+
range (0, 1)
|
137 |
+
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
141 |
+
block, sorted ascending by index
|
142 |
+
"""
|
143 |
+
outp = []
|
144 |
+
x = inp
|
145 |
+
|
146 |
+
if self.resize_input:
|
147 |
+
x = F.interpolate(x,
|
148 |
+
size=(299, 299),
|
149 |
+
mode='bilinear',
|
150 |
+
align_corners=False)
|
151 |
+
|
152 |
+
if self.normalize_input:
|
153 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
154 |
+
|
155 |
+
for idx, block in enumerate(self.blocks):
|
156 |
+
x = block(x)
|
157 |
+
if idx in self.output_blocks:
|
158 |
+
outp.append(x)
|
159 |
+
|
160 |
+
if idx == self.last_needed_block:
|
161 |
+
break
|
162 |
+
|
163 |
+
return outp
|
164 |
+
|
165 |
+
|
166 |
+
def _inception_v3(*args, **kwargs):
|
167 |
+
"""Wraps `torchvision.models.inception_v3`
|
168 |
+
|
169 |
+
Skips default weight inititialization if supported by torchvision version.
|
170 |
+
See https://github.com/mseitzer/pytorch-fid/issues/28.
|
171 |
+
"""
|
172 |
+
try:
|
173 |
+
version = tuple(map(int, torchvision.__version__.split('.')[:2]))
|
174 |
+
except ValueError:
|
175 |
+
# Just a caution against weird version strings
|
176 |
+
version = (0,)
|
177 |
+
|
178 |
+
if version >= (0, 6):
|
179 |
+
kwargs['init_weights'] = False
|
180 |
+
|
181 |
+
return torchvision.models.inception_v3(*args, **kwargs)
|
182 |
+
|
183 |
+
|
184 |
+
def fid_inception_v3():
|
185 |
+
"""Build pretrained Inception model for FID computation
|
186 |
+
|
187 |
+
The Inception model for FID computation uses a different set of weights
|
188 |
+
and has a slightly different structure than torchvision's Inception.
|
189 |
+
|
190 |
+
This method first constructs torchvision's Inception and then patches the
|
191 |
+
necessary parts that are different in the FID Inception model.
|
192 |
+
"""
|
193 |
+
inception = _inception_v3(num_classes=1008,
|
194 |
+
aux_logits=False,
|
195 |
+
pretrained=False)
|
196 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
197 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
198 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
199 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
200 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
201 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
202 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
203 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
204 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
205 |
+
|
206 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, model_dir="checkpoints", progress=True)
|
207 |
+
inception.load_state_dict(state_dict)
|
208 |
+
return inception
|
209 |
+
|
210 |
+
|
211 |
+
class FIDInceptionA(torchvision.models.inception.InceptionA):
|
212 |
+
"""InceptionA block patched for FID computation"""
|
213 |
+
def __init__(self, in_channels, pool_features):
|
214 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
215 |
+
|
216 |
+
def forward(self, x):
|
217 |
+
branch1x1 = self.branch1x1(x)
|
218 |
+
|
219 |
+
branch5x5 = self.branch5x5_1(x)
|
220 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
221 |
+
|
222 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
223 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
224 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
225 |
+
|
226 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
227 |
+
# its average calculation
|
228 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
229 |
+
count_include_pad=False)
|
230 |
+
branch_pool = self.branch_pool(branch_pool)
|
231 |
+
|
232 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
233 |
+
return torch.cat(outputs, 1)
|
234 |
+
|
235 |
+
|
236 |
+
class FIDInceptionC(torchvision.models.inception.InceptionC):
|
237 |
+
"""InceptionC block patched for FID computation"""
|
238 |
+
def __init__(self, in_channels, channels_7x7):
|
239 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
branch1x1 = self.branch1x1(x)
|
243 |
+
|
244 |
+
branch7x7 = self.branch7x7_1(x)
|
245 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
246 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
247 |
+
|
248 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
249 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
250 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
251 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
252 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
253 |
+
|
254 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
255 |
+
# its average calculation
|
256 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
257 |
+
count_include_pad=False)
|
258 |
+
branch_pool = self.branch_pool(branch_pool)
|
259 |
+
|
260 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
261 |
+
return torch.cat(outputs, 1)
|
262 |
+
|
263 |
+
|
264 |
+
class FIDInceptionE_1(torchvision.models.inception.InceptionE):
|
265 |
+
"""First InceptionE block patched for FID computation"""
|
266 |
+
def __init__(self, in_channels):
|
267 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
268 |
+
|
269 |
+
def forward(self, x):
|
270 |
+
branch1x1 = self.branch1x1(x)
|
271 |
+
|
272 |
+
branch3x3 = self.branch3x3_1(x)
|
273 |
+
branch3x3 = [
|
274 |
+
self.branch3x3_2a(branch3x3),
|
275 |
+
self.branch3x3_2b(branch3x3),
|
276 |
+
]
|
277 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
278 |
+
|
279 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
280 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
281 |
+
branch3x3dbl = [
|
282 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
283 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
284 |
+
]
|
285 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
286 |
+
|
287 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
288 |
+
# its average calculation
|
289 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
290 |
+
count_include_pad=False)
|
291 |
+
branch_pool = self.branch_pool(branch_pool)
|
292 |
+
|
293 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
294 |
+
return torch.cat(outputs, 1)
|
295 |
+
|
296 |
+
|
297 |
+
class FIDInceptionE_2(torchvision.models.inception.InceptionE):
|
298 |
+
"""Second InceptionE block patched for FID computation"""
|
299 |
+
def __init__(self, in_channels):
|
300 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
301 |
+
|
302 |
+
def forward(self, x):
|
303 |
+
branch1x1 = self.branch1x1(x)
|
304 |
+
|
305 |
+
branch3x3 = self.branch3x3_1(x)
|
306 |
+
branch3x3 = [
|
307 |
+
self.branch3x3_2a(branch3x3),
|
308 |
+
self.branch3x3_2b(branch3x3),
|
309 |
+
]
|
310 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
311 |
+
|
312 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
313 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
314 |
+
branch3x3dbl = [
|
315 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
316 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
317 |
+
]
|
318 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
319 |
+
|
320 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
321 |
+
# pooling. This is likely an error in this specific Inception
|
322 |
+
# implementation, as other Inception models use average pooling here
|
323 |
+
# (which matches the description in the paper).
|
324 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
325 |
+
branch_pool = self.branch_pool(branch_pool)
|
326 |
+
|
327 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
328 |
+
return torch.cat(outputs, 1)
|
train_t2i.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ml_collections
|
2 |
+
import torch
|
3 |
+
from torch import multiprocessing as mp
|
4 |
+
from datasets import get_dataset
|
5 |
+
from torchvision.utils import make_grid, save_image
|
6 |
+
import utils
|
7 |
+
import einops
|
8 |
+
from torch.utils._pytree import tree_map
|
9 |
+
import accelerate
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
import tempfile
|
13 |
+
from absl import logging
|
14 |
+
import builtins
|
15 |
+
import os
|
16 |
+
import wandb
|
17 |
+
import numpy as np
|
18 |
+
import time
|
19 |
+
import random
|
20 |
+
|
21 |
+
import libs.autoencoder
|
22 |
+
from libs.t5 import T5Embedder
|
23 |
+
from libs.clip import FrozenCLIPEmbedder
|
24 |
+
from diffusion.flow_matching import FlowMatching, ODEFlowMatchingSolver, ODEEulerFlowMatchingSolver
|
25 |
+
from tools.fid_score import calculate_fid_given_paths
|
26 |
+
from tools.clip_score import ClipSocre
|
27 |
+
|
28 |
+
|
29 |
+
def train(config):
|
30 |
+
if config.get('benchmark', False):
|
31 |
+
torch.backends.cudnn.benchmark = True
|
32 |
+
torch.backends.cudnn.deterministic = False
|
33 |
+
|
34 |
+
mp.set_start_method('spawn')
|
35 |
+
accelerator = accelerate.Accelerator()
|
36 |
+
device = accelerator.device
|
37 |
+
accelerate.utils.set_seed(config.seed, device_specific=True)
|
38 |
+
logging.info(f'Process {accelerator.process_index} using device: {device}')
|
39 |
+
|
40 |
+
config.mixed_precision = accelerator.mixed_precision
|
41 |
+
config = ml_collections.FrozenConfigDict(config)
|
42 |
+
|
43 |
+
assert config.train.batch_size % accelerator.num_processes == 0
|
44 |
+
mini_batch_size = config.train.batch_size // accelerator.num_processes
|
45 |
+
|
46 |
+
if accelerator.is_main_process:
|
47 |
+
os.makedirs(config.ckpt_root, exist_ok=True)
|
48 |
+
os.makedirs(config.sample_dir, exist_ok=True)
|
49 |
+
accelerator.wait_for_everyone()
|
50 |
+
if accelerator.is_main_process:
|
51 |
+
wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(),
|
52 |
+
name=config.hparams, job_type='train', mode='offline')
|
53 |
+
utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
|
54 |
+
logging.info(config)
|
55 |
+
else:
|
56 |
+
utils.set_logger(log_level='error')
|
57 |
+
builtins.print = lambda *args: None
|
58 |
+
logging.info(f'Run on {accelerator.num_processes} devices')
|
59 |
+
|
60 |
+
dataset = get_dataset(**config.dataset)
|
61 |
+
assert os.path.exists(dataset.fid_stat)
|
62 |
+
|
63 |
+
gpu_model = torch.cuda.get_device_name(torch.cuda.current_device())
|
64 |
+
num_workers = 8
|
65 |
+
|
66 |
+
train_dataset = dataset.get_split(split='train', labeled=True)
|
67 |
+
train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
|
68 |
+
num_workers=num_workers, pin_memory=True, persistent_workers=True)
|
69 |
+
|
70 |
+
test_dataset = dataset.get_split(split='test', labeled=True) # for sampling
|
71 |
+
test_dataset_loader = DataLoader(test_dataset, batch_size=config.sample.mini_batch_size, shuffle=True, drop_last=True,
|
72 |
+
num_workers=num_workers, pin_memory=True, persistent_workers=True)
|
73 |
+
|
74 |
+
train_state = utils.initialize_train_state(config, device)
|
75 |
+
nnet, nnet_ema, optimizer, train_dataset_loader, test_dataset_loader = accelerator.prepare(
|
76 |
+
train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader, test_dataset_loader)
|
77 |
+
lr_scheduler = train_state.lr_scheduler
|
78 |
+
train_state.resume(config.ckpt_root)
|
79 |
+
|
80 |
+
autoencoder = libs.autoencoder.get_model(**config.autoencoder)
|
81 |
+
autoencoder.to(device)
|
82 |
+
|
83 |
+
if config.nnet.model_args.clip_dim == 4096:
|
84 |
+
llm = "t5"
|
85 |
+
t5 = T5Embedder(device=device)
|
86 |
+
elif config.nnet.model_args.clip_dim == 768:
|
87 |
+
llm = "clip"
|
88 |
+
clip = FrozenCLIPEmbedder()
|
89 |
+
clip.eval()
|
90 |
+
clip.to(device)
|
91 |
+
else:
|
92 |
+
raise NotImplementedError
|
93 |
+
|
94 |
+
ss_empty_context = None
|
95 |
+
|
96 |
+
ClipSocre_model = ClipSocre(device=device)
|
97 |
+
|
98 |
+
@ torch.cuda.amp.autocast()
|
99 |
+
def encode(_batch):
|
100 |
+
return autoencoder.encode(_batch)
|
101 |
+
|
102 |
+
@ torch.cuda.amp.autocast()
|
103 |
+
def decode(_batch):
|
104 |
+
return autoencoder.decode(_batch)
|
105 |
+
|
106 |
+
def get_data_generator():
|
107 |
+
while True:
|
108 |
+
for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
|
109 |
+
yield data
|
110 |
+
|
111 |
+
data_generator = get_data_generator()
|
112 |
+
|
113 |
+
def get_context_generator(autoencoder):
|
114 |
+
while True:
|
115 |
+
for data in test_dataset_loader:
|
116 |
+
if len(data) == 5:
|
117 |
+
_img, _context, _token_mask, _token, _caption = data
|
118 |
+
else:
|
119 |
+
_img, _context = data
|
120 |
+
_token_mask = None
|
121 |
+
_token = None
|
122 |
+
_caption = None
|
123 |
+
|
124 |
+
if len(_img.shape)==5:
|
125 |
+
_testbatch_img_blurred = autoencoder.sample(_img[:,1,:])
|
126 |
+
yield _context, _token_mask, _token, _caption, _testbatch_img_blurred
|
127 |
+
else:
|
128 |
+
assert len(_img.shape)==4
|
129 |
+
yield _context, _token_mask, _token, _caption, None
|
130 |
+
|
131 |
+
context_generator = get_context_generator(autoencoder)
|
132 |
+
|
133 |
+
_flow_mathcing_model = FlowMatching()
|
134 |
+
|
135 |
+
def train_step(_batch, _ss_empty_context):
|
136 |
+
_metrics = dict()
|
137 |
+
optimizer.zero_grad()
|
138 |
+
|
139 |
+
assert len(_batch)==6
|
140 |
+
assert not config.dataset.cfg
|
141 |
+
_batch_img = _batch[0]
|
142 |
+
_batch_con = _batch[1]
|
143 |
+
_batch_mask = _batch[2]
|
144 |
+
_batch_token = _batch[3]
|
145 |
+
_batch_caption = _batch[4]
|
146 |
+
_batch_img_ori = _batch[5]
|
147 |
+
|
148 |
+
_z = autoencoder.sample(_batch_img)
|
149 |
+
|
150 |
+
loss, loss_dict = _flow_mathcing_model(_z, nnet, loss_coeffs=config.loss_coeffs, cond=_batch_con, con_mask=_batch_mask, batch_img_clip=_batch_img_ori, \
|
151 |
+
nnet_style=config.nnet.name, text_token=_batch_token, model_config=config.nnet.model_args, all_config=config, training_step=train_state.step)
|
152 |
+
|
153 |
+
_metrics['loss'] = accelerator.gather(loss.detach()).mean()
|
154 |
+
for key in loss_dict.keys():
|
155 |
+
_metrics[key] = accelerator.gather(loss_dict[key].detach()).mean()
|
156 |
+
accelerator.backward(loss.mean())
|
157 |
+
optimizer.step()
|
158 |
+
lr_scheduler.step()
|
159 |
+
train_state.ema_update(config.get('ema_rate', 0.9999))
|
160 |
+
train_state.step += 1
|
161 |
+
return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)
|
162 |
+
|
163 |
+
def ode_fm_solver_sample(nnet_ema, _n_samples, _sample_steps, context=None, caption=None, testbatch_img_blurred=None, two_stage_generation=-1, token_mask=None, return_clipScore=False, ClipSocre_model=None):
|
164 |
+
with torch.no_grad():
|
165 |
+
_z_gaussian = torch.randn(_n_samples, *config.z_shape, device=device)
|
166 |
+
|
167 |
+
_z_x0, _mu, _log_var = nnet_ema(context, text_encoder = True, shape = _z_gaussian.shape, mask=token_mask)
|
168 |
+
_z_init = _z_x0.reshape(_z_gaussian.shape)
|
169 |
+
|
170 |
+
assert config.sample.scale > 1
|
171 |
+
_cfg = config.sample.scale
|
172 |
+
|
173 |
+
has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator")
|
174 |
+
|
175 |
+
ode_solver = ODEEulerFlowMatchingSolver(nnet_ema, step_size_type="step_in_dsigma", guidance_scale=_cfg)
|
176 |
+
_z, _ = ode_solver.sample(x_T=_z_init, batch_size=_n_samples, sample_steps=_sample_steps, unconditional_guidance_scale=_cfg, has_null_indicator=has_null_indicator)
|
177 |
+
|
178 |
+
image_unprocessed = decode(_z)
|
179 |
+
|
180 |
+
if return_clipScore:
|
181 |
+
clip_score = ClipSocre_model.calculate_clip_score(caption, image_unprocessed)
|
182 |
+
return image_unprocessed, clip_score
|
183 |
+
else:
|
184 |
+
return image_unprocessed
|
185 |
+
|
186 |
+
def eval_step(n_samples, sample_steps):
|
187 |
+
logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm=ODE_Euler_Flow_Matching_Solver, '
|
188 |
+
f'mini_batch_size={config.sample.mini_batch_size}')
|
189 |
+
|
190 |
+
def sample_fn(_n_samples, return_caption=False, return_clipScore=False, ClipSocre_model=None, config=None):
|
191 |
+
_context, _token_mask, _token, _caption, _testbatch_img_blurred = next(context_generator)
|
192 |
+
assert _context.size(0) == _n_samples
|
193 |
+
assert not return_caption # during training we should not use this
|
194 |
+
if return_caption:
|
195 |
+
return ode_fm_solver_sample(nnet_ema, _n_samples, sample_steps, context=_context, token_mask=_token_mask), _caption
|
196 |
+
elif return_clipScore:
|
197 |
+
return ode_fm_solver_sample(nnet_ema, _n_samples, sample_steps, context=_context, token_mask=_token_mask, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, caption=_caption)
|
198 |
+
else:
|
199 |
+
return ode_fm_solver_sample(nnet_ema, _n_samples, sample_steps, context=_context, token_mask=_token_mask)
|
200 |
+
|
201 |
+
with tempfile.TemporaryDirectory() as temp_path:
|
202 |
+
path = config.sample.path or temp_path
|
203 |
+
if accelerator.is_main_process:
|
204 |
+
os.makedirs(path, exist_ok=True)
|
205 |
+
clip_score_list = utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess, return_clipScore=True, ClipSocre_model=ClipSocre_model, config=config)
|
206 |
+
_fid = 0
|
207 |
+
if accelerator.is_main_process:
|
208 |
+
_fid = calculate_fid_given_paths((dataset.fid_stat, path))
|
209 |
+
_clip_score_list = torch.cat(clip_score_list)
|
210 |
+
logging.info(f'step={train_state.step} fid{n_samples}={_fid} clip_score{len(_clip_score_list)} = {_clip_score_list.mean().item()}')
|
211 |
+
with open(os.path.join(config.workdir, 'eval.log'), 'a') as f:
|
212 |
+
print(f'step={train_state.step} fid{n_samples}={_fid} clip_score{len(_clip_score_list)} = {_clip_score_list.mean().item()}', file=f)
|
213 |
+
wandb.log({f'fid{n_samples}': _fid}, step=train_state.step)
|
214 |
+
_fid = torch.tensor(_fid, device=device)
|
215 |
+
_fid = accelerator.reduce(_fid, reduction='sum')
|
216 |
+
|
217 |
+
return _fid.item()
|
218 |
+
|
219 |
+
logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')
|
220 |
+
|
221 |
+
step_fid = []
|
222 |
+
while train_state.step < config.train.n_steps:
|
223 |
+
nnet.train()
|
224 |
+
batch = tree_map(lambda x: x, next(data_generator))
|
225 |
+
metrics = train_step(batch, ss_empty_context)
|
226 |
+
|
227 |
+
nnet.eval()
|
228 |
+
if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
|
229 |
+
logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
|
230 |
+
logging.info(config.workdir)
|
231 |
+
wandb.log(metrics, step=train_state.step)
|
232 |
+
|
233 |
+
############# save rigid image
|
234 |
+
if train_state.step % config.train.eval_interval == 0:
|
235 |
+
torch.cuda.empty_cache()
|
236 |
+
logging.info('Save a grid of images...')
|
237 |
+
if hasattr(dataset, "token_embedding"):
|
238 |
+
contexts = torch.tensor(dataset.token_embedding, device=device)[ : config.train.n_samples_eval]
|
239 |
+
token_mask = torch.tensor(dataset.token_mask, device=device)[ : config.train.n_samples_eval]
|
240 |
+
elif hasattr(dataset, "contexts"):
|
241 |
+
contexts = torch.tensor(dataset.contexts, device=device)[ : config.train.n_samples_eval]
|
242 |
+
token_mask = None
|
243 |
+
else:
|
244 |
+
raise NotImplementedError
|
245 |
+
samples = ode_fm_solver_sample(nnet_ema, _n_samples=config.train.n_samples_eval, _sample_steps=50, context=contexts, token_mask=token_mask)
|
246 |
+
samples = make_grid(dataset.unpreprocess(samples), 5)
|
247 |
+
if accelerator.is_main_process:
|
248 |
+
save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png'))
|
249 |
+
wandb.log({'samples': wandb.Image(samples)}, step=train_state.step)
|
250 |
+
accelerator.wait_for_everyone()
|
251 |
+
torch.cuda.empty_cache()
|
252 |
+
|
253 |
+
############ save checkpoint and evaluate results
|
254 |
+
if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
|
255 |
+
torch.cuda.empty_cache()
|
256 |
+
logging.info(f'Save and eval checkpoint {train_state.step}...')
|
257 |
+
|
258 |
+
if accelerator.local_process_index == 0:
|
259 |
+
train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
|
260 |
+
accelerator.wait_for_everyone()
|
261 |
+
|
262 |
+
fid = eval_step(n_samples=10000, sample_steps=50) # calculate fid of the saved checkpoint
|
263 |
+
step_fid.append((train_state.step, fid))
|
264 |
+
|
265 |
+
torch.cuda.empty_cache()
|
266 |
+
accelerator.wait_for_everyone()
|
267 |
+
|
268 |
+
logging.info(f'Finish fitting, step={train_state.step}')
|
269 |
+
logging.info(f'step_fid: {step_fid}')
|
270 |
+
step_best = sorted(step_fid, key=lambda x: x[1])[0][0]
|
271 |
+
logging.info(f'step_best: {step_best}')
|
272 |
+
train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt'))
|
273 |
+
del metrics
|
274 |
+
accelerator.wait_for_everyone()
|
275 |
+
eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps)
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
from absl import flags
|
280 |
+
from absl import app
|
281 |
+
from ml_collections import config_flags
|
282 |
+
import sys
|
283 |
+
from pathlib import Path
|
284 |
+
|
285 |
+
|
286 |
+
FLAGS = flags.FLAGS
|
287 |
+
config_flags.DEFINE_config_file(
|
288 |
+
"config", None, "Training configuration.", lock_config=False)
|
289 |
+
flags.mark_flags_as_required(["config"])
|
290 |
+
flags.DEFINE_string("workdir", None, "Work unit directory.")
|
291 |
+
|
292 |
+
|
293 |
+
def get_config_name():
|
294 |
+
argv = sys.argv
|
295 |
+
for i in range(1, len(argv)):
|
296 |
+
if argv[i].startswith('--config='):
|
297 |
+
return Path(argv[i].split('=')[-1]).stem
|
298 |
+
|
299 |
+
|
300 |
+
def get_hparams():
|
301 |
+
argv = sys.argv
|
302 |
+
lst = []
|
303 |
+
for i in range(1, len(argv)):
|
304 |
+
assert '=' in argv[i]
|
305 |
+
if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'):
|
306 |
+
hparam, val = argv[i].split('=')
|
307 |
+
hparam = hparam.split('.')[-1]
|
308 |
+
if hparam.endswith('path'):
|
309 |
+
val = Path(val).stem
|
310 |
+
lst.append(f'{hparam}={val}')
|
311 |
+
hparams = '-'.join(lst)
|
312 |
+
if hparams == '':
|
313 |
+
hparams = 'default'
|
314 |
+
return hparams
|
315 |
+
|
316 |
+
|
317 |
+
def main(argv):
|
318 |
+
config = FLAGS.config
|
319 |
+
config.config_name = get_config_name()
|
320 |
+
config.hparams = get_hparams()
|
321 |
+
config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams)
|
322 |
+
config.ckpt_root = os.path.join(config.workdir, 'ckpts')
|
323 |
+
config.sample_dir = os.path.join(config.workdir, 'samples')
|
324 |
+
train(config)
|
325 |
+
|
326 |
+
|
327 |
+
if __name__ == "__main__":
|
328 |
+
app.run(main)
|
utils.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains some tools
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
from tqdm import tqdm
|
9 |
+
from torchvision import transforms
|
10 |
+
from torchvision.utils import save_image
|
11 |
+
from absl import logging
|
12 |
+
from PIL import Image, ImageDraw, ImageFont
|
13 |
+
import textwrap
|
14 |
+
|
15 |
+
def save_image_with_caption(image_tensor, caption, filename, font_size=20, font_path='/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf'):
|
16 |
+
"""
|
17 |
+
Save an image with a caption
|
18 |
+
"""
|
19 |
+
image_tensor = image_tensor.clone().detach()
|
20 |
+
image_tensor = torch.clamp(image_tensor, min=0, max=1)
|
21 |
+
image_pil = transforms.ToPILImage()(image_tensor)
|
22 |
+
draw = ImageDraw.Draw(image_pil)
|
23 |
+
|
24 |
+
font = ImageFont.truetype(font_path, font_size)
|
25 |
+
wrap_text = textwrap.wrap(caption, width=len(caption)//4 + 1)
|
26 |
+
text_sizes = [draw.textsize(line, font=font) for line in wrap_text]
|
27 |
+
max_text_width = max(size[0] for size in text_sizes)
|
28 |
+
total_text_height = sum(size[1] for size in text_sizes) + 15
|
29 |
+
|
30 |
+
new_height = image_pil.height + total_text_height + 25
|
31 |
+
new_image = Image.new('RGB', (image_pil.width, new_height), 'white')
|
32 |
+
new_image.paste(image_pil, (0, 0))
|
33 |
+
current_y = image_pil.height + 5
|
34 |
+
draw = ImageDraw.Draw(new_image)
|
35 |
+
|
36 |
+
for line, size in zip(wrap_text, text_sizes):
|
37 |
+
x = (new_image.width - size[0]) / 2
|
38 |
+
draw.text((x, current_y), line, font=font, fill='black')
|
39 |
+
current_y += size[1] + 5
|
40 |
+
new_image.save(filename)
|
41 |
+
|
42 |
+
|
43 |
+
def set_logger(log_level='info', fname=None):
|
44 |
+
import logging as _logging
|
45 |
+
handler = logging.get_absl_handler()
|
46 |
+
formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s')
|
47 |
+
handler.setFormatter(formatter)
|
48 |
+
logging.set_verbosity(log_level)
|
49 |
+
if fname is not None:
|
50 |
+
handler = _logging.FileHandler(fname)
|
51 |
+
handler.setFormatter(formatter)
|
52 |
+
logging.get_absl_logger().addHandler(handler)
|
53 |
+
|
54 |
+
|
55 |
+
def dct2str(dct):
|
56 |
+
return str({k: f'{v:.6g}' for k, v in dct.items()})
|
57 |
+
|
58 |
+
|
59 |
+
def get_nnet(name, **kwargs):
|
60 |
+
if name == 'dimr':
|
61 |
+
from libs.model.dimr_t2i import MRModel
|
62 |
+
return MRModel(kwargs["model_args"])
|
63 |
+
elif name == 'dit':
|
64 |
+
from libs.model.dit_t2i import DiT_H_2
|
65 |
+
return DiT_H_2(kwargs["model_args"])
|
66 |
+
else:
|
67 |
+
raise NotImplementedError(name)
|
68 |
+
|
69 |
+
|
70 |
+
def set_seed(seed: int):
|
71 |
+
if seed is not None:
|
72 |
+
torch.manual_seed(seed)
|
73 |
+
np.random.seed(seed)
|
74 |
+
|
75 |
+
|
76 |
+
def get_optimizer(params, name, **kwargs):
|
77 |
+
if name == 'adam':
|
78 |
+
from torch.optim import Adam
|
79 |
+
return Adam(params, **kwargs)
|
80 |
+
elif name == 'adamw':
|
81 |
+
from torch.optim import AdamW
|
82 |
+
return AdamW(params, **kwargs)
|
83 |
+
else:
|
84 |
+
raise NotImplementedError(name)
|
85 |
+
|
86 |
+
|
87 |
+
def customized_lr_scheduler(optimizer, warmup_steps=-1):
|
88 |
+
from torch.optim.lr_scheduler import LambdaLR
|
89 |
+
def fn(step):
|
90 |
+
if warmup_steps > 0:
|
91 |
+
return min(step / warmup_steps, 1)
|
92 |
+
else:
|
93 |
+
return 1
|
94 |
+
return LambdaLR(optimizer, fn)
|
95 |
+
|
96 |
+
|
97 |
+
def get_lr_scheduler(optimizer, name, **kwargs):
|
98 |
+
if name == 'customized':
|
99 |
+
return customized_lr_scheduler(optimizer, **kwargs)
|
100 |
+
elif name == 'cosine':
|
101 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
102 |
+
return CosineAnnealingLR(optimizer, **kwargs)
|
103 |
+
else:
|
104 |
+
raise NotImplementedError(name)
|
105 |
+
|
106 |
+
|
107 |
+
def ema(model_dest: nn.Module, model_src: nn.Module, rate):
|
108 |
+
param_dict_src = dict(model_src.named_parameters())
|
109 |
+
for p_name, p_dest in model_dest.named_parameters():
|
110 |
+
p_src = param_dict_src[p_name]
|
111 |
+
assert p_src is not p_dest
|
112 |
+
p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
|
113 |
+
|
114 |
+
|
115 |
+
class TrainState(object):
|
116 |
+
def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None):
|
117 |
+
self.optimizer = optimizer
|
118 |
+
self.lr_scheduler = lr_scheduler
|
119 |
+
self.step = step
|
120 |
+
self.nnet = nnet
|
121 |
+
self.nnet_ema = nnet_ema
|
122 |
+
|
123 |
+
def ema_update(self, rate=0.9999):
|
124 |
+
if self.nnet_ema is not None:
|
125 |
+
ema(self.nnet_ema, self.nnet, rate)
|
126 |
+
|
127 |
+
def save(self, path):
|
128 |
+
os.makedirs(path, exist_ok=True)
|
129 |
+
torch.save(self.step, os.path.join(path, 'step.pth'))
|
130 |
+
for key, val in self.__dict__.items():
|
131 |
+
if key != 'step' and val is not None:
|
132 |
+
torch.save(val.state_dict(), os.path.join(path, f'{key}.pth'))
|
133 |
+
|
134 |
+
def load(self, path):
|
135 |
+
logging.info(f'load from {path}')
|
136 |
+
self.step = torch.load(os.path.join(path, 'step.pth'))
|
137 |
+
for key, val in self.__dict__.items():
|
138 |
+
if key != 'step' and val is not None:
|
139 |
+
val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu'))
|
140 |
+
|
141 |
+
def resume(self, ckpt_root, step=None):
|
142 |
+
if not os.path.exists(ckpt_root):
|
143 |
+
return
|
144 |
+
if step is None:
|
145 |
+
ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root)))
|
146 |
+
if not ckpts:
|
147 |
+
return
|
148 |
+
steps = map(lambda x: int(x.split(".")[0]), ckpts)
|
149 |
+
step = max(steps)
|
150 |
+
ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt')
|
151 |
+
logging.info(f'resume from {ckpt_path}')
|
152 |
+
self.load(ckpt_path)
|
153 |
+
|
154 |
+
def to(self, device):
|
155 |
+
for key, val in self.__dict__.items():
|
156 |
+
if isinstance(val, nn.Module):
|
157 |
+
val.to(device)
|
158 |
+
|
159 |
+
|
160 |
+
def trainable_parameters(nnet):
|
161 |
+
params_decay = []
|
162 |
+
params_nodecay = []
|
163 |
+
for name, param in nnet.named_parameters():
|
164 |
+
if name.endswith(".nodecay_weight") or name.endswith(".nodecay_bias"):
|
165 |
+
params_nodecay.append(param)
|
166 |
+
else:
|
167 |
+
params_decay.append(param)
|
168 |
+
print("params_decay", len(params_decay))
|
169 |
+
print("params_nodecay", len(params_nodecay))
|
170 |
+
params = [
|
171 |
+
{'params': params_decay},
|
172 |
+
{'params': params_nodecay, 'weight_decay': 0.0}
|
173 |
+
]
|
174 |
+
return params
|
175 |
+
|
176 |
+
|
177 |
+
def initialize_train_state(config, device):
|
178 |
+
|
179 |
+
nnet = get_nnet(**config.nnet)
|
180 |
+
nnet_ema = get_nnet(**config.nnet)
|
181 |
+
nnet_ema.eval()
|
182 |
+
|
183 |
+
optimizer = get_optimizer(trainable_parameters(nnet), **config.optimizer)
|
184 |
+
lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler)
|
185 |
+
|
186 |
+
train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0,
|
187 |
+
nnet=nnet, nnet_ema=nnet_ema)
|
188 |
+
train_state.ema_update(0)
|
189 |
+
train_state.to(device)
|
190 |
+
return train_state
|
191 |
+
|
192 |
+
|
193 |
+
def amortize(n_samples, batch_size):
|
194 |
+
k = n_samples // batch_size
|
195 |
+
r = n_samples % batch_size
|
196 |
+
return k * [batch_size] if r == 0 else k * [batch_size] + [r]
|
197 |
+
|
198 |
+
|
199 |
+
def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, return_clipScore=False, ClipSocre_model=None, config=None):
|
200 |
+
os.makedirs(path, exist_ok=True)
|
201 |
+
idx = 0
|
202 |
+
batch_size = mini_batch_size * accelerator.num_processes
|
203 |
+
clip_score_list = []
|
204 |
+
|
205 |
+
if return_clipScore:
|
206 |
+
assert ClipSocre_model is not None
|
207 |
+
|
208 |
+
for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'):
|
209 |
+
samples, clip_score = sample_fn(mini_batch_size, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, config=config)
|
210 |
+
samples = unpreprocess_fn(samples)
|
211 |
+
samples = accelerator.gather(samples.contiguous())[:_batch_size]
|
212 |
+
clip_score_list.append(accelerator.gather(clip_score)[:_batch_size])
|
213 |
+
if accelerator.is_main_process:
|
214 |
+
for sample in samples:
|
215 |
+
save_image(sample, os.path.join(path, f"{idx}.png"))
|
216 |
+
idx += 1
|
217 |
+
|
218 |
+
if return_clipScore:
|
219 |
+
return clip_score_list
|
220 |
+
else:
|
221 |
+
return None
|
222 |
+
|
223 |
+
|
224 |
+
def sample2dir_wCLIP(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, return_clipScore=False, ClipSocre_model=None, config=None):
|
225 |
+
os.makedirs(path, exist_ok=True)
|
226 |
+
idx = 0
|
227 |
+
batch_size = mini_batch_size * accelerator.num_processes
|
228 |
+
clip_score_list = []
|
229 |
+
|
230 |
+
if return_clipScore:
|
231 |
+
assert ClipSocre_model is not None
|
232 |
+
|
233 |
+
for _batch_size in amortize(n_samples, batch_size):
|
234 |
+
samples, clip_score = sample_fn(mini_batch_size, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, config=config)
|
235 |
+
samples = unpreprocess_fn(samples)
|
236 |
+
samples = accelerator.gather(samples.contiguous())[:_batch_size]
|
237 |
+
clip_score_list.append(accelerator.gather(clip_score)[:_batch_size])
|
238 |
+
if accelerator.is_main_process:
|
239 |
+
for sample in samples:
|
240 |
+
save_image(sample, os.path.join(path, f"{idx}.png"))
|
241 |
+
idx += 1
|
242 |
+
break
|
243 |
+
|
244 |
+
if return_clipScore:
|
245 |
+
return clip_score_list
|
246 |
+
else:
|
247 |
+
return None
|
248 |
+
|
249 |
+
|
250 |
+
def sample2dir_wPrompt(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, config=None):
|
251 |
+
os.makedirs(path, exist_ok=True)
|
252 |
+
idx = 0
|
253 |
+
batch_size = mini_batch_size * accelerator.num_processes
|
254 |
+
|
255 |
+
for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'):
|
256 |
+
samples, samples_caption = sample_fn(mini_batch_size, return_caption=True, config=config)
|
257 |
+
samples = unpreprocess_fn(samples)
|
258 |
+
samples = accelerator.gather(samples.contiguous())[:_batch_size]
|
259 |
+
if accelerator.is_main_process:
|
260 |
+
for sample, caption in zip(samples,samples_caption):
|
261 |
+
try:
|
262 |
+
save_image_with_caption(sample, caption, os.path.join(path, f"{idx}.png"))
|
263 |
+
except:
|
264 |
+
save_image(sample, os.path.join(path, f"{idx}.png"))
|
265 |
+
idx += 1
|
266 |
+
|
267 |
+
|
268 |
+
def grad_norm(model):
|
269 |
+
total_norm = 0.
|
270 |
+
for p in model.parameters():
|
271 |
+
param_norm = p.grad.data.norm(2)
|
272 |
+
total_norm += param_norm.item() ** 2
|
273 |
+
total_norm = total_norm ** (1. / 2)
|
274 |
+
return total_norm
|