QHL067 commited on
Commit
f9567e5
·
1 Parent(s): 9772b52
app.py CHANGED
@@ -1,8 +1,83 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
 
 
3
  import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from diffusers import DiffusionPipeline
7
  import torch
8
 
@@ -21,7 +96,7 @@ MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
26
  prompt1,
27
  prompt2,
@@ -69,7 +144,8 @@ css = """
69
 
70
  with gr.Blocks(css=css) as demo:
71
  with gr.Column(elem_id="col-container"):
72
- gr.Markdown(" # Text-to-Image Gradio Template")
 
73
 
74
  with gr.Row():
75
  prompt1 = gr.Text(
 
1
  import gradio as gr
2
+
3
+ from absl import flags
4
+ from absl import app
5
+ from ml_collections import config_flags
6
+ import os
7
+
8
+ import ml_collections
9
+ import torch
10
+ from torch import multiprocessing as mp
11
+ import torch.nn as nn
12
+ import accelerate
13
+ import utils
14
+ import tempfile
15
+ from absl import logging
16
+ import builtins
17
+ import einops
18
+ import math
19
  import numpy as np
20
+ import time
21
+ from PIL import Image
22
  import random
23
 
24
+ from diffusion.flow_matching import FlowMatching, ODEFlowMatchingSolver, ODEEulerFlowMatchingSolver
25
+ from tools.clip_score import ClipSocre
26
+ import libs.autoencoder
27
+ from libs.clip import FrozenCLIPEmbedder
28
+ from libs.t5 import T5Embedder
29
+
30
+
31
+ def unpreprocess(x):
32
+ x = 0.5 * (x + 1.)
33
+ x.clamp_(0., 1.)
34
+ return x
35
+
36
+ def batch_decode(_z, decode, batch_size=10):
37
+ """
38
+ The VAE decoder requires large GPU memory. To run the interpolation model on GPUs with 24 GB or smaller RAM, you can use this code to reduce memory usage for the VAE.
39
+ It works by splitting the input tensor into smaller chunks.
40
+ """
41
+ num_samples = _z.size(0)
42
+ decoded_batches = []
43
+
44
+ for i in range(0, num_samples, batch_size):
45
+ batch = _z[i:i + batch_size]
46
+ decoded_batch = decode(batch)
47
+ decoded_batches.append(decoded_batch)
48
+
49
+ image_unprocessed = torch.cat(decoded_batches, dim=0)
50
+ return image_unprocessed
51
+
52
+ def get_caption(llm, text_model, prompt_dict, batch_size):
53
+
54
+ if batch_size == 3:
55
+ # only addition or only subtraction
56
+ assert len(prompt_dict) == 2
57
+ _batch_con = list(prompt_dict.values()) + [' ']
58
+ elif batch_size == 4:
59
+ # addition and subtraction
60
+ assert len(prompt_dict) == 3
61
+ _batch_con = list(prompt_dict.values()) + [' ']
62
+ elif batch_size >= 5:
63
+ # linear interpolation
64
+ assert len(prompt_dict) == 2
65
+ _batch_con = [prompt_dict['prompt_1']] + [' '] * (batch_size-2) + [prompt_dict['prompt_2']]
66
+
67
+ if llm == "clip":
68
+ _latent, _latent_and_others = text_model.encode(_batch_con)
69
+ _con = _latent_and_others['token_embedding'].detach()
70
+ elif llm == "t5":
71
+ _latent, _latent_and_others = text_model.get_text_embeddings(_batch_con)
72
+ _con = (_latent_and_others['token_embedding'] * 10.0).detach()
73
+ else:
74
+ raise NotImplementedError
75
+ _con_mask = _latent_and_others['token_mask'].detach()
76
+ _batch_token = _latent_and_others['tokens'].detach()
77
+ _batch_caption = _batch_con
78
+ return (_con, _con_mask, _batch_token, _batch_caption)
79
+
80
+ import spaces #[uncomment to use ZeroGPU]
81
  from diffusers import DiffusionPipeline
82
  import torch
83
 
 
96
  MAX_IMAGE_SIZE = 1024
97
 
98
 
99
+ @spaces.GPU #[uncomment to use ZeroGPU]
100
  def infer(
101
  prompt1,
102
  prompt2,
 
144
 
145
  with gr.Blocks(css=css) as demo:
146
  with gr.Column(elem_id="col-container"):
147
+ gr.Markdown(" # CrossFlow")
148
+ gr.Markdown(" CrossFlow directly transforms text representations into images for text-to-image generation, enabling interpolation in the input text latent space.")
149
 
150
  with gr.Row():
151
  prompt1 = gr.Text(
configs/t2i_256px_clip_dimr.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class Args:
6
+ def __init__(self, **kwargs):
7
+ for key, value in kwargs.items():
8
+ setattr(self, key, value)
9
+
10
+ model = Args(
11
+ channels = 4,
12
+ block_grad_to_lowres = False,
13
+ norm_type = "TDRMSN",
14
+ use_t2i = True,
15
+ clip_dim=768,
16
+ num_clip_token=77,
17
+ gradient_checking=True,
18
+ cfg_indicator=0.1,
19
+ textVAE = Args(
20
+ num_blocks = 11,
21
+ hidden_dim = 1024,
22
+ hidden_token_length = 256,
23
+ num_attention_heads = 8,
24
+ dropout_prob = 0.1,
25
+ ),
26
+ stage_configs = [
27
+ Args(
28
+ block_type = "TransformerBlock",
29
+ dim = 1024, # channel
30
+ hidden_dim = 2048,
31
+ num_attention_heads = 16,
32
+ num_blocks = 65, # depth
33
+ max_height = 16,
34
+ max_width = 16,
35
+ image_input_ratio = 1,
36
+ input_feature_ratio = 2,
37
+ final_kernel_size = 3,
38
+ dropout_prob = 0,
39
+ ),
40
+ Args(
41
+ block_type = "ConvNeXtBlock",
42
+ dim = 512,
43
+ hidden_dim = 1024,
44
+ kernel_size = 7,
45
+ num_blocks = 33,
46
+ max_height = 32,
47
+ max_width = 32,
48
+ image_input_ratio = 1,
49
+ input_feature_ratio = 1,
50
+ final_kernel_size = 3,
51
+ dropout_prob = 0,
52
+ ),
53
+ ],
54
+ )
55
+
56
+ def d(**kwargs):
57
+ """Helper of creating a config dict."""
58
+ return ml_collections.ConfigDict(initial_dictionary=kwargs)
59
+
60
+
61
+ def get_config():
62
+ config = ml_collections.ConfigDict()
63
+
64
+ config.seed = 1234
65
+ config.z_shape = (4, 32, 32)
66
+
67
+ config.autoencoder = d(
68
+ pretrained_path='assets/stable-diffusion/autoencoder_kl.pth',
69
+ scale_factor=0.23010
70
+ )
71
+
72
+ config.train = d(
73
+ n_steps=1000000,
74
+ batch_size=1024,
75
+ mode='cond',
76
+ log_interval=10,
77
+ eval_interval=5000,
78
+ save_interval=50000,
79
+ )
80
+
81
+ config.optimizer = d(
82
+ name='adamw',
83
+ lr=0.00001,
84
+ weight_decay=0.03,
85
+ betas=(0.9, 0.9),
86
+ )
87
+
88
+ config.lr_scheduler = d(
89
+ name='customized',
90
+ warmup_steps=5000
91
+ )
92
+
93
+ global model
94
+ config.nnet = d(
95
+ name='dimr',
96
+ model_args=model,
97
+ )
98
+ config.loss_coeffs = [1/4, 1]
99
+
100
+ config.dataset = d(
101
+ name='JDB_demo_features',
102
+ resolution=256,
103
+ llm='clip',
104
+ train_path='/data/qihao/dataset/JDB_demo_feature/',
105
+ val_path='/data/qihao/dataset/coco_val_features/',
106
+ cfg=False
107
+ )
108
+
109
+ config.sample = d(
110
+ sample_steps=50,
111
+ n_samples=30000,
112
+ mini_batch_size=20,
113
+ cfg=False,
114
+ scale=7,
115
+ path=''
116
+ )
117
+
118
+ return config
configs/t2i_256px_t5_dimr.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class Args:
6
+ def __init__(self, **kwargs):
7
+ for key, value in kwargs.items():
8
+ setattr(self, key, value)
9
+
10
+ model = Args(
11
+ channels = 4,
12
+ block_grad_to_lowres = False,
13
+ norm_type = "TDRMSN",
14
+ use_t2i = True,
15
+ clip_dim=4096,
16
+ num_clip_token=77,
17
+ gradient_checking=True,
18
+ cfg_indicator=0.1,
19
+ textVAE = Args(
20
+ num_blocks = 11,
21
+ hidden_dim = 1024,
22
+ hidden_token_length = 256,
23
+ num_attention_heads = 8,
24
+ dropout_prob = 0.1,
25
+ ),
26
+ stage_configs = [
27
+ Args(
28
+ block_type = "TransformerBlock",
29
+ dim = 1024, # channel
30
+ hidden_dim = 2048,
31
+ num_attention_heads = 16,
32
+ num_blocks = 65, # depth
33
+ max_height = 16,
34
+ max_width = 16,
35
+ image_input_ratio = 1,
36
+ input_feature_ratio = 2,
37
+ final_kernel_size = 3,
38
+ dropout_prob = 0,
39
+ ),
40
+ Args(
41
+ block_type = "ConvNeXtBlock",
42
+ dim = 512,
43
+ hidden_dim = 1024,
44
+ kernel_size = 7,
45
+ num_blocks = 33,
46
+ max_height = 32,
47
+ max_width = 32,
48
+ image_input_ratio = 1,
49
+ input_feature_ratio = 1,
50
+ final_kernel_size = 3,
51
+ dropout_prob = 0,
52
+ ),
53
+ ],
54
+ )
55
+
56
+ def d(**kwargs):
57
+ """Helper of creating a config dict."""
58
+ return ml_collections.ConfigDict(initial_dictionary=kwargs)
59
+
60
+
61
+ def get_config():
62
+ config = ml_collections.ConfigDict()
63
+
64
+ config.seed = 1234
65
+ config.z_shape = (4, 32, 32)
66
+
67
+ config.autoencoder = d(
68
+ pretrained_path='assets/stable-diffusion/autoencoder_kl.pth',
69
+ scale_factor=0.23010
70
+ )
71
+
72
+ config.train = d(
73
+ n_steps=1000000,
74
+ batch_size=1024,
75
+ mode='cond',
76
+ log_interval=10,
77
+ eval_interval=5000,
78
+ save_interval=50000,
79
+ )
80
+
81
+ config.optimizer = d(
82
+ name='adamw',
83
+ lr=0.00005,
84
+ weight_decay=0.03,
85
+ betas=(0.9, 0.9),
86
+ )
87
+
88
+ config.lr_scheduler = d(
89
+ name='customized',
90
+ warmup_steps=5000
91
+ )
92
+
93
+ global model
94
+ config.nnet = d(
95
+ name='dimr',
96
+ model_args=model,
97
+ )
98
+ config.loss_coeffs = [1/4, 1]
99
+
100
+ config.dataset = d(
101
+ name='JDB_demo_features',
102
+ resolution=256,
103
+ llm='t5',
104
+ train_path='/data/qihao/dataset/JDB_demo_feature/',
105
+ val_path='/data/qihao/dataset/coco_val_features/',
106
+ cfg=False
107
+ )
108
+
109
+ config.sample = d(
110
+ sample_steps=50,
111
+ n_samples=30000,
112
+ mini_batch_size=20,
113
+ cfg=False,
114
+ scale=7,
115
+ path=''
116
+ )
117
+
118
+ return config
configs/t2i_512px_clip_dimr.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class Args:
6
+ def __init__(self, **kwargs):
7
+ for key, value in kwargs.items():
8
+ setattr(self, key, value)
9
+
10
+ model = Args(
11
+ channels = 4,
12
+ block_grad_to_lowres = False,
13
+ norm_type = "TDRMSN",
14
+ use_t2i = True,
15
+ clip_dim=768,
16
+ num_clip_token=77,
17
+ gradient_checking=True,
18
+ cfg_indicator=0.15,
19
+ textVAE = Args(
20
+ num_blocks = 11,
21
+ hidden_dim = 1024,
22
+ hidden_token_length = 256,
23
+ num_attention_heads = 8,
24
+ dropout_prob = 0.1,
25
+ ),
26
+ stage_configs = [
27
+ Args(
28
+ block_type = "TransformerBlock",
29
+ dim = 1024, # channel
30
+ hidden_dim = 2048,
31
+ num_attention_heads = 16,
32
+ num_blocks = 65, # depth
33
+ max_height = 16,
34
+ max_width = 16,
35
+ image_input_ratio = 1,
36
+ input_feature_ratio = 4,
37
+ final_kernel_size = 3,
38
+ dropout_prob = 0,
39
+ ),
40
+ Args(
41
+ block_type = "ConvNeXtBlock",
42
+ dim = 512,
43
+ hidden_dim = 1024,
44
+ kernel_size = 7,
45
+ num_blocks = 33,
46
+ max_height = 32,
47
+ max_width = 32,
48
+ image_input_ratio = 1,
49
+ input_feature_ratio = 2,
50
+ final_kernel_size = 3,
51
+ dropout_prob = 0,
52
+ ),
53
+ Args(
54
+ block_type = "ConvNeXtBlock",
55
+ dim = 256,
56
+ hidden_dim = 512,
57
+ kernel_size = 7,
58
+ num_blocks = 33,
59
+ max_height = 64,
60
+ max_width = 64,
61
+ image_input_ratio = 1,
62
+ input_feature_ratio = 1,
63
+ final_kernel_size = 3,
64
+ dropout_prob = 0,
65
+ ),
66
+ ],
67
+ )
68
+
69
+ def d(**kwargs):
70
+ """Helper of creating a config dict."""
71
+ return ml_collections.ConfigDict(initial_dictionary=kwargs)
72
+
73
+
74
+ def get_config():
75
+ config = ml_collections.ConfigDict()
76
+
77
+ config.seed = 1234
78
+ config.z_shape = (4, 64, 64)
79
+
80
+ config.autoencoder = d(
81
+ pretrained_path='assets/stable-diffusion/autoencoder_kl.pth',
82
+ scale_factor=0.23010
83
+ )
84
+
85
+ config.train = d(
86
+ n_steps=1000000,
87
+ batch_size=1024,
88
+ mode='cond',
89
+ log_interval=10,
90
+ eval_interval=5000,
91
+ save_interval=50000,
92
+ )
93
+
94
+ config.optimizer = d(
95
+ name='adamw',
96
+ lr=0.00001,
97
+ weight_decay=0.03,
98
+ betas=(0.9, 0.9),
99
+ )
100
+
101
+ config.lr_scheduler = d(
102
+ name='customized',
103
+ warmup_steps=5000
104
+ )
105
+
106
+ global model
107
+ config.nnet = d(
108
+ name='dimr',
109
+ model_args=model,
110
+ )
111
+ config.loss_coeffs = [1/4, 1/2, 1]
112
+
113
+ config.dataset = d(
114
+ name='JDB_demo_features',
115
+ resolution=512,
116
+ llm='clip',
117
+ train_path='/data/qihao/dataset/JDB_demo_feature/',
118
+ val_path='/data/qihao/dataset/coco_val_features/',
119
+ cfg=False
120
+ )
121
+
122
+ config.sample = d(
123
+ sample_steps=50,
124
+ n_samples=30000,
125
+ mini_batch_size=10,
126
+ cfg=False,
127
+ scale=7,
128
+ path=''
129
+ )
130
+
131
+ return config
configs/t2i_512px_t5_dimr.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class Args:
6
+ def __init__(self, **kwargs):
7
+ for key, value in kwargs.items():
8
+ setattr(self, key, value)
9
+
10
+ model = Args(
11
+ channels = 4,
12
+ block_grad_to_lowres = False,
13
+ norm_type = "TDRMSN",
14
+ use_t2i = True,
15
+ clip_dim=4096,
16
+ num_clip_token=77,
17
+ gradient_checking=True,
18
+ cfg_indicator=0.15,
19
+ textVAE = Args(
20
+ num_blocks = 11,
21
+ hidden_dim = 1024,
22
+ hidden_token_length = 256,
23
+ num_attention_heads = 8,
24
+ dropout_prob = 0.1,
25
+ ),
26
+ stage_configs = [
27
+ Args(
28
+ block_type = "TransformerBlock",
29
+ dim = 1024, # channel
30
+ hidden_dim = 2048,
31
+ num_attention_heads = 16,
32
+ num_blocks = 65, # depth
33
+ max_height = 16,
34
+ max_width = 16,
35
+ image_input_ratio = 1,
36
+ input_feature_ratio = 4,
37
+ final_kernel_size = 3,
38
+ dropout_prob = 0,
39
+ ),
40
+ Args(
41
+ block_type = "ConvNeXtBlock",
42
+ dim = 512,
43
+ hidden_dim = 1024,
44
+ kernel_size = 7,
45
+ num_blocks = 33,
46
+ max_height = 32,
47
+ max_width = 32,
48
+ image_input_ratio = 1,
49
+ input_feature_ratio = 2,
50
+ final_kernel_size = 3,
51
+ dropout_prob = 0,
52
+ ),
53
+ Args(
54
+ block_type = "ConvNeXtBlock",
55
+ dim = 256,
56
+ hidden_dim = 512,
57
+ kernel_size = 7,
58
+ num_blocks = 33,
59
+ max_height = 64,
60
+ max_width = 64,
61
+ image_input_ratio = 1,
62
+ input_feature_ratio = 1,
63
+ final_kernel_size = 3,
64
+ dropout_prob = 0,
65
+ ),
66
+ ],
67
+ )
68
+
69
+ def d(**kwargs):
70
+ """Helper of creating a config dict."""
71
+ return ml_collections.ConfigDict(initial_dictionary=kwargs)
72
+
73
+
74
+ def get_config():
75
+ config = ml_collections.ConfigDict()
76
+
77
+ config.seed = 1234
78
+ config.z_shape = (4, 64, 64)
79
+
80
+ config.autoencoder = d(
81
+ pretrained_path='assets/stable-diffusion/autoencoder_kl.pth',
82
+ scale_factor=0.23010
83
+ )
84
+
85
+ config.train = d(
86
+ n_steps=1000000,
87
+ batch_size=1024,
88
+ mode='cond',
89
+ log_interval=10,
90
+ eval_interval=5000,
91
+ save_interval=50000,
92
+ )
93
+
94
+ config.optimizer = d(
95
+ name='adamw',
96
+ lr=0.00001,
97
+ weight_decay=0.03,
98
+ betas=(0.9, 0.9),
99
+ )
100
+
101
+ config.lr_scheduler = d(
102
+ name='customized',
103
+ warmup_steps=5000
104
+ )
105
+
106
+ global model
107
+ config.nnet = d(
108
+ name='dimr',
109
+ model_args=model,
110
+ )
111
+ config.loss_coeffs = [1/4, 1/2, 1]
112
+
113
+ config.dataset = d(
114
+ name='JDB_demo_features',
115
+ resolution=512,
116
+ llm='t5',
117
+ train_path='/data/qihao/dataset/JDB_demo_feature/',
118
+ val_path='/data/qihao/dataset/coco_val_features/',
119
+ cfg=False
120
+ )
121
+
122
+ config.sample = d(
123
+ sample_steps=50,
124
+ n_samples=30000,
125
+ mini_batch_size=10,
126
+ cfg=False,
127
+ scale=7,
128
+ path=''
129
+ )
130
+
131
+ return config
configs/t2i_512px_t5_dit.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class Args:
6
+ def __init__(self, **kwargs):
7
+ for key, value in kwargs.items():
8
+ setattr(self, key, value)
9
+
10
+ model = Args(
11
+ latent_size = 64,
12
+ learn_sigma = False, # different from DiT, we direct predict noise here
13
+ channels = 4,
14
+ block_grad_to_lowres = False,
15
+ norm_type = "TDRMSN",
16
+ use_t2i = True,
17
+ clip_dim=4096,
18
+ num_clip_token=77,
19
+ gradient_checking=True, # for larger model
20
+ cfg_indicator=0.10,
21
+ textVAE = Args(
22
+ num_blocks = 11,
23
+ hidden_dim = 1024,
24
+ hidden_token_length = 256,
25
+ num_attention_heads = 8,
26
+ dropout_prob = 0.1,
27
+ ),
28
+ )
29
+
30
+ def d(**kwargs):
31
+ """Helper of creating a config dict."""
32
+ return ml_collections.ConfigDict(initial_dictionary=kwargs)
33
+
34
+
35
+ def get_config():
36
+ config = ml_collections.ConfigDict()
37
+
38
+ config.seed = 1234
39
+ config.z_shape = (4, 64, 64)
40
+
41
+ config.autoencoder = d(
42
+ pretrained_path='assets/stable-diffusion/autoencoder_kl.pth',
43
+ scale_factor=0.23010
44
+ )
45
+
46
+ config.train = d(
47
+ n_steps=1000000,
48
+ batch_size=1024,
49
+ mode='cond',
50
+ log_interval=10,
51
+ eval_interval=5000,
52
+ save_interval=50000,
53
+ )
54
+
55
+ config.optimizer = d(
56
+ name='adamw',
57
+ lr=0.00002,
58
+ weight_decay=0.03,
59
+ betas=(0.9, 0.9),
60
+ )
61
+
62
+ config.lr_scheduler = d(
63
+ name='customized',
64
+ warmup_steps=5000
65
+ )
66
+
67
+ global model
68
+ config.nnet = d(
69
+ name='dit',
70
+ model_args=model,
71
+ )
72
+ config.loss_coeffs = []
73
+
74
+ config.dataset = d(
75
+ name='JDB_demo_features',
76
+ resolution=512,
77
+ llm='t5',
78
+ train_path='/data/qihao/dataset/JDB_demo_feature/',
79
+ val_path='/data/qihao/dataset/coco_val_features/',
80
+ cfg=False
81
+ )
82
+
83
+ config.sample = d(
84
+ sample_steps=50,
85
+ n_samples=30000,
86
+ mini_batch_size=10,
87
+ cfg=False,
88
+ scale=7,
89
+ path=''
90
+ )
91
+
92
+ return config
configs/t2i_training_demo.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class Args:
6
+ def __init__(self, **kwargs):
7
+ for key, value in kwargs.items():
8
+ setattr(self, key, value)
9
+
10
+ model = Args(
11
+ channels = 4,
12
+ block_grad_to_lowres = False,
13
+ norm_type = "TDRMSN",
14
+ use_t2i = True,
15
+ clip_dim=768, # 768 for CLIP, 4096 for T5-XXL
16
+ num_clip_token=77,
17
+ gradient_checking=True,
18
+ cfg_indicator=0.1,
19
+ textVAE = Args(
20
+ num_blocks = 11,
21
+ hidden_dim = 1024,
22
+ hidden_token_length = 256,
23
+ num_attention_heads = 8,
24
+ dropout_prob = 0.1,
25
+ ),
26
+ stage_configs = [ # this is just an example
27
+ Args(
28
+ block_type = "TransformerBlock",
29
+ dim = 960,
30
+ hidden_dim = 1920,
31
+ num_attention_heads = 16,
32
+ num_blocks = 29,
33
+ max_height = 16,
34
+ max_width = 16,
35
+ image_input_ratio = 1,
36
+ input_feature_ratio = 4,
37
+ final_kernel_size = 3,
38
+ dropout_prob = 0,
39
+ ),
40
+ Args(
41
+ block_type = "ConvNeXtBlock",
42
+ dim = 480,
43
+ hidden_dim = 960,
44
+ kernel_size = 7,
45
+ num_blocks = 15,
46
+ max_height = 32,
47
+ max_width = 32,
48
+ image_input_ratio = 1,
49
+ input_feature_ratio = 2,
50
+ final_kernel_size = 3,
51
+ dropout_prob = 0,
52
+ ),
53
+ Args(
54
+ block_type = "ConvNeXtBlock",
55
+ dim = 240,
56
+ hidden_dim = 480,
57
+ kernel_size = 7,
58
+ num_blocks = 15,
59
+ max_height = 64,
60
+ max_width = 64,
61
+ image_input_ratio = 1,
62
+ input_feature_ratio = 1,
63
+ final_kernel_size = 3,
64
+ dropout_prob = 0,
65
+ ),
66
+ ],
67
+ )
68
+
69
+ def d(**kwargs):
70
+ """Helper of creating a config dict."""
71
+ return ml_collections.ConfigDict(initial_dictionary=kwargs)
72
+
73
+
74
+ def get_config():
75
+ config = ml_collections.ConfigDict()
76
+
77
+ config.seed = 1234 # random seed
78
+ config.z_shape = (4, 64, 64) # image latent size
79
+
80
+ config.autoencoder = d(
81
+ pretrained_path='assets/stable-diffusion/autoencoder_kl.pth', # path of pretrained VAE CKPT from LDM
82
+ scale_factor=0.23010
83
+ )
84
+
85
+ config.train = d(
86
+ n_steps=1000000, # total training iterations
87
+ batch_size=4, # overall batch size across ALL gpus, where batch_size_per_gpu == batch_size / number_of_gpus
88
+ mode='cond',
89
+ log_interval=10,
90
+ eval_interval=10, # iteration interval for visual testing on the specified prompt
91
+ save_interval=100, # iteration interval for saving checkpoints and testing FID
92
+ n_samples_eval=5, # number of samples duing visual testing. This depends on your GPU memory and can be any integer between 1 and 15 (as we provide only 15 prompts).
93
+ )
94
+
95
+ config.optimizer = d(
96
+ name='adamw',
97
+ lr=0.00001, # learning rate
98
+ weight_decay=0.03,
99
+ betas=(0.9, 0.9),
100
+ )
101
+
102
+ config.lr_scheduler = d(
103
+ name='customized',
104
+ warmup_steps=5000 # warmup steps
105
+ )
106
+
107
+ global model
108
+ config.nnet = d(
109
+ name='dimr',
110
+ model_args=model,
111
+ )
112
+ config.loss_coeffs = [1/4, 1/2, 1] # weight on loss, only needed for DiMR. Here, loss = 1/4 * loss_block1 + 1/2 * loss_block2 + 1 * loss_block3
113
+
114
+ config.dataset = d(
115
+ name='JDB_demo_features', # dataset name
116
+ resolution=512, # dataset resolution
117
+ llm='clip', # language model to generate language embedding
118
+ train_path='/data/qihao/dataset/JDB_demo_feature/', # training set path
119
+ val_path='/data/qihao/dataset/coco_val_features/', # val set path
120
+ cfg=False
121
+ )
122
+
123
+ config.sample = d(
124
+ sample_steps=50, # sample steps duing inference/testing
125
+ n_samples=30000, # number of samples for testing (during training, we sample 10K images, which is hardcoded in the training script)
126
+ mini_batch_size=10, # batch size for testing (i.e., the number of images generated per GPU)
127
+ cfg=False,
128
+ scale=7, # cfg scale
129
+ path=''
130
+ )
131
+
132
+ return config
datasets.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from torchvision import datasets
3
+ import torchvision.transforms as transforms
4
+ from scipy.signal import convolve2d
5
+ import numpy as np
6
+ import torch
7
+ import math
8
+ import random
9
+ from PIL import Image
10
+ import os
11
+ import glob
12
+ import einops
13
+ import torchvision.transforms.functional as F
14
+ import time
15
+ from tqdm import tqdm
16
+ import json
17
+ import pickle
18
+ import io
19
+ import cv2
20
+
21
+ import libs.clip
22
+ import bisect
23
+
24
+
25
+ class UnlabeledDataset(Dataset):
26
+ def __init__(self, dataset):
27
+ self.dataset = dataset
28
+
29
+ def __len__(self):
30
+ return len(self.dataset)
31
+
32
+ def __getitem__(self, item):
33
+ data = tuple(self.dataset[item][:-1]) # remove label
34
+ if len(data) == 1:
35
+ data = data[0]
36
+ return data
37
+
38
+
39
+ class LabeledDataset(Dataset):
40
+ def __init__(self, dataset, labels):
41
+ self.dataset = dataset
42
+ self.labels = labels
43
+
44
+ def __len__(self):
45
+ return len(self.dataset)
46
+
47
+ def __getitem__(self, item):
48
+ return self.dataset[item], self.labels[item]
49
+
50
+
51
+ class DatasetFactory(object):
52
+
53
+ def __init__(self):
54
+ self.train = None
55
+ self.test = None
56
+
57
+ def get_split(self, split, labeled=False):
58
+ if split == "train":
59
+ dataset = self.train
60
+ elif split == "test":
61
+ dataset = self.test
62
+ else:
63
+ raise ValueError
64
+
65
+ if self.has_label:
66
+ return dataset if labeled else UnlabeledDataset(dataset)
67
+ else:
68
+ assert not labeled
69
+ return dataset
70
+
71
+ def unpreprocess(self, v): # to B C H W and [0, 1]
72
+ v = 0.5 * (v + 1.)
73
+ v.clamp_(0., 1.)
74
+ return v
75
+
76
+ @property
77
+ def has_label(self):
78
+ return True
79
+
80
+ @property
81
+ def data_shape(self):
82
+ raise NotImplementedError
83
+
84
+ @property
85
+ def data_dim(self):
86
+ return int(np.prod(self.data_shape))
87
+
88
+ @property
89
+ def fid_stat(self):
90
+ return None
91
+
92
+ def sample_label(self, n_samples, device):
93
+ raise NotImplementedError
94
+
95
+ def label_prob(self, k):
96
+ raise NotImplementedError
97
+
98
+
99
+ def center_crop_arr(pil_image, image_size):
100
+ # We are not on a new enough PIL to support the `reducing_gap`
101
+ # argument, which uses BOX downsampling at powers of two first.
102
+ # Thus, we do it by hand to improve downsample quality.
103
+ while min(*pil_image.size) >= 2 * image_size:
104
+ pil_image = pil_image.resize(
105
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
106
+ )
107
+
108
+ scale = image_size / min(*pil_image.size)
109
+ pil_image = pil_image.resize(
110
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
111
+ )
112
+
113
+ arr = np.array(pil_image)
114
+ crop_y = (arr.shape[0] - image_size) // 2
115
+ crop_x = (arr.shape[1] - image_size) // 2
116
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
117
+
118
+
119
+ # MS COCO
120
+
121
+
122
+ def center_crop(width, height, img):
123
+ resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos']
124
+ crop = np.min(img.shape[:2])
125
+ img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
126
+ (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2]
127
+ try:
128
+ img = Image.fromarray(img, 'RGB')
129
+ except:
130
+ img = Image.fromarray(img)
131
+ img = img.resize((width, height), resample)
132
+
133
+ return np.array(img).astype(np.uint8)
134
+
135
+
136
+ class MSCOCODatabase(Dataset):
137
+ def __init__(self, root, annFile, size=None):
138
+ from pycocotools.coco import COCO
139
+ self.root = root
140
+ self.height = self.width = size
141
+
142
+ self.coco = COCO(annFile)
143
+ self.keys = list(sorted(self.coco.imgs.keys()))
144
+
145
+ def _load_image(self, key: int):
146
+ path = self.coco.loadImgs(key)[0]["file_name"]
147
+ return Image.open(os.path.join(self.root, path)).convert("RGB")
148
+
149
+ def _load_target(self, key: int):
150
+ return self.coco.loadAnns(self.coco.getAnnIds(key))
151
+
152
+ def __len__(self):
153
+ return len(self.keys)
154
+
155
+ def __getitem__(self, index):
156
+ key = self.keys[index]
157
+ image = self._load_image(key)
158
+ image = np.array(image).astype(np.uint8)
159
+ image = center_crop(self.width, self.height, image).astype(np.float32)
160
+ image = (image / 127.5 - 1.0).astype(np.float32)
161
+ image = einops.rearrange(image, 'h w c -> c h w')
162
+
163
+ anns = self._load_target(key)
164
+ target = []
165
+ for ann in anns:
166
+ target.append(ann['caption'])
167
+
168
+ return image, target
169
+
170
+
171
+ def get_feature_dir_info(root):
172
+ files = glob.glob(os.path.join(root, '*.npy'))
173
+ files_caption = glob.glob(os.path.join(root, '*_*.npy'))
174
+ num_data = len(files) - len(files_caption)
175
+ n_captions = {k: 0 for k in range(num_data)}
176
+ for f in files_caption:
177
+ name = os.path.split(f)[-1]
178
+ k1, k2 = os.path.splitext(name)[0].split('_')
179
+ n_captions[int(k1)] += 1
180
+ return num_data, n_captions
181
+
182
+
183
+ class MSCOCOFeatureDataset(Dataset):
184
+ # the image features are got through sample
185
+ def __init__(self, root, need_squeeze=False, full_feature=False, fix_test_order=False):
186
+ self.root = root
187
+ self.num_data, self.n_captions = get_feature_dir_info(root)
188
+ self.need_squeeze = need_squeeze
189
+ self.full_feature = full_feature
190
+ self.fix_test_order = fix_test_order
191
+
192
+ def __len__(self):
193
+ return self.num_data
194
+
195
+ def __getitem__(self, index):
196
+ if self.full_feature:
197
+ z = np.load(os.path.join(self.root, f'{index}.npy'))
198
+
199
+ if self.fix_test_order:
200
+ k = self.n_captions[index] - 1
201
+ else:
202
+ k = random.randint(0, self.n_captions[index] - 1)
203
+
204
+ test_item = np.load(os.path.join(self.root, f'{index}_{k}.npy'), allow_pickle=True).item()
205
+ token_embedding = test_item['token_embedding']
206
+ token_mask = test_item['token_mask']
207
+ token = test_item['token']
208
+ caption = test_item['promt']
209
+ return z, token_embedding, token_mask, token, caption
210
+ else:
211
+ z = np.load(os.path.join(self.root, f'{index}.npy'))
212
+ k = random.randint(0, self.n_captions[index] - 1)
213
+ c = np.load(os.path.join(self.root, f'{index}_{k}.npy'))
214
+ if self.need_squeeze:
215
+ return z, c.squeeze()
216
+ else:
217
+ return z, c
218
+
219
+
220
+ class JDBFeatureDataset(Dataset):
221
+ def __init__(self, root, resolution, llm):
222
+ super().__init__()
223
+ json_path = os.path.join(root,'img_text_pair.jsonl')
224
+ self.img_root = os.path.join(root,'imgs')
225
+ self.feature_root = os.path.join(root,'features')
226
+ self.resolution = resolution
227
+ self.llm = llm
228
+ self.file_list = []
229
+ with open(json_path, 'r', encoding='utf-8') as file:
230
+ for line in file:
231
+ self.file_list.append(json.loads(line)['img_path'])
232
+
233
+ def __len__(self):
234
+ return len(self.file_list)
235
+
236
+ def __getitem__(self, idx):
237
+ data_item = self.file_list[idx]
238
+ feature_path = os.path.join(self.feature_root, data_item.split('/')[-1].replace('.jpg','.npy'))
239
+ img_path = os.path.join(self.img_root, data_item)
240
+
241
+ train_item = np.load(feature_path, allow_pickle=True).item()
242
+ pil_image = Image.open(img_path)
243
+ pil_image.load()
244
+ pil_image = pil_image.convert("RGB")
245
+
246
+
247
+ z = train_item[f'image_latent_{self.resolution}']
248
+ token_embedding = train_item[f'token_embedding_{self.llm}']
249
+ token_mask = train_item[f'token_mask_{self.llm}']
250
+ token = train_item[f'token_{self.llm}']
251
+ caption = train_item['batch_caption']
252
+
253
+ img = center_crop_arr(pil_image, image_size=self.resolution)
254
+ img = (img / 127.5 - 1.0).astype(np.float32)
255
+ img = einops.rearrange(img, 'h w c -> c h w')
256
+
257
+ # return z, token_embedding, token_mask, token, caption, 0, img, 0, 0
258
+ return z, token_embedding, token_mask, token, caption, img
259
+
260
+
261
+ class JDBFullFeatures(DatasetFactory): # the moments calculated by Stable Diffusion image encoder & the contexts calculated by clip
262
+ def __init__(self, train_path, val_path, resolution, llm, cfg=False, p_uncond=None, fix_test_order=False):
263
+ super().__init__()
264
+ print('Prepare dataset...')
265
+ self.resolution = resolution
266
+
267
+ self.train = JDBFeatureDataset(train_path, resolution=resolution, llm=llm)
268
+ self.test = MSCOCOFeatureDataset(os.path.join(val_path, 'val'), full_feature=True, fix_test_order=fix_test_order)
269
+ assert len(self.test) == 40504
270
+
271
+ print('Prepare dataset ok')
272
+
273
+ self.empty_context = np.load(os.path.join(val_path, 'empty_context.npy'), allow_pickle=True).item()
274
+
275
+ assert not cfg
276
+
277
+ # text embedding extracted by clip
278
+ self.prompts, self.token_embedding, self.token_mask, self.token = [], [], [], []
279
+ for f in sorted(os.listdir(os.path.join(val_path, 'run_vis')), key=lambda x: int(x.split('.')[0])):
280
+ vis_item = np.load(os.path.join(val_path, 'run_vis', f), allow_pickle=True).item()
281
+ self.prompts.append(vis_item['promt'])
282
+ self.token_embedding.append(vis_item['token_embedding'])
283
+ self.token_mask.append(vis_item['token_mask'])
284
+ self.token.append(vis_item['token'])
285
+ self.token_embedding = np.array(self.token_embedding)
286
+ self.token_mask = np.array(self.token_mask)
287
+ self.token = np.array(self.token)
288
+
289
+ @property
290
+ def data_shape(self):
291
+ if self.resolution==512:
292
+ return 4, 64, 64
293
+ else:
294
+ return 4, 32, 32
295
+
296
+ @property
297
+ def fid_stat(self):
298
+ return f'assets/fid_stats/fid_stats_mscoco256_val.npz'
299
+
300
+
301
+ def get_dataset(name, **kwargs):
302
+ if name == 'JDB_demo_features':
303
+ return JDBFullFeatures(**kwargs)
304
+ else:
305
+ raise NotImplementedError(name)
demo_t2i.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is used for T2I generation, it also compute the clip similarity between the generated images and the input prompt
3
+ """
4
+ from absl import flags
5
+ from absl import app
6
+ from ml_collections import config_flags
7
+ import os
8
+
9
+ import ml_collections
10
+ import torch
11
+ from torch import multiprocessing as mp
12
+ import torch.nn as nn
13
+ import accelerate
14
+ import utils
15
+ import tempfile
16
+ from absl import logging
17
+ import builtins
18
+ import einops
19
+ import math
20
+ import numpy as np
21
+ import time
22
+ from PIL import Image
23
+
24
+ from diffusion.flow_matching import FlowMatching, ODEFlowMatchingSolver, ODEEulerFlowMatchingSolver
25
+ from tools.clip_score import ClipSocre
26
+ import libs.autoencoder
27
+ from libs.clip import FrozenCLIPEmbedder
28
+ from libs.t5 import T5Embedder
29
+
30
+
31
+ def unpreprocess(x):
32
+ x = 0.5 * (x + 1.)
33
+ x.clamp_(0., 1.)
34
+ return x
35
+
36
+ def get_caption(llm, text_model, _batch_prompt):
37
+ _batch_con = _batch_prompt
38
+ if llm == "clip":
39
+ _latent, _latent_and_others = text_model.encode(_batch_con)
40
+ _con = _latent_and_others['token_embedding'].detach()
41
+ elif llm == "t5":
42
+ _latent, _latent_and_others = text_model.get_text_embeddings(_batch_con)
43
+ _con = (_latent_and_others['token_embedding'] * 10.0).detach()
44
+ else:
45
+ raise NotImplementedError
46
+ _con_mask = _latent_and_others['token_mask'].detach()
47
+ _batch_token = _latent_and_others['tokens'].detach()
48
+ _batch_caption = _batch_con
49
+ return (_con, _con_mask, _batch_token, _batch_caption)
50
+
51
+
52
+ def evaluate(config):
53
+
54
+ if config.get('benchmark', False):
55
+ torch.backends.cudnn.benchmark = True
56
+ torch.backends.cudnn.deterministic = False
57
+
58
+ mp.set_start_method('spawn')
59
+ accelerator = accelerate.Accelerator()
60
+ device = accelerator.device
61
+ accelerate.utils.set_seed(config.seed, device_specific=True)
62
+ logging.info(f'Process {accelerator.process_index} using device: {device}')
63
+
64
+ config.mixed_precision = accelerator.mixed_precision
65
+ config = ml_collections.FrozenConfigDict(config)
66
+ if accelerator.is_main_process:
67
+ utils.set_logger(log_level='info', fname=config.output_path)
68
+ else:
69
+ utils.set_logger(log_level='error')
70
+ builtins.print = lambda *args: None
71
+
72
+ nnet = utils.get_nnet(**config.nnet)
73
+ nnet = accelerator.prepare(nnet)
74
+ logging.info(f'load nnet from {config.nnet_path}')
75
+ accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
76
+ nnet.eval()
77
+
78
+ ##
79
+
80
+ if config.nnet.model_args.clip_dim == 4096:
81
+ llm = "t5"
82
+ t5 = T5Embedder(device=device)
83
+ elif config.nnet.model_args.clip_dim == 768:
84
+ llm = "clip"
85
+ clip = FrozenCLIPEmbedder()
86
+ clip.eval()
87
+ clip.to(device)
88
+ else:
89
+ raise NotImplementedError
90
+
91
+ if llm == "clip":
92
+ context_generator = get_caption(llm, clip, _batch_prompt=[config.prompt]*config.sample.mini_batch_size)
93
+ elif llm == "t5":
94
+ context_generator = get_caption(llm, t5, _batch_prompt=[config.prompt]*config.sample.mini_batch_size)
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ ##
99
+
100
+ autoencoder = libs.autoencoder.get_model(**config.autoencoder)
101
+ autoencoder.to(device)
102
+
103
+ @torch.cuda.amp.autocast()
104
+ def encode(_batch):
105
+ return autoencoder.encode(_batch)
106
+
107
+ @torch.cuda.amp.autocast()
108
+ def decode(_batch):
109
+ return autoencoder.decode(_batch)
110
+
111
+ bdv_nnet = None # We don't use Autoguidance
112
+ ClipSocre_model = ClipSocre(device=device) # we also return clip score
113
+
114
+ #######
115
+ logging.info(config.sample)
116
+ logging.info(f'sample: n_samples={config.sample.n_samples}, mode=t2i, mixed_precision={config.mixed_precision}')
117
+
118
+
119
+ def ode_fm_solver_sample(nnet_ema, _n_samples, _sample_steps, bdv_nnet=bdv_nnet, context=None, caption=None, testbatch_img_blurred=None, two_stage_generation=-1, token=None, token_mask=None, return_clipScore=False, ClipSocre_model=None):
120
+ with torch.no_grad():
121
+ del testbatch_img_blurred
122
+
123
+ _z_gaussian = torch.randn(_n_samples, *config.z_shape, device=device)
124
+
125
+ if 'dimr' in config.nnet.name or 'dit' in config.nnet.name:
126
+ _z_x0, _mu, _log_var = nnet_ema(context, text_encoder = True, shape = _z_gaussian.shape, mask=token_mask)
127
+ _z_init = _z_x0.reshape(_z_gaussian.shape)
128
+ else:
129
+ raise NotImplementedError
130
+
131
+ assert config.sample.scale > 1
132
+ if config.cfg != -1:
133
+ _cfg = config.cfg
134
+ else:
135
+ _cfg = config.sample.scale
136
+
137
+ has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator")
138
+
139
+ _sample_steps = config.sample.sample_steps
140
+
141
+ ode_solver = ODEEulerFlowMatchingSolver(nnet_ema, bdv_model_fn=bdv_nnet, step_size_type="step_in_dsigma", guidance_scale=_cfg)
142
+ _z, _ = ode_solver.sample(x_T=_z_init, batch_size=_n_samples, sample_steps=_sample_steps, unconditional_guidance_scale=_cfg, has_null_indicator=has_null_indicator)
143
+
144
+ image_unprocessed = decode(_z)
145
+ clip_score = ClipSocre_model.calculate_clip_score(caption, image_unprocessed)
146
+
147
+ return image_unprocessed, clip_score
148
+
149
+
150
+ def sample_fn(_n_samples, return_caption=False, return_clipScore=False, ClipSocre_model=None, config=None):
151
+ _context, _token_mask, _token, _caption = context_generator
152
+ assert _context.size(0) == _n_samples
153
+ assert return_clipScore
154
+ assert not return_caption
155
+ return ode_fm_solver_sample(nnet, _n_samples, config.sample.sample_steps, bdv_nnet=bdv_nnet, context=_context, token=_token, token_mask=_token_mask, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, caption=_caption)
156
+
157
+
158
+ with tempfile.TemporaryDirectory() as temp_path:
159
+ path = config.img_save_path or config.sample.path or temp_path
160
+ if accelerator.is_main_process:
161
+ os.makedirs(path, exist_ok=True)
162
+ logging.info(f'Samples are saved in {path}')
163
+
164
+ clip_score_list = utils.sample2dir_wCLIP(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, unpreprocess, return_clipScore=True, ClipSocre_model=ClipSocre_model, config=config)
165
+ if clip_score_list is not None:
166
+ _clip_score_list = torch.cat(clip_score_list)
167
+ if accelerator.is_main_process:
168
+ logging.info(f'nnet_path={config.nnet_path}, clip_score{len(_clip_score_list)}={_clip_score_list.mean().item()}')
169
+
170
+
171
+ FLAGS = flags.FLAGS
172
+ config_flags.DEFINE_config_file(
173
+ "config", None, "Training configuration.", lock_config=False)
174
+
175
+ flags.mark_flags_as_required(["config"])
176
+ flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
177
+ flags.DEFINE_string("prompt", None, "The prompt used for generation.")
178
+ flags.DEFINE_string("output_path", None, "The path to output log.")
179
+ flags.DEFINE_float("cfg", -1, 'cfg scale, will use the scale defined in the config file is not assigned')
180
+ flags.DEFINE_string("img_save_path", None, "The path to image log.")
181
+
182
+
183
+ def main(argv):
184
+ config = FLAGS.config
185
+ config.nnet_path = FLAGS.nnet_path
186
+ config.prompt = FLAGS.prompt
187
+ config.output_path = FLAGS.output_path
188
+ config.img_save_path = FLAGS.img_save_path
189
+ config.cfg = FLAGS.cfg
190
+ evaluate(config)
191
+
192
+
193
+ if __name__ == "__main__":
194
+ app.run(main)
demo_t2i_arith.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is used for T2I generation, it also compute the clip similarity between the generated images and the input prompt
3
+ """
4
+ from absl import flags
5
+ from absl import app
6
+ from ml_collections import config_flags
7
+ import os
8
+
9
+ import ml_collections
10
+ import torch
11
+ from torch import multiprocessing as mp
12
+ import torch.nn as nn
13
+ import accelerate
14
+ import utils
15
+ import tempfile
16
+ from absl import logging
17
+ import builtins
18
+ import einops
19
+ import math
20
+ import numpy as np
21
+ import time
22
+ from PIL import Image
23
+
24
+ from diffusion.flow_matching import FlowMatching, ODEFlowMatchingSolver, ODEEulerFlowMatchingSolver
25
+ from tools.clip_score import ClipSocre
26
+ import libs.autoencoder
27
+ from libs.clip import FrozenCLIPEmbedder
28
+ from libs.t5 import T5Embedder
29
+
30
+
31
+ def unpreprocess(x):
32
+ x = 0.5 * (x + 1.)
33
+ x.clamp_(0., 1.)
34
+ return x
35
+
36
+
37
+ def batch_decode(_z, decode, batch_size=10):
38
+ """
39
+ The VAE decoder requires large GPU memory. To run the interpolation model on GPUs with 24 GB or smaller RAM, you can use this code to reduce memory usage for the VAE.
40
+ It works by splitting the input tensor into smaller chunks.
41
+ """
42
+ num_samples = _z.size(0)
43
+ decoded_batches = []
44
+
45
+ for i in range(0, num_samples, batch_size):
46
+ batch = _z[i:i + batch_size]
47
+ decoded_batch = decode(batch)
48
+ decoded_batches.append(decoded_batch)
49
+
50
+ image_unprocessed = torch.cat(decoded_batches, dim=0)
51
+ return image_unprocessed
52
+
53
+ def get_caption(llm, text_model, prompt_dict, batch_size):
54
+
55
+ if batch_size == 3:
56
+ # only addition or only subtraction
57
+ assert len(prompt_dict) == 2
58
+ _batch_con = list(prompt_dict.values()) + [' ']
59
+ elif batch_size == 4:
60
+ # addition and subtraction
61
+ assert len(prompt_dict) == 3
62
+ _batch_con = list(prompt_dict.values()) + [' ']
63
+ elif batch_size >= 5:
64
+ # linear interpolation
65
+ assert len(prompt_dict) == 2
66
+ _batch_con = [prompt_dict['prompt_1']] + [' '] * (batch_size-2) + [prompt_dict['prompt_2']]
67
+
68
+ if llm == "clip":
69
+ _latent, _latent_and_others = text_model.encode(_batch_con)
70
+ _con = _latent_and_others['token_embedding'].detach()
71
+ elif llm == "t5":
72
+ _latent, _latent_and_others = text_model.get_text_embeddings(_batch_con)
73
+ _con = (_latent_and_others['token_embedding'] * 10.0).detach()
74
+ else:
75
+ raise NotImplementedError
76
+ _con_mask = _latent_and_others['token_mask'].detach()
77
+ _batch_token = _latent_and_others['tokens'].detach()
78
+ _batch_caption = _batch_con
79
+ return (_con, _con_mask, _batch_token, _batch_caption)
80
+
81
+
82
+ def evaluate(config):
83
+
84
+ if config.get('benchmark', False):
85
+ torch.backends.cudnn.benchmark = True
86
+ torch.backends.cudnn.deterministic = False
87
+
88
+ mp.set_start_method('spawn')
89
+ accelerator = accelerate.Accelerator()
90
+ device = accelerator.device
91
+ accelerate.utils.set_seed(config.seed, device_specific=True)
92
+ logging.info(f'Process {accelerator.process_index} using device: {device}')
93
+
94
+ config.mixed_precision = accelerator.mixed_precision
95
+ config = ml_collections.FrozenConfigDict(config)
96
+ if accelerator.is_main_process:
97
+ utils.set_logger(log_level='info', fname=config.output_path)
98
+ else:
99
+ utils.set_logger(log_level='error')
100
+ builtins.print = lambda *args: None
101
+
102
+ nnet = utils.get_nnet(**config.nnet)
103
+ nnet = accelerator.prepare(nnet)
104
+ logging.info(f'load nnet from {config.nnet_path}')
105
+ accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
106
+ nnet.eval()
107
+
108
+ ##
109
+
110
+ if config.nnet.model_args.clip_dim == 4096:
111
+ llm = "t5"
112
+ t5 = T5Embedder(device=device)
113
+ elif config.nnet.model_args.clip_dim == 768:
114
+ llm = "clip"
115
+ clip = FrozenCLIPEmbedder()
116
+ clip.eval()
117
+ clip.to(device)
118
+ else:
119
+ raise NotImplementedError
120
+
121
+
122
+ config = ml_collections.ConfigDict(config)
123
+
124
+ if config.test_type == 'interpolation':
125
+ prompt_dict = {'prompt_1':config.prompt_1, 'prompt_2':config.prompt_2}
126
+ for key in prompt_dict.keys():
127
+ assert prompt_dict[key] is not None
128
+ config.sample.mini_batch_size = config.num_of_interpolation
129
+ assert config.sample.mini_batch_size >= 5, "for linear interpolation, please sample at least five image"
130
+ elif config.test_type == 'arithmetic':
131
+ prompt_dict = {'prompt_ori':config.prompt_ori, 'prompt_a':config.prompt_a, 'prompt_s':config.prompt_s}
132
+ keys_to_remove = [key for key, value in prompt_dict.items() if value is None]
133
+ for key in keys_to_remove:
134
+ del prompt_dict[key]
135
+ counter = len(prompt_dict)
136
+ assert prompt_dict['prompt_ori'] is not None
137
+ assert counter == 2 or counter == 3
138
+ config.sample.mini_batch_size = counter + 1
139
+ else:
140
+ raise NotImplementedError
141
+
142
+ config = ml_collections.FrozenConfigDict(config)
143
+
144
+ if llm == "clip":
145
+ context_generator = get_caption(llm, clip, prompt_dict=prompt_dict, batch_size=config.sample.mini_batch_size)
146
+ elif llm == "t5":
147
+ context_generator = get_caption(llm, t5, prompt_dict=prompt_dict, batch_size=config.sample.mini_batch_size)
148
+ else:
149
+ raise NotImplementedError
150
+
151
+ ##
152
+
153
+ autoencoder = libs.autoencoder.get_model(**config.autoencoder)
154
+ autoencoder.to(device)
155
+
156
+ @torch.cuda.amp.autocast()
157
+ def encode(_batch):
158
+ return autoencoder.encode(_batch)
159
+
160
+ @torch.cuda.amp.autocast()
161
+ def decode(_batch):
162
+ return autoencoder.decode(_batch)
163
+
164
+ bdv_nnet = None # We don't use Autoguidance
165
+ ClipSocre_model = ClipSocre(device=device) # we also return clip score
166
+
167
+ #######
168
+ logging.info(config.sample)
169
+ logging.info(f'sample: n_samples={config.sample.n_samples}, mode=t2i, mixed_precision={config.mixed_precision}')
170
+
171
+
172
+ def ode_fm_solver_sample(nnet_ema, _n_samples, _sample_steps, bdv_nnet=bdv_nnet, context=None, caption=None, testbatch_img_blurred=None, two_stage_generation=-1, token=None, token_mask=None, return_clipScore=False, ClipSocre_model=None):
173
+ with torch.no_grad():
174
+ del testbatch_img_blurred
175
+
176
+ _z_gaussian = torch.randn(_n_samples, *config.z_shape, device=device)
177
+
178
+ if 'dimr' in config.nnet.name or 'dit' in config.nnet.name:
179
+ _z_x0, _mu, _log_var = nnet_ema(context, text_encoder = True, shape = _z_gaussian.shape, mask=token_mask)
180
+ _z_init = _z_x0.reshape(_z_gaussian.shape)
181
+ else:
182
+ raise NotImplementedError
183
+
184
+ if len(_z_init) == 3:
185
+ if config.prompt_a is not None:
186
+ assert config.prompt_s is None
187
+ _z_x0_temp = _z_x0[0] + _z_x0[1]
188
+ elif config.prompt_s is not None:
189
+ assert config.prompt_a is None
190
+ _z_x0_temp = _z_x0[0] - _z_x0[1]
191
+ else:
192
+ raise NotImplementedError
193
+ mean = _z_x0_temp.mean()
194
+ std = _z_x0_temp.std()
195
+ _z_x0[2] = (_z_x0_temp - mean) / std
196
+ elif len(_z_init) == 4:
197
+ _z_x0_temp = _z_x0[0] + _z_x0[1] - _z_x0[2]
198
+ mean = _z_x0_temp.mean()
199
+ std = _z_x0_temp.std()
200
+ _z_x0[3] = (_z_x0_temp - mean) / std
201
+ elif len(_z_init) >= 5:
202
+ tensor_a = _z_init[0]
203
+ tensor_b = _z_init[-1]
204
+ num_interpolations = len(_z_init) - 2
205
+ interpolations = [tensor_a + (tensor_b - tensor_a) * (i / (num_interpolations + 1)) for i in range(1, num_interpolations + 1)]
206
+ _z_init = torch.stack([tensor_a] + interpolations + [tensor_b], dim=0)
207
+
208
+ assert config.sample.scale > 1
209
+ if config.cfg != -1:
210
+ _cfg = config.cfg
211
+ else:
212
+ _cfg = config.sample.scale
213
+
214
+ has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator")
215
+
216
+ _sample_steps = config.sample.sample_steps
217
+
218
+ ode_solver = ODEEulerFlowMatchingSolver(nnet_ema, bdv_model_fn=bdv_nnet, step_size_type="step_in_dsigma", guidance_scale=_cfg)
219
+ _z, _ = ode_solver.sample(x_T=_z_init, batch_size=_n_samples, sample_steps=_sample_steps, unconditional_guidance_scale=_cfg, has_null_indicator=has_null_indicator)
220
+
221
+ if config.save_gpu_memory:
222
+ image_unprocessed = batch_decode(_z, decode)
223
+ else:
224
+ image_unprocessed = decode(_z)
225
+ clip_score = ClipSocre_model.calculate_clip_score(caption, image_unprocessed)
226
+
227
+ return image_unprocessed, clip_score
228
+
229
+
230
+ def sample_fn(_n_samples, return_caption=False, return_clipScore=False, ClipSocre_model=None, config=None):
231
+ _context, _token_mask, _token, _caption = context_generator
232
+ assert return_clipScore
233
+ assert not return_caption
234
+ return ode_fm_solver_sample(nnet, _n_samples, config.sample.sample_steps, bdv_nnet=bdv_nnet, context=_context, token=_token, token_mask=_token_mask, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, caption=_caption)
235
+
236
+
237
+ with tempfile.TemporaryDirectory() as temp_path:
238
+ path = config.img_save_path or config.sample.path or temp_path
239
+ if accelerator.is_main_process:
240
+ os.makedirs(path, exist_ok=True)
241
+ logging.info(f'Samples are saved in {path}')
242
+
243
+ clip_score_list = utils.sample2dir_wCLIP(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, unpreprocess, return_clipScore=True, ClipSocre_model=ClipSocre_model, config=config)
244
+ if clip_score_list is not None:
245
+ _clip_score_list = torch.cat(clip_score_list)
246
+ if accelerator.is_main_process:
247
+ logging.info(f'nnet_path={config.nnet_path}, clip_score{len(_clip_score_list)}={_clip_score_list.mean().item()}')
248
+
249
+
250
+ FLAGS = flags.FLAGS
251
+ config_flags.DEFINE_config_file(
252
+ "config", None, "Training configuration.", lock_config=False)
253
+
254
+ flags.mark_flags_as_required(["config"])
255
+ flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
256
+ flags.DEFINE_string("output_path", None, "The path to output log.")
257
+ flags.DEFINE_float("cfg", -1, 'cfg scale, will use the scale defined in the config file is not assigned')
258
+ flags.DEFINE_string("img_save_path", None, "The path to image log.")
259
+
260
+ flags.DEFINE_string("test_type", None, "The prompt used for generation.")
261
+
262
+ flags.DEFINE_string("prompt_1", None, "The prompt used for linear interpolation.")
263
+ flags.DEFINE_string("prompt_2", None, "The prompt used for linear interpolation.")
264
+ flags.DEFINE_integer("num_of_interpolation", -1, 'number of image being samples for linear interpolation')
265
+ flags.DEFINE_boolean('save_gpu_memory', False, 'To save VRAM')
266
+
267
+ flags.DEFINE_string("prompt_ori", None, "The prompt used for arithmetic operations.")
268
+ flags.DEFINE_string("prompt_a", None, "The prompt used for arithmetic operations (addition).")
269
+ flags.DEFINE_string("prompt_s", None, "The prompt used for arithmetic operations (subtraction).")
270
+
271
+
272
+ def main(argv):
273
+ config = FLAGS.config
274
+ config.nnet_path = FLAGS.nnet_path
275
+ config.output_path = FLAGS.output_path
276
+ config.img_save_path = FLAGS.img_save_path
277
+ config.cfg = FLAGS.cfg
278
+ config.test_type = FLAGS.test_type
279
+ config.prompt_1 = FLAGS.prompt_1
280
+ config.prompt_2 = FLAGS.prompt_2
281
+ config.num_of_interpolation = FLAGS.num_of_interpolation
282
+ config.save_gpu_memory = FLAGS.save_gpu_memory
283
+ config.prompt_ori = FLAGS.prompt_ori
284
+ config.prompt_a = FLAGS.prompt_a
285
+ config.prompt_s = FLAGS.prompt_s
286
+ evaluate(config)
287
+
288
+
289
+ if __name__ == "__main__":
290
+ app.run(main)
diffusion/base_solver.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains the solver base class, including the cfg indicator
3
+ """
4
+
5
+ import enum
6
+ import logging
7
+ from collections import defaultdict
8
+ from typing import Callable, Dict, List, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ import random
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ _default_cfg_processor = {"caption": lambda x, T, t: x}
19
+
20
+
21
+ class ConditionTypes(enum.Enum):
22
+ IMAGE_EMBED: str = "image_conditioning" # not implemented yet
23
+ TEXT_EMBED: str = "caption"
24
+ HINT_EMBED: str = "hint" # not implemented yet
25
+
26
+ class Solver:
27
+ def __init__(
28
+ self,
29
+ model_fn,
30
+ bdv_model_fn=None,
31
+ schedule="linear",
32
+ conditioning_types: List[str] = ["caption"],
33
+ guidance_scale: Union[float, Dict[ConditionTypes, float]] = 1.0,
34
+ cfg_processor: Callable = _default_cfg_processor,
35
+ **kwargs,
36
+ ):
37
+ self.model = model_fn
38
+ self.bdv_model = bdv_model_fn
39
+ self.schedule = schedule
40
+ # This list (conditioning_types) is important to decide which conditioning variable is given the priority
41
+ # For multi_cfg with 2 variables c,i, the cfg equation is
42
+ # output = e(null,null) + scale_c * (e(i,c) - e(i,null)) + scale_i * (e(i,null) - e(null,null))
43
+ # Note that the marginalization can be changed slightly to obtain a different equation
44
+ # output = e(null,null) + scale_i * (e(c,i) - e(c,null)) + scale_c * (e(c,null) - e(null,null))
45
+ # The order of the conditioning variables in the list decides which of the two equations above are used
46
+ # If the list is ["image", "caption"] then the first equation is used and
47
+ # if the list is ["caption", "image"] then the second is used
48
+ self.condition_types = [ConditionTypes(el) for el in conditioning_types]
49
+
50
+ self.unconditional_guidance_scale = guidance_scale
51
+ if isinstance(guidance_scale, dict):
52
+ self.unconditional_guidance_scale = {
53
+ ConditionTypes(k): v for k, v in guidance_scale.items()
54
+ }
55
+ else:
56
+ # If a single float is provided, we assume it is for text conditioning
57
+ self.unconditional_guidance_scale = {
58
+ ConditionTypes.TEXT_EMBED: guidance_scale
59
+ }
60
+ assert all(
61
+ [
62
+ el in self.unconditional_guidance_scale.keys()
63
+ for el in self.condition_types
64
+ ]
65
+ )
66
+ self.cfg_processor = cfg_processor
67
+ if self.cfg_processor is None:
68
+ self.cfg_processor = _default_cfg_processor
69
+ if isinstance(self.cfg_processor, dict):
70
+ assert all(callable(v) for k, v in self.cfg_processor.items())
71
+ self.cfg_processor = {
72
+ ConditionTypes(k): v for k, v in self.cfg_processor.items()
73
+ }
74
+ else:
75
+ assert callable(self.cfg_processor)
76
+ self.cfg_processor = {ConditionTypes.TEXT_EMBED: cfg_processor}
77
+
78
+ if self.cfg_processor is not None:
79
+ assert all([el in self.cfg_processor.keys() for el in self.condition_types])
80
+ self.inf_steps_completed = 0
81
+
82
+ @property
83
+ def device(self):
84
+ return self.model.device
85
+
86
+ def register_buffer(self, name, attr):
87
+ if isinstance(attr, torch.Tensor):
88
+ attr = attr.to(self.device)
89
+ setattr(self, name, attr)
90
+
91
+ def _check_the_conditioning(self, conditioning, batch_size):
92
+ # Checks if batch sizes match
93
+ if conditioning is not None:
94
+ if isinstance(conditioning, dict):
95
+ ctmp = conditioning[list(conditioning.keys())[0]]
96
+ while isinstance(ctmp, list):
97
+ ctmp = ctmp[0]
98
+ if isinstance(ctmp, dict):
99
+ if isinstance(ctmp["c"], list):
100
+ cbs = ctmp["c"][0].shape[0]
101
+ else:
102
+ cbs = ctmp["c"].shape[0]
103
+ else:
104
+ cbs = ctmp.shape[0]
105
+ if cbs != batch_size:
106
+ logger.info(
107
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
108
+ )
109
+
110
+ elif isinstance(conditioning, list):
111
+ for ctmp in conditioning:
112
+ if ctmp.shape[0] != batch_size:
113
+ logger.info(
114
+ f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}"
115
+ )
116
+
117
+ else:
118
+ if conditioning.shape[0] != batch_size:
119
+ logger.info(
120
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
121
+ )
122
+
123
+ def sample(
124
+ self,
125
+ sample_steps,
126
+ batch_size,
127
+ sampling_method,
128
+ unconditional_guidance_scale,
129
+ has_null_indicator,
130
+ shape=None, # no longer use it
131
+ callback=None,
132
+ normals_sequence=None,
133
+ img_callback=None,
134
+ quantize_x0=False,
135
+ eta=0.0,
136
+ mask=None,
137
+ x0=None,
138
+ temperature=1.0,
139
+ noise_dropout=0.0,
140
+ verbose=True,
141
+ x_T=None,
142
+ log_every_t=100,
143
+ dynamic_threshold=None,
144
+ ucg_schedule=None,
145
+ t_schedule=None, # Default value is set below
146
+ skip_type=None, # Deprecated, kept for backward compatibility. Use `t_schedule` instead.
147
+ start_timestep=None,
148
+ num_timesteps=None,
149
+ do_make_schedule=True,
150
+ **kwargs,
151
+ ):
152
+ self.num_inf_timesteps = sample_steps
153
+ assert skip_type is None
154
+
155
+ t_schedule = t_schedule or "time_uniform"
156
+
157
+ if self.unconditional_guidance_scale is None:
158
+ self.unconditional_guidance_scale = unconditional_guidance_scale
159
+
160
+ assert isinstance(sampling_method, Callable)
161
+ samples, intermediates = sampling_method(
162
+ x_T=x_T,
163
+ # Hardcoded in PLMS file
164
+ ddim_use_original_steps=False,
165
+ callback=callback,
166
+ num_timesteps=num_timesteps,
167
+ quantize_denoised=quantize_x0,
168
+ mask=mask,
169
+ x0=x0,
170
+ img_callback=img_callback,
171
+ log_every_t=log_every_t,
172
+ temperature=temperature,
173
+ noise_dropout=noise_dropout,
174
+ unconditional_guidance_scale=unconditional_guidance_scale,
175
+ has_null_indicator=has_null_indicator,
176
+ dynamic_threshold=dynamic_threshold,
177
+ verbose=verbose,
178
+ ucg_schedule=ucg_schedule,
179
+ start_timestep=start_timestep,
180
+ )
181
+ return samples, intermediates
182
+
183
+ @torch.no_grad()
184
+ def get_model_output_dimr(
185
+ self,
186
+ x,
187
+ t_continuous,
188
+ unconditional_guidance_scale,
189
+ has_null_indicator,
190
+ ):
191
+
192
+ log_snr = 4 - t_continuous * 8 # inversed
193
+
194
+ if has_null_indicator:
195
+ _cond = self.model(x, t=t_continuous, log_snr=log_snr, null_indicator=torch.tensor([False] * x.shape[0]).to(x.device))[-1]
196
+ _uncond = self.model(x, t=t_continuous, log_snr=log_snr, null_indicator=torch.tensor([True] * x.shape[0]).to(x.device))[-1]
197
+
198
+ assert unconditional_guidance_scale > 1
199
+ return _uncond + unconditional_guidance_scale * (_cond - _uncond)
200
+ else:
201
+ _cond = self.model(x, log_snr=log_snr)[-1]
202
+ return _cond
203
+
diffusion/flow_matching.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import logging
3
+ from typing import Callable, Dict, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+ import torchdiffeq
9
+ import random
10
+
11
+ from sde import multi_scale_targets
12
+ from diffusion.base_solver import Solver
13
+ import numpy as np
14
+ from torchvision import transforms
15
+
16
+
17
+ def check_zip(*args):
18
+ args = [list(arg) for arg in args]
19
+ length = len(args[0])
20
+ for arg in args:
21
+ assert len(arg) == length
22
+ return zip(*args)
23
+
24
+
25
+ def kl_divergence(source, target):
26
+ q_raw = source.view(-1)
27
+ p_raw = target.view(-1)
28
+
29
+ p = F.softmax(p_raw, dim=0)
30
+ q = F.softmax(q_raw, dim=0)
31
+
32
+
33
+ q_log = torch.log(q)
34
+ kl_div_1 = F.kl_div(q_log, p, reduction='sum')
35
+
36
+ return kl_div_1
37
+
38
+
39
+
40
+ class TimeStepSampler:
41
+ """
42
+ Abstract class to sample timesteps for flow matching.
43
+ """
44
+
45
+ def sample_time(self, x_start):
46
+ # In flow matching, time is in range [0, 1] and 1 indicates the original image; 0 is pure noise
47
+ # this convention is *REVERSE* of diffusion
48
+ raise NotImplementedError
49
+
50
+ class ClipLoss(nn.Module):
51
+
52
+ def __init__(
53
+ self,
54
+ local_loss=False,
55
+ gather_with_grad=False,
56
+ cache_labels=False,
57
+ rank=0,
58
+ world_size=1,
59
+ use_horovod=False,
60
+ ):
61
+ super().__init__()
62
+ self.local_loss = local_loss
63
+ self.gather_with_grad = gather_with_grad
64
+ self.cache_labels = cache_labels
65
+ self.rank = rank
66
+ self.world_size = world_size
67
+ self.use_horovod = use_horovod
68
+
69
+ # cache state
70
+ self.prev_num_logits = 0
71
+ self.labels = {}
72
+
73
+ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
74
+ # calculated ground-truth and cache if enabled
75
+ if self.prev_num_logits != num_logits or device not in self.labels:
76
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
77
+ if self.world_size > 1 and self.local_loss:
78
+ labels = labels + num_logits * self.rank
79
+ if self.cache_labels:
80
+ self.labels[device] = labels
81
+ self.prev_num_logits = num_logits
82
+ else:
83
+ labels = self.labels[device]
84
+ return labels
85
+
86
+ def get_logits(self, image_features, text_features, logit_scale):
87
+ if self.world_size > 1:
88
+ all_image_features, all_text_features = gather_features(
89
+ image_features, text_features,
90
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
91
+
92
+ if self.local_loss:
93
+ logits_per_image = logit_scale * image_features @ all_text_features.T
94
+ logits_per_text = logit_scale * text_features @ all_image_features.T
95
+ else:
96
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
97
+ logits_per_text = logits_per_image.T
98
+ else:
99
+ logits_per_image = logit_scale * image_features @ text_features.T
100
+ logits_per_text = logit_scale * text_features @ image_features.T
101
+
102
+ return logits_per_image, logits_per_text
103
+
104
+ def forward(self, image_features, text_features, logit_scale, output_dict=False):
105
+ device = image_features.device
106
+ logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
107
+
108
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
109
+
110
+ total_loss = (
111
+ F.cross_entropy(logits_per_image, labels) +
112
+ F.cross_entropy(logits_per_text, labels)
113
+ ) / 2
114
+
115
+ return {"contrastive_loss": total_loss} if output_dict else total_loss
116
+
117
+
118
+ class SigLipLoss(nn.Module):
119
+ """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343
120
+
121
+ @article{zhai2023sigmoid,
122
+ title={Sigmoid loss for language image pre-training},
123
+ author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
124
+ journal={arXiv preprint arXiv:2303.15343},
125
+ year={2023}
126
+ }
127
+ """
128
+ def __init__(
129
+ self,
130
+ cache_labels=False,
131
+ rank=0,
132
+ world_size=1,
133
+ bidir=True,
134
+ use_horovod=False,
135
+ ):
136
+ super().__init__()
137
+ self.cache_labels = cache_labels
138
+ self.rank = rank
139
+ self.world_size = world_size
140
+ assert not use_horovod # FIXME need to look at hvd ops for ring transfers
141
+ self.use_horovod = use_horovod
142
+ self.bidir = bidir
143
+
144
+ # cache state FIXME cache not currently used, worthwhile?
145
+ self.prev_num_logits = 0
146
+ self.labels = {}
147
+
148
+ def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:
149
+ labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)
150
+ if not negative_only:
151
+ labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels
152
+ return labels
153
+
154
+ def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
155
+ logits = logit_scale * image_features @ text_features.T
156
+ if logit_bias is not None:
157
+ logits += logit_bias
158
+ return logits
159
+
160
+ def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
161
+ logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
162
+ labels = self.get_ground_truth(
163
+ image_features.device,
164
+ image_features.dtype,
165
+ image_features.shape[0],
166
+ negative_only=negative_only,
167
+ )
168
+ loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
169
+ return loss
170
+
171
+ def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False):
172
+ loss = self._loss(image_features, text_features, logit_scale, logit_bias)
173
+
174
+ if self.world_size > 1:
175
+ # exchange text features w/ neighbour world_size - 1 times
176
+ right_rank = (self.rank + 1) % self.world_size
177
+ left_rank = (self.rank - 1 + self.world_size) % self.world_size
178
+ if self.bidir:
179
+ text_features_to_right = text_features_to_left = text_features
180
+ num_bidir, remainder = divmod(self.world_size - 1, 2)
181
+ for i in range(num_bidir):
182
+ text_features_recv = neighbour_exchange_bidir_with_grad(
183
+ left_rank,
184
+ right_rank,
185
+ text_features_to_left,
186
+ text_features_to_right,
187
+ )
188
+
189
+ for f in text_features_recv:
190
+ loss += self._loss(
191
+ image_features,
192
+ f,
193
+ logit_scale,
194
+ logit_bias,
195
+ negative_only=True,
196
+ )
197
+ text_features_to_left, text_features_to_right = text_features_recv
198
+
199
+ if remainder:
200
+ text_features_recv = neighbour_exchange_with_grad(
201
+ left_rank, right_rank, text_features_to_right)
202
+
203
+ loss += self._loss(
204
+ image_features,
205
+ text_features_recv,
206
+ logit_scale,
207
+ logit_bias,
208
+ negative_only=True,
209
+ )
210
+ else:
211
+ text_features_to_right = text_features
212
+ for i in range(self.world_size - 1):
213
+ text_features_from_left = neighbour_exchange_with_grad(
214
+ left_rank, right_rank, text_features_to_right)
215
+
216
+ loss += self._loss(
217
+ image_features,
218
+ text_features_from_left,
219
+ logit_scale,
220
+ logit_bias,
221
+ negative_only=True,
222
+ )
223
+ text_features_to_right = text_features_from_left
224
+
225
+ return {"contrastive_loss": loss} if output_dict else loss
226
+
227
+
228
+ class ResolutionScaledTimeStepSampler(TimeStepSampler):
229
+ def __init__(self, scale: float, base_time_step_sampler: TimeStepSampler):
230
+ self.scale = scale
231
+ self.base_time_step_sampler = base_time_step_sampler
232
+
233
+ @torch.no_grad()
234
+ def sample_time(self, x_start):
235
+ base_time = self.base_time_step_sampler.sample_time(x_start)
236
+ # based on eq (23) of https://arxiv.org/abs/2403.03206
237
+ scaled_time = (base_time * self.scale) / (1 + (self.scale - 1) * base_time)
238
+ return scaled_time
239
+
240
+
241
+ class LogitNormalSampler(TimeStepSampler):
242
+ def __init__(self, normal_mean: float = 0, normal_std: float = 1):
243
+ # follows https://arxiv.org/pdf/2403.03206.pdf
244
+ # sample from a normal distribution
245
+ # pass the output through standard logistic function, i.e., sigmoid
246
+ self.normal_mean = float(normal_mean)
247
+ self.normal_std = float(normal_std)
248
+
249
+ @torch.no_grad()
250
+ def sample_time(self, x_start):
251
+ x_normal = torch.normal(
252
+ mean=self.normal_mean,
253
+ std=self.normal_std,
254
+ size=(x_start.shape[0],),
255
+ device=x_start.device,
256
+ )
257
+ x_logistic = torch.nn.functional.sigmoid(x_normal)
258
+ return x_logistic
259
+
260
+
261
+ class UniformTimeSampler(TimeStepSampler):
262
+ @torch.no_grad()
263
+ def sample_time(self, x_start):
264
+ # [0, 1] and 1 indicates the original image; 0 is pure noise
265
+ return torch.rand(x_start.shape[0], device=x_start.device)
266
+
267
+
268
+ class FlowMatching(nn.Module):
269
+ def __init__(
270
+ self,
271
+ sigma_min: float = 1e-5,
272
+ sigma_max: float = 1.0,
273
+ timescale: float = 1.0,
274
+ **kwargs,
275
+ ):
276
+ # LatentDiffusion/DDPM will create too many class variables we do not need
277
+ super().__init__(**kwargs)
278
+ self.time_step_sampler = LogitNormalSampler()
279
+ self.sigma_min = sigma_min
280
+ self.sigma_max = sigma_max
281
+ self.timescale = timescale
282
+
283
+ self.clip_loss = ClipLoss()
284
+ # self.SigLipLoss = SigLipLoss()
285
+
286
+ self.resizer = transforms.Resize(256) # for clip
287
+
288
+ def sample_noise(self, x_start):
289
+ # simple IID noise
290
+ return torch.randn_like(x_start, device=x_start.device) * self.sigma_max
291
+
292
+ def mos(self, err, start_dim=1, con_mask=None): # mean of square
293
+ if con_mask is not None:
294
+ return (err.pow(2).mean(dim=-1) * con_mask).sum(dim=-1) / con_mask.sum(dim=-1)
295
+ else:
296
+ return err.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
297
+
298
+
299
+ def Xentropy(self, pred, tar, con_mask=None):
300
+ if con_mask is not None:
301
+ return (nn.functional.cross_entropy(pred, tar, reduction='none') * con_mask).sum(dim=-1) / con_mask.sum(dim=-1)
302
+ else:
303
+ return nn.functional.cross_entropy(pred, tar, reduction='none').mean(dim=-1)
304
+
305
+ def l2_reg(self, pred, lam = 0.0001):
306
+ return lam * torch.norm(pred, p=2, dim=(1, 2, 3)) ** 2
307
+
308
+ # model forward and prediction
309
+ def forward(
310
+ self,
311
+ x,
312
+ nnet,
313
+ loss_coeffs,
314
+ cond,
315
+ con_mask,
316
+ nnet_style,
317
+ training_step,
318
+ cond_ori=None, # not using
319
+ con_mask_ori=None, # not using
320
+ batch_img_clip=None, # not using
321
+ model_config=None,
322
+ all_config=None,
323
+ text_token=None,
324
+ return_raw_loss=False,
325
+ additional_embeddings=None,
326
+ timesteps: Optional[Tuple[int, int]] = None,
327
+ *args,
328
+ **kwargs,
329
+ ):
330
+ assert timesteps is None, "timesteps must be None"
331
+
332
+ timesteps = self.time_step_sampler.sample_time(x)
333
+
334
+ if nnet_style == 'dimr':
335
+ if hasattr(model_config, "standard_diffusion") and model_config.standard_diffusion:
336
+ standard_diffusion=True
337
+ else:
338
+ standard_diffusion=False
339
+ return self.p_losses_textVAE(
340
+ x, cond, con_mask, timesteps, nnet, batch_img_clip=batch_img_clip, cond_ori=cond_ori, con_mask_ori=con_mask_ori, text_token=text_token, loss_coeffs=loss_coeffs, return_raw_loss=return_raw_loss, nnet_style=nnet_style, standard_diffusion=standard_diffusion, all_config=all_config, training_step=training_step, *args, **kwargs
341
+ )
342
+ elif nnet_style == 'dit':
343
+ if hasattr(model_config, "standard_diffusion") and model_config.standard_diffusion:
344
+ standard_diffusion=True
345
+ raise NotImplementedError("need update")
346
+ else:
347
+ standard_diffusion=False
348
+ return self.p_losses_textVAE_dit(
349
+ x, cond, con_mask, timesteps, nnet, batch_img_clip=batch_img_clip, cond_ori=cond_ori, con_mask_ori=con_mask_ori, text_token=text_token, loss_coeffs=loss_coeffs, return_raw_loss=return_raw_loss, nnet_style=nnet_style, standard_diffusion=standard_diffusion, all_config=all_config, training_step=training_step, *args, **kwargs
350
+ )
351
+ else:
352
+ raise NotImplementedError
353
+
354
+
355
+
356
+ def p_losses_textVAE(
357
+ self,
358
+ x_start,
359
+ cond,
360
+ con_mask,
361
+ t,
362
+ nnet,
363
+ loss_coeffs,
364
+ training_step,
365
+ text_token=None,
366
+ nnet_style=None,
367
+ all_config=None,
368
+ batch_img_clip=None,
369
+ cond_ori=None, # not using
370
+ con_mask_ori=None, # not using
371
+ return_raw_loss=False,
372
+ additional_embeddings=None,
373
+ standard_diffusion=False,
374
+ noise=None,
375
+ ):
376
+ """
377
+ CrossFlow training for DiMR
378
+ """
379
+
380
+ assert noise is None
381
+
382
+ x0, mu, log_var = nnet(cond, text_encoder = True, shape = x_start.shape, mask = con_mask)
383
+
384
+ ############ loss for Text VE
385
+ if batch_img_clip.shape[-1] == 512:
386
+ recon_gt = self.resizer(batch_img_clip)
387
+ else:
388
+ recon_gt = batch_img_clip
389
+ recon_gt_clip, logit_scale = nnet(recon_gt, image_clip = True)
390
+ image_features = recon_gt_clip / recon_gt_clip.norm(dim=-1, keepdim=True)
391
+ text_features = x0 / x0.norm(dim=-1, keepdim=True)
392
+ recons_loss = self.clip_loss(image_features, text_features, logit_scale)
393
+
394
+ # kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1)
395
+ kld_loss = -0.5 * torch.sum(1 + log_var - (0.3 * mu) ** 6 - log_var.exp(), dim = 1) # slightly different KL loss function: mu -> 0 [(0.3*mu) ** 6] and var -> 1
396
+ kld_loss_weight = 1e-2 # 0.0005
397
+
398
+ loss_mlp = recons_loss + kld_loss * kld_loss_weight
399
+
400
+
401
+ ############ loss for FM
402
+ noise = x0.reshape(x_start.shape)
403
+
404
+ if hasattr(all_config.nnet.model_args, "cfg_indicator"):
405
+ null_indicator = torch.from_numpy(np.array([random.random() < all_config.nnet.model_args.cfg_indicator for _ in range(x_start.shape[0])])).to(x_start.device)
406
+ if null_indicator.sum()<=1:
407
+ null_indicator[null_indicator==True] = False
408
+ assert null_indicator.sum() == 0
409
+ pass
410
+ else:
411
+ target_null = x_start[null_indicator]
412
+ target_null = torch.cat((target_null[1:], target_null[:1]))
413
+ x_start[null_indicator] = target_null
414
+ else:
415
+ null_indicator = None
416
+
417
+
418
+ x_noisy = self.psi(t, x=noise, x1=x_start)
419
+ target_velocity = self.Dt_psi(t, x=noise, x1=x_start)
420
+ log_snr = 4 - t * 8 # compute from timestep : inversed
421
+
422
+ prediction = nnet(x_noisy, log_snr = log_snr, null_indicator=null_indicator)
423
+
424
+ target = multi_scale_targets(target_velocity, levels = len(prediction), scale_correction = True)
425
+
426
+ loss_diff = 0
427
+ for pred, coeff in check_zip(prediction, loss_coeffs):
428
+ loss_diff = loss_diff + coeff * self.mos(pred - target[pred.shape[-1]])
429
+
430
+ ###########
431
+
432
+ loss = loss_diff + loss_mlp
433
+
434
+ return loss, {'loss_diff': loss_diff, 'clip_loss': recons_loss, 'kld_loss': kld_loss, 'kld_loss_weight': torch.tensor(kld_loss_weight, device=kld_loss.device), 'clip_logit_scale': logit_scale}
435
+
436
+
437
+ def p_losses_textVAE_dit(
438
+ self,
439
+ x_start,
440
+ cond,
441
+ con_mask,
442
+ t,
443
+ nnet,
444
+ loss_coeffs,
445
+ training_step,
446
+ text_token=None,
447
+ nnet_style=None,
448
+ all_config=None,
449
+ batch_img_clip=None,
450
+ cond_ori=None, # not using
451
+ con_mask_ori=None, # not using
452
+ return_raw_loss=False,
453
+ additional_embeddings=None,
454
+ standard_diffusion=False,
455
+ noise=None,
456
+ ):
457
+ """
458
+ CrossFLow training for DiT
459
+ """
460
+
461
+ assert noise is None
462
+
463
+ x0, mu, log_var = nnet(cond, text_encoder = True, shape = x_start.shape, mask = con_mask)
464
+
465
+ ############ loss for Text VE
466
+ if batch_img_clip.shape[-1] == 512:
467
+ recon_gt = self.resizer(batch_img_clip)
468
+ else:
469
+ recon_gt = batch_img_clip
470
+ recon_gt_clip, logit_scale = nnet(recon_gt, image_clip = True)
471
+ image_features = recon_gt_clip / recon_gt_clip.norm(dim=-1, keepdim=True)
472
+ text_features = x0 / x0.norm(dim=-1, keepdim=True)
473
+ recons_loss = self.clip_loss(image_features, text_features, logit_scale)
474
+
475
+ # kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1)
476
+ kld_loss = -0.5 * torch.sum(1 + log_var - (0.3 * mu) ** 6 - log_var.exp(), dim = 1)
477
+ kld_loss_weight = 1e-2 # 0.0005
478
+
479
+ loss_mlp = recons_loss + kld_loss * kld_loss_weight
480
+
481
+ ############ loss for FM
482
+ noise = x0.reshape(x_start.shape)
483
+
484
+ if hasattr(all_config.nnet.model_args, "cfg_indicator"):
485
+ null_indicator = torch.from_numpy(np.array([random.random() < all_config.nnet.model_args.cfg_indicator for _ in range(x_start.shape[0])])).to(x_start.device)
486
+ if null_indicator.sum()<=1:
487
+ null_indicator[null_indicator==True] = False
488
+ assert null_indicator.sum() == 0
489
+ pass
490
+ else:
491
+ target_null = x_start[null_indicator]
492
+ target_null = torch.cat((target_null[1:], target_null[:1]))
493
+ x_start[null_indicator] = target_null
494
+ else:
495
+ null_indicator = None
496
+
497
+ x_noisy = self.psi(t, x=noise, x1=x_start)
498
+ target_velocity = self.Dt_psi(t, x=noise, x1=x_start)
499
+
500
+ prediction = nnet(x_noisy, t = t, null_indicator = null_indicator)[0]
501
+
502
+ loss_diff = self.mos(prediction - target_velocity)
503
+
504
+ ###########
505
+
506
+ loss = loss_diff + loss_mlp
507
+
508
+ return loss, {'loss_diff': loss_diff, 'clip_loss': recons_loss, 'kld_loss': kld_loss, 'kld_loss_weight': torch.tensor(kld_loss_weight, device=kld_loss.device), 'clip_logit_scale': logit_scale}
509
+
510
+
511
+ ## flow matching specific functions
512
+ def psi(self, t, x, x1):
513
+ assert (
514
+ t.shape[0] == x.shape[0]
515
+ ), f"Batch size of t and x does not agree {t.shape[0]} vs. {x.shape[0]}"
516
+ assert (
517
+ t.shape[0] == x1.shape[0]
518
+ ), f"Batch size of t and x1 does not agree {t.shape[0]} vs. {x1.shape[0]}"
519
+ assert t.ndim == 1
520
+ t = self.expand_t(t, x)
521
+ return (t * (self.sigma_min / self.sigma_max - 1) + 1) * x + t * x1
522
+
523
+ def Dt_psi(self, t: torch.Tensor, x: torch.Tensor, x1: torch.Tensor):
524
+ assert x.shape[0] == x1.shape[0]
525
+ return (self.sigma_min / self.sigma_max - 1) * x + x1
526
+
527
+ def expand_t(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
528
+ t_expanded = t
529
+ while t_expanded.ndim < x.ndim:
530
+ t_expanded = t_expanded.unsqueeze(-1)
531
+ return t_expanded.expand_as(x)
532
+
533
+
534
+
535
+
536
+ class ODEEulerFlowMatchingSolver(Solver):
537
+ """
538
+ ODE Solver for Flow matching that uses an Euler discretization
539
+ Supports number of time steps at inference
540
+ """
541
+
542
+ def __init__(self, *args, **kwargs):
543
+ super().__init__(*args, **kwargs)
544
+ self.step_size_type = kwargs.get("step_size_type", "step_in_dsigma")
545
+ assert self.step_size_type in ["step_in_dsigma", "step_in_dt"]
546
+ self.sample_timescale = 1.0 - 1e-5
547
+
548
+ @torch.no_grad()
549
+ def sample_euler(
550
+ self,
551
+ x_T,
552
+ unconditional_guidance_scale,
553
+ has_null_indicator,
554
+ t=[0, 1.0],
555
+ **kwargs,
556
+ ):
557
+ """
558
+ Euler solver for flow matching.
559
+ Based on https://github.com/VinAIResearch/LFM/blob/main/sampler/karras_sample.py
560
+ """
561
+ t = torch.tensor(t)
562
+ t = t * self.sample_timescale
563
+ sigma_min = 1e-5
564
+ sigma_max = 1.0
565
+ sigma_steps = torch.linspace(
566
+ sigma_min, sigma_max, self.num_time_steps + 1, device=x_T.device
567
+ )
568
+ discrete_time_steps_for_step = torch.linspace(
569
+ t[0], t[1], self.num_time_steps + 1, device=x_T.device
570
+ )
571
+ discrete_time_steps_to_eval_model_at = torch.linspace(
572
+ t[0], t[1], self.num_time_steps, device=x_T.device
573
+ )
574
+
575
+ print("num_time_steps : " + str(self.num_time_steps))
576
+
577
+ for i in range(self.num_time_steps):
578
+ t_i = discrete_time_steps_to_eval_model_at[i]
579
+ velocity = self.get_model_output_dimr(
580
+ x_T,
581
+ has_null_indicator = has_null_indicator,
582
+ t_continuous = t_i.repeat(x_T.shape[0]),
583
+ unconditional_guidance_scale = unconditional_guidance_scale,
584
+ )
585
+ if self.step_size_type == "step_in_dsigma":
586
+ step_size = sigma_steps[i + 1] - sigma_steps[i]
587
+ elif self.step_size_type == "step_in_dt":
588
+ step_size = (
589
+ discrete_time_steps_for_step[i + 1]
590
+ - discrete_time_steps_for_step[i]
591
+ )
592
+ x_T = x_T + velocity * step_size
593
+
594
+ intermediates = None
595
+ return x_T, intermediates
596
+
597
+ @torch.no_grad()
598
+ def sample(
599
+ self,
600
+ *args,
601
+ **kwargs,
602
+ ):
603
+ assert kwargs.get("ucg_schedule", None) is None
604
+ assert kwargs.get("skip_type", None) is None
605
+ assert kwargs.get("dynamic_threshold", None) is None
606
+ assert kwargs.get("x0", None) is None
607
+ assert kwargs.get("x_T") is not None
608
+ assert kwargs.get("score_corrector", None) is None
609
+ assert kwargs.get("normals_sequence", None) is None
610
+ assert kwargs.get("callback", None) is None
611
+ assert kwargs.get("quantize_x0", False) is False
612
+ assert kwargs.get("eta", 0.0) == 0.0
613
+ assert kwargs.get("mask", None) is None
614
+ assert kwargs.get("noise_dropout", 0.0) == 0.0
615
+
616
+ self.num_time_steps = kwargs.get("sample_steps")
617
+ self.x_T_uncon = kwargs.get("x_T_uncon")
618
+
619
+ samples, intermediates = super().sample(
620
+ *args,
621
+ sampling_method=self.sample_euler,
622
+ do_make_schedule=False,
623
+ **kwargs,
624
+ )
625
+ return samples, intermediates
626
+
627
+
628
+ class ODEFlowMatchingSolver(Solver):
629
+ """
630
+ ODE Solver for Flow matching that uses `dopri5`
631
+ Does not support number of time steps based control
632
+ """
633
+
634
+ def __init__(self, *args, **kwargs):
635
+ super().__init__(*args, **kwargs)
636
+ self.sample_timescale = 1.0 - 1e-5
637
+
638
+ # sampling for inference
639
+ @torch.no_grad()
640
+ def sample_transport(
641
+ self,
642
+ x_T,
643
+ unconditional_guidance_scale,
644
+ has_null_indicator,
645
+ t=[0, 1.0],
646
+ ode_opts={},
647
+ **kwargs,
648
+ ):
649
+ num_evals = 0
650
+ t = torch.tensor(t, device=x_T.device)
651
+ if "options" not in ode_opts:
652
+ ode_opts["options"] = {}
653
+ ode_opts["options"]["step_t"] = [self.sample_timescale + 1e-6]
654
+
655
+ def ode_func(t, x_T):
656
+ nonlocal num_evals
657
+ num_evals += 1
658
+ model_output = self.get_model_output_dimr(
659
+ x_T,
660
+ has_null_indicator = has_null_indicator,
661
+ t_continuous = t.repeat(x_T.shape[0]),
662
+ unconditional_guidance_scale = unconditional_guidance_scale,
663
+ )
664
+ return model_output
665
+
666
+ z = torchdiffeq.odeint(
667
+ ode_func,
668
+ x_T,
669
+ t * self.sample_timescale,
670
+ **{"atol": 1e-5, "rtol": 1e-5, "method": "dopri5", **ode_opts},
671
+ )
672
+ # first dimension of z contains solutions to different timepoints
673
+ # we only need the last one (corresponding to t=1, i.e., image)
674
+ z = z[-1]
675
+ intermediates = None
676
+ return z, intermediates
677
+
678
+ @torch.no_grad()
679
+ def sample(
680
+ self,
681
+ *args,
682
+ **kwargs,
683
+ ):
684
+ assert kwargs.get("ucg_schedule", None) is None
685
+ assert kwargs.get("skip_type", None) is None
686
+ assert kwargs.get("dynamic_threshold", None) is None
687
+ assert kwargs.get("x0", None) is None
688
+ assert kwargs.get("x_T") is not None
689
+ assert kwargs.get("score_corrector", None) is None
690
+ assert kwargs.get("normals_sequence", None) is None
691
+ assert kwargs.get("callback", None) is None
692
+ assert kwargs.get("quantize_x0", False) is False
693
+ assert kwargs.get("eta", 0.0) == 0.0
694
+ assert kwargs.get("mask", None) is None
695
+ assert kwargs.get("noise_dropout", 0.0) == 0.0
696
+ samples, intermediates = super().sample(
697
+ *args,
698
+ sampling_method=self.sample_transport,
699
+ do_make_schedule=False,
700
+ **kwargs,
701
+ )
702
+ return samples, intermediates
libs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # codes from third party
libs/autoencoder.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from einops import rearrange
5
+
6
+
7
+ class LinearAttention(nn.Module):
8
+ def __init__(self, dim, heads=4, dim_head=32):
9
+ super().__init__()
10
+ self.heads = heads
11
+ hidden_dim = dim_head * heads
12
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
13
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
14
+
15
+ def forward(self, x):
16
+ b, c, h, w = x.shape
17
+ qkv = self.to_qkv(x)
18
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
19
+ k = k.softmax(dim=-1)
20
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
21
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
22
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
23
+ return self.to_out(out)
24
+
25
+
26
+ def nonlinearity(x):
27
+ # swish
28
+ return x*torch.sigmoid(x)
29
+
30
+
31
+ def Normalize(in_channels, num_groups=32):
32
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
33
+
34
+
35
+ class Upsample(nn.Module):
36
+ def __init__(self, in_channels, with_conv):
37
+ super().__init__()
38
+ self.with_conv = with_conv
39
+ if self.with_conv:
40
+ self.conv = torch.nn.Conv2d(in_channels,
41
+ in_channels,
42
+ kernel_size=3,
43
+ stride=1,
44
+ padding=1)
45
+
46
+ def forward(self, x):
47
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
48
+ if self.with_conv:
49
+ x = self.conv(x)
50
+ return x
51
+
52
+
53
+ class Downsample(nn.Module):
54
+ def __init__(self, in_channels, with_conv):
55
+ super().__init__()
56
+ self.with_conv = with_conv
57
+ if self.with_conv:
58
+ # no asymmetric padding in torch conv, must do it ourselves
59
+ self.conv = torch.nn.Conv2d(in_channels,
60
+ in_channels,
61
+ kernel_size=3,
62
+ stride=2,
63
+ padding=0)
64
+
65
+ def forward(self, x):
66
+ if self.with_conv:
67
+ pad = (0,1,0,1)
68
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
69
+ x = self.conv(x)
70
+ else:
71
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
72
+ return x
73
+
74
+
75
+ class ResnetBlock(nn.Module):
76
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
77
+ dropout, temb_channels=512):
78
+ super().__init__()
79
+ self.in_channels = in_channels
80
+ out_channels = in_channels if out_channels is None else out_channels
81
+ self.out_channels = out_channels
82
+ self.use_conv_shortcut = conv_shortcut
83
+
84
+ self.norm1 = Normalize(in_channels)
85
+ self.conv1 = torch.nn.Conv2d(in_channels,
86
+ out_channels,
87
+ kernel_size=3,
88
+ stride=1,
89
+ padding=1)
90
+ if temb_channels > 0:
91
+ self.temb_proj = torch.nn.Linear(temb_channels,
92
+ out_channels)
93
+ self.norm2 = Normalize(out_channels)
94
+ self.dropout = torch.nn.Dropout(dropout)
95
+ self.conv2 = torch.nn.Conv2d(out_channels,
96
+ out_channels,
97
+ kernel_size=3,
98
+ stride=1,
99
+ padding=1)
100
+ if self.in_channels != self.out_channels:
101
+ if self.use_conv_shortcut:
102
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ else:
108
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
109
+ out_channels,
110
+ kernel_size=1,
111
+ stride=1,
112
+ padding=0)
113
+
114
+ def forward(self, x, temb):
115
+ h = x
116
+ h = self.norm1(h)
117
+ h = nonlinearity(h)
118
+ h = self.conv1(h)
119
+
120
+ if temb is not None:
121
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
122
+
123
+ h = self.norm2(h)
124
+ h = nonlinearity(h)
125
+ h = self.dropout(h)
126
+ h = self.conv2(h)
127
+
128
+ if self.in_channels != self.out_channels:
129
+ if self.use_conv_shortcut:
130
+ x = self.conv_shortcut(x)
131
+ else:
132
+ x = self.nin_shortcut(x)
133
+
134
+ return x+h
135
+
136
+
137
+ class LinAttnBlock(LinearAttention):
138
+ """to match AttnBlock usage"""
139
+ def __init__(self, in_channels):
140
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
141
+
142
+
143
+ class AttnBlock(nn.Module):
144
+ def __init__(self, in_channels):
145
+ super().__init__()
146
+ self.in_channels = in_channels
147
+
148
+ self.norm = Normalize(in_channels)
149
+ self.q = torch.nn.Conv2d(in_channels,
150
+ in_channels,
151
+ kernel_size=1,
152
+ stride=1,
153
+ padding=0)
154
+ self.k = torch.nn.Conv2d(in_channels,
155
+ in_channels,
156
+ kernel_size=1,
157
+ stride=1,
158
+ padding=0)
159
+ self.v = torch.nn.Conv2d(in_channels,
160
+ in_channels,
161
+ kernel_size=1,
162
+ stride=1,
163
+ padding=0)
164
+ self.proj_out = torch.nn.Conv2d(in_channels,
165
+ in_channels,
166
+ kernel_size=1,
167
+ stride=1,
168
+ padding=0)
169
+
170
+
171
+ def forward(self, x):
172
+ h_ = x
173
+ h_ = self.norm(h_)
174
+ q = self.q(h_)
175
+ k = self.k(h_)
176
+ v = self.v(h_)
177
+
178
+ # compute attention
179
+ b,c,h,w = q.shape
180
+ q = q.reshape(b,c,h*w)
181
+ q = q.permute(0,2,1) # b,hw,c
182
+ k = k.reshape(b,c,h*w) # b,c,hw
183
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
184
+ w_ = w_ * (int(c)**(-0.5))
185
+ w_ = torch.nn.functional.softmax(w_, dim=2)
186
+
187
+ # attend to values
188
+ v = v.reshape(b,c,h*w)
189
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
190
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
191
+ h_ = h_.reshape(b,c,h,w)
192
+
193
+ h_ = self.proj_out(h_)
194
+
195
+ return x+h_
196
+
197
+
198
+ def make_attn(in_channels, attn_type="vanilla"):
199
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
200
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
201
+ if attn_type == "vanilla":
202
+ return AttnBlock(in_channels)
203
+ elif attn_type == "none":
204
+ return nn.Identity(in_channels)
205
+ else:
206
+ return LinAttnBlock(in_channels)
207
+
208
+
209
+ class Encoder(nn.Module):
210
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
211
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
212
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
213
+ **ignore_kwargs):
214
+ super().__init__()
215
+ if use_linear_attn: attn_type = "linear"
216
+ self.ch = ch
217
+ self.temb_ch = 0
218
+ self.num_resolutions = len(ch_mult)
219
+ self.num_res_blocks = num_res_blocks
220
+ self.resolution = resolution
221
+ self.in_channels = in_channels
222
+
223
+ # downsampling
224
+ self.conv_in = torch.nn.Conv2d(in_channels,
225
+ self.ch,
226
+ kernel_size=3,
227
+ stride=1,
228
+ padding=1)
229
+
230
+ curr_res = resolution
231
+ in_ch_mult = (1,)+tuple(ch_mult)
232
+ self.in_ch_mult = in_ch_mult
233
+ self.down = nn.ModuleList()
234
+ for i_level in range(self.num_resolutions):
235
+ block = nn.ModuleList()
236
+ attn = nn.ModuleList()
237
+ block_in = ch*in_ch_mult[i_level]
238
+ block_out = ch*ch_mult[i_level]
239
+ for i_block in range(self.num_res_blocks):
240
+ block.append(ResnetBlock(in_channels=block_in,
241
+ out_channels=block_out,
242
+ temb_channels=self.temb_ch,
243
+ dropout=dropout))
244
+ block_in = block_out
245
+ if curr_res in attn_resolutions:
246
+ attn.append(make_attn(block_in, attn_type=attn_type))
247
+ down = nn.Module()
248
+ down.block = block
249
+ down.attn = attn
250
+ if i_level != self.num_resolutions-1:
251
+ down.downsample = Downsample(block_in, resamp_with_conv)
252
+ curr_res = curr_res // 2
253
+ self.down.append(down)
254
+
255
+ # middle
256
+ self.mid = nn.Module()
257
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
258
+ out_channels=block_in,
259
+ temb_channels=self.temb_ch,
260
+ dropout=dropout)
261
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
262
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
263
+ out_channels=block_in,
264
+ temb_channels=self.temb_ch,
265
+ dropout=dropout)
266
+
267
+ # end
268
+ self.norm_out = Normalize(block_in)
269
+ self.conv_out = torch.nn.Conv2d(block_in,
270
+ 2*z_channels if double_z else z_channels,
271
+ kernel_size=3,
272
+ stride=1,
273
+ padding=1)
274
+
275
+ def forward(self, x):
276
+ # timestep embedding
277
+ temb = None
278
+
279
+ # downsampling
280
+ hs = [self.conv_in(x)]
281
+ for i_level in range(self.num_resolutions):
282
+ for i_block in range(self.num_res_blocks):
283
+ h = self.down[i_level].block[i_block](hs[-1], temb)
284
+ if len(self.down[i_level].attn) > 0:
285
+ h = self.down[i_level].attn[i_block](h)
286
+ hs.append(h)
287
+ if i_level != self.num_resolutions-1:
288
+ hs.append(self.down[i_level].downsample(hs[-1]))
289
+
290
+ # middle
291
+ h = hs[-1]
292
+ h = self.mid.block_1(h, temb)
293
+ h = self.mid.attn_1(h)
294
+ h = self.mid.block_2(h, temb)
295
+
296
+ # end
297
+ h = self.norm_out(h)
298
+ h = nonlinearity(h)
299
+ h = self.conv_out(h)
300
+ return h
301
+
302
+
303
+ class Decoder(nn.Module):
304
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
305
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
306
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
307
+ attn_type="vanilla", **ignorekwargs):
308
+ super().__init__()
309
+ if use_linear_attn: attn_type = "linear"
310
+ self.ch = ch
311
+ self.temb_ch = 0
312
+ self.num_resolutions = len(ch_mult)
313
+ self.num_res_blocks = num_res_blocks
314
+ self.resolution = resolution
315
+ self.in_channels = in_channels
316
+ self.give_pre_end = give_pre_end
317
+ self.tanh_out = tanh_out
318
+
319
+ # compute in_ch_mult, block_in and curr_res at lowest res
320
+ in_ch_mult = (1,)+tuple(ch_mult)
321
+ block_in = ch*ch_mult[self.num_resolutions-1]
322
+ curr_res = resolution // 2**(self.num_resolutions-1)
323
+ self.z_shape = (1,z_channels,curr_res,curr_res)
324
+ print("Working with z of shape {} = {} dimensions.".format(
325
+ self.z_shape, np.prod(self.z_shape)))
326
+
327
+ # z to block_in
328
+ self.conv_in = torch.nn.Conv2d(z_channels,
329
+ block_in,
330
+ kernel_size=3,
331
+ stride=1,
332
+ padding=1)
333
+
334
+ # middle
335
+ self.mid = nn.Module()
336
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
337
+ out_channels=block_in,
338
+ temb_channels=self.temb_ch,
339
+ dropout=dropout)
340
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
341
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
342
+ out_channels=block_in,
343
+ temb_channels=self.temb_ch,
344
+ dropout=dropout)
345
+
346
+ # upsampling
347
+ self.up = nn.ModuleList()
348
+ for i_level in reversed(range(self.num_resolutions)):
349
+ block = nn.ModuleList()
350
+ attn = nn.ModuleList()
351
+ block_out = ch*ch_mult[i_level]
352
+ for i_block in range(self.num_res_blocks+1):
353
+ block.append(ResnetBlock(in_channels=block_in,
354
+ out_channels=block_out,
355
+ temb_channels=self.temb_ch,
356
+ dropout=dropout))
357
+ block_in = block_out
358
+ if curr_res in attn_resolutions:
359
+ attn.append(make_attn(block_in, attn_type=attn_type))
360
+ up = nn.Module()
361
+ up.block = block
362
+ up.attn = attn
363
+ if i_level != 0:
364
+ up.upsample = Upsample(block_in, resamp_with_conv)
365
+ curr_res = curr_res * 2
366
+ self.up.insert(0, up) # prepend to get consistent order
367
+
368
+ # end
369
+ self.norm_out = Normalize(block_in)
370
+ self.conv_out = torch.nn.Conv2d(block_in,
371
+ out_ch,
372
+ kernel_size=3,
373
+ stride=1,
374
+ padding=1)
375
+
376
+ def forward(self, z):
377
+ #assert z.shape[1:] == self.z_shape[1:]
378
+ self.last_z_shape = z.shape
379
+
380
+ # timestep embedding
381
+ temb = None
382
+
383
+ # z to block_in
384
+ h = self.conv_in(z)
385
+
386
+ # middle
387
+ h = self.mid.block_1(h, temb)
388
+ h = self.mid.attn_1(h)
389
+ h = self.mid.block_2(h, temb)
390
+
391
+ # upsampling
392
+ for i_level in reversed(range(self.num_resolutions)):
393
+ for i_block in range(self.num_res_blocks+1):
394
+ h = self.up[i_level].block[i_block](h, temb)
395
+ if len(self.up[i_level].attn) > 0:
396
+ h = self.up[i_level].attn[i_block](h)
397
+ if i_level != 0:
398
+ h = self.up[i_level].upsample(h)
399
+
400
+ # end
401
+ if self.give_pre_end:
402
+ return h
403
+
404
+ h = self.norm_out(h)
405
+ h = nonlinearity(h)
406
+ h = self.conv_out(h)
407
+ if self.tanh_out:
408
+ h = torch.tanh(h)
409
+ return h
410
+
411
+
412
+ class FrozenAutoencoderKL(nn.Module):
413
+ def __init__(self, ddconfig, embed_dim, pretrained_path, scale_factor=0.18215):
414
+ super().__init__()
415
+ print(f'Create autoencoder with scale_factor={scale_factor}')
416
+ self.encoder = Encoder(**ddconfig)
417
+ self.decoder = Decoder(**ddconfig)
418
+ assert ddconfig["double_z"]
419
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
420
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
421
+ self.embed_dim = embed_dim
422
+ self.scale_factor = scale_factor
423
+ m, u = self.load_state_dict(torch.load(pretrained_path, map_location='cpu'))
424
+ assert len(m) == 0 and len(u) == 0
425
+ self.eval()
426
+ self.requires_grad_(False)
427
+
428
+ def encode_moments(self, x):
429
+ h = self.encoder(x)
430
+ moments = self.quant_conv(h)
431
+ return moments
432
+
433
+ def sample(self, moments):
434
+ mean, logvar = torch.chunk(moments, 2, dim=1)
435
+ logvar = torch.clamp(logvar, -30.0, 20.0)
436
+ std = torch.exp(0.5 * logvar)
437
+ z = mean + std * torch.randn_like(mean)
438
+ z = self.scale_factor * z
439
+ return z
440
+
441
+ def encode(self, x):
442
+ moments = self.encode_moments(x)
443
+ z = self.sample(moments)
444
+ return z
445
+
446
+ def decode(self, z):
447
+ z = (1. / self.scale_factor) * z
448
+ z = self.post_quant_conv(z)
449
+ dec = self.decoder(z)
450
+ return dec
451
+
452
+ def forward(self, inputs, fn):
453
+ if fn == 'encode_moments':
454
+ return self.encode_moments(inputs)
455
+ elif fn == 'encode':
456
+ return self.encode(inputs)
457
+ elif fn == 'decode':
458
+ return self.decode(inputs)
459
+ else:
460
+ raise NotImplementedError
461
+
462
+
463
+ def get_model(pretrained_path, scale_factor=0.18215):
464
+ ddconfig = dict(
465
+ double_z=True,
466
+ z_channels=4,
467
+ resolution=256,
468
+ in_channels=3,
469
+ out_ch=3,
470
+ ch=128,
471
+ ch_mult=[1, 2, 4, 4],
472
+ num_res_blocks=2,
473
+ attn_resolutions=[],
474
+ dropout=0.0
475
+ )
476
+ return FrozenAutoencoderKL(ddconfig, 4, pretrained_path, scale_factor)
477
+
478
+
479
+ def main():
480
+ import torchvision.transforms as transforms
481
+ from torchvision.utils import save_image
482
+ import os
483
+ from PIL import Image
484
+
485
+ model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
486
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
487
+ model = model.to(device)
488
+
489
+ scale_factor = 0.18215
490
+ T = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor()])
491
+ path = 'imgs'
492
+ fnames = os.listdir(path)
493
+ for fname in fnames:
494
+ p = os.path.join(path, fname)
495
+ img = Image.open(p)
496
+ img = T(img)
497
+ img = img * 2. - 1
498
+ img = img[None, ...]
499
+ img = img.to(device)
500
+
501
+ # with torch.cuda.amp.autocast():
502
+ # moments = model.encode_moments(img)
503
+ # mean, logvar = torch.chunk(moments, 2, dim=1)
504
+ # logvar = torch.clamp(logvar, -30.0, 20.0)
505
+ # std = torch.exp(0.5 * logvar)
506
+ # zs = [(mean + std * torch.randn_like(mean)) * scale_factor for _ in range(4)]
507
+ # recons = [model.decode(z) for z in zs]
508
+
509
+ with torch.cuda.amp.autocast():
510
+ print('test encode & decode')
511
+ recons = [model.decode(model.encode(img)) for _ in range(4)]
512
+
513
+ out = torch.cat([img, *recons], dim=0)
514
+ out = (out + 1) * 0.5
515
+ save_image(out, f'recons_{fname}')
516
+
517
+
518
+ if __name__ == "__main__":
519
+ main()
libs/clip.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import CLIPTokenizer, CLIPTextModel
3
+ import time
4
+
5
+
6
+ class AbstractEncoder(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def encode(self, *args, **kwargs):
11
+ raise NotImplementedError
12
+
13
+
14
+ class FrozenCLIPEmbedder(AbstractEncoder):
15
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
16
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
17
+ super().__init__()
18
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
19
+ self.transformer = CLIPTextModel.from_pretrained(version)
20
+ self.device = device
21
+ self.max_length = max_length
22
+ self.freeze()
23
+
24
+ def freeze(self):
25
+ self.transformer = self.transformer.eval()
26
+ for param in self.parameters():
27
+ param.requires_grad = False
28
+
29
+ def forward(self, text):
30
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
31
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
32
+ tokens = batch_encoding["input_ids"].to(self.device)
33
+ outputs = self.transformer(input_ids=tokens)
34
+
35
+ z = outputs.last_hidden_state
36
+ return z, {'token_embedding': outputs.last_hidden_state, 'pooler_output': outputs.pooler_output, 'token_mask': batch_encoding['attention_mask'].to(self.device), 'tokens': batch_encoding["input_ids"].to(self.device)}
37
+
38
+ def encode_from_token(self, tokens):
39
+ tokens = tokens.to(self.device)
40
+ outputs = self.transformer(input_ids=tokens)
41
+
42
+ z = outputs.last_hidden_state
43
+ return z
44
+
45
+ def encode(self, text):
46
+ return self(text)
47
+
48
+
49
+ class FrozenCLIPTokenizer(AbstractEncoder):
50
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
51
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
52
+ super().__init__()
53
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
54
+ self.max_length = max_length
55
+ self.freeze()
56
+
57
+ def freeze(self):
58
+ for param in self.parameters():
59
+ param.requires_grad = False
60
+
61
+ def forward(self, text):
62
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
63
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
64
+ tokens = batch_encoding["input_ids"]
65
+ return tokens
66
+
67
+ def encode(self, text):
68
+ return self(text)
libs/model/axial_rope.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch._dynamo
5
+ from torch import nn
6
+
7
+ from . import flags
8
+
9
+ if flags.get_use_compile():
10
+ torch._dynamo.config.suppress_errors = True
11
+
12
+
13
+ def rotate_half(x):
14
+ x1, x2 = x[..., 0::2], x[..., 1::2]
15
+ x = torch.stack((-x2, x1), dim=-1)
16
+ *shape, d, r = x.shape
17
+ return x.view(*shape, d * r)
18
+
19
+
20
+ @flags.compile_wrap
21
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0):
22
+ freqs = freqs.to(t)
23
+ rot_dim = freqs.shape[-1]
24
+ end_index = start_index + rot_dim
25
+ assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
26
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
27
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
28
+ return torch.cat((t_left, t, t_right), dim=-1)
29
+
30
+
31
+ def centers(start, stop, num, dtype=None, device=None):
32
+ edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
33
+ return (edges[:-1] + edges[1:]) / 2
34
+
35
+
36
+ def make_grid(h_pos, w_pos):
37
+ grid = torch.stack(torch.meshgrid(h_pos, w_pos, indexing='ij'), dim=-1)
38
+ h, w, d = grid.shape
39
+ return grid.view(h * w, d)
40
+
41
+
42
+ def bounding_box(h, w, pixel_aspect_ratio=1.0):
43
+ # Adjusted dimensions
44
+ w_adj = w
45
+ h_adj = h * pixel_aspect_ratio
46
+
47
+ # Adjusted aspect ratio
48
+ ar_adj = w_adj / h_adj
49
+
50
+ # Determine bounding box based on the adjusted aspect ratio
51
+ y_min, y_max, x_min, x_max = -1.0, 1.0, -1.0, 1.0
52
+ if ar_adj > 1:
53
+ y_min, y_max = -1 / ar_adj, 1 / ar_adj
54
+ elif ar_adj < 1:
55
+ x_min, x_max = -ar_adj, ar_adj
56
+
57
+ return y_min, y_max, x_min, x_max
58
+
59
+
60
+ def make_axial_pos(h, w, pixel_aspect_ratio=1.0, align_corners=False, dtype=None, device=None):
61
+ y_min, y_max, x_min, x_max = bounding_box(h, w, pixel_aspect_ratio)
62
+ if align_corners:
63
+ h_pos = torch.linspace(y_min, y_max, h, dtype=dtype, device=device)
64
+ w_pos = torch.linspace(x_min, x_max, w, dtype=dtype, device=device)
65
+ else:
66
+ h_pos = centers(y_min, y_max, h, dtype=dtype, device=device)
67
+ w_pos = centers(x_min, x_max, w, dtype=dtype, device=device)
68
+ return make_grid(h_pos, w_pos)
69
+
70
+
71
+ def freqs_pixel(max_freq=10.0):
72
+ def init(shape):
73
+ freqs = torch.linspace(1.0, max_freq / 2, shape[-1]) * math.pi
74
+ return freqs.log().expand(shape)
75
+ return init
76
+
77
+
78
+ def freqs_pixel_log(max_freq=10.0):
79
+ def init(shape):
80
+ log_min = math.log(math.pi)
81
+ log_max = math.log(max_freq * math.pi / 2)
82
+ return torch.linspace(log_min, log_max, shape[-1]).expand(shape)
83
+ return init
84
+
85
+
86
+ class AxialRoPE(nn.Module):
87
+ def __init__(self, dim, n_heads, start_index=0, freqs_init=freqs_pixel_log(max_freq=10.0)):
88
+ super().__init__()
89
+ self.n_heads = n_heads
90
+ self.start_index = start_index
91
+ log_freqs = freqs_init((n_heads, dim // 4))
92
+ self.freqs_h = nn.Parameter(log_freqs.clone())
93
+ self.freqs_w = nn.Parameter(log_freqs.clone())
94
+
95
+ def extra_repr(self):
96
+ dim = (self.freqs_h.shape[-1] + self.freqs_w.shape[-1]) * 2
97
+ return f"dim={dim}, n_heads={self.n_heads}, start_index={self.start_index}"
98
+
99
+ def get_freqs(self, pos):
100
+ if pos.shape[-1] != 2:
101
+ raise ValueError("input shape must be (..., 2)")
102
+ freqs_h = pos[..., None, None, 0] * self.freqs_h.exp()
103
+ freqs_w = pos[..., None, None, 1] * self.freqs_w.exp()
104
+ freqs = torch.cat((freqs_h, freqs_w), dim=-1).repeat_interleave(2, dim=-1)
105
+ return freqs.transpose(-2, -3)
106
+
107
+ def forward(self, x, pos):
108
+ freqs = self.get_freqs(pos)
109
+ return apply_rotary_emb(freqs, x, self.start_index)
libs/model/common_layers.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from timm.models.layers import trunc_normal_
6
+
7
+ class Linear(nn.Linear):
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ trunc_normal_(self.weight, mean = 0, std = 0.02)
11
+ if self.bias is not None:
12
+ nn.init.zeros_(self.bias)
13
+
14
+ class LayerNorm(nn.LayerNorm):
15
+ def __init__(self, *args, **kwargs):
16
+ super().__init__(*args, **kwargs)
17
+ trunc_normal_(self.weight, mean = 0, std = 0.02)
18
+ if self.bias is not None:
19
+ nn.init.zeros_(self.bias)
20
+
21
+ class Conv2d(nn.Conv2d):
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ trunc_normal_(self.weight, mean = 0, std = 0.02)
25
+ if self.bias is not None:
26
+ nn.init.zeros_(self.bias)
27
+
28
+ class Embedding(nn.Embedding):
29
+ def __init__(self, *args, **kwargs):
30
+ super().__init__(*args, **kwargs)
31
+ trunc_normal_(self.weight, mean = 0, std = 0.02)
32
+
33
+ class ImageNorm(nn.Module):
34
+ def forward(self, x):
35
+ assert x.dim() == 4
36
+ eps = 1e-05
37
+ x = x / (x.var(dim = (1, 2, 3), keepdim = True) + eps).sqrt()
38
+ return x
39
+
40
+ class Flatten(nn.Module):
41
+ def forward(self, x):
42
+ B, H, W, C = x.shape
43
+ x = x.reshape(B, H * W, C)
44
+ return x
45
+
46
+ class ChannelLast(nn.Module):
47
+ def forward(self, x):
48
+ assert x.dim() == 4
49
+ x = x.permute(0, 2, 3, 1) # [B, H, W, C]
50
+ return x
51
+
52
+ class ChannelFirst(nn.Module):
53
+ def forward(self, x):
54
+ assert x.dim() == 4
55
+ x = x.permute(0, 3, 1, 2) # [B, C, H, W]
56
+ return x
57
+
58
+ class OddUpInterpolate(nn.Module):
59
+ def __init__(self, ratio):
60
+ super().__init__()
61
+ self.ratio = ratio
62
+
63
+ def forward(self, x):
64
+ if self.ratio == 1:
65
+ return x
66
+ assert x.dim() == 4
67
+ B, C, H, W = x.shape
68
+ x = F.interpolate(x, size = ((H - 1) * self.ratio + 1, (W - 1) * self.ratio + 1), mode = "bilinear", align_corners = True)
69
+ return x
70
+
71
+ def __repr__(self):
72
+ return f"UpInterpolate(ratio={self.ratio})"
73
+
74
+ class OddDownInterpolate(nn.Module):
75
+ def __init__(self, ratio):
76
+ super().__init__()
77
+ self.ratio = ratio
78
+
79
+ def forward(self, x):
80
+ if self.ratio == 1:
81
+ return x
82
+ assert x.dim() == 4
83
+ B, C, H, W = x.shape
84
+ x = F.interpolate(x, size = ((H - 1) // self.ratio + 1, (W - 1) // self.ratio + 1), mode = "area")
85
+ return x
86
+
87
+ def __repr__(self):
88
+ return f"DownInterpolate(ratio={self.ratio})"
89
+
90
+ class EvenDownInterpolate(nn.Module):
91
+ def __init__(self, ratio):
92
+ super().__init__()
93
+ self.ratio = ratio
94
+
95
+ def forward(self, x):
96
+ if self.ratio == 1:
97
+ return x
98
+ assert len(x.shape) == 4
99
+ B, C, H, W = x.shape
100
+ x = F.interpolate(x, size = (H // self.ratio, W // self.ratio), mode = "area")
101
+ return x
102
+
103
+ def __repr__(self):
104
+ return f"DownInterpolate(ratio={self.ratio})"
libs/model/dimr_t2i.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from re import A
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ import math
6
+ import einops
7
+ import torch.utils.checkpoint
8
+ from functools import partial
9
+ import open_clip
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+ import torch.nn.functional as F
14
+ import timm
15
+ from timm.models.layers import trunc_normal_, Mlp
16
+ from .sigmoid.module import LayerNorm, RMSNorm, AdaRMSNorm, TDRMSNorm, QKNorm, TimeDependentParameter
17
+ from .common_layers import Linear, EvenDownInterpolate, ChannelFirst, ChannelLast, Embedding
18
+ from .axial_rope import AxialRoPE, make_axial_pos
19
+ from .trans_autoencoder import TransEncoder, Adaptor
20
+
21
+ def check_zip(*args):
22
+ args = [list(arg) for arg in args]
23
+ length = len(args[0])
24
+ for arg in args:
25
+ assert len(arg) == length
26
+ return zip(*args)
27
+
28
+ class PixelShuffleUpsample(nn.Module):
29
+ def __init__(self, dim_in, dim_out, ratio = 2):
30
+ super().__init__()
31
+ self.ratio = ratio
32
+ self.kernel = Linear(dim_in, dim_out * self.ratio * self.ratio)
33
+
34
+ def forward(self, x):
35
+ x = self.kernel(x)
36
+ B, H, W, C = x.shape
37
+ x = x.reshape(B, H, W, self.ratio, self.ratio, C // self.ratio // self.ratio)
38
+ x = x.transpose(2, 3)
39
+ x = x.reshape(B, H * self.ratio, W * self.ratio, C // self.ratio // self.ratio)
40
+ return x
41
+
42
+ class PositionEmbeddings(nn.Module):
43
+ def __init__(self, max_height, max_width, dim):
44
+ super().__init__()
45
+ self.max_height = max_height
46
+ self.max_width = max_width
47
+ self.position_embeddings = Embedding(self.max_height * self.max_width, dim)
48
+
49
+ def forward(self, x):
50
+ B, H, W, C = x.shape
51
+ height_idxes = torch.arange(H, device = x.device)[:, None].repeat(1, W)
52
+ width_idxes = torch.arange(W, device = x.device)[None, :].repeat(H, 1)
53
+ idxes = height_idxes * self.max_width + width_idxes
54
+ x = x + self.position_embeddings(idxes[None])
55
+ return x
56
+
57
+ class TextPositionEmbeddings(nn.Module):
58
+ def __init__(self, num_embeddings, embedding_dim):
59
+ super().__init__()
60
+ self.embedding = Embedding(num_embeddings, embedding_dim)
61
+
62
+ def forward(self, x):
63
+ batch_size, num_embeddings, embedding_dim = x.shape
64
+ # positions = torch.arange(height * width, device=x.device).reshape(1, height, width)
65
+ positions = torch.arange(num_embeddings, device=x.device).unsqueeze(0).expand(batch_size, num_embeddings)
66
+ x = x + self.embedding(positions)
67
+ return x
68
+
69
+
70
+ class MLPBlock(nn.Module):
71
+ def __init__(self, config):
72
+ super().__init__()
73
+ if config.norm_type == 'LN':
74
+ self.norm_type = 'LN'
75
+ self.norm = LayerNorm(config.dim)
76
+ elif config.norm_type == 'RMSN':
77
+ self.norm_type = 'RMSN'
78
+ self.norm = RMSNorm(config.dim)
79
+ elif config.norm_type == 'TDRMSN':
80
+ self.norm_type = 'TDRMSN'
81
+ self.norm = TDRMSNorm(config.dim)
82
+ elif config.norm_type == 'ADARMSN':
83
+ self.norm_type = 'ADARMSN'
84
+ self.norm = AdaRMSNorm(config.dim, config.dim)
85
+ self.act = nn.GELU()
86
+ self.w0 = Linear(config.dim, config.hidden_dim)
87
+ self.w1 = Linear(config.dim, config.hidden_dim)
88
+ self.w2 = Linear(config.hidden_dim, config.dim)
89
+
90
+ def forward(self, x):
91
+ if self.norm_type == 'LN' or self.norm_type == 'RMSN' or self.norm_type == 'TDRMSN':
92
+ x = self.norm(x)
93
+ elif self.norm_type == 'ADARMSN':
94
+ condition = x[:,0]
95
+ x = self.norm(x, condition)
96
+ x = self.act(self.w0(x)) * self.w1(x)
97
+ x = self.w2(x)
98
+ return x
99
+
100
+ class SelfAttention(nn.Module):
101
+ def __init__(self, config):
102
+ super().__init__()
103
+ assert config.dim % config.num_attention_heads == 0
104
+
105
+ self.num_heads = config.num_attention_heads
106
+ self.head_dim = config.dim // config.num_attention_heads
107
+
108
+ if hasattr(config, "self_att_prompt") and config.self_att_prompt:
109
+ self.condition_key_value = Linear(config.clip_dim, 2 * config.dim, bias = False)
110
+
111
+ if config.norm_type == 'LN':
112
+ self.norm_type = 'LN'
113
+ self.norm = LayerNorm(config.dim)
114
+ elif config.norm_type == 'RMSN':
115
+ self.norm_type = 'RMSN'
116
+ self.norm = RMSNorm(config.dim)
117
+ elif config.norm_type == 'TDRMSN':
118
+ self.norm_type = 'TDRMSN'
119
+ self.norm = TDRMSNorm(config.dim)
120
+ elif config.norm_type == 'ADARMSN':
121
+ self.norm_type = 'ADARMSN'
122
+ self.norm = AdaRMSNorm(config.dim, config.dim)
123
+
124
+ self.pe_type = config.pe_type
125
+ if config.pe_type == 'Axial_RoPE':
126
+ self.pos_emb = AxialRoPE(self.head_dim, self.num_heads)
127
+ self.qk_norm = QKNorm(self.num_heads)
128
+
129
+ self.query_key_value = Linear(config.dim, 3 * config.dim, bias = False)
130
+ self.dense = Linear(config.dim, config.dim)
131
+
132
+ def forward(self, x, condition_embeds, condition_masks, pos=None):
133
+ B, N, C = x.shape
134
+
135
+ if self.norm_type == 'LN' or self.norm_type == 'RMSN' or self.norm_type == 'TDRMSN':
136
+ qkv = self.query_key_value(self.norm(x))
137
+ elif self.norm_type == 'ADARMSN':
138
+ condition = x[:,0]
139
+ qkv = self.query_key_value(self.norm(x, condition))
140
+ q, k, v = qkv.reshape(B, N, 3 * self.num_heads, self.head_dim).permute(0, 2, 1, 3).float().chunk(3, dim = 1)
141
+
142
+ if self.pe_type == 'Axial_RoPE':
143
+ q = self.pos_emb(self.qk_norm(q), pos)
144
+ k = self.pos_emb(self.qk_norm(k), pos)
145
+
146
+ if condition_embeds is not None:
147
+ _, L, D = condition_embeds.shape
148
+ kcvc = self.condition_key_value(condition_embeds)
149
+ kc, vc = kcvc.reshape(B, L, 2 * self.num_heads, self.head_dim).permute(0, 2, 1, 3).float().chunk(2, dim = 1)
150
+ k = torch.cat([k, kc], dim = 2)
151
+ v = torch.cat([v, vc], dim = 2)
152
+ mask = torch.cat([torch.ones(B, N, dtype = torch.bool, device = condition_masks.device), condition_masks], dim = -1)
153
+ mask = mask[:, None, None, :]
154
+ else:
155
+ mask = None
156
+
157
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask = mask)
158
+ x = self.dense(x.permute(0, 2, 1, 3).reshape(B, N, C))
159
+
160
+ return x
161
+
162
+ class TransformerBlock(nn.Module):
163
+ def __init__(self, config):
164
+ super().__init__()
165
+ self.block1 = SelfAttention(config)
166
+ self.block2 = MLPBlock(config)
167
+ self.dropout = nn.Dropout(config.dropout_prob)
168
+ self.gradient_checking = config.gradient_checking
169
+
170
+ def forward(self, x, condition_embeds, condition_masks, pos):
171
+ if self.gradient_checking:
172
+ return torch.utils.checkpoint.checkpoint(self._forward, x, condition_embeds, condition_masks, pos)
173
+ else:
174
+ return self._forward(x, condition_embeds, condition_masks, pos)
175
+
176
+ def _forward(self, x, condition_embeds, condition_masks, pos):
177
+ x = x + self.dropout(self.block1(x, condition_embeds, condition_masks, pos))
178
+ x = x + self.dropout(self.block2(x))
179
+ return x
180
+
181
+ class ConvNeXtBlock(nn.Module):
182
+ def __init__(self, config):
183
+ super().__init__()
184
+ self.block1 = nn.Sequential(
185
+ ChannelFirst(),
186
+ nn.Conv2d(config.dim, config.dim, kernel_size = config.kernel_size, padding = config.kernel_size // 2, stride = 1, groups = config.dim),
187
+ ChannelLast()
188
+ )
189
+ self.block2 = MLPBlock(config)
190
+ self.dropout = nn.Dropout(config.dropout_prob)
191
+ self.gradient_checking = config.gradient_checking
192
+
193
+ def forward(self, x, condition_embeds, condition_masks, pos):
194
+ if self.gradient_checking:
195
+ return torch.utils.checkpoint.checkpoint(self._forward, x)
196
+ else:
197
+ return self._forward(x)
198
+
199
+ def _forward(self, x):
200
+ x = x + self.dropout(self.block1(x))
201
+ x = x + self.dropout(self.block2(x))
202
+ return x
203
+
204
+
205
+ class Stage(nn.Module):
206
+ def __init__(self, channels, config, lowres_dim = None, lowres_height = None):
207
+ super().__init__()
208
+ if config.block_type == "TransformerBlock":
209
+ self.encoder_cls = TransformerBlock
210
+ elif config.block_type == "ConvNeXtBlock":
211
+ self.encoder_cls = ConvNeXtBlock
212
+ else:
213
+ raise Exception()
214
+
215
+ self.pe_type = config.pe_type
216
+
217
+ self.input_layer = nn.Sequential(
218
+ EvenDownInterpolate(config.image_input_ratio),
219
+ nn.Conv2d(channels, config.dim, kernel_size = config.input_feature_ratio, stride = config.input_feature_ratio),
220
+ ChannelLast(),
221
+ PositionEmbeddings(config.max_height, config.max_width, config.dim)
222
+ )
223
+
224
+
225
+ if lowres_dim is not None:
226
+ ratio = config.max_height // lowres_height
227
+ self.upsample = nn.Sequential(
228
+ LayerNorm(lowres_dim),
229
+ PixelShuffleUpsample(lowres_dim, config.dim, ratio = ratio),
230
+ LayerNorm(config.dim),
231
+ )
232
+
233
+ self.blocks = nn.ModuleList([self.encoder_cls(config) for _ in range(config.num_blocks // 2 * 2 + 1)])
234
+ self.skip_denses = nn.ModuleList([Linear(config.dim * 2, config.dim) for _ in range(config.num_blocks // 2)])
235
+
236
+ self.output_layer = nn.Sequential(
237
+ LayerNorm(config.dim),
238
+ ChannelFirst(),
239
+ nn.Conv2d(config.dim, channels, kernel_size = config.final_kernel_size, padding = config.final_kernel_size // 2),
240
+ )
241
+
242
+ self.tensor_true = torch.nn.Parameter(torch.tensor([-1.0])) if self.encoder_cls is TransformerBlock else None
243
+ self.tensor_false = torch.nn.Parameter(torch.tensor([1.0])) if self.encoder_cls is TransformerBlock else None
244
+
245
+
246
+
247
+
248
+ def forward(self, images, lowres_skips = None, condition_context = None, condition_embeds = None, condition_masks = None, null_indicator=None):
249
+ if self.pe_type == 'Axial_RoPE' and self.encoder_cls is TransformerBlock:
250
+ x = self.input_layer(images)
251
+ _, H, W, _ = x.shape
252
+ pos = make_axial_pos(H, W)
253
+ else:
254
+ x = self.input_layer(images)
255
+ pos = None
256
+
257
+ if lowres_skips is not None:
258
+ x = x + self.upsample(lowres_skips)
259
+
260
+ if self.encoder_cls is TransformerBlock:
261
+ B, H, W, C = x.shape
262
+ x = x.reshape(B, H * W, C)
263
+
264
+ if null_indicator is not None:
265
+ indicator_tensor = torch.where(null_indicator, self.tensor_true, self.tensor_false)
266
+ indicator_tensor = indicator_tensor.view(B, 1, 1).expand(-1, -1, C)
267
+
268
+ x = torch.cat([indicator_tensor, x], dim = 1)
269
+
270
+ external_skips = [x]
271
+
272
+ num_blocks = len(self.blocks)
273
+ in_blocks = self.blocks[:(num_blocks // 2)]
274
+ mid_block = self.blocks[(num_blocks // 2)]
275
+ out_blocks = self.blocks[(num_blocks // 2 + 1):]
276
+
277
+ skips = []
278
+ for block in in_blocks:
279
+ x = block(x, condition_embeds, condition_masks, pos=pos)
280
+ external_skips.append(x)
281
+ skips.append(x)
282
+
283
+ x = mid_block(x, condition_embeds, condition_masks, pos=pos)
284
+ external_skips.append(x)
285
+
286
+ for dense, block in check_zip(self.skip_denses, out_blocks):
287
+ x = dense(torch.cat([x, skips.pop()], dim = -1))
288
+ x = block(x, condition_embeds, condition_masks, pos=pos)
289
+ external_skips.append(x)
290
+
291
+ if self.encoder_cls is TransformerBlock:
292
+
293
+ if null_indicator is not None:
294
+ x = x[:, 1:, :]
295
+ external_skips = [skip[:, 1:, :] for skip in external_skips]
296
+
297
+ x = x.reshape(B, H, W, C)
298
+ external_skips = [skip.reshape(B, H, W, C) for skip in external_skips]
299
+
300
+ output = self.output_layer(x)
301
+
302
+ return output, external_skips
303
+
304
+
305
+ class MRModel(nn.Module):
306
+ def __init__(self, config):
307
+ super().__init__()
308
+ self.channels = config.channels
309
+ self.block_grad_to_lowres = config.block_grad_to_lowres
310
+
311
+ for stage_config in config.stage_configs:
312
+ if hasattr(config, "use_t2i"):
313
+ stage_config.use_t2i = config.use_t2i
314
+ if hasattr(config, "clip_dim"):
315
+ stage_config.clip_dim = config.clip_dim
316
+ if hasattr(config, "num_clip_token"):
317
+ stage_config.num_clip_token = config.num_clip_token
318
+ if hasattr(config, "gradient_checking"):
319
+ stage_config.gradient_checking = config.gradient_checking
320
+ if hasattr(config, "pe_type"):
321
+ stage_config.pe_type = config.pe_type
322
+ else:
323
+ stage_config.pe_type = 'APE'
324
+ if hasattr(config, "norm_type"):
325
+ stage_config.norm_type = config.norm_type
326
+ else:
327
+ stage_config.norm_type = 'LN'
328
+
329
+
330
+ #### diffusion model
331
+ if hasattr(config, "not_training_diff") and config.not_training_diff:
332
+ self.has_diff = False
333
+ else:
334
+ self.has_diff = True
335
+
336
+ lowres_dims = [None] + [stage_config.dim * (stage_config.num_blocks // 2 * 2 + 2) for stage_config in config.stage_configs[:-1]]
337
+ lowres_heights = [None] + [stage_config.max_height for stage_config in config.stage_configs[:-1]]
338
+ self.stages = nn.ModuleList([
339
+ Stage(self.channels, stage_config, lowres_dim = lowres_dim, lowres_height=lowres_height)
340
+ for stage_config, lowres_dim, lowres_height in check_zip(config.stage_configs, lowres_dims, lowres_heights)]
341
+ )
342
+
343
+
344
+ #### Text VE
345
+ if hasattr(config.textVAE, "num_down_sample_block"):
346
+ down_sample_block = config.textVAE.num_down_sample_block
347
+ else:
348
+ down_sample_block = 3
349
+
350
+ self.context_encoder = TransEncoder(d_model=config.clip_dim, N=config.textVAE.num_blocks, num_token=config.num_clip_token,
351
+ head_num=config.textVAE.num_attention_heads, d_ff=config.textVAE.hidden_dim,
352
+ latten_size=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width * 2,
353
+ down_sample_block=down_sample_block, dropout=config.textVAE.dropout_prob, last_norm=False)
354
+
355
+
356
+
357
+ #### image encoder to train VE
358
+ self.open_clip, _, self.open_clip_preprocess = open_clip.create_model_and_transforms('ViT-L-16-SigLIP-256', pretrained=None)
359
+ if config.stage_configs[-1].max_width==32:
360
+ # for 256px generation
361
+ self.open_clip_output = Mlp(in_features=1024,
362
+ hidden_features=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width,
363
+ out_features=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width,
364
+ norm_layer=nn.LayerNorm,
365
+ )
366
+ else:
367
+ # for 512px generation
368
+ self.open_clip_output = Adaptor(input_dim=1024,
369
+ tar_dim=config.channels*config.stage_configs[-1].max_height*config.stage_configs[-1].max_width
370
+ )
371
+ del self.open_clip.text
372
+ del self.open_clip.logit_bias
373
+
374
+
375
+ def _forward(self, images, log_snr, condition_context = None, condition_text_embeds = None, condition_text_masks = None, condition_drop_prob = None, null_indicator=None):
376
+ if self.has_diff:
377
+ TimeDependentParameter.seed_time(self, log_snr)
378
+
379
+ assert condition_context is None
380
+ assert condition_text_embeds is None
381
+
382
+ if condition_text_embeds is not None:
383
+ condition_embeds = self.text_conditioning(condition_text_embeds)
384
+ condition_masks = condition_text_masks
385
+ else:
386
+ condition_embeds = None
387
+ condition_masks = None
388
+
389
+ outputs = []
390
+ lowres_skips = None
391
+ for stage in self.stages:
392
+ output, lowres_skips = stage(images, lowres_skips = lowres_skips, condition_context = condition_context, condition_embeds = condition_embeds, condition_masks = condition_masks, null_indicator=null_indicator)
393
+ outputs.append(output)
394
+ lowres_skips = torch.cat(lowres_skips, dim = -1)
395
+ if self.block_grad_to_lowres:
396
+ lowres_skips = lowres_skips.detach()
397
+
398
+ return outputs
399
+
400
+ else:
401
+ return [images]
402
+
403
+
404
+ def _reparameterize(self, mu, logvar):
405
+ std = torch.exp(0.5 * logvar)
406
+ eps = torch.randn_like(std)
407
+ return eps * std + mu
408
+
409
+ def _text_encoder(self, condition_context, tar_shape, mask):
410
+
411
+ output = self.context_encoder(condition_context, mask)
412
+ mu, log_var = torch.chunk(output, 2, dim=-1)
413
+
414
+ z = self._reparameterize(mu, log_var)
415
+
416
+ return [z, mu, log_var]
417
+
418
+ def _text_decoder(self, condition_enbedding, tar_shape):
419
+
420
+ context_token = self.context_decoder(condition_enbedding)
421
+
422
+ return context_token
423
+
424
+ def _img_clip(self, image_input):
425
+
426
+ image_latent = self.open_clip.encode_image(image_input)
427
+ image_latent = self.open_clip_output(image_latent)
428
+
429
+ return image_latent, self.open_clip.logit_scale
430
+
431
+
432
+
433
+ def forward(self, x, t = None, log_snr = None, text_encoder=False, text_decoder=False, image_clip=False, shape=None, mask=None, null_indicator=None):
434
+ if text_encoder:
435
+ return self._text_encoder(condition_context = x, tar_shape=shape, mask=mask)
436
+ elif text_decoder:
437
+ return self._text_decoder(condition_enbedding = x, tar_shape=shape) # mask is not needed for decoder
438
+ elif image_clip:
439
+ return self._img_clip(image_input = x)
440
+ else:
441
+ assert log_snr.dtype == torch.float32
442
+ return self._forward(images = x, log_snr = log_snr, null_indicator=null_indicator)
443
+
libs/model/dit_t2i.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
2
+ # --------------------------------------------------------
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ import math
8
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
9
+
10
+ import open_clip
11
+ import torch.utils.checkpoint
12
+
13
+ from .trans_autoencoder import TransEncoder, Adaptor
14
+
15
+
16
+ def modulate(x, shift, scale):
17
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
18
+
19
+
20
+ #################################################################################
21
+ # Embedding Layers for Timesteps and Class Labels #
22
+ #################################################################################
23
+
24
+ class TimestepEmbedder(nn.Module):
25
+ """
26
+ Embeds scalar timesteps into vector representations.
27
+ """
28
+ def __init__(self, hidden_size, frequency_embedding_size=256):
29
+ super().__init__()
30
+ self.mlp = nn.Sequential(
31
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
32
+ nn.SiLU(),
33
+ nn.Linear(hidden_size, hidden_size, bias=True),
34
+ )
35
+ self.frequency_embedding_size = frequency_embedding_size
36
+
37
+ @staticmethod
38
+ def timestep_embedding(t, dim, max_period=10000):
39
+ """
40
+ Create sinusoidal timestep embeddings.
41
+ :param t: a 1-D Tensor of N indices, one per batch element.
42
+ These may be fractional.
43
+ :param dim: the dimension of the output.
44
+ :param max_period: controls the minimum frequency of the embeddings.
45
+ :return: an (N, D) Tensor of positional embeddings.
46
+ """
47
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
48
+ half = dim // 2
49
+ freqs = torch.exp(
50
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
51
+ ).to(device=t.device)
52
+ args = t[:, None].float() * freqs[None]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if dim % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ return embedding
57
+
58
+ def forward(self, t):
59
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
60
+ t_emb = self.mlp(t_freq)
61
+ return t_emb
62
+
63
+
64
+ class LabelEmbedder(nn.Module):
65
+ """
66
+ CrossFlow: update it for CFG with indicator
67
+ """
68
+ def __init__(self, num_classes, hidden_size):
69
+ super().__init__()
70
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
71
+
72
+ def forward(self, labels):
73
+ embeddings = self.embedding_table(labels.int())
74
+ return embeddings
75
+
76
+
77
+ #################################################################################
78
+ # Core DiT Model #
79
+ #################################################################################
80
+
81
+ class DiTBlock(nn.Module):
82
+ """
83
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
84
+ """
85
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
86
+ super().__init__()
87
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
88
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
89
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
90
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
91
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
92
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
93
+ self.adaLN_modulation = nn.Sequential(
94
+ nn.SiLU(),
95
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
96
+ )
97
+
98
+ def forward(self, x, c):
99
+ return torch.utils.checkpoint.checkpoint(self._forward, x, c)
100
+ # return self._forward(x, c)
101
+
102
+ def _forward(self, x, c):
103
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
104
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
105
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
106
+ return x
107
+
108
+
109
+ class FinalLayer(nn.Module):
110
+ """
111
+ The final layer of DiT.
112
+ """
113
+ def __init__(self, hidden_size, patch_size, out_channels):
114
+ super().__init__()
115
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
116
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
117
+ self.adaLN_modulation = nn.Sequential(
118
+ nn.SiLU(),
119
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
120
+ )
121
+
122
+ def forward(self, x, c):
123
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
124
+ x = modulate(self.norm_final(x), shift, scale)
125
+ x = self.linear(x)
126
+ return x
127
+
128
+
129
+ class DiT(nn.Module):
130
+ """
131
+ Diffusion model with a Transformer backbone.
132
+ """
133
+ def __init__(
134
+ self,
135
+ config,
136
+ patch_size=2,
137
+ hidden_size=1152,
138
+ depth=28,
139
+ num_heads=16,
140
+ mlp_ratio=4.0,
141
+ num_classes=2, # for cfg indicator
142
+ ):
143
+ super().__init__()
144
+ self.input_size = config.latent_size
145
+ self.learn_sigma = config.learn_sigma
146
+ self.in_channels = config.channels
147
+ self.out_channels = self.in_channels * 2 if self.learn_sigma else self.in_channels
148
+ self.patch_size = patch_size
149
+ self.num_heads = num_heads
150
+
151
+ self.x_embedder = PatchEmbed(self.input_size, patch_size, self.in_channels, hidden_size, bias=True)
152
+ self.t_embedder = TimestepEmbedder(hidden_size)
153
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size)
154
+ num_patches = self.x_embedder.num_patches
155
+ # Will use fixed sin-cos embedding:
156
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
157
+
158
+ self.blocks = nn.ModuleList([
159
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
160
+ ])
161
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
162
+ self.initialize_weights()
163
+
164
+ ######### CrossFlow related
165
+ if hasattr(config.textVAE, "num_down_sample_block"):
166
+ down_sample_block = config.textVAE.num_down_sample_block
167
+ else:
168
+ down_sample_block = 3
169
+ self.context_encoder = TransEncoder(d_model=config.clip_dim, N=config.textVAE.num_blocks, num_token=config.num_clip_token,
170
+ head_num=config.textVAE.num_attention_heads, d_ff=config.textVAE.hidden_dim,
171
+ latten_size=config.channels * config.latent_size * config.latent_size * 2,
172
+ down_sample_block=down_sample_block, dropout=config.textVAE.dropout_prob, last_norm=False)
173
+
174
+
175
+ self.open_clip, _, self.open_clip_preprocess = open_clip.create_model_and_transforms('ViT-L-16-SigLIP-256', pretrained=None)
176
+ self.open_clip_output = Adaptor(input_dim=1024,
177
+ tar_dim=config.channels * config.latent_size * config.latent_size
178
+ )
179
+ del self.open_clip.text
180
+ del self.open_clip.logit_bias
181
+
182
+
183
+
184
+ def initialize_weights(self):
185
+ # Initialize transformer layers:
186
+ def _basic_init(module):
187
+ if isinstance(module, nn.Linear):
188
+ torch.nn.init.xavier_uniform_(module.weight)
189
+ if module.bias is not None:
190
+ nn.init.constant_(module.bias, 0)
191
+ self.apply(_basic_init)
192
+
193
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
194
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
195
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
196
+
197
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
198
+ w = self.x_embedder.proj.weight.data
199
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
200
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
201
+
202
+ # Initialize label embedding table:
203
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
204
+
205
+ # Initialize timestep embedding MLP:
206
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
207
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
208
+
209
+ # Zero-out adaLN modulation layers in DiT blocks:
210
+ for block in self.blocks:
211
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
212
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
213
+
214
+ # Zero-out output layers:
215
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
216
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
217
+ nn.init.constant_(self.final_layer.linear.weight, 0)
218
+ nn.init.constant_(self.final_layer.linear.bias, 0)
219
+
220
+ def unpatchify(self, x):
221
+ """
222
+ x: (N, T, patch_size**2 * C)
223
+ imgs: (N, H, W, C)
224
+ """
225
+ c = self.out_channels
226
+ p = self.x_embedder.patch_size[0]
227
+ h = w = int(x.shape[1] ** 0.5)
228
+ assert h * w == x.shape[1]
229
+
230
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
231
+ x = torch.einsum('nhwpqc->nchpwq', x)
232
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
233
+ return imgs
234
+
235
+ def _forward(self, x, t, null_indicator):
236
+ """
237
+ Forward pass of DiT.
238
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
239
+ t: (N,) tensor of diffusion timesteps
240
+ """
241
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
242
+ t = self.t_embedder(t) # (N, D)
243
+ y = self.y_embedder(null_indicator) # (N, D)
244
+ c = t + y # (N, D)
245
+ for block in self.blocks:
246
+ x = block(x, c) # (N, T, D)
247
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
248
+ x = self.unpatchify(x) # (N, out_channels, H, W)
249
+ return [x]
250
+
251
+ def _forward_with_cfg(self, x, t, cfg_scale):
252
+ """
253
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
254
+ """
255
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
256
+ half = x[: len(x) // 2]
257
+ combined = torch.cat([half, half], dim=0)
258
+ model_out = self.forward(combined, t)
259
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
260
+ # three channels by default. The standard approach to cfg applies it to all channels.
261
+ # This can be done by uncommenting the following line and commenting-out the line following that.
262
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
263
+ eps, rest = model_out[:, :3], model_out[:, 3:]
264
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
265
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
266
+ eps = torch.cat([half_eps, half_eps], dim=0)
267
+ return torch.cat([eps, rest], dim=1)
268
+
269
+ def _reparameterize(self, mu, logvar):
270
+ std = torch.exp(0.5 * logvar)
271
+ eps = torch.randn_like(std)
272
+ return eps * std + mu
273
+
274
+ def _text_encoder(self, condition_context, tar_shape, mask):
275
+
276
+ output = self.context_encoder(condition_context, mask)
277
+ mu, log_var = torch.chunk(output, 2, dim=-1)
278
+ z = self._reparameterize(mu, log_var)
279
+
280
+ return [z, mu, log_var]
281
+
282
+ def _img_clip(self, image_input):
283
+
284
+ image_latent = self.open_clip.encode_image(image_input)
285
+ image_latent = self.open_clip_output(image_latent)
286
+
287
+ return image_latent, self.open_clip.logit_scale
288
+
289
+ def forward(self, x, t = None, log_snr = None, text_encoder=False, text_decoder=False, image_clip=False, shape=None, mask=None, null_indicator=None):
290
+ if text_encoder:
291
+ return self._text_encoder(condition_context = x, tar_shape=shape, mask=mask)
292
+ elif text_decoder:
293
+ raise NotImplementedError
294
+ return self._text_decoder(condition_enbedding = x, tar_shape=shape) # mask is not needed for decoder
295
+ elif image_clip:
296
+ return self._img_clip(image_input = x)
297
+ else:
298
+ return self._forward(x = x, t = t, null_indicator=null_indicator)
299
+
300
+
301
+ #################################################################################
302
+ # Sine/Cosine Positional Embedding Functions #
303
+ #################################################################################
304
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
305
+
306
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
307
+ """
308
+ grid_size: int of the grid height and width
309
+ return:
310
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
311
+ """
312
+ grid_h = np.arange(grid_size, dtype=np.float32)
313
+ grid_w = np.arange(grid_size, dtype=np.float32)
314
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
315
+ grid = np.stack(grid, axis=0)
316
+
317
+ grid = grid.reshape([2, 1, grid_size, grid_size])
318
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
319
+ if cls_token and extra_tokens > 0:
320
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
321
+ return pos_embed
322
+
323
+
324
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
325
+ assert embed_dim % 2 == 0
326
+
327
+ # use half of dimensions to encode grid_h
328
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
329
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
330
+
331
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
332
+ return emb
333
+
334
+
335
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
336
+ """
337
+ embed_dim: output dimension for each position
338
+ pos: a list of positions to be encoded: size (M,)
339
+ out: (M, D)
340
+ """
341
+ assert embed_dim % 2 == 0
342
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
343
+ omega /= embed_dim / 2.
344
+ omega = 1. / 10000**omega # (D/2,)
345
+
346
+ pos = pos.reshape(-1) # (M,)
347
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
348
+
349
+ emb_sin = np.sin(out) # (M, D/2)
350
+ emb_cos = np.cos(out) # (M, D/2)
351
+
352
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
353
+ return emb
354
+
355
+
356
+ #################################################################################
357
+ # DiT Configs #
358
+ #################################################################################
359
+
360
+ def DiT_H_2(config, **kwargs):
361
+ return DiT(config=config, depth=36, hidden_size=1280, patch_size=2, num_heads=20, **kwargs)
362
+
363
+ def DiT_XL_2(config, **kwargs):
364
+ return DiT(config=config, depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
365
+
366
+ def DiT_XL_4(config, **kwargs):
367
+ return DiT(config=config, depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
368
+
369
+ def DiT_XL_8(config, **kwargs):
370
+ return DiT(config=config, depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
371
+
372
+ def DiT_L_2(config, **kwargs):
373
+ return DiT(config=config, depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
374
+
375
+ def DiT_L_4(config, **kwargs):
376
+ return DiT(config=config, depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
377
+
378
+ def DiT_L_8(config, **kwargs):
379
+ return DiT(config=config, depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
380
+
381
+ def DiT_B_2(config, **kwargs):
382
+ return DiT(config=config, depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
383
+
384
+ def DiT_B_4(config, **kwargs):
385
+ return DiT(config=config, depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
386
+
387
+ def DiT_B_8(config, **kwargs):
388
+ return DiT(config=config, depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
389
+
390
+ def DiT_S_2(config, **kwargs):
391
+ return DiT(config=config, depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
392
+
393
+ def DiT_S_4(config, **kwargs):
394
+ return DiT(config=config, depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
395
+
396
+ def DiT_S_8(config, **kwargs):
397
+ return DiT(config=config, depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
398
+
399
+
400
+ DiT_models = {
401
+ 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
402
+ 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
403
+ 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
404
+ 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
405
+ }
libs/model/flags.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from functools import update_wrapper
3
+ import os
4
+ import threading
5
+
6
+ import torch
7
+
8
+
9
+ def get_use_compile():
10
+ return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1"
11
+
12
+
13
+ def get_use_flash_attention_2():
14
+ return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1"
15
+
16
+
17
+ state = threading.local()
18
+ state.checkpointing = False
19
+
20
+
21
+ @contextmanager
22
+ def checkpointing(enable=True):
23
+ try:
24
+ old_checkpointing, state.checkpointing = state.checkpointing, enable
25
+ yield
26
+ finally:
27
+ state.checkpointing = old_checkpointing
28
+
29
+
30
+ def get_checkpointing():
31
+ return getattr(state, "checkpointing", False)
32
+
33
+
34
+ class compile_wrap:
35
+ def __init__(self, function, *args, **kwargs):
36
+ self.function = function
37
+ self.args = args
38
+ self.kwargs = kwargs
39
+ self._compiled_function = None
40
+ update_wrapper(self, function)
41
+
42
+ @property
43
+ def compiled_function(self):
44
+ if self._compiled_function is not None:
45
+ return self._compiled_function
46
+ if get_use_compile():
47
+ try:
48
+ self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs)
49
+ except RuntimeError:
50
+ self._compiled_function = self.function
51
+ else:
52
+ self._compiled_function = self.function
53
+ return self._compiled_function
54
+
55
+ def __call__(self, *args, **kwargs):
56
+ return self.compiled_function(*args, **kwargs)
libs/model/sigmoid/kernel.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.cpp_extension import load
5
+ import os
6
+ import time
7
+ import random
8
+ import math
9
+ from torch.utils.checkpoint import checkpoint
10
+ from torch.autograd import Function
11
+ from functools import partial
12
+ import warnings
13
+
14
+ # curr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extension")
15
+ # src_files = ['tdp.cu', 'torch_extension.cpp']
16
+ # src_files = [os.path.join(curr_path, file) for file in src_files]
17
+ # tdp = load('tdp', src_files, verbose = True)
18
+
19
+ # import tdp
20
+
21
+ def exported_tdp(param0, param1, weight, bias, times, custom = True):
22
+ original_shape = param0.shape
23
+ param0 = param0.reshape(-1)
24
+ param1 = param1.reshape(-1)
25
+ weight = weight.reshape(-1)
26
+ bias = bias.reshape(-1)
27
+ if custom and param0.shape[0] % 2 == 0:
28
+ result = TDP.apply(param0, param1, weight, bias, times)
29
+ else:
30
+ warnings.warn(f'Using slower tdp_torch implementation for a tensor with shape {param0.shape}')
31
+ result = tdp_torch(param0, param1, weight, bias, times)
32
+ result = result.reshape(*([times.shape[0]] + [d for d in original_shape]))
33
+ return result
34
+
35
+ class TDP(Function):
36
+ @staticmethod
37
+ def forward(ctx, param0, param1, weight, bias, times):
38
+ assert param0.shape[0] % 2 == 0
39
+ param0 = param0.contiguous()
40
+ param1 = param1.contiguous()
41
+ weight = weight.contiguous()
42
+ bias = bias.contiguous()
43
+ times = times.contiguous()
44
+ assert param0.shape[0] == param1.shape[0] and param0.shape[0] == weight.shape[0] and param0.shape[0] == bias.shape[0]
45
+ assert param0.dim() == 1 and param1.dim() == 1 and weight.dim() == 1 and bias.dim() == 1 and times.dim() == 1
46
+ ctx.save_for_backward(param0, param1, weight, bias, times)
47
+ return tdp_cuda(param0, param1, weight, bias, times)
48
+
49
+ @staticmethod
50
+ def backward(ctx, g_result):
51
+ g_result = g_result.contiguous()
52
+ param0, param1, weight, bias, times = ctx.saved_tensors
53
+ g_param0, g_param1, g_weight, g_bias = backward_tdp_cuda(param0, param1, weight, bias, times, g_result)
54
+ return g_param0, g_param1, g_weight, g_bias, None
55
+
56
+ def backward_tdp_torch(param0, param1, weight, bias, times, g_result):
57
+ param0 = param0[None]
58
+ param1 = param1[None]
59
+ weight = weight[None]
60
+ bias = bias[None]
61
+
62
+ a = times[:, None] * weight + bias
63
+ s = torch.sigmoid(a)
64
+ g_param0 = (s * g_result).sum(0)
65
+ g_param1 = ((1 - s) * g_result).sum(0)
66
+ g_s = (param0 - param1) * g_result
67
+ g_a = g_s * s * (1 - s)
68
+ g_weight = (g_a * times[:, None]).sum(0)
69
+ g_bias = g_a.sum(0)
70
+
71
+ return g_param0, g_param1, g_weight, g_bias
72
+
73
+ def backward_tdp_cuda(param0, param1, weight, bias, times, g_result):
74
+ g_param0 = torch.empty_like(param0)
75
+ g_param1 = torch.empty_like(param0)
76
+ g_weight = torch.empty_like(param0)
77
+ g_bias = torch.empty_like(param0)
78
+ if param0.dtype == torch.half:
79
+ tdp.backward_tdp_fp16(param0, param1, weight, bias, times, g_result, g_param0, g_param1, g_weight, g_bias)
80
+ elif param0.dtype == torch.float:
81
+ tdp.backward_tdp_fp32(param0, param1, weight, bias, times, g_result, g_param0, g_param1, g_weight, g_bias)
82
+ else:
83
+ raise NotImplementedError
84
+ return g_param0, g_param1, g_weight, g_bias
85
+
86
+ def tdp_torch(param0, param1, weight, bias, times):
87
+ a = torch.addcmul(bias[None], times[:, None], weight[None])
88
+ s = torch.sigmoid(a)
89
+ result = torch.addcmul(param1[None], s, param0[None] - param1[None])
90
+ return result
91
+
92
+ def tdp_cuda(param0, param1, weight, bias, times):
93
+ result = torch.empty(times.shape[0], param0.shape[0], dtype = param0.dtype, device = param0.device)
94
+ if param0.dtype == torch.half:
95
+ tdp.tdp_fp16(param0, param1, weight, bias, times, result)
96
+ elif param0.dtype == torch.float:
97
+ tdp.tdp_fp32(param0, param1, weight, bias, times, result)
98
+ else:
99
+ raise NotImplementedError
100
+ return result
101
+
102
+ def corrcoef(x, y):
103
+ return torch.corrcoef(torch.stack([x.reshape(-1).float(), y.reshape(-1).float()], dim = 0))[0, 1]
104
+
105
+ def tdp_cuda_unit_test():
106
+ print("***** tdp_cuda_unit_test *****")
107
+
108
+ batch_size = random.randrange(1, 128)
109
+ num_params = random.randrange(1, 1000000) * 2
110
+ print("batch_size", batch_size, "num_params", num_params)
111
+
112
+ param0 = torch.randn(num_params).cuda()
113
+ param1 = torch.randn(num_params).cuda()
114
+ weight = torch.randn(num_params).cuda()
115
+ bias = torch.randn(num_params).cuda()
116
+ times = torch.rand(batch_size).cuda()
117
+
118
+ ref = tdp_torch(param0, param1, weight, bias, times)
119
+
120
+ out = tdp_cuda(param0, param1, weight, bias, times)
121
+ print(corrcoef(ref, out), (ref - out).abs().max())
122
+
123
+ out = tdp_cuda(param0.half(), param1.half(), weight.half(), bias.half(), times.half()).float()
124
+ print(corrcoef(ref, out), (ref - out).abs().max())
125
+
126
+ def backward_tdp_cuda_unit_test():
127
+ print("***** backward_tdp_cuda_unit_test *****")
128
+
129
+ batch_size = random.randrange(1, 128)
130
+ num_params = random.randrange(1, 100000) * 2
131
+ print("batch_size", batch_size, "num_params", num_params)
132
+
133
+ param0 = torch.randn(num_params).cuda()
134
+ param1 = torch.randn(num_params).cuda()
135
+ weight = torch.randn(num_params).cuda()
136
+ bias = torch.randn(num_params).cuda()
137
+ times = torch.rand(batch_size).cuda()
138
+ g_result = torch.randn(batch_size, num_params).cuda()
139
+
140
+ refs = backward_tdp_torch(param0, param1, weight, bias, times, g_result)
141
+
142
+ outs = backward_tdp_cuda(param0, param1, weight, bias, times, g_result)
143
+ for r, o in zip(refs, outs):
144
+ print(corrcoef(r, o), (r - o).abs().max())
145
+
146
+ outs = backward_tdp_cuda(param0.half(), param1.half(), weight.half(), bias.half(), times.half(), g_result.half())
147
+ for r, o in zip(refs, outs):
148
+ print(corrcoef(r, o), (r - o).abs().max())
149
+
150
+ def autograd_unit_test():
151
+ print("***** autograd_unit_test *****")
152
+ batch_size = random.randrange(1, 128)
153
+ num_params = random.randrange(1, 100000) * 2
154
+ print("batch_size", batch_size, "num_params", num_params)
155
+
156
+ def get_outputs(fn):
157
+ torch.manual_seed(1)
158
+ param0 = torch.randn(num_params, requires_grad = True).cuda()
159
+ param1 = torch.randn(num_params, requires_grad = True).cuda()
160
+ weight = torch.randn(num_params, requires_grad = True).cuda()
161
+ bias = torch.randn(num_params, requires_grad = True).cuda()
162
+ times = torch.rand(batch_size).cuda()
163
+
164
+ out = fn(param0, param1, weight, bias, times)
165
+ loss = ((out - 1.5) ** 2).mean()
166
+
167
+ param0.retain_grad()
168
+ param1.retain_grad()
169
+ weight.retain_grad()
170
+ bias.retain_grad()
171
+
172
+ loss.backward()
173
+ g_param0 = param0.grad
174
+ g_param1 = param1.grad
175
+ g_weight = weight.grad
176
+ g_bias = bias.grad
177
+
178
+ return out, g_param0, g_param1, g_weight, g_bias
179
+
180
+ refs = get_outputs(tdp_torch)
181
+ outs = get_outputs(TDP.apply)
182
+ for r, o in zip(refs, outs):
183
+ print(corrcoef(r, o), (r - o).abs().max())
184
+
185
+ def exported_tdp_unit_test():
186
+ print("***** exported_tdp_unit_test *****")
187
+ batch_size = random.randrange(1, 128)
188
+ num_params = random.randrange(1, 100000) * 2
189
+ print("batch_size", batch_size, "num_params", num_params)
190
+
191
+ def get_outputs(fn):
192
+ torch.manual_seed(1)
193
+ param0 = torch.randn(num_params, requires_grad = True).cuda()
194
+ param1 = torch.randn(num_params, requires_grad = True).cuda()
195
+ weight = torch.randn(num_params, requires_grad = True).cuda()
196
+ bias = torch.randn(num_params, requires_grad = True).cuda()
197
+ times = torch.rand(batch_size).cuda()
198
+
199
+ out = fn(param0, param1, weight, bias, times)
200
+ loss = ((out - 1.5) ** 2).mean()
201
+
202
+ param0.retain_grad()
203
+ param1.retain_grad()
204
+ weight.retain_grad()
205
+ bias.retain_grad()
206
+
207
+ loss.backward()
208
+ g_param0 = param0.grad
209
+ g_param1 = param1.grad
210
+ g_weight = weight.grad
211
+ g_bias = bias.grad
212
+
213
+ return out, g_param0, g_param1, g_weight, g_bias
214
+
215
+ refs = get_outputs(partial(exported_tdp, custom = False))
216
+ outs = get_outputs(partial(exported_tdp, custom = True))
217
+ for r, o in zip(refs, outs):
218
+ print(corrcoef(r, o), (r - o).abs().max())
219
+
220
+ def tdp_cuda_profile():
221
+ print("***** tdp_cuda_profile *****")
222
+ def profiler(fn, args):
223
+ for _ in range(10):
224
+ fn(*args)
225
+ torch.cuda.synchronize()
226
+ t0 = time.time()
227
+ for _ in range(100):
228
+ fn(*args)
229
+ torch.cuda.synchronize()
230
+ t1 = time.time()
231
+ return t1 - t0
232
+
233
+ batch_size = 16
234
+ num_params = 1024 * 1024
235
+ print("batch_size", batch_size, "num_params", num_params)
236
+
237
+ param0 = torch.randn(num_params).cuda()
238
+ param1 = torch.randn(num_params).cuda()
239
+ weight = torch.randn(num_params).cuda()
240
+ bias = torch.randn(num_params).cuda()
241
+ times = torch.rand(batch_size).cuda()
242
+
243
+ print("ref", profiler(tdp_torch, (param0, param1, weight, bias, times)))
244
+ print("cuda", profiler(tdp_cuda, (param0, param1, weight, bias, times)))
245
+
246
+ print("ref", profiler(tdp_torch, (param0.half(), param1.half(), weight.half(), bias.half(), times.half())))
247
+ print("cuda", profiler(tdp_cuda, (param0.half(), param1.half(), weight.half(), bias.half(), times.half())))
248
+
249
+ def backward_tdp_cuda_profile():
250
+ print("***** backward_tdp_cuda_profile *****")
251
+ def profiler(fn, args):
252
+ for _ in range(10):
253
+ fn(*args)
254
+ torch.cuda.synchronize()
255
+ t0 = time.time()
256
+ for _ in range(100):
257
+ fn(*args)
258
+ torch.cuda.synchronize()
259
+ t1 = time.time()
260
+ return t1 - t0
261
+
262
+ batch_size = 16
263
+ num_params = 1024 * 1024
264
+ print("batch_size", batch_size, "num_params", num_params)
265
+
266
+ param0 = torch.randn(num_params).cuda()
267
+ param1 = torch.randn(num_params).cuda()
268
+ weight = torch.randn(num_params).cuda()
269
+ bias = torch.randn(num_params).cuda()
270
+ times = torch.rand(batch_size).cuda()
271
+ g_result = torch.randn(batch_size, num_params).cuda()
272
+
273
+
274
+ print("ref", profiler(backward_tdp_torch, (param0, param1, weight, bias, times, g_result)))
275
+ print("cuda", profiler(backward_tdp_cuda, (param0, param1, weight, bias, times, g_result)))
276
+
277
+ print("ref", profiler(backward_tdp_torch, (param0.half(), param1.half(), weight.half(), bias.half(), times.half(), g_result.half())))
278
+ print("cuda", profiler(backward_tdp_cuda, (param0.half(), param1.half(), weight.half(), bias.half(), times.half(), g_result.half())))
279
+
280
+ def autogad_profile():
281
+ print("***** autogad_profile *****")
282
+ def profiler(fn, args):
283
+ for _ in range(10):
284
+ fn(*args).mean().backward()
285
+ torch.cuda.synchronize()
286
+ t0 = time.time()
287
+ for _ in range(100):
288
+ fn(*args).mean().backward()
289
+ torch.cuda.synchronize()
290
+ t1 = time.time()
291
+ return t1 - t0
292
+
293
+ batch_size = 16
294
+ num_params = 1024 * 1024
295
+ print("batch_size", batch_size, "num_params", num_params)
296
+
297
+ param0 = nn.Parameter(torch.randn(num_params)).cuda()
298
+ param1 = nn.Parameter(torch.randn(num_params)).cuda()
299
+ weight = nn.Parameter(torch.randn(num_params)).cuda()
300
+ bias = nn.Parameter(torch.randn(num_params)).cuda()
301
+ times = torch.rand(batch_size).cuda()
302
+
303
+ print("ref", profiler(tdp_torch, (param0, param1, weight, bias, times)))
304
+ print("cuda", profiler(TDP.apply, (param0, param1, weight, bias, times)))
305
+
306
+ print("ref", profiler(tdp_torch, (param0.half(), param1.half(), weight.half(), bias.half(), times.half())))
307
+ print("cuda", profiler(TDP.apply, (param0.half(), param1.half(), weight.half(), bias.half(), times.half())))
308
+
309
+ if __name__ == "__main__":
310
+ tdp_cuda_unit_test()
311
+ backward_tdp_cuda_unit_test()
312
+ autograd_unit_test()
313
+ exported_tdp_unit_test()
314
+ tdp_cuda_profile()
315
+ backward_tdp_cuda_profile()
316
+ autogad_profile()
libs/model/sigmoid/module.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from .kernel import exported_tdp
7
+ import torch.nn.functional as F
8
+ from functools import partial
9
+ from timm.models.layers import trunc_normal_
10
+
11
+ class TimeDependentParameter(nn.Module):
12
+ def __init__(self, shape, init_fn):
13
+ super().__init__()
14
+ self.shape = shape
15
+
16
+ w = torch.empty(*shape)
17
+ init_fn(w)
18
+
19
+ self.param0 = nn.Parameter(w.clone().detach())
20
+ self.param1 = nn.Parameter(w.clone().detach())
21
+
22
+ self.nodecay_weight = nn.Parameter(torch.zeros(*shape))
23
+ self.nodecay_bias = nn.Parameter(torch.zeros(*shape))
24
+ self.curr_weight = None
25
+
26
+ def forward(self):
27
+ weight = self.curr_weight
28
+ # self.curr_weight = None
29
+ return weight
30
+
31
+ def __repr__(self):
32
+ return f"TimeDependentParameter(shape={self.shape})"
33
+
34
+ @staticmethod
35
+ def seed_time(model, log_snr):
36
+ assert log_snr.dim() == 1
37
+ if torch.all(log_snr == log_snr[0]):
38
+ log_snr = log_snr[0][None]
39
+ time_condition = log_snr / 4.0
40
+
41
+ tdp_list = [module for module in model.modules() if isinstance(module, TimeDependentParameter)]
42
+ for tdp in tdp_list:
43
+ tdp.curr_weight = exported_tdp(tdp.param0, tdp.param1, tdp.nodecay_weight + 1, tdp.nodecay_bias, time_condition, custom = False)
44
+
45
+ class LayerNorm(nn.Module):
46
+ def __init__(self, dim, num_groups = 1, eps = 1e-05):
47
+ super().__init__()
48
+ self.eps = eps
49
+ self.dim = dim
50
+ self.num_groups = num_groups
51
+ self.weight = TimeDependentParameter((dim, ), nn.init.ones_)
52
+ self.bias = TimeDependentParameter((dim, ), nn.init.zeros_)
53
+
54
+ def _forward(self, x):
55
+ weight, bias = self.weight(), self.bias()
56
+ assert weight.shape[0] == bias.shape[0]
57
+
58
+ assert x.shape[-1] == self.dim
59
+
60
+ if weight.shape[0] == 1:
61
+ x = F.layer_norm(x, (self.dim, ), weight = weight[0], bias = bias[0], eps = self.eps)
62
+ else:
63
+ assert x.shape[0] == weight.shape[0]
64
+ x = F.layer_norm(x, (self.dim, ), eps = self.eps)
65
+ x = torch.addcmul(bias[:, None, :], weight[:, None, :], x)
66
+
67
+ return x
68
+
69
+ def forward(self, x):
70
+ original_shape = x.shape
71
+ batch_size = x.shape[0]
72
+ assert self.dim == x.shape[-1]
73
+
74
+ x = x.reshape(batch_size, -1, self.dim)
75
+ x = self._forward(x)
76
+ x = x.reshape(*original_shape)
77
+
78
+ return x
79
+
80
+ class Linear(nn.Module):
81
+ def __init__(self, din, dout, bias = True, weight_init_fn = partial(trunc_normal_, std = 0.02)):
82
+ super().__init__()
83
+ self.din = din
84
+ self.dout = dout
85
+ self.weight = TimeDependentParameter((din, dout), weight_init_fn)
86
+ if bias:
87
+ self.bias = TimeDependentParameter((dout, ), nn.init.zeros_)
88
+ else:
89
+ self.bias = None
90
+
91
+ def _forward(self, x):
92
+ weight = self.weight()
93
+ bias = self.bias() if self.bias is not None else None
94
+
95
+ # if weight.shape[0] == 1:
96
+ # B, L, D = x.shape
97
+ # if bias is not None:
98
+ # assert weight.shape[0] == bias.shape[0]
99
+ # x = torch.addmm(bias, x.reshape(B * L, D), weight[0])
100
+ # else:
101
+ # x = torch.matmul(x.reshape(B * L, D), weight[0])
102
+ # x = x.reshape(B, L, -1)
103
+ # else:
104
+ if bias is not None:
105
+ x = torch.baddbmm(bias[:, None, :], x, weight)
106
+ else:
107
+ x = torch.bmm(x, weight)
108
+
109
+ return x
110
+
111
+ def forward(self, x):
112
+ original_shape = x.shape
113
+ batch_size = x.shape[0]
114
+
115
+ x = x.reshape(batch_size, -1, self.din)
116
+ x = self._forward(x)
117
+ x = x.reshape(*(list(original_shape[:-1]) + [self.dout]))
118
+
119
+ return x
120
+
121
+ class RMSNorm(nn.Module):
122
+ def __init__(self, d, p=-1., eps=1e-8, bias=False):
123
+ """
124
+ Root Mean Square Layer Normalization
125
+ :param d: model size
126
+ :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
127
+ :param eps: epsilon value, default 1e-8
128
+ :param bias: whether use bias term for RMSNorm, disabled by
129
+ default because RMSNorm doesn't enforce re-centering invariance.
130
+ """
131
+ super(RMSNorm, self).__init__()
132
+
133
+ self.eps = eps
134
+ self.d = d
135
+ self.p = p
136
+ self.bias = bias
137
+
138
+ self.scale = nn.Parameter(torch.ones(d))
139
+ self.register_parameter("scale", self.scale)
140
+
141
+ if self.bias:
142
+ self.offset = nn.Parameter(torch.zeros(d))
143
+ self.register_parameter("offset", self.offset)
144
+
145
+ def forward(self, x):
146
+ if self.p < 0. or self.p > 1.:
147
+ norm_x = x.norm(2, dim=-1, keepdim=True)
148
+ d_x = self.d
149
+ else:
150
+ partial_size = int(self.d * self.p)
151
+ partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)
152
+
153
+ norm_x = partial_x.norm(2, dim=-1, keepdim=True)
154
+ d_x = partial_size
155
+
156
+ rms_x = norm_x * d_x ** (-1. / 2)
157
+ x_normed = x / (rms_x + self.eps)
158
+
159
+ if self.bias:
160
+ return self.scale * x_normed + self.offset
161
+
162
+ return self.scale * x_normed
163
+
164
+
165
+ class TDRMSNorm(nn.Module):
166
+ def __init__(self, d, p=-1., eps=1e-8, bias=False):
167
+ """
168
+ Root Mean Square Layer Normalization
169
+ :param d: model size
170
+ :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
171
+ :param eps: epsilon value, default 1e-8
172
+ :param bias: whether use bias term for RMSNorm, disabled by
173
+ default because RMSNorm doesn't enforce re-centering invariance.
174
+ """
175
+ super(TDRMSNorm, self).__init__()
176
+
177
+ self.eps = eps
178
+ self.d = d
179
+ self.p = p
180
+ self.bias = bias
181
+
182
+ # self.scale = nn.Parameter(torch.ones(d))
183
+ self.scale = TimeDependentParameter((d, ), nn.init.ones_)
184
+ # self.register_parameter("scale", self.scale)
185
+
186
+ if self.bias:
187
+ # self.offset = nn.Parameter(torch.zeros(d))
188
+ self.offset = TimeDependentParameter((d, ), nn.init.zeros_)
189
+ # self.register_parameter("offset", self.offset)
190
+
191
+ def forward(self, x):
192
+ if self.p < 0. or self.p > 1.:
193
+ norm_x = x.norm(2, dim=-1, keepdim=True)
194
+ d_x = self.d
195
+ else:
196
+ partial_size = int(self.d * self.p)
197
+ partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)
198
+
199
+ norm_x = partial_x.norm(2, dim=-1, keepdim=True)
200
+ d_x = partial_size
201
+
202
+ rms_x = norm_x * d_x ** (-1. / 2)
203
+ x_normed = x / (rms_x + self.eps)
204
+
205
+ _scale = self.scale()
206
+
207
+ if self.bias:
208
+ # return self.scale * x_normed + self.offset
209
+ _offset = self.offset()
210
+ if _scale.shape[0] == 1:
211
+ return _scale[0] * x_normed + _offset[0]
212
+ elif x_normed.dim() == 3:
213
+ return torch.addcmul(_offset[:, None, :], _scale[:, None, :], x_normed)
214
+ elif x_normed.dim() == 4:
215
+ return torch.addcmul(_offset[:, None, None, :], _scale[:, None, None, :], x_normed)
216
+ else:
217
+ raise NotImplementedError
218
+
219
+ # return self.scale * x_normed
220
+ if _scale.shape[0] == 1:
221
+ return _scale[0] * x_normed
222
+ elif x_normed.dim() == 3:
223
+ return _scale[:, None, :] * x_normed
224
+ elif x_normed.dim() == 4:
225
+ return _scale[:, None, None, :] * x_normed
226
+ else:
227
+ raise NotImplementedError
228
+
229
+
230
+ def zero_init(layer):
231
+ nn.init.zeros_(layer.weight)
232
+ if layer.bias is not None:
233
+ nn.init.zeros_(layer.bias)
234
+ return layer
235
+
236
+ def rms_norm(x, scale, eps):
237
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
238
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
239
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
240
+ return x * scale.to(x.dtype)
241
+
242
+ class AdaRMSNorm(nn.Module):
243
+ def __init__(self, features, cond_features, eps=1e-6):
244
+ super().__init__()
245
+ self.eps = eps
246
+ self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
247
+
248
+ def extra_repr(self):
249
+ return f"eps={self.eps},"
250
+
251
+ def forward(self, x, cond):
252
+ return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
253
+
254
+ class QKNorm(nn.Module):
255
+ def __init__(self, n_heads, eps=1e-6, max_scale=100.0):
256
+ super().__init__()
257
+ self.eps = eps
258
+ self.max_scale = math.log(max_scale)
259
+ self.scale = nn.Parameter(torch.full((n_heads,), math.log(10.0)))
260
+ self.proj_()
261
+
262
+ def extra_repr(self):
263
+ return f"n_heads={self.scale.shape[0]}, eps={self.eps}"
264
+
265
+ @torch.no_grad()
266
+ def proj_(self):
267
+ """Modify the scale in-place so it doesn't get "stuck" with zero gradient if it's clamped
268
+ to the max value."""
269
+ self.scale.clamp_(max=self.max_scale)
270
+
271
+ def forward(self, x):
272
+ self.proj_()
273
+ scale = torch.exp(0.5 * self.scale - 0.25 * math.log(x.shape[-1]))
274
+ return rms_norm(x, scale[:, None, None], self.eps)
libs/model/trans_autoencoder.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformer-based varitional encoder model.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import math
9
+ import copy
10
+
11
+
12
+ def clones(module, N):
13
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
14
+
15
+
16
+ def build_mask(base_mask):
17
+ assert len(base_mask.shape) == 2
18
+ batch_size, seq_len = base_mask.shape[0], base_mask.shape[-1]
19
+
20
+ # create subsequent token mask
21
+ sub_mask = torch.tril(torch.ones([seq_len, seq_len],
22
+ dtype=torch.uint8)).type_as(base_mask)
23
+ sub_mask = sub_mask.unsqueeze(0).expand(batch_size, -1, -1)
24
+ base_mask = base_mask.unsqueeze(1).expand(-1, seq_len, -1)
25
+ return sub_mask & base_mask
26
+
27
+
28
+ class Adaptor(nn.Module):
29
+ def __init__(self, input_dim, tar_dim):
30
+ super(Adaptor, self).__init__()
31
+
32
+ if tar_dim == 32768:
33
+ output_channel = 8
34
+ elif tar_dim == 16384:
35
+ output_channel = 4
36
+ else:
37
+ raise NotImplementedError("only support 512px, 256px does not need this")
38
+
39
+ self.tar_dim = tar_dim
40
+
41
+ self.fc1 = nn.Linear(input_dim, 4096)
42
+ self.ln_fc1 = nn.LayerNorm(4096)
43
+ self.fc2 = nn.Linear(4096, 4096)
44
+ self.ln_fc2 = nn.LayerNorm(4096)
45
+
46
+ self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
47
+ self.ln_conv1 = nn.LayerNorm([32, 64, 64])
48
+ self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
49
+ self.ln_conv2 = nn.LayerNorm([64, 64, 64])
50
+ self.conv3 = nn.Conv2d(in_channels=64, out_channels=output_channel, kernel_size=3, padding=1)
51
+
52
+ def forward(self, x):
53
+ x = torch.relu(self.ln_fc1(self.fc1(x)))
54
+ x = torch.relu(self.ln_fc2(self.fc2(x)))
55
+
56
+ x = x.view(-1, 1, 64, 64)
57
+
58
+ x = torch.relu(self.ln_conv1(self.conv1(x)))
59
+ x = torch.relu(self.ln_conv2(self.conv2(x)))
60
+
61
+ x = self.conv3(x)
62
+ x = x.view(-1, self.tar_dim)
63
+
64
+ return x
65
+
66
+
67
+ class Compressor(nn.Module):
68
+ def __init__(self, input_dim=4096, tar_dim=2048):
69
+ super(Compressor, self).__init__()
70
+
71
+ self.fc1 = nn.Linear(input_dim, tar_dim)
72
+ self.ln_fc1 = nn.LayerNorm(tar_dim)
73
+ self.fc2 = nn.Linear(tar_dim, tar_dim)
74
+
75
+
76
+ def forward(self, x):
77
+ x = torch.relu(self.ln_fc1(self.fc1(x)))
78
+ x = self.fc2(x)
79
+
80
+ return x
81
+
82
+
83
+ class TransEncoder(nn.Module):
84
+ def __init__(self, d_model, N, num_token, head_num, d_ff, latten_size, down_sample_block=3, dropout=0.1, last_norm=True):
85
+ super(TransEncoder, self).__init__()
86
+ self.N = N
87
+ if d_model==4096:
88
+ # for T5-XXL, first use MLP to compress into 1024
89
+ self.compressor = Compressor(input_dim=d_model, tar_dim=1024)
90
+ d_model = 1024
91
+ else:
92
+ self.compressor = None
93
+
94
+ self.layers = clones(EncoderLayer(MultiHeadAttentioin(d_model, head_num, dropout=dropout),
95
+ FeedForward(d_model, d_ff, dropout=dropout),
96
+ LayerNorm(d_model),
97
+ LayerNorm(d_model)), N)
98
+
99
+ self.reduction_layers = nn.ModuleList()
100
+ for _ in range(down_sample_block):
101
+ self.reduction_layers.append(
102
+ EncoderReductionLayer(MultiHeadAttentioin(d_model, head_num, dropout=dropout),
103
+ FeedForward(d_model, d_ff, dropout=dropout),
104
+ nn.Linear(d_model, d_model // 2),
105
+ LayerNorm(d_model),
106
+ LayerNorm(d_model)))
107
+ d_model = d_model // 2
108
+
109
+ if latten_size == 8192 or latten_size == 4096:
110
+ self.arc = 0
111
+ self.linear = nn.Linear(d_model*num_token, latten_size)
112
+ self.norm = LayerNorm(latten_size) if last_norm else None
113
+ else:
114
+ self.arc = 1
115
+ self.adaptor = Adaptor(d_model*num_token, latten_size)
116
+
117
+
118
+ def forward(self, x, mask):
119
+ mask = mask.unsqueeze(1)
120
+
121
+ if self.compressor is not None:
122
+ x = self.compressor(x)
123
+
124
+ for i, layer in enumerate(self.layers):
125
+ x = layer(x, mask)
126
+
127
+ for i, layer in enumerate(self.reduction_layers):
128
+ x = layer(x, mask)
129
+
130
+ if self.arc == 0:
131
+ x = self.linear(x.view(x.shape[0],-1))
132
+ x = self.norm(x) if self.norm else x
133
+ else:
134
+ x = self.adaptor(x.view(x.shape[0],-1))
135
+
136
+ return x
137
+
138
+
139
+ class EncoderLayer(nn.Module):
140
+ def __init__(self, attn, feed_forward, norm1, norm2, dropout=0.1):
141
+ super(EncoderLayer, self).__init__()
142
+ self.attn = attn
143
+ self.feed_forward = feed_forward
144
+ self.norm1, self.norm2 = norm1, norm2
145
+
146
+ self.dropout1 = nn.Dropout(dropout)
147
+ self.dropout2 = nn.Dropout(dropout)
148
+
149
+ def forward(self, x, mask):
150
+ # multihead attn & norm
151
+ a = self.attn(x, x, x, mask)
152
+ t = self.norm1(x + self.dropout1(a))
153
+
154
+ # feed forward & norm
155
+ z = self.feed_forward(t) # linear(dropout(act(linear(x)))))
156
+ y = self.norm2(t + self.dropout2(z))
157
+
158
+ return y
159
+
160
+
161
+ class EncoderReductionLayer(nn.Module):
162
+ def __init__(self, attn, feed_forward, reduction, norm1, norm2, dropout=0.1):
163
+ super(EncoderReductionLayer, self).__init__()
164
+ self.attn = attn
165
+ self.feed_forward = feed_forward
166
+ self.reduction = reduction
167
+ self.norm1, self.norm2 = norm1, norm2
168
+
169
+ self.dropout1 = nn.Dropout(dropout)
170
+ self.dropout2 = nn.Dropout(dropout)
171
+
172
+ def forward(self, x, mask):
173
+ # multihead attn & norm
174
+ a = self.attn(x, x, x, mask)
175
+ t = self.norm1(x + self.dropout1(a))
176
+
177
+ # feed forward & norm
178
+ z = self.feed_forward(t) # linear(dropout(act(linear(x)))))
179
+ y = self.norm2(t + self.dropout2(z))
180
+
181
+ # reduction
182
+ # y = self.reduction(y).view(x.shape[0], -1, x.shape[-1])
183
+ y = self.reduction(y)
184
+
185
+ return y
186
+
187
+
188
+ class MultiHeadAttentioin(nn.Module):
189
+ def __init__(self, d_model, head_num, dropout=0.1, d_v=None):
190
+ super(MultiHeadAttentioin, self).__init__()
191
+ assert d_model % head_num == 0, "d_model must be divisible by head_num"
192
+
193
+ self.d_model = d_model
194
+ self.head_num = head_num
195
+ self.d_k = d_model // head_num
196
+ self.d_v = self.d_k if d_v is None else d_v
197
+
198
+ # d_model = d_k * head_num
199
+ self.W_Q = nn.Linear(d_model, head_num * self.d_k)
200
+ self.W_K = nn.Linear(d_model, head_num * self.d_k)
201
+ self.W_V = nn.Linear(d_model, head_num * self.d_v)
202
+ self.W_O = nn.Linear(d_model, d_model)
203
+
204
+ self.dropout = nn.Dropout(dropout)
205
+
206
+ def scaled_dp_attn(self, query, key, value, mask=None):
207
+ assert self.d_k == query.shape[-1]
208
+
209
+ # scores: [batch_size, head_num, seq_len, seq_len]
210
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
211
+
212
+ # if torch.isinf(scores).any():
213
+ # # to avoid leaking
214
+ # scores = torch.where(scores == float('-inf'), torch.tensor(-65504.0), scores)
215
+ # scores = torch.where(scores == float('inf'), torch.tensor(65504.0), scores)
216
+
217
+ if mask is not None:
218
+ assert mask.ndim == 3, "Mask shape {} doesn't seem right...".format(mask.shape)
219
+ mask = mask.unsqueeze(1)
220
+ try:
221
+ if scores.dtype == torch.float32:
222
+ scores = scores.masked_fill(mask == 0, -1e9)
223
+ else:
224
+ scores = scores.masked_fill(mask == 0, -1e4)
225
+ except RuntimeError:
226
+ print("- scores device: {}".format(scores.device))
227
+ print("- mask device: {}".format(mask.device))
228
+
229
+ # attn: [batch_size, head_num, seq_len, seq_len]
230
+ attn = F.softmax(scores, dim=-1)
231
+ attn = self.dropout(attn)
232
+ return torch.matmul(attn, value), attn
233
+
234
+ def forward(self, q, k, v, mask):
235
+ batch_size = q.shape[0]
236
+
237
+ query = self.W_Q(q).view(batch_size, -1, self.head_num, self.d_k).transpose(1, 2)
238
+ key = self.W_K(k).view(batch_size, -1, self.head_num, self.d_k).transpose(1, 2)
239
+ value = self.W_V(v).view(batch_size, -1, self.head_num, self.d_k).transpose(1, 2)
240
+
241
+ heads, attn = self.scaled_dp_attn(query, key, value, mask)
242
+ heads = heads.transpose(1, 2).contiguous().view(batch_size, -1,
243
+ self.head_num * self.d_k)
244
+ assert heads.shape[-1] == self.d_model and heads.shape[0] == batch_size
245
+
246
+ y = self.W_O(heads)
247
+
248
+ assert y.shape == q.shape
249
+ return y
250
+
251
+
252
+ class LayerNorm(nn.Module):
253
+ def __init__(self, layer_size, eps=1e-5):
254
+ super(LayerNorm, self).__init__()
255
+ self.g = nn.Parameter(torch.ones(layer_size))
256
+ self.b = nn.Parameter(torch.zeros(layer_size))
257
+ self.eps = eps
258
+
259
+ def forward(self, x):
260
+ mean = x.mean(-1, keepdim=True)
261
+ std = x.std(-1, keepdim=True)
262
+ x = (x - mean) / (std + self.eps)
263
+ return self.g * x + self.b
264
+
265
+
266
+ class FeedForward(nn.Module):
267
+ def __init__(self, d_model, d_ff, dropout=0.1, act='relu', d_output=None):
268
+ super(FeedForward, self).__init__()
269
+ self.d_model = d_model
270
+ self.d_ff = d_ff
271
+ d_output = d_model if d_output is None else d_output
272
+
273
+ self.ffn_1 = nn.Linear(d_model, d_ff)
274
+ self.ffn_2 = nn.Linear(d_ff, d_output)
275
+
276
+ if act == 'relu':
277
+ self.act = nn.ReLU()
278
+ elif act == 'rrelu':
279
+ self.act = nn.RReLU()
280
+ else:
281
+ raise NotImplementedError
282
+
283
+ self.dropout = nn.Dropout(dropout)
284
+
285
+ def forward(self, x):
286
+ y = self.ffn_2(self.dropout(self.act(self.ffn_1(x))))
287
+ return y
288
+
289
+
libs/t5.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains code for t5 model.
3
+
4
+ Reference:
5
+ https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py
6
+ """
7
+
8
+ # -*- coding: utf-8 -*-
9
+ import os
10
+ import re
11
+ import html
12
+ import urllib.parse as ul
13
+
14
+ import ftfy
15
+ import torch
16
+ from bs4 import BeautifulSoup
17
+ from transformers import T5EncoderModel, AutoTokenizer
18
+ from huggingface_hub import hf_hub_download
19
+
20
+
21
+ class T5Embedder:
22
+
23
+ available_models = ['t5-v1_1-xxl']
24
+ bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
25
+
26
+ def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=None, hf_token=None, use_text_preprocessing=True,
27
+ t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None):
28
+ self.device = torch.device(device)
29
+ self.torch_dtype = torch_dtype or torch.bfloat16
30
+ if t5_model_kwargs is None:
31
+ t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
32
+ if use_offload_folder is not None:
33
+ t5_model_kwargs['offload_folder'] = use_offload_folder
34
+ t5_model_kwargs['device_map'] = {
35
+ 'shared': self.device,
36
+ 'encoder.embed_tokens': self.device,
37
+ 'encoder.block.0': self.device,
38
+ 'encoder.block.1': self.device,
39
+ 'encoder.block.2': self.device,
40
+ 'encoder.block.3': self.device,
41
+ 'encoder.block.4': self.device,
42
+ 'encoder.block.5': self.device,
43
+ 'encoder.block.6': self.device,
44
+ 'encoder.block.7': self.device,
45
+ 'encoder.block.8': self.device,
46
+ 'encoder.block.9': self.device,
47
+ 'encoder.block.10': self.device,
48
+ 'encoder.block.11': self.device,
49
+ 'encoder.block.12': 'disk',
50
+ 'encoder.block.13': 'disk',
51
+ 'encoder.block.14': 'disk',
52
+ 'encoder.block.15': 'disk',
53
+ 'encoder.block.16': 'disk',
54
+ 'encoder.block.17': 'disk',
55
+ 'encoder.block.18': 'disk',
56
+ 'encoder.block.19': 'disk',
57
+ 'encoder.block.20': 'disk',
58
+ 'encoder.block.21': 'disk',
59
+ 'encoder.block.22': 'disk',
60
+ 'encoder.block.23': 'disk',
61
+ 'encoder.final_layer_norm': 'disk',
62
+ 'encoder.dropout': 'disk',
63
+ }
64
+ else:
65
+ t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}
66
+
67
+ self.use_text_preprocessing = use_text_preprocessing
68
+ self.hf_token = hf_token
69
+ self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
70
+ self.dir_or_name = dir_or_name
71
+
72
+ tokenizer_path, path = dir_or_name, dir_or_name
73
+ if dir_or_name in self.available_models:
74
+ cache_dir = os.path.join(self.cache_dir, dir_or_name)
75
+ for filename in [
76
+ 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
77
+ 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
78
+ ]:
79
+ hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
80
+ force_filename=filename, token=self.hf_token)
81
+ tokenizer_path, path = cache_dir, cache_dir
82
+ else:
83
+ cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
84
+ for filename in [
85
+ 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
86
+ ]:
87
+ hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
88
+ force_filename=filename, token=self.hf_token)
89
+ tokenizer_path = cache_dir
90
+
91
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
92
+ self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
93
+
94
+ def get_text_embeddings(self, texts):
95
+ texts = [self.text_preprocessing(text) for text in texts]
96
+
97
+ text_tokens_and_mask = self.tokenizer(
98
+ texts,
99
+ max_length=77,
100
+ padding='max_length',
101
+ truncation=True,
102
+ return_attention_mask=True,
103
+ add_special_tokens=True,
104
+ return_tensors='pt'
105
+ )
106
+ text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']
107
+ text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
108
+
109
+ with torch.no_grad():
110
+ text_encoder_embs = self.model(
111
+ input_ids=text_tokens_and_mask['input_ids'].to(self.device),
112
+ attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
113
+ )['last_hidden_state'].detach()
114
+
115
+ return text_encoder_embs, {'token_embedding': text_encoder_embs, 'token_mask': text_tokens_and_mask['attention_mask'].to(self.device), 'tokens': text_tokens_and_mask['input_ids'].to(self.device)}
116
+
117
+ def text_preprocessing(self, text):
118
+ if self.use_text_preprocessing:
119
+ # The exact text cleaning as was in the training stage:
120
+ text = self.clean_caption(text)
121
+ text = self.clean_caption(text)
122
+ return text
123
+ else:
124
+ return text.lower().strip()
125
+
126
+ @staticmethod
127
+ def basic_clean(text):
128
+ text = ftfy.fix_text(text)
129
+ text = html.unescape(html.unescape(text))
130
+ return text.strip()
131
+
132
+ def clean_caption(self, caption):
133
+ caption = str(caption)
134
+ caption = ul.unquote_plus(caption)
135
+ caption = caption.strip().lower()
136
+ caption = re.sub('<person>', 'person', caption)
137
+ # urls:
138
+ caption = re.sub(
139
+ r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
140
+ '', caption) # regex for urls
141
+ caption = re.sub(
142
+ r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
143
+ '', caption) # regex for urls
144
+ # html:
145
+ caption = BeautifulSoup(caption, features='html.parser').text
146
+
147
+ # @<nickname>
148
+ caption = re.sub(r'@[\w\d]+\b', '', caption)
149
+
150
+ # 31C0—31EF CJK Strokes
151
+ # 31F0—31FF Katakana Phonetic Extensions
152
+ # 3200—32FF Enclosed CJK Letters and Months
153
+ # 3300—33FF CJK Compatibility
154
+ # 3400—4DBF CJK Unified Ideographs Extension A
155
+ # 4DC0—4DFF Yijing Hexagram Symbols
156
+ # 4E00—9FFF CJK Unified Ideographs
157
+ caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
158
+ caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
159
+ caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
160
+ caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
161
+ caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
162
+ caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
163
+ caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
164
+ #######################################################
165
+
166
+ # все виды тире / all types of dash --> "-"
167
+ caption = re.sub(
168
+ r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa
169
+ '-', caption)
170
+
171
+ # кавычки к одному стандарту
172
+ caption = re.sub(r'[`´«»“”¨]', '"', caption)
173
+ caption = re.sub(r'[‘’]', "'", caption)
174
+
175
+ # &quot;
176
+ caption = re.sub(r'&quot;?', '', caption)
177
+ # &amp
178
+ caption = re.sub(r'&amp', '', caption)
179
+
180
+ # ip adresses:
181
+ caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)
182
+
183
+ # article ids:
184
+ caption = re.sub(r'\d:\d\d\s+$', '', caption)
185
+
186
+ # \n
187
+ caption = re.sub(r'\\n', ' ', caption)
188
+
189
+ # "#123"
190
+ caption = re.sub(r'#\d{1,3}\b', '', caption)
191
+ # "#12345.."
192
+ caption = re.sub(r'#\d{5,}\b', '', caption)
193
+ # "123456.."
194
+ caption = re.sub(r'\b\d{6,}\b', '', caption)
195
+ # filenames:
196
+ caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)
197
+
198
+ #
199
+ caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT"""
200
+ caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT"""
201
+
202
+ caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
203
+ caption = re.sub(r'\s+\.\s+', r' ', caption) # " . "
204
+
205
+ # this-is-my-cute-cat / this_is_my_cute_cat
206
+ regex2 = re.compile(r'(?:\-|\_)')
207
+ if len(re.findall(regex2, caption)) > 3:
208
+ caption = re.sub(regex2, ' ', caption)
209
+
210
+ caption = self.basic_clean(caption)
211
+
212
+ caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640
213
+ caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc
214
+ caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231
215
+
216
+ caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
217
+ caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
218
+ caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
219
+ caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
220
+ caption = re.sub(r'\bpage\s+\d+\b', '', caption)
221
+
222
+ caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a...
223
+
224
+ caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
225
+
226
+ caption = re.sub(r'\b\s+\:\s+', r': ', caption)
227
+ caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
228
+ caption = re.sub(r'\s+', ' ', caption)
229
+
230
+ caption.strip()
231
+
232
+ caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
233
+ caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
234
+ caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
235
+ caption = re.sub(r'^\.\S+$', '', caption)
236
+
237
+ return caption.strip()
libs/timm.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ code from timm 0.3.2
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import math
7
+ import warnings
8
+
9
+
10
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
11
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
12
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
13
+ def norm_cdf(x):
14
+ # Computes standard normal cumulative distribution function
15
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
16
+
17
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
18
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
19
+ "The distribution of values may be incorrect.",
20
+ stacklevel=2)
21
+
22
+ with torch.no_grad():
23
+ # Values are generated by using a truncated uniform distribution and
24
+ # then using the inverse CDF for the normal distribution.
25
+ # Get upper and lower cdf values
26
+ l = norm_cdf((a - mean) / std)
27
+ u = norm_cdf((b - mean) / std)
28
+
29
+ # Uniformly fill tensor with values from [l, u], then translate to
30
+ # [2l-1, 2u-1].
31
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
32
+
33
+ # Use inverse cdf transform for normal distribution to get truncated
34
+ # standard normal
35
+ tensor.erfinv_()
36
+
37
+ # Transform to proper mean, std
38
+ tensor.mul_(std * math.sqrt(2.))
39
+ tensor.add_(mean)
40
+
41
+ # Clamp to ensure it's in the proper range
42
+ tensor.clamp_(min=a, max=b)
43
+ return tensor
44
+
45
+
46
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
47
+ # type: (Tensor, float, float, float, float) -> Tensor
48
+ r"""Fills the input Tensor with values drawn from a truncated
49
+ normal distribution. The values are effectively drawn from the
50
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
51
+ with values outside :math:`[a, b]` redrawn until they are within
52
+ the bounds. The method used for generating the random values works
53
+ best when :math:`a \leq \text{mean} \leq b`.
54
+ Args:
55
+ tensor: an n-dimensional `torch.Tensor`
56
+ mean: the mean of the normal distribution
57
+ std: the standard deviation of the normal distribution
58
+ a: the minimum cutoff value
59
+ b: the maximum cutoff value
60
+ Examples:
61
+ >>> w = torch.empty(3, 5)
62
+ >>> nn.init.trunc_normal_(w)
63
+ """
64
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
65
+
66
+
67
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
68
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
69
+
70
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
71
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
72
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
73
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
74
+ 'survival rate' as the argument.
75
+
76
+ """
77
+ if drop_prob == 0. or not training:
78
+ return x
79
+ keep_prob = 1 - drop_prob
80
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
81
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
82
+ random_tensor.floor_() # binarize
83
+ output = x.div(keep_prob) * random_tensor
84
+ return output
85
+
86
+
87
+ class DropPath(nn.Module):
88
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
89
+ """
90
+ def __init__(self, drop_prob=None):
91
+ super(DropPath, self).__init__()
92
+ self.drop_prob = drop_prob
93
+
94
+ def forward(self, x):
95
+ return drop_path(x, self.drop_prob, self.training)
96
+
97
+
98
+ class Mlp(nn.Module):
99
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
100
+ super().__init__()
101
+ out_features = out_features or in_features
102
+ hidden_features = hidden_features or in_features
103
+ self.fc1 = nn.Linear(in_features, hidden_features)
104
+ self.act = act_layer()
105
+ self.fc2 = nn.Linear(hidden_features, out_features)
106
+ self.drop = nn.Dropout(drop)
107
+
108
+ def forward(self, x):
109
+ x = self.fc1(x)
110
+ x = self.act(x)
111
+ x = self.drop(x)
112
+ x = self.fc2(x)
113
+ x = self.drop(x)
114
+ return x
requirements.txt CHANGED
@@ -1,6 +1,21 @@
1
- accelerate
2
  diffusers
3
- invisible_watermark
4
  torch
5
- transformers
6
- xformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  diffusers
 
2
  torch
3
+ xformers
4
+ openai-clip
5
+ scikit-learn
6
+ opencv-python
7
+ torchdiffeq
8
+ beautifulsoup4
9
+ open_clip_torch
10
+ scikit-image
11
+ cython
12
+ matplotlib
13
+ accelerate==0.12.0
14
+ absl-py
15
+ ml_collections
16
+ einops
17
+ wandb
18
+ ftfy==6.1.1
19
+ transformers==4.23.1
20
+ timm
21
+ tensorboard
scripts/extract_empty_feature.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is used to extract feature of the empty prompt.
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8
+
9
+ import torch
10
+ import os
11
+ import numpy as np
12
+ from libs.clip import FrozenCLIPEmbedder
13
+ from libs.t5 import T5Embedder
14
+
15
+
16
+ def main():
17
+ prompts = [
18
+ '',
19
+ ]
20
+
21
+ device = 'cuda'
22
+ llm = 'clip'
23
+
24
+ if llm=='clip':
25
+ clip = FrozenCLIPEmbedder()
26
+ clip.eval()
27
+ clip.to(device)
28
+ elif llm=='t5':
29
+ t5 = T5Embedder(device=device)
30
+ else:
31
+ raise NotImplementedError
32
+
33
+ save_dir = f'./'
34
+
35
+ if llm=='clip':
36
+ latent, latent_and_others = clip.encode(prompts)
37
+ token_embedding = latent_and_others['token_embedding']
38
+ token_mask = latent_and_others['token_mask']
39
+ token = latent_and_others['tokens']
40
+ elif llm=='t5':
41
+ latent, latent_and_others = t5.get_text_embeddings(prompts)
42
+ token_embedding = latent_and_others['token_embedding'].to(torch.float32) * 10.0
43
+ token_mask = latent_and_others['token_mask']
44
+ token = latent_and_others['tokens']
45
+
46
+ for i in range(len(prompts)):
47
+ data = {'token_embedding': token_embedding[i].detach().cpu().numpy(),
48
+ 'token_mask': token_mask[i].detach().cpu().numpy(),
49
+ 'token': token[i].detach().cpu().numpy(),
50
+ 'batch_caption': prompts[i]}
51
+ np.save(os.path.join(save_dir, f'empty_context.npy'), data)
52
+
53
+
54
+
55
+ if __name__ == '__main__':
56
+ main()
scripts/extract_mscoco_feature.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is used to extract feature of the coco val set (to test zero-shot FID).
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8
+
9
+ import torch
10
+ import os
11
+ import numpy as np
12
+ from datasets import MSCOCODatabase
13
+ import argparse
14
+ from tqdm import tqdm
15
+
16
+ import libs.autoencoder
17
+ from libs.clip import FrozenCLIPEmbedder
18
+ from libs.t5 import T5Embedder
19
+
20
+
21
+ def main(resolution=256):
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument('--split', default='val')
24
+ args = parser.parse_args()
25
+ print(args)
26
+
27
+ if args.split == "val":
28
+ datas = MSCOCODatabase(root='/data/qihao/dataset/coco2014/val2014',
29
+ annFile='/data/qihao/dataset/coco2014/annotations/captions_val2014.json',
30
+ size=resolution)
31
+ save_dir = f'val'
32
+ else:
33
+ raise NotImplementedError
34
+
35
+ device = "cuda"
36
+ os.makedirs(save_dir, exist_ok=True)
37
+
38
+ autoencoder = libs.autoencoder.get_model('../assets/stable-diffusion/autoencoder_kl.pth')
39
+ autoencoder.to(device)
40
+
41
+ llm = 'clip'
42
+
43
+ if llm=='clip':
44
+ clip = FrozenCLIPEmbedder()
45
+ clip.eval()
46
+ clip.to(device)
47
+ elif llm=='t5':
48
+ t5 = T5Embedder(device=device)
49
+ else:
50
+ raise NotImplementedError
51
+
52
+ with torch.no_grad():
53
+ for idx, data in tqdm(enumerate(datas)):
54
+ x, captions = data
55
+
56
+ if len(x.shape) == 3:
57
+ x = x[None, ...]
58
+ x = torch.tensor(x, device=device)
59
+ moments = autoencoder(x, fn='encode_moments').squeeze(0)
60
+ moments = moments.detach().cpu().numpy()
61
+ np.save(os.path.join(save_dir, f'{idx}.npy'), moments)
62
+
63
+ if llm=='clip':
64
+ latent, latent_and_others = clip.encode(captions)
65
+ token_embedding = latent_and_others['token_embedding']
66
+ token_mask = latent_and_others['token_mask']
67
+ token = latent_and_others['tokens']
68
+ elif llm=='t5':
69
+ latent, latent_and_others = t5.get_text_embeddings(captions)
70
+ token_embedding = latent_and_others['token_embedding'].to(torch.float32) * 10.0
71
+ token_mask = latent_and_others['token_mask']
72
+ token = latent_and_others['tokens']
73
+
74
+ for i in range(len(captions)):
75
+ data = {'promt': captions[i],
76
+ 'token_embedding': token_embedding[i].detach().cpu().numpy(),
77
+ 'token_mask': token_mask[i].detach().cpu().numpy(),
78
+ 'token': token[i].detach().cpu().numpy()}
79
+ np.save(os.path.join(save_dir, f'{idx}_{i}.npy'), data)
80
+
81
+
82
+ if __name__ == '__main__':
83
+ main()
scripts/extract_test_prompt_feature.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is used to extract feature for visulization during training
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8
+
9
+ import torch
10
+ import os
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ import libs.autoencoder
15
+ from libs.clip import FrozenCLIPEmbedder
16
+ from libs.t5 import T5Embedder
17
+
18
+
19
+ def main():
20
+ prompts = [
21
+ 'A road with traffic lights, street lights and cars.',
22
+ 'A bus driving in a city area with traffic signs.',
23
+ 'A bus pulls over to the curb close to an intersection.',
24
+ 'A group of people are walking and one is holding an umbrella.',
25
+ 'A baseball player taking a swing at an incoming ball.',
26
+ 'A dog next to a white cat with black-tipped ears.',
27
+ 'A tiger standing on a rooftop while singing and jamming on an electric guitar under a spotlight. anime illustration.',
28
+ 'A bird wearing headphones and speaking into a high-end microphone in a recording studio.',
29
+ 'A bus made of cardboard.',
30
+ 'A tower in the mountains.',
31
+ 'Two cups of coffee, one with latte art of a cat. The other has latter art of a bird.',
32
+ 'Oil painting of a robot made of sushi, holding chopsticks.',
33
+ 'Portrait of a dog wearing a hat and holding a flag that has a yin-yang symbol on it.',
34
+ 'A teddy bear wearing a motorcycle helmet and cape is standing in front of Loch Awe with Kilchurn Castle behind him. dslr photo.',
35
+ 'A man standing on the moon',
36
+ ]
37
+ save_dir = f'run_vis'
38
+ os.makedirs(save_dir, exist_ok=True)
39
+
40
+ device = 'cuda'
41
+ llm = 'clip'
42
+
43
+ if llm=='clip':
44
+ clip = FrozenCLIPEmbedder()
45
+ clip.eval()
46
+ clip.to(device)
47
+ elif llm=='t5':
48
+ t5 = T5Embedder(device=device)
49
+ else:
50
+ raise NotImplementedError
51
+
52
+ if llm=='clip':
53
+ latent, latent_and_others = clip.encode(prompts)
54
+ token_embedding = latent_and_others['token_embedding']
55
+ token_mask = latent_and_others['token_mask']
56
+ token = latent_and_others['tokens']
57
+ elif llm=='t5':
58
+ latent, latent_and_others = t5.get_text_embeddings(prompts)
59
+ token_embedding = latent_and_others['token_embedding'].to(torch.float32) * 10.0
60
+ token_mask = latent_and_others['token_mask']
61
+ token = latent_and_others['tokens']
62
+
63
+ for i in range(len(prompts)):
64
+ data = {'promt': prompts[i],
65
+ 'token_embedding': token_embedding[i].detach().cpu().numpy(),
66
+ 'token_mask': token_mask[i].detach().cpu().numpy(),
67
+ 'token': token[i].detach().cpu().numpy()}
68
+ np.save(os.path.join(save_dir, f'{i}.npy'), data)
69
+
70
+
71
+ if __name__ == '__main__':
72
+ main()
scripts/extract_train_feature.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is used to extract feature of the demo training data.
3
+ """
4
+
5
+ import os
6
+ import shutil
7
+ import sys
8
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import os
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+ from PIL import Image
16
+ import io
17
+ import einops
18
+ import random
19
+ import json
20
+ import libs.autoencoder
21
+ from libs.clip import FrozenCLIPEmbedder
22
+ from libs.t5 import T5Embedder
23
+
24
+
25
+ def recreate_folder(folder_path):
26
+ if os.path.exists(folder_path):
27
+ shutil.rmtree(folder_path)
28
+ os.makedirs(folder_path)
29
+
30
+ def center_crop_arr(pil_image, image_size):
31
+ while min(*pil_image.size) >= 2 * image_size:
32
+ pil_image = pil_image.resize(
33
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
34
+ )
35
+
36
+ scale = image_size / min(*pil_image.size)
37
+ pil_image = pil_image.resize(
38
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
39
+ )
40
+
41
+ arr = np.array(pil_image)
42
+ crop_y = (arr.shape[0] - image_size) // 2
43
+ crop_x = (arr.shape[1] - image_size) // 2
44
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
45
+
46
+
47
+ def main(bz = 16):
48
+
49
+ json_path = '/path/to/JourneyDB_demo/img_text_pair.jsonl'
50
+ root_path = '/path/to/JourneyDB_demo/imgs'
51
+
52
+ dicts_list = []
53
+ with open(json_path, 'r', encoding='utf-8') as file:
54
+ for line in file:
55
+ dicts_list.append(json.loads(line))
56
+
57
+ save_dir = f'feature'
58
+ device = "cuda"
59
+ recreate_folder(save_dir)
60
+
61
+ autoencoder = libs.autoencoder.get_model('../assets/stable-diffusion/autoencoder_kl.pth')
62
+ autoencoder.to(device)
63
+
64
+ # CLIP model:
65
+ clip = FrozenCLIPEmbedder()
66
+ clip.eval()
67
+ clip.to(device)
68
+
69
+ # T5 model:
70
+ t5 = T5Embedder(device=device)
71
+
72
+ idx = 0
73
+ batch_img_256 = []
74
+ batch_img_512 = []
75
+ batch_caption = []
76
+ batch_name = []
77
+ for i, sample in enumerate(tqdm(dicts_list)):
78
+ try:
79
+ pil_image = Image.open(os.path.join(root_path,sample['img_path']))
80
+ caption = sample['prompt']
81
+ img_name = sample['img_path'].replace('.jpg','')
82
+
83
+ pil_image.load()
84
+ pil_image = pil_image.convert("RGB")
85
+ except:
86
+ with open("failed_file.txt", 'a+') as file:
87
+ file.write(sample['img_path'] + "\n")
88
+ continue
89
+
90
+ image_256 = center_crop_arr(pil_image, image_size=256)
91
+ image_512 = center_crop_arr(pil_image, image_size=512)
92
+
93
+ # if True:
94
+ # image_id = random.randint(0,20)
95
+ # Image.fromarray(image_256.astype(np.uint8)).save(f"temp_img_{image_id}_256.jpg")
96
+ # Image.fromarray(image_512.astype(np.uint8)).save(f"temp_img_{image_id}_512.jpg")
97
+
98
+ image_256 = (image_256 / 127.5 - 1.0).astype(np.float32)
99
+ image_256 = einops.rearrange(image_256, 'h w c -> c h w')
100
+ batch_img_256.append(image_256)
101
+
102
+ image_512 = (image_512 / 127.5 - 1.0).astype(np.float32)
103
+ image_512 = einops.rearrange(image_512, 'h w c -> c h w')
104
+ batch_img_512.append(image_512)
105
+
106
+ batch_caption.append(caption)
107
+ batch_name.append(img_name)
108
+
109
+ if len(batch_name) == bz or i == len(dicts_list) - 1:
110
+ batch_img_256 = torch.tensor(np.stack(batch_img_256)).to(device)
111
+ moments_256 = autoencoder(batch_img_256, fn='encode_moments').squeeze(0)
112
+ moments_256 = moments_256.detach().cpu().numpy()
113
+
114
+ batch_img_512 = torch.tensor(np.stack(batch_img_512)).to(device)
115
+ moments_512 = autoencoder(batch_img_512, fn='encode_moments').squeeze(0)
116
+ moments_512 = moments_512.detach().cpu().numpy()
117
+
118
+ _latent_clip, latent_and_others_clip = clip.encode(batch_caption)
119
+ token_embedding_clip = latent_and_others_clip['token_embedding'].detach().cpu().numpy()
120
+ token_mask_clip = latent_and_others_clip['token_mask'].detach().cpu().numpy()
121
+ token_clip = latent_and_others_clip['tokens'].detach().cpu().numpy()
122
+
123
+ _latent_t5, latent_and_others_t5 = t5.get_text_embeddings(batch_caption)
124
+ token_embedding_t5 = (latent_and_others_t5['token_embedding'].to(torch.float32) * 10.0).detach().cpu().numpy()
125
+ token_mask_t5 = latent_and_others_t5['token_mask'].detach().cpu().numpy()
126
+ token_t5 = latent_and_others_t5['tokens'].detach().cpu().numpy()
127
+
128
+ for mt_256, mt_512, te_c, te_t, tm_c, tm_t, tk_c, tk_t, bc, bn in zip(moments_256, moments_512, token_embedding_clip, token_embedding_t5, token_mask_clip, token_mask_t5, token_clip, token_t5, batch_caption, batch_name):
129
+ assert mt_256.shape == (8,32,32)
130
+ assert mt_512.shape == (8,64,64)
131
+ assert te_c.shape == (77, 768)
132
+ assert te_t.shape == (77, 4096)
133
+ tar_path_name = os.path.join(save_dir, f'{bn}.npy')
134
+ if os.path.exists(tar_path_name):
135
+ os.remove(tar_path_name)
136
+ data = {'image_latent_256': mt_256,
137
+ 'image_latent_512': mt_512,
138
+ 'token_embedding_clip': te_c,
139
+ 'token_embedding_t5': te_t,
140
+ 'token_mask_clip': tm_c,
141
+ 'token_mask_t5': tm_t,
142
+ 'token_clip': tk_c,
143
+ 'token_t5': tk_t,
144
+ 'batch_caption': bc}
145
+ try:
146
+ np.save(tar_path_name, data)
147
+ idx += 1
148
+ except:
149
+ pass
150
+
151
+ batch_img_256 = []
152
+ batch_img_512 = []
153
+ batch_caption = []
154
+ batch_name = []
155
+
156
+ print(f'save {idx} files')
157
+
158
+ if __name__ == '__main__':
159
+ main()
sde.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from absl import logging
4
+ import numpy as np
5
+ import math
6
+ from tqdm import tqdm
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def check_zip(*args):
11
+ args = [list(arg) for arg in args]
12
+ length = len(args[0])
13
+ for arg in args:
14
+ assert len(arg) == length
15
+ return zip(*args)
16
+
17
+ def get_sde(name, **kwargs):
18
+ if name == 'vpsde':
19
+ return VPSDE(**kwargs)
20
+ elif name == 'vpsde_cosine':
21
+ return VPSDECosine(**kwargs)
22
+ else:
23
+ raise NotImplementedError
24
+
25
+
26
+ def stp(s, ts: torch.Tensor): # scalar tensor product
27
+ if isinstance(s, np.ndarray):
28
+ s = torch.from_numpy(s).type_as(ts)
29
+ extra_dims = (1,) * (ts.dim() - 1)
30
+ return s.view(-1, *extra_dims) * ts
31
+
32
+
33
+ def mos(a, start_dim=1): # mean of square
34
+ return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
35
+
36
+
37
+ def duplicate(tensor, *size):
38
+ return tensor.unsqueeze(dim=0).expand(*size, *tensor.shape)
39
+
40
+
41
+ class SDE(object):
42
+ r"""
43
+ dx = f(x, t)dt + g(t) dw with 0 <= t <= 1
44
+ f(x, t) is the drift
45
+ g(t) is the diffusion
46
+ """
47
+ def drift(self, x, t):
48
+ raise NotImplementedError
49
+
50
+ def diffusion(self, t):
51
+ raise NotImplementedError
52
+
53
+ def cum_beta(self, t): # the variance of xt|x0
54
+ raise NotImplementedError
55
+
56
+ def cum_alpha(self, t):
57
+ raise NotImplementedError
58
+
59
+ def snr(self, t): # signal noise ratio
60
+ raise NotImplementedError
61
+
62
+ def nsr(self, t): # noise signal ratio
63
+ raise NotImplementedError
64
+
65
+ def marginal_prob(self, x0, t): # the mean and std of q(xt|x0)
66
+ alpha = self.cum_alpha(t)
67
+ beta = self.cum_beta(t)
68
+ mean = stp(alpha ** 0.5, x0) # E[xt|x0]
69
+ std = beta ** 0.5 # Cov[xt|x0] ** 0.5
70
+ return mean, std
71
+
72
+ def sample(self, x0, t_init=0): # sample from q(xn|x0), where n is uniform
73
+ t = torch.rand(x0.shape[0], device=x0.device) * (1. - t_init) + t_init
74
+ mean, std = self.marginal_prob(x0, t)
75
+ eps = torch.randn_like(x0)
76
+ xt = mean + stp(std, eps)
77
+ return t, eps, xt
78
+
79
+
80
+ class VPSDE(SDE):
81
+ def __init__(self, beta_min=0.1, beta_max=20):
82
+ # 0 <= t <= 1
83
+ self.beta_0 = beta_min
84
+ self.beta_1 = beta_max
85
+
86
+ def drift(self, x, t):
87
+ return -0.5 * stp(self.squared_diffusion(t), x)
88
+
89
+ def diffusion(self, t):
90
+ return self.squared_diffusion(t) ** 0.5
91
+
92
+ def squared_diffusion(self, t): # beta(t)
93
+ return self.beta_0 + t * (self.beta_1 - self.beta_0)
94
+
95
+ def squared_diffusion_integral(self, s, t): # \int_s^t beta(tau) d tau
96
+ return self.beta_0 * (t - s) + (self.beta_1 - self.beta_0) * (t ** 2 - s ** 2) * 0.5
97
+
98
+ def skip_beta(self, s, t): # beta_{t|s}, Cov[xt|xs]=beta_{t|s} I
99
+ return 1. - self.skip_alpha(s, t)
100
+
101
+ def skip_alpha(self, s, t): # alpha_{t|s}, E[xt|xs]=alpha_{t|s}**0.5 xs
102
+ x = -self.squared_diffusion_integral(s, t)
103
+ return x.exp()
104
+
105
+ def cum_beta(self, t):
106
+ return self.skip_beta(0, t)
107
+
108
+ def cum_alpha(self, t):
109
+ return self.skip_alpha(0, t)
110
+
111
+ def nsr(self, t):
112
+ nsr = self.squared_diffusion_integral(0, t).expm1()
113
+ nsr = nsr.clamp(max = 1e6, min = 1e-12)
114
+ return nsr
115
+
116
+ def snr(self, t):
117
+ snr = 1. / self.nsr(t)
118
+ snr = snr.clamp(max = 1e6, min = 1e-12)
119
+ return snr
120
+
121
+ def __str__(self):
122
+ return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}'
123
+
124
+ def __repr__(self):
125
+ return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}'
126
+
127
+
128
+ class VPSDECosine(SDE):
129
+ r"""
130
+ dx = f(x, t)dt + g(t) dw with 0 <= t <= 1
131
+ f(x, t) is the drift
132
+ g(t) is the diffusion
133
+ """
134
+ def __init__(self, s=0.008):
135
+ self.s = s
136
+ self.F = lambda t: torch.cos((t + s) / (1 + s) * math.pi / 2) ** 2
137
+ self.F0 = math.cos(s / (1 + s) * math.pi / 2) ** 2
138
+
139
+ def drift(self, x, t):
140
+ ft = - torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi / 2
141
+ return stp(ft, x)
142
+
143
+ def diffusion(self, t):
144
+ return (torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi) ** 0.5
145
+
146
+ def cum_beta(self, t): # the variance of xt|x0
147
+ return 1 - self.cum_alpha(t)
148
+
149
+ def cum_alpha(self, t):
150
+ return self.F(t) / self.F0
151
+
152
+ def snr(self, t): # signal noise ratio
153
+ Ft = self.F(t)
154
+ snr = Ft / (self.F0 - Ft)
155
+ snr = snr.clamp(max = 1e6, min = 1e-12)
156
+ return snr
157
+
158
+ def nsr(self, t): # noise signal ratio
159
+ Ft = self.F(t)
160
+ nsr = self.F0 / Ft - 1
161
+ nsr = nsr.clamp(max = 1e6, min = 1e-12)
162
+ return nsr
163
+
164
+ def __str__(self):
165
+ return 'vpsde_cosine'
166
+
167
+ def __repr__(self):
168
+ return 'vpsde_cosine'
169
+
170
+
171
+ class ScoreModel(object):
172
+ r"""
173
+ The forward process is q(x_[0,T])
174
+ """
175
+
176
+ def __init__(self, nnet: nn.Module, loss_coeffs:list, sde: SDE, using_cfg: bool = False, T=1):
177
+ assert T == 1
178
+ self.nnet = nnet
179
+ self.loss_coeffs = loss_coeffs
180
+ self.sde = sde
181
+ self.T = T
182
+ self.using_cfg = using_cfg
183
+ print(f'ScoreModel with loss_coeffs={loss_coeffs}, sde={sde}, T={T}')
184
+
185
+ def predict(self, xt, t, **kwargs):
186
+ if not isinstance(t, torch.Tensor):
187
+ t = torch.tensor(t)
188
+ t = t.to(xt.device)
189
+ if t.dim() == 0:
190
+ t = duplicate(t, xt.size(0))
191
+ log_snr = self.sde.snr(t).log()
192
+
193
+ return self.nnet(xt, t = t * 999, log_snr = log_snr, **kwargs) # follow SDE
194
+ # return self.nnet(xt, t = t, log_snr = log_snr, **kwargs) # follow SDE
195
+
196
+ def noise_pred(self, xt, t, sampling = True, **kwargs):
197
+ if sampling:
198
+ if self.using_cfg:
199
+ return self.predict(xt, t, **kwargs)
200
+ else:
201
+ return self.predict(xt, t, **kwargs)[-1]
202
+ else:
203
+ return self.predict(xt, t, **kwargs)
204
+
205
+ def score(self, xt, t, **kwargs):
206
+ cum_beta = self.sde.cum_beta(t)
207
+ noise_pred = self.noise_pred(xt, t, sampling = True, **kwargs)
208
+ return stp(-cum_beta.rsqrt(), noise_pred)
209
+
210
+
211
+ class ReverseSDE(object):
212
+ r"""
213
+ dx = [f(x, t) - g(t)^2 s(x, t)] dt + g(t) dw
214
+ """
215
+ def __init__(self, score_model):
216
+ self.sde = score_model.sde # the forward sde
217
+ self.score_model = score_model
218
+
219
+ def drift(self, x, t, **kwargs):
220
+ drift = self.sde.drift(x, t) # f(x, t)
221
+ diffusion = self.sde.diffusion(t) # g(t)
222
+ score = self.score_model.score(x, t, **kwargs)
223
+ return drift - stp(diffusion ** 2, score)
224
+
225
+ def diffusion(self, t):
226
+ return self.sde.diffusion(t)
227
+
228
+
229
+ class ODE(object):
230
+ r"""
231
+ dx = [f(x, t) - g(t)^2 s(x, t)] dt
232
+ """
233
+
234
+ def __init__(self, score_model):
235
+ self.sde = score_model.sde # the forward sde
236
+ self.score_model = score_model
237
+
238
+ def drift(self, x, t, **kwargs):
239
+ drift = self.sde.drift(x, t) # f(x, t)
240
+ diffusion = self.sde.diffusion(t) # g(t)
241
+ score = self.score_model.score(x, t, **kwargs)
242
+ return drift - 0.5 * stp(diffusion ** 2, score)
243
+
244
+ def diffusion(self, t):
245
+ return 0
246
+
247
+
248
+ def dct2str(dct):
249
+ return str({k: f'{v:.6g}' for k, v in dct.items()})
250
+
251
+
252
+ @ torch.no_grad()
253
+ def euler_maruyama(rsde, x_init, sample_steps, eps=1e-3, T=1, trace=None, verbose=False, **kwargs):
254
+ r"""
255
+ The Euler Maruyama sampler for reverse SDE / ODE
256
+ See `Score-Based Generative Modeling through Stochastic Differential Equations`
257
+ """
258
+ assert isinstance(rsde, ReverseSDE) or isinstance(rsde, ODE)
259
+ print(f"euler_maruyama with sample_steps={sample_steps}")
260
+ timesteps = np.append(0., np.linspace(eps, T, sample_steps))
261
+ timesteps = torch.tensor(timesteps).to(x_init)
262
+ x = x_init
263
+ if trace is not None:
264
+ trace.append(x)
265
+ for s, t in tqdm(list(zip(timesteps, timesteps[1:]))[::-1], disable=not verbose, desc='euler_maruyama'):
266
+ drift = rsde.drift(x, t, **kwargs)
267
+ diffusion = rsde.diffusion(t)
268
+ dt = s - t
269
+ mean = x + drift * dt
270
+ sigma = diffusion * (-dt).sqrt()
271
+ x = mean + stp(sigma, torch.randn_like(x)) if s != 0 else mean
272
+ if trace is not None:
273
+ trace.append(x)
274
+ statistics = dict(s=s, t=t, sigma=sigma.item())
275
+ logging.debug(dct2str(statistics))
276
+ return x
277
+
278
+
279
+ def LSimple(score_model: ScoreModel, x0, **kwargs):
280
+ t, noise, xt = score_model.sde.sample(x0)
281
+ prediction = score_model.noise_pred(xt, t, sampling = False, **kwargs)
282
+ target = multi_scale_targets(noise, levels = len(prediction), scale_correction = True)
283
+ loss = 0
284
+ for pred, coeff in check_zip(prediction, score_model.loss_coeffs):
285
+ loss = loss + coeff * mos(pred - target[pred.shape[-1]])
286
+ return loss
287
+
288
+
289
+ def odd_multi_scale_targets(target, levels, scale_correction):
290
+ B, C, H, W = target.shape
291
+ targets = {}
292
+ for l in range(levels):
293
+ ratio = int(2 ** l)
294
+ if ratio == 1:
295
+ targets[target.shape[-1]] = target
296
+ continue
297
+ assert (H - 1) % ratio == 0 and (W - 1) % ratio == 0
298
+ KS = ratio + 1
299
+ scale = KS if scale_correction else KS ** 2
300
+ kernel = torch.ones(C, 1, KS, KS, device = target.device) / scale
301
+ downsampled = F.conv2d(target, kernel, stride = ratio, padding = KS // 2, groups = C)
302
+ targets[downsampled.shape[-1]] = downsampled
303
+ return targets
304
+
305
+ def even_multi_scale_targets(target, levels, scale_correction):
306
+ B, C, H, W = target.shape
307
+ targets = {}
308
+ for l in range(levels):
309
+ ratio = int(2 ** l)
310
+ if ratio == 1:
311
+ targets[target.shape[-1]] = target
312
+ continue
313
+ assert H % ratio == 0 and W % ratio == 0
314
+ KS = ratio
315
+ scale = KS if scale_correction else KS ** 2
316
+ kernel = torch.ones(C, 1, KS, KS, device = target.device) / scale
317
+ downsampled = F.conv2d(target, kernel, stride = ratio, groups = C)
318
+ targets[downsampled.shape[-1]] = downsampled
319
+ return targets
320
+
321
+ def multi_scale_targets(target, levels, scale_correction):
322
+ B, C, H, W = target.shape
323
+ if H % 2 == 0:
324
+ return even_multi_scale_targets(target, levels, scale_correction)
325
+ else:
326
+ return odd_multi_scale_targets(target, levels, scale_correction)
tools/clip_score.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file computes the clip score given image and text pair
3
+ """
4
+ import clip
5
+ import torch
6
+ from PIL import Image
7
+ from sklearn.preprocessing import normalize
8
+ from torchvision.transforms import Compose, Normalize, Resize
9
+ import torch
10
+ import numpy as np
11
+
12
+ class ClipSocre:
13
+ def __init__(self,device='cuda', prefix='A photo depicts', weight=1.0): # weight=2.5
14
+ self.device = device
15
+
16
+ self.model, _ = clip.load("ViT-B/32", device=device, jit=False)
17
+ self.model.eval()
18
+
19
+ self.transform = Compose([
20
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
21
+ ])
22
+
23
+ self.prefix = prefix
24
+ if self.prefix[-1] != ' ':
25
+ self.prefix += ' '
26
+
27
+ self.w = weight
28
+
29
+ def extract_all_images(self, images):
30
+ images_input = self.transform(images)
31
+ if self.device == 'cuda':
32
+ images_input = images_input.to(torch.float16)
33
+ image_feature = self.model.encode_image(images_input)
34
+ return image_feature
35
+
36
+ def extract_all_texts(self, texts,need_prefix):
37
+ if need_prefix:
38
+ c_data = clip.tokenize(self.prefix + texts, truncate=True).to(self.device)
39
+ else:
40
+ c_data = clip.tokenize(texts, truncate=True).to(self.device)
41
+ text_feature = self.model.encode_text(c_data)
42
+ return text_feature
43
+
44
+ def get_clip_score(self, img, text, need_prefix=False):
45
+
46
+ img_f = self.extract_all_images(img)
47
+ text_f = self.extract_all_texts(text,need_prefix)
48
+ images = img_f / torch.sqrt(torch.sum(img_f**2, axis=1, keepdims=True))
49
+ candidates = text_f / torch.sqrt(torch.sum(text_f**2, axis=1, keepdims=True))
50
+
51
+ clip_per = self.w * torch.clip(torch.sum(images * candidates, axis=1), 0, None)
52
+
53
+ return clip_per
54
+
55
+ def get_text_clip_score(self, text_1, text_2, need_prefix=False):
56
+ text_1_f = self.extract_all_texts(text_1,need_prefix)
57
+ text_2_f = self.extract_all_texts(text_2,need_prefix)
58
+
59
+ candidates_1 = text_1_f / torch.sqrt(torch.sum(text_1_f**2, axis=1, keepdims=True))
60
+ candidates_2 = text_2_f / torch.sqrt(torch.sum(text_2_f**2, axis=1, keepdims=True))
61
+
62
+ per = self.w * torch.clip(torch.sum(candidates_1 * candidates_2, axis=1), 0, None)
63
+
64
+
65
+ results = 'ClipS : ' + str(format(per.item(),'.4f'))
66
+
67
+ print(results)
68
+
69
+ return per.sum()
70
+
71
+ def get_img_clip_score(self, img_1, img_2, weight = 1):
72
+
73
+ img_f_1 = self.extract_all_images(img_1)
74
+ img_f_2 = self.extract_all_images(img_2)
75
+
76
+ images_1 = img_f_1 / torch.sqrt(torch.sum(img_f_1**2, axis=1, keepdims=True))
77
+ images_2 = img_f_2 / torch.sqrt(torch.sum(img_f_2**2, axis=1, keepdims=True))
78
+
79
+ # per = self.w * torch.clip(torch.sum(images_1 * images_2, axis=1), 0, None)
80
+ per = weight * torch.clip(torch.sum(images_1 * images_2, axis=1), 0, None)
81
+
82
+
83
+ return per.sum()
84
+
85
+
86
+ def calculate_clip_score(self, caption_list, image_unprocessed):
87
+ image_unprocessed = 0.5 * (image_unprocessed + 1.)
88
+ image_unprocessed.clamp_(0., 1.)
89
+ img_resize = Resize((224))(image_unprocessed)
90
+ return self.get_clip_score(img_resize,caption_list)
tools/fid_score.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2
+
3
+ The FID metric calculates the distance between two distributions of images.
4
+ Typically, we have summary statistics (mean & covariance matrix) of one
5
+ of these distributions, while the 2nd distribution is given by a GAN.
6
+
7
+ When run as a stand-alone program, it compares the distribution of
8
+ images that are stored as PNG/JPEG at a specified location with a
9
+ distribution given by summary statistics (in pickle format).
10
+
11
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
12
+ the pool_3 layer of the inception net for generated samples and real world
13
+ samples respectively.
14
+
15
+ See --help to see further details.
16
+
17
+ Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18
+ of Tensorflow
19
+
20
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
21
+
22
+ Licensed under the Apache License, Version 2.0 (the "License");
23
+ you may not use this file except in compliance with the License.
24
+ You may obtain a copy of the License at
25
+
26
+ http://www.apache.org/licenses/LICENSE-2.0
27
+
28
+ Unless required by applicable law or agreed to in writing, software
29
+ distributed under the License is distributed on an "AS IS" BASIS,
30
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ See the License for the specific language governing permissions and
32
+ limitations under the License.
33
+ """
34
+ import os
35
+ import pathlib
36
+
37
+ import numpy as np
38
+ import torch
39
+ import torchvision.transforms as TF
40
+ from PIL import Image
41
+ from scipy import linalg
42
+ from torch.nn.functional import adaptive_avg_pool2d
43
+ from torchvision import transforms
44
+
45
+ try:
46
+ from tqdm import tqdm
47
+ except ImportError:
48
+ # If tqdm is not available, provide a mock version of it
49
+ def tqdm(x):
50
+ return x
51
+
52
+ from .inception import InceptionV3
53
+
54
+
55
+ IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
56
+ 'tif', 'tiff', 'webp'}
57
+
58
+
59
+ class ImagePathDataset(torch.utils.data.Dataset):
60
+ def __init__(self, files, transforms=None):
61
+ self.files = files
62
+ self.transforms = transforms
63
+
64
+ def __len__(self):
65
+ return len(self.files)
66
+
67
+ def __getitem__(self, i):
68
+ path = self.files[i]
69
+ img = Image.open(path).convert('RGB')
70
+
71
+ if img.size == (512,512):
72
+ img = img.resize((256, 256))
73
+
74
+ if self.transforms is not None:
75
+ img = self.transforms(img)
76
+ return img
77
+
78
+
79
+ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8):
80
+ """Calculates the activations of the pool_3 layer for all images.
81
+
82
+ Params:
83
+ -- files : List of image files paths
84
+ -- model : Instance of inception model
85
+ -- batch_size : Batch size of images for the model to process at once.
86
+ Make sure that the number of samples is a multiple of
87
+ the batch size, otherwise some samples are ignored. This
88
+ behavior is retained to match the original FID score
89
+ implementation.
90
+ -- dims : Dimensionality of features returned by Inception
91
+ -- device : Device to run calculations
92
+ -- num_workers : Number of parallel dataloader workers
93
+
94
+ Returns:
95
+ -- A numpy array of dimension (num images, dims) that contains the
96
+ activations of the given tensor when feeding inception with the
97
+ query tensor.
98
+ """
99
+ model.eval()
100
+
101
+ if batch_size > len(files):
102
+ print(('Warning: batch size is bigger than the data size. '
103
+ 'Setting batch size to data size'))
104
+ batch_size = len(files)
105
+
106
+ dataset = ImagePathDataset(files, transforms=TF.ToTensor())
107
+ dataloader = torch.utils.data.DataLoader(dataset,
108
+ batch_size=batch_size,
109
+ shuffle=False,
110
+ drop_last=False,
111
+ num_workers=num_workers)
112
+
113
+ pred_arr = np.empty((len(files), dims))
114
+
115
+ start_idx = 0
116
+
117
+ # resizer = transforms.Resize(256) # for clip
118
+
119
+ for batch in tqdm(dataloader):
120
+ batch = batch.to(device)
121
+
122
+ with torch.no_grad():
123
+
124
+ pred = model(batch)[0]
125
+
126
+ # If model output is not scalar, apply global spatial average pooling.
127
+ # This happens if you choose a dimensionality not equal 2048.
128
+ if pred.size(2) != 1 or pred.size(3) != 1:
129
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
130
+
131
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
132
+
133
+ pred_arr[start_idx:start_idx + pred.shape[0]] = pred
134
+
135
+ start_idx = start_idx + pred.shape[0]
136
+
137
+ return pred_arr
138
+
139
+
140
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
141
+ """Numpy implementation of the Frechet Distance.
142
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
143
+ and X_2 ~ N(mu_2, C_2) is
144
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
145
+
146
+ Stable version by Dougal J. Sutherland.
147
+
148
+ Params:
149
+ -- mu1 : Numpy array containing the activations of a layer of the
150
+ inception net (like returned by the function 'get_predictions')
151
+ for generated samples.
152
+ -- mu2 : The sample mean over activations, precalculated on an
153
+ representative data set.
154
+ -- sigma1: The covariance matrix over activations for generated samples.
155
+ -- sigma2: The covariance matrix over activations, precalculated on an
156
+ representative data set.
157
+
158
+ Returns:
159
+ -- : The Frechet Distance.
160
+ """
161
+
162
+ mu1 = np.atleast_1d(mu1)
163
+ mu2 = np.atleast_1d(mu2)
164
+
165
+ sigma1 = np.atleast_2d(sigma1)
166
+ sigma2 = np.atleast_2d(sigma2)
167
+
168
+ assert mu1.shape == mu2.shape, \
169
+ 'Training and test mean vectors have different lengths'
170
+ assert sigma1.shape == sigma2.shape, \
171
+ 'Training and test covariances have different dimensions'
172
+
173
+ diff = mu1 - mu2
174
+
175
+ # Product might be almost singular
176
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
177
+ if not np.isfinite(covmean).all():
178
+ msg = ('fid calculation produces singular product; '
179
+ 'adding %s to diagonal of cov estimates') % eps
180
+ print(msg)
181
+ offset = np.eye(sigma1.shape[0]) * eps
182
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
183
+
184
+ # Numerical error might give slight imaginary component
185
+ if np.iscomplexobj(covmean):
186
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
187
+ m = np.max(np.abs(covmean.imag))
188
+ raise ValueError('Imaginary component {}'.format(m))
189
+ covmean = covmean.real
190
+
191
+ tr_covmean = np.trace(covmean)
192
+
193
+ return (diff.dot(diff) + np.trace(sigma1)
194
+ + np.trace(sigma2) - 2 * tr_covmean)
195
+
196
+
197
+ def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
198
+ device='cpu', num_workers=8):
199
+ """Calculation of the statistics used by the FID.
200
+ Params:
201
+ -- files : List of image files paths
202
+ -- model : Instance of inception model
203
+ -- batch_size : The images numpy array is split into batches with
204
+ batch size batch_size. A reasonable batch size
205
+ depends on the hardware.
206
+ -- dims : Dimensionality of features returned by Inception
207
+ -- device : Device to run calculations
208
+ -- num_workers : Number of parallel dataloader workers
209
+
210
+ Returns:
211
+ -- mu : The mean over samples of the activations of the pool_3 layer of
212
+ the inception model.
213
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
214
+ the inception model.
215
+ """
216
+ act = get_activations(files, model, batch_size, dims, device, num_workers)
217
+ mu = np.mean(act, axis=0)
218
+ sigma = np.cov(act, rowvar=False)
219
+ return mu, sigma
220
+
221
+
222
+ def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8):
223
+ if path.endswith('.npz'):
224
+ with np.load(path) as f:
225
+ m, s = f['mu'][:], f['sigma'][:]
226
+ else:
227
+ path = pathlib.Path(path)
228
+ files = sorted([file for ext in IMAGE_EXTENSIONS
229
+ for file in path.glob('*.{}'.format(ext))])
230
+ m, s = calculate_activation_statistics(files, model, batch_size,
231
+ dims, device, num_workers)
232
+
233
+ return m, s
234
+
235
+
236
+ def save_statistics_of_path(path, out_path, device=None, batch_size=50, dims=2048, num_workers=8):
237
+ if device is None:
238
+ device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
239
+ else:
240
+ device = torch.device(device)
241
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
242
+ model = InceptionV3([block_idx]).to(device)
243
+ m1, s1 = compute_statistics_of_path(path, model, batch_size, dims, device, num_workers)
244
+ np.savez(out_path, mu=m1, sigma=s1)
245
+
246
+
247
+ def calculate_fid_given_paths(paths, device=None, batch_size=50, dims=2048, num_workers=8):
248
+ """Calculates the FID of two paths"""
249
+ if device is None:
250
+ device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
251
+ else:
252
+ device = torch.device(device)
253
+
254
+ for p in paths:
255
+ if not os.path.exists(p):
256
+ raise RuntimeError('Invalid path: %s' % p)
257
+
258
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
259
+
260
+ model = InceptionV3([block_idx]).to(device)
261
+
262
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
263
+ dims, device, num_workers)
264
+ m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
265
+ dims, device, num_workers)
266
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
267
+
268
+ return fid_value
tools/inception.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ try:
7
+ from torchvision.models.utils import load_state_dict_from_url
8
+ except ImportError:
9
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
14
+
15
+
16
+ class InceptionV3(nn.Module):
17
+ """Pretrained InceptionV3 network returning feature maps"""
18
+
19
+ # Index of default block of inception to return,
20
+ # corresponds to output of final average pooling
21
+ DEFAULT_BLOCK_INDEX = 3
22
+
23
+ # Maps feature dimensionality to their output blocks indices
24
+ BLOCK_INDEX_BY_DIM = {
25
+ 64: 0, # First max pooling features
26
+ 192: 1, # Second max pooling featurs
27
+ 768: 2, # Pre-aux classifier features
28
+ 2048: 3 # Final average pooling features
29
+ }
30
+
31
+ def __init__(self,
32
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
33
+ resize_input=True,
34
+ normalize_input=True,
35
+ requires_grad=False,
36
+ use_fid_inception=True):
37
+ """Build pretrained InceptionV3
38
+
39
+ Parameters
40
+ ----------
41
+ output_blocks : list of int
42
+ Indices of blocks to return features of. Possible values are:
43
+ - 0: corresponds to output of first max pooling
44
+ - 1: corresponds to output of second max pooling
45
+ - 2: corresponds to output which is fed to aux classifier
46
+ - 3: corresponds to output of final average pooling
47
+ resize_input : bool
48
+ If true, bilinearly resizes input to width and height 299 before
49
+ feeding input to model. As the network without fully connected
50
+ layers is fully convolutional, it should be able to handle inputs
51
+ of arbitrary size, so resizing might not be strictly needed
52
+ normalize_input : bool
53
+ If true, scales the input from range (0, 1) to the range the
54
+ pretrained Inception network expects, namely (-1, 1)
55
+ requires_grad : bool
56
+ If true, parameters of the model require gradients. Possibly useful
57
+ for finetuning the network
58
+ use_fid_inception : bool
59
+ If true, uses the pretrained Inception model used in Tensorflow's
60
+ FID implementation. If false, uses the pretrained Inception model
61
+ available in torchvision. The FID Inception model has different
62
+ weights and a slightly different structure from torchvision's
63
+ Inception model. If you want to compute FID scores, you are
64
+ strongly advised to set this parameter to true to get comparable
65
+ results.
66
+ """
67
+ super(InceptionV3, self).__init__()
68
+
69
+ self.resize_input = resize_input
70
+ self.normalize_input = normalize_input
71
+ self.output_blocks = sorted(output_blocks)
72
+ self.last_needed_block = max(output_blocks)
73
+
74
+ assert self.last_needed_block <= 3, \
75
+ 'Last possible output block index is 3'
76
+
77
+ self.blocks = nn.ModuleList()
78
+
79
+ if use_fid_inception:
80
+ inception = fid_inception_v3()
81
+ else:
82
+ inception = _inception_v3(pretrained=True)
83
+
84
+ # Block 0: input to maxpool1
85
+ block0 = [
86
+ inception.Conv2d_1a_3x3,
87
+ inception.Conv2d_2a_3x3,
88
+ inception.Conv2d_2b_3x3,
89
+ nn.MaxPool2d(kernel_size=3, stride=2)
90
+ ]
91
+ self.blocks.append(nn.Sequential(*block0))
92
+
93
+ # Block 1: maxpool1 to maxpool2
94
+ if self.last_needed_block >= 1:
95
+ block1 = [
96
+ inception.Conv2d_3b_1x1,
97
+ inception.Conv2d_4a_3x3,
98
+ nn.MaxPool2d(kernel_size=3, stride=2)
99
+ ]
100
+ self.blocks.append(nn.Sequential(*block1))
101
+
102
+ # Block 2: maxpool2 to aux classifier
103
+ if self.last_needed_block >= 2:
104
+ block2 = [
105
+ inception.Mixed_5b,
106
+ inception.Mixed_5c,
107
+ inception.Mixed_5d,
108
+ inception.Mixed_6a,
109
+ inception.Mixed_6b,
110
+ inception.Mixed_6c,
111
+ inception.Mixed_6d,
112
+ inception.Mixed_6e,
113
+ ]
114
+ self.blocks.append(nn.Sequential(*block2))
115
+
116
+ # Block 3: aux classifier to final avgpool
117
+ if self.last_needed_block >= 3:
118
+ block3 = [
119
+ inception.Mixed_7a,
120
+ inception.Mixed_7b,
121
+ inception.Mixed_7c,
122
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
123
+ ]
124
+ self.blocks.append(nn.Sequential(*block3))
125
+
126
+ for param in self.parameters():
127
+ param.requires_grad = requires_grad
128
+
129
+ def forward(self, inp):
130
+ """Get Inception feature maps
131
+
132
+ Parameters
133
+ ----------
134
+ inp : torch.autograd.Variable
135
+ Input tensor of shape Bx3xHxW. Values are expected to be in
136
+ range (0, 1)
137
+
138
+ Returns
139
+ -------
140
+ List of torch.autograd.Variable, corresponding to the selected output
141
+ block, sorted ascending by index
142
+ """
143
+ outp = []
144
+ x = inp
145
+
146
+ if self.resize_input:
147
+ x = F.interpolate(x,
148
+ size=(299, 299),
149
+ mode='bilinear',
150
+ align_corners=False)
151
+
152
+ if self.normalize_input:
153
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154
+
155
+ for idx, block in enumerate(self.blocks):
156
+ x = block(x)
157
+ if idx in self.output_blocks:
158
+ outp.append(x)
159
+
160
+ if idx == self.last_needed_block:
161
+ break
162
+
163
+ return outp
164
+
165
+
166
+ def _inception_v3(*args, **kwargs):
167
+ """Wraps `torchvision.models.inception_v3`
168
+
169
+ Skips default weight inititialization if supported by torchvision version.
170
+ See https://github.com/mseitzer/pytorch-fid/issues/28.
171
+ """
172
+ try:
173
+ version = tuple(map(int, torchvision.__version__.split('.')[:2]))
174
+ except ValueError:
175
+ # Just a caution against weird version strings
176
+ version = (0,)
177
+
178
+ if version >= (0, 6):
179
+ kwargs['init_weights'] = False
180
+
181
+ return torchvision.models.inception_v3(*args, **kwargs)
182
+
183
+
184
+ def fid_inception_v3():
185
+ """Build pretrained Inception model for FID computation
186
+
187
+ The Inception model for FID computation uses a different set of weights
188
+ and has a slightly different structure than torchvision's Inception.
189
+
190
+ This method first constructs torchvision's Inception and then patches the
191
+ necessary parts that are different in the FID Inception model.
192
+ """
193
+ inception = _inception_v3(num_classes=1008,
194
+ aux_logits=False,
195
+ pretrained=False)
196
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
197
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
198
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
199
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
200
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
201
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
202
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
203
+ inception.Mixed_7b = FIDInceptionE_1(1280)
204
+ inception.Mixed_7c = FIDInceptionE_2(2048)
205
+
206
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, model_dir="checkpoints", progress=True)
207
+ inception.load_state_dict(state_dict)
208
+ return inception
209
+
210
+
211
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
212
+ """InceptionA block patched for FID computation"""
213
+ def __init__(self, in_channels, pool_features):
214
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
215
+
216
+ def forward(self, x):
217
+ branch1x1 = self.branch1x1(x)
218
+
219
+ branch5x5 = self.branch5x5_1(x)
220
+ branch5x5 = self.branch5x5_2(branch5x5)
221
+
222
+ branch3x3dbl = self.branch3x3dbl_1(x)
223
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
224
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
225
+
226
+ # Patch: Tensorflow's average pool does not use the padded zero's in
227
+ # its average calculation
228
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
229
+ count_include_pad=False)
230
+ branch_pool = self.branch_pool(branch_pool)
231
+
232
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
233
+ return torch.cat(outputs, 1)
234
+
235
+
236
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
237
+ """InceptionC block patched for FID computation"""
238
+ def __init__(self, in_channels, channels_7x7):
239
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
240
+
241
+ def forward(self, x):
242
+ branch1x1 = self.branch1x1(x)
243
+
244
+ branch7x7 = self.branch7x7_1(x)
245
+ branch7x7 = self.branch7x7_2(branch7x7)
246
+ branch7x7 = self.branch7x7_3(branch7x7)
247
+
248
+ branch7x7dbl = self.branch7x7dbl_1(x)
249
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
250
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
251
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
252
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
253
+
254
+ # Patch: Tensorflow's average pool does not use the padded zero's in
255
+ # its average calculation
256
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
257
+ count_include_pad=False)
258
+ branch_pool = self.branch_pool(branch_pool)
259
+
260
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
261
+ return torch.cat(outputs, 1)
262
+
263
+
264
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
265
+ """First InceptionE block patched for FID computation"""
266
+ def __init__(self, in_channels):
267
+ super(FIDInceptionE_1, self).__init__(in_channels)
268
+
269
+ def forward(self, x):
270
+ branch1x1 = self.branch1x1(x)
271
+
272
+ branch3x3 = self.branch3x3_1(x)
273
+ branch3x3 = [
274
+ self.branch3x3_2a(branch3x3),
275
+ self.branch3x3_2b(branch3x3),
276
+ ]
277
+ branch3x3 = torch.cat(branch3x3, 1)
278
+
279
+ branch3x3dbl = self.branch3x3dbl_1(x)
280
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
281
+ branch3x3dbl = [
282
+ self.branch3x3dbl_3a(branch3x3dbl),
283
+ self.branch3x3dbl_3b(branch3x3dbl),
284
+ ]
285
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
286
+
287
+ # Patch: Tensorflow's average pool does not use the padded zero's in
288
+ # its average calculation
289
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
290
+ count_include_pad=False)
291
+ branch_pool = self.branch_pool(branch_pool)
292
+
293
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
294
+ return torch.cat(outputs, 1)
295
+
296
+
297
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
298
+ """Second InceptionE block patched for FID computation"""
299
+ def __init__(self, in_channels):
300
+ super(FIDInceptionE_2, self).__init__(in_channels)
301
+
302
+ def forward(self, x):
303
+ branch1x1 = self.branch1x1(x)
304
+
305
+ branch3x3 = self.branch3x3_1(x)
306
+ branch3x3 = [
307
+ self.branch3x3_2a(branch3x3),
308
+ self.branch3x3_2b(branch3x3),
309
+ ]
310
+ branch3x3 = torch.cat(branch3x3, 1)
311
+
312
+ branch3x3dbl = self.branch3x3dbl_1(x)
313
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
314
+ branch3x3dbl = [
315
+ self.branch3x3dbl_3a(branch3x3dbl),
316
+ self.branch3x3dbl_3b(branch3x3dbl),
317
+ ]
318
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
319
+
320
+ # Patch: The FID Inception model uses max pooling instead of average
321
+ # pooling. This is likely an error in this specific Inception
322
+ # implementation, as other Inception models use average pooling here
323
+ # (which matches the description in the paper).
324
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
325
+ branch_pool = self.branch_pool(branch_pool)
326
+
327
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
328
+ return torch.cat(outputs, 1)
train_t2i.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ import torch
3
+ from torch import multiprocessing as mp
4
+ from datasets import get_dataset
5
+ from torchvision.utils import make_grid, save_image
6
+ import utils
7
+ import einops
8
+ from torch.utils._pytree import tree_map
9
+ import accelerate
10
+ from torch.utils.data import DataLoader
11
+ from tqdm.auto import tqdm
12
+ import tempfile
13
+ from absl import logging
14
+ import builtins
15
+ import os
16
+ import wandb
17
+ import numpy as np
18
+ import time
19
+ import random
20
+
21
+ import libs.autoencoder
22
+ from libs.t5 import T5Embedder
23
+ from libs.clip import FrozenCLIPEmbedder
24
+ from diffusion.flow_matching import FlowMatching, ODEFlowMatchingSolver, ODEEulerFlowMatchingSolver
25
+ from tools.fid_score import calculate_fid_given_paths
26
+ from tools.clip_score import ClipSocre
27
+
28
+
29
+ def train(config):
30
+ if config.get('benchmark', False):
31
+ torch.backends.cudnn.benchmark = True
32
+ torch.backends.cudnn.deterministic = False
33
+
34
+ mp.set_start_method('spawn')
35
+ accelerator = accelerate.Accelerator()
36
+ device = accelerator.device
37
+ accelerate.utils.set_seed(config.seed, device_specific=True)
38
+ logging.info(f'Process {accelerator.process_index} using device: {device}')
39
+
40
+ config.mixed_precision = accelerator.mixed_precision
41
+ config = ml_collections.FrozenConfigDict(config)
42
+
43
+ assert config.train.batch_size % accelerator.num_processes == 0
44
+ mini_batch_size = config.train.batch_size // accelerator.num_processes
45
+
46
+ if accelerator.is_main_process:
47
+ os.makedirs(config.ckpt_root, exist_ok=True)
48
+ os.makedirs(config.sample_dir, exist_ok=True)
49
+ accelerator.wait_for_everyone()
50
+ if accelerator.is_main_process:
51
+ wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(),
52
+ name=config.hparams, job_type='train', mode='offline')
53
+ utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
54
+ logging.info(config)
55
+ else:
56
+ utils.set_logger(log_level='error')
57
+ builtins.print = lambda *args: None
58
+ logging.info(f'Run on {accelerator.num_processes} devices')
59
+
60
+ dataset = get_dataset(**config.dataset)
61
+ assert os.path.exists(dataset.fid_stat)
62
+
63
+ gpu_model = torch.cuda.get_device_name(torch.cuda.current_device())
64
+ num_workers = 8
65
+
66
+ train_dataset = dataset.get_split(split='train', labeled=True)
67
+ train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
68
+ num_workers=num_workers, pin_memory=True, persistent_workers=True)
69
+
70
+ test_dataset = dataset.get_split(split='test', labeled=True) # for sampling
71
+ test_dataset_loader = DataLoader(test_dataset, batch_size=config.sample.mini_batch_size, shuffle=True, drop_last=True,
72
+ num_workers=num_workers, pin_memory=True, persistent_workers=True)
73
+
74
+ train_state = utils.initialize_train_state(config, device)
75
+ nnet, nnet_ema, optimizer, train_dataset_loader, test_dataset_loader = accelerator.prepare(
76
+ train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader, test_dataset_loader)
77
+ lr_scheduler = train_state.lr_scheduler
78
+ train_state.resume(config.ckpt_root)
79
+
80
+ autoencoder = libs.autoencoder.get_model(**config.autoencoder)
81
+ autoencoder.to(device)
82
+
83
+ if config.nnet.model_args.clip_dim == 4096:
84
+ llm = "t5"
85
+ t5 = T5Embedder(device=device)
86
+ elif config.nnet.model_args.clip_dim == 768:
87
+ llm = "clip"
88
+ clip = FrozenCLIPEmbedder()
89
+ clip.eval()
90
+ clip.to(device)
91
+ else:
92
+ raise NotImplementedError
93
+
94
+ ss_empty_context = None
95
+
96
+ ClipSocre_model = ClipSocre(device=device)
97
+
98
+ @ torch.cuda.amp.autocast()
99
+ def encode(_batch):
100
+ return autoencoder.encode(_batch)
101
+
102
+ @ torch.cuda.amp.autocast()
103
+ def decode(_batch):
104
+ return autoencoder.decode(_batch)
105
+
106
+ def get_data_generator():
107
+ while True:
108
+ for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
109
+ yield data
110
+
111
+ data_generator = get_data_generator()
112
+
113
+ def get_context_generator(autoencoder):
114
+ while True:
115
+ for data in test_dataset_loader:
116
+ if len(data) == 5:
117
+ _img, _context, _token_mask, _token, _caption = data
118
+ else:
119
+ _img, _context = data
120
+ _token_mask = None
121
+ _token = None
122
+ _caption = None
123
+
124
+ if len(_img.shape)==5:
125
+ _testbatch_img_blurred = autoencoder.sample(_img[:,1,:])
126
+ yield _context, _token_mask, _token, _caption, _testbatch_img_blurred
127
+ else:
128
+ assert len(_img.shape)==4
129
+ yield _context, _token_mask, _token, _caption, None
130
+
131
+ context_generator = get_context_generator(autoencoder)
132
+
133
+ _flow_mathcing_model = FlowMatching()
134
+
135
+ def train_step(_batch, _ss_empty_context):
136
+ _metrics = dict()
137
+ optimizer.zero_grad()
138
+
139
+ assert len(_batch)==6
140
+ assert not config.dataset.cfg
141
+ _batch_img = _batch[0]
142
+ _batch_con = _batch[1]
143
+ _batch_mask = _batch[2]
144
+ _batch_token = _batch[3]
145
+ _batch_caption = _batch[4]
146
+ _batch_img_ori = _batch[5]
147
+
148
+ _z = autoencoder.sample(_batch_img)
149
+
150
+ loss, loss_dict = _flow_mathcing_model(_z, nnet, loss_coeffs=config.loss_coeffs, cond=_batch_con, con_mask=_batch_mask, batch_img_clip=_batch_img_ori, \
151
+ nnet_style=config.nnet.name, text_token=_batch_token, model_config=config.nnet.model_args, all_config=config, training_step=train_state.step)
152
+
153
+ _metrics['loss'] = accelerator.gather(loss.detach()).mean()
154
+ for key in loss_dict.keys():
155
+ _metrics[key] = accelerator.gather(loss_dict[key].detach()).mean()
156
+ accelerator.backward(loss.mean())
157
+ optimizer.step()
158
+ lr_scheduler.step()
159
+ train_state.ema_update(config.get('ema_rate', 0.9999))
160
+ train_state.step += 1
161
+ return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)
162
+
163
+ def ode_fm_solver_sample(nnet_ema, _n_samples, _sample_steps, context=None, caption=None, testbatch_img_blurred=None, two_stage_generation=-1, token_mask=None, return_clipScore=False, ClipSocre_model=None):
164
+ with torch.no_grad():
165
+ _z_gaussian = torch.randn(_n_samples, *config.z_shape, device=device)
166
+
167
+ _z_x0, _mu, _log_var = nnet_ema(context, text_encoder = True, shape = _z_gaussian.shape, mask=token_mask)
168
+ _z_init = _z_x0.reshape(_z_gaussian.shape)
169
+
170
+ assert config.sample.scale > 1
171
+ _cfg = config.sample.scale
172
+
173
+ has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator")
174
+
175
+ ode_solver = ODEEulerFlowMatchingSolver(nnet_ema, step_size_type="step_in_dsigma", guidance_scale=_cfg)
176
+ _z, _ = ode_solver.sample(x_T=_z_init, batch_size=_n_samples, sample_steps=_sample_steps, unconditional_guidance_scale=_cfg, has_null_indicator=has_null_indicator)
177
+
178
+ image_unprocessed = decode(_z)
179
+
180
+ if return_clipScore:
181
+ clip_score = ClipSocre_model.calculate_clip_score(caption, image_unprocessed)
182
+ return image_unprocessed, clip_score
183
+ else:
184
+ return image_unprocessed
185
+
186
+ def eval_step(n_samples, sample_steps):
187
+ logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm=ODE_Euler_Flow_Matching_Solver, '
188
+ f'mini_batch_size={config.sample.mini_batch_size}')
189
+
190
+ def sample_fn(_n_samples, return_caption=False, return_clipScore=False, ClipSocre_model=None, config=None):
191
+ _context, _token_mask, _token, _caption, _testbatch_img_blurred = next(context_generator)
192
+ assert _context.size(0) == _n_samples
193
+ assert not return_caption # during training we should not use this
194
+ if return_caption:
195
+ return ode_fm_solver_sample(nnet_ema, _n_samples, sample_steps, context=_context, token_mask=_token_mask), _caption
196
+ elif return_clipScore:
197
+ return ode_fm_solver_sample(nnet_ema, _n_samples, sample_steps, context=_context, token_mask=_token_mask, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, caption=_caption)
198
+ else:
199
+ return ode_fm_solver_sample(nnet_ema, _n_samples, sample_steps, context=_context, token_mask=_token_mask)
200
+
201
+ with tempfile.TemporaryDirectory() as temp_path:
202
+ path = config.sample.path or temp_path
203
+ if accelerator.is_main_process:
204
+ os.makedirs(path, exist_ok=True)
205
+ clip_score_list = utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess, return_clipScore=True, ClipSocre_model=ClipSocre_model, config=config)
206
+ _fid = 0
207
+ if accelerator.is_main_process:
208
+ _fid = calculate_fid_given_paths((dataset.fid_stat, path))
209
+ _clip_score_list = torch.cat(clip_score_list)
210
+ logging.info(f'step={train_state.step} fid{n_samples}={_fid} clip_score{len(_clip_score_list)} = {_clip_score_list.mean().item()}')
211
+ with open(os.path.join(config.workdir, 'eval.log'), 'a') as f:
212
+ print(f'step={train_state.step} fid{n_samples}={_fid} clip_score{len(_clip_score_list)} = {_clip_score_list.mean().item()}', file=f)
213
+ wandb.log({f'fid{n_samples}': _fid}, step=train_state.step)
214
+ _fid = torch.tensor(_fid, device=device)
215
+ _fid = accelerator.reduce(_fid, reduction='sum')
216
+
217
+ return _fid.item()
218
+
219
+ logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')
220
+
221
+ step_fid = []
222
+ while train_state.step < config.train.n_steps:
223
+ nnet.train()
224
+ batch = tree_map(lambda x: x, next(data_generator))
225
+ metrics = train_step(batch, ss_empty_context)
226
+
227
+ nnet.eval()
228
+ if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
229
+ logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
230
+ logging.info(config.workdir)
231
+ wandb.log(metrics, step=train_state.step)
232
+
233
+ ############# save rigid image
234
+ if train_state.step % config.train.eval_interval == 0:
235
+ torch.cuda.empty_cache()
236
+ logging.info('Save a grid of images...')
237
+ if hasattr(dataset, "token_embedding"):
238
+ contexts = torch.tensor(dataset.token_embedding, device=device)[ : config.train.n_samples_eval]
239
+ token_mask = torch.tensor(dataset.token_mask, device=device)[ : config.train.n_samples_eval]
240
+ elif hasattr(dataset, "contexts"):
241
+ contexts = torch.tensor(dataset.contexts, device=device)[ : config.train.n_samples_eval]
242
+ token_mask = None
243
+ else:
244
+ raise NotImplementedError
245
+ samples = ode_fm_solver_sample(nnet_ema, _n_samples=config.train.n_samples_eval, _sample_steps=50, context=contexts, token_mask=token_mask)
246
+ samples = make_grid(dataset.unpreprocess(samples), 5)
247
+ if accelerator.is_main_process:
248
+ save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png'))
249
+ wandb.log({'samples': wandb.Image(samples)}, step=train_state.step)
250
+ accelerator.wait_for_everyone()
251
+ torch.cuda.empty_cache()
252
+
253
+ ############ save checkpoint and evaluate results
254
+ if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
255
+ torch.cuda.empty_cache()
256
+ logging.info(f'Save and eval checkpoint {train_state.step}...')
257
+
258
+ if accelerator.local_process_index == 0:
259
+ train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
260
+ accelerator.wait_for_everyone()
261
+
262
+ fid = eval_step(n_samples=10000, sample_steps=50) # calculate fid of the saved checkpoint
263
+ step_fid.append((train_state.step, fid))
264
+
265
+ torch.cuda.empty_cache()
266
+ accelerator.wait_for_everyone()
267
+
268
+ logging.info(f'Finish fitting, step={train_state.step}')
269
+ logging.info(f'step_fid: {step_fid}')
270
+ step_best = sorted(step_fid, key=lambda x: x[1])[0][0]
271
+ logging.info(f'step_best: {step_best}')
272
+ train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt'))
273
+ del metrics
274
+ accelerator.wait_for_everyone()
275
+ eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps)
276
+
277
+
278
+
279
+ from absl import flags
280
+ from absl import app
281
+ from ml_collections import config_flags
282
+ import sys
283
+ from pathlib import Path
284
+
285
+
286
+ FLAGS = flags.FLAGS
287
+ config_flags.DEFINE_config_file(
288
+ "config", None, "Training configuration.", lock_config=False)
289
+ flags.mark_flags_as_required(["config"])
290
+ flags.DEFINE_string("workdir", None, "Work unit directory.")
291
+
292
+
293
+ def get_config_name():
294
+ argv = sys.argv
295
+ for i in range(1, len(argv)):
296
+ if argv[i].startswith('--config='):
297
+ return Path(argv[i].split('=')[-1]).stem
298
+
299
+
300
+ def get_hparams():
301
+ argv = sys.argv
302
+ lst = []
303
+ for i in range(1, len(argv)):
304
+ assert '=' in argv[i]
305
+ if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'):
306
+ hparam, val = argv[i].split('=')
307
+ hparam = hparam.split('.')[-1]
308
+ if hparam.endswith('path'):
309
+ val = Path(val).stem
310
+ lst.append(f'{hparam}={val}')
311
+ hparams = '-'.join(lst)
312
+ if hparams == '':
313
+ hparams = 'default'
314
+ return hparams
315
+
316
+
317
+ def main(argv):
318
+ config = FLAGS.config
319
+ config.config_name = get_config_name()
320
+ config.hparams = get_hparams()
321
+ config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams)
322
+ config.ckpt_root = os.path.join(config.workdir, 'ckpts')
323
+ config.sample_dir = os.path.join(config.workdir, 'samples')
324
+ train(config)
325
+
326
+
327
+ if __name__ == "__main__":
328
+ app.run(main)
utils.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains some tools
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ import os
8
+ from tqdm import tqdm
9
+ from torchvision import transforms
10
+ from torchvision.utils import save_image
11
+ from absl import logging
12
+ from PIL import Image, ImageDraw, ImageFont
13
+ import textwrap
14
+
15
+ def save_image_with_caption(image_tensor, caption, filename, font_size=20, font_path='/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf'):
16
+ """
17
+ Save an image with a caption
18
+ """
19
+ image_tensor = image_tensor.clone().detach()
20
+ image_tensor = torch.clamp(image_tensor, min=0, max=1)
21
+ image_pil = transforms.ToPILImage()(image_tensor)
22
+ draw = ImageDraw.Draw(image_pil)
23
+
24
+ font = ImageFont.truetype(font_path, font_size)
25
+ wrap_text = textwrap.wrap(caption, width=len(caption)//4 + 1)
26
+ text_sizes = [draw.textsize(line, font=font) for line in wrap_text]
27
+ max_text_width = max(size[0] for size in text_sizes)
28
+ total_text_height = sum(size[1] for size in text_sizes) + 15
29
+
30
+ new_height = image_pil.height + total_text_height + 25
31
+ new_image = Image.new('RGB', (image_pil.width, new_height), 'white')
32
+ new_image.paste(image_pil, (0, 0))
33
+ current_y = image_pil.height + 5
34
+ draw = ImageDraw.Draw(new_image)
35
+
36
+ for line, size in zip(wrap_text, text_sizes):
37
+ x = (new_image.width - size[0]) / 2
38
+ draw.text((x, current_y), line, font=font, fill='black')
39
+ current_y += size[1] + 5
40
+ new_image.save(filename)
41
+
42
+
43
+ def set_logger(log_level='info', fname=None):
44
+ import logging as _logging
45
+ handler = logging.get_absl_handler()
46
+ formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s')
47
+ handler.setFormatter(formatter)
48
+ logging.set_verbosity(log_level)
49
+ if fname is not None:
50
+ handler = _logging.FileHandler(fname)
51
+ handler.setFormatter(formatter)
52
+ logging.get_absl_logger().addHandler(handler)
53
+
54
+
55
+ def dct2str(dct):
56
+ return str({k: f'{v:.6g}' for k, v in dct.items()})
57
+
58
+
59
+ def get_nnet(name, **kwargs):
60
+ if name == 'dimr':
61
+ from libs.model.dimr_t2i import MRModel
62
+ return MRModel(kwargs["model_args"])
63
+ elif name == 'dit':
64
+ from libs.model.dit_t2i import DiT_H_2
65
+ return DiT_H_2(kwargs["model_args"])
66
+ else:
67
+ raise NotImplementedError(name)
68
+
69
+
70
+ def set_seed(seed: int):
71
+ if seed is not None:
72
+ torch.manual_seed(seed)
73
+ np.random.seed(seed)
74
+
75
+
76
+ def get_optimizer(params, name, **kwargs):
77
+ if name == 'adam':
78
+ from torch.optim import Adam
79
+ return Adam(params, **kwargs)
80
+ elif name == 'adamw':
81
+ from torch.optim import AdamW
82
+ return AdamW(params, **kwargs)
83
+ else:
84
+ raise NotImplementedError(name)
85
+
86
+
87
+ def customized_lr_scheduler(optimizer, warmup_steps=-1):
88
+ from torch.optim.lr_scheduler import LambdaLR
89
+ def fn(step):
90
+ if warmup_steps > 0:
91
+ return min(step / warmup_steps, 1)
92
+ else:
93
+ return 1
94
+ return LambdaLR(optimizer, fn)
95
+
96
+
97
+ def get_lr_scheduler(optimizer, name, **kwargs):
98
+ if name == 'customized':
99
+ return customized_lr_scheduler(optimizer, **kwargs)
100
+ elif name == 'cosine':
101
+ from torch.optim.lr_scheduler import CosineAnnealingLR
102
+ return CosineAnnealingLR(optimizer, **kwargs)
103
+ else:
104
+ raise NotImplementedError(name)
105
+
106
+
107
+ def ema(model_dest: nn.Module, model_src: nn.Module, rate):
108
+ param_dict_src = dict(model_src.named_parameters())
109
+ for p_name, p_dest in model_dest.named_parameters():
110
+ p_src = param_dict_src[p_name]
111
+ assert p_src is not p_dest
112
+ p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
113
+
114
+
115
+ class TrainState(object):
116
+ def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None):
117
+ self.optimizer = optimizer
118
+ self.lr_scheduler = lr_scheduler
119
+ self.step = step
120
+ self.nnet = nnet
121
+ self.nnet_ema = nnet_ema
122
+
123
+ def ema_update(self, rate=0.9999):
124
+ if self.nnet_ema is not None:
125
+ ema(self.nnet_ema, self.nnet, rate)
126
+
127
+ def save(self, path):
128
+ os.makedirs(path, exist_ok=True)
129
+ torch.save(self.step, os.path.join(path, 'step.pth'))
130
+ for key, val in self.__dict__.items():
131
+ if key != 'step' and val is not None:
132
+ torch.save(val.state_dict(), os.path.join(path, f'{key}.pth'))
133
+
134
+ def load(self, path):
135
+ logging.info(f'load from {path}')
136
+ self.step = torch.load(os.path.join(path, 'step.pth'))
137
+ for key, val in self.__dict__.items():
138
+ if key != 'step' and val is not None:
139
+ val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu'))
140
+
141
+ def resume(self, ckpt_root, step=None):
142
+ if not os.path.exists(ckpt_root):
143
+ return
144
+ if step is None:
145
+ ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root)))
146
+ if not ckpts:
147
+ return
148
+ steps = map(lambda x: int(x.split(".")[0]), ckpts)
149
+ step = max(steps)
150
+ ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt')
151
+ logging.info(f'resume from {ckpt_path}')
152
+ self.load(ckpt_path)
153
+
154
+ def to(self, device):
155
+ for key, val in self.__dict__.items():
156
+ if isinstance(val, nn.Module):
157
+ val.to(device)
158
+
159
+
160
+ def trainable_parameters(nnet):
161
+ params_decay = []
162
+ params_nodecay = []
163
+ for name, param in nnet.named_parameters():
164
+ if name.endswith(".nodecay_weight") or name.endswith(".nodecay_bias"):
165
+ params_nodecay.append(param)
166
+ else:
167
+ params_decay.append(param)
168
+ print("params_decay", len(params_decay))
169
+ print("params_nodecay", len(params_nodecay))
170
+ params = [
171
+ {'params': params_decay},
172
+ {'params': params_nodecay, 'weight_decay': 0.0}
173
+ ]
174
+ return params
175
+
176
+
177
+ def initialize_train_state(config, device):
178
+
179
+ nnet = get_nnet(**config.nnet)
180
+ nnet_ema = get_nnet(**config.nnet)
181
+ nnet_ema.eval()
182
+
183
+ optimizer = get_optimizer(trainable_parameters(nnet), **config.optimizer)
184
+ lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler)
185
+
186
+ train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0,
187
+ nnet=nnet, nnet_ema=nnet_ema)
188
+ train_state.ema_update(0)
189
+ train_state.to(device)
190
+ return train_state
191
+
192
+
193
+ def amortize(n_samples, batch_size):
194
+ k = n_samples // batch_size
195
+ r = n_samples % batch_size
196
+ return k * [batch_size] if r == 0 else k * [batch_size] + [r]
197
+
198
+
199
+ def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, return_clipScore=False, ClipSocre_model=None, config=None):
200
+ os.makedirs(path, exist_ok=True)
201
+ idx = 0
202
+ batch_size = mini_batch_size * accelerator.num_processes
203
+ clip_score_list = []
204
+
205
+ if return_clipScore:
206
+ assert ClipSocre_model is not None
207
+
208
+ for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'):
209
+ samples, clip_score = sample_fn(mini_batch_size, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, config=config)
210
+ samples = unpreprocess_fn(samples)
211
+ samples = accelerator.gather(samples.contiguous())[:_batch_size]
212
+ clip_score_list.append(accelerator.gather(clip_score)[:_batch_size])
213
+ if accelerator.is_main_process:
214
+ for sample in samples:
215
+ save_image(sample, os.path.join(path, f"{idx}.png"))
216
+ idx += 1
217
+
218
+ if return_clipScore:
219
+ return clip_score_list
220
+ else:
221
+ return None
222
+
223
+
224
+ def sample2dir_wCLIP(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, return_clipScore=False, ClipSocre_model=None, config=None):
225
+ os.makedirs(path, exist_ok=True)
226
+ idx = 0
227
+ batch_size = mini_batch_size * accelerator.num_processes
228
+ clip_score_list = []
229
+
230
+ if return_clipScore:
231
+ assert ClipSocre_model is not None
232
+
233
+ for _batch_size in amortize(n_samples, batch_size):
234
+ samples, clip_score = sample_fn(mini_batch_size, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, config=config)
235
+ samples = unpreprocess_fn(samples)
236
+ samples = accelerator.gather(samples.contiguous())[:_batch_size]
237
+ clip_score_list.append(accelerator.gather(clip_score)[:_batch_size])
238
+ if accelerator.is_main_process:
239
+ for sample in samples:
240
+ save_image(sample, os.path.join(path, f"{idx}.png"))
241
+ idx += 1
242
+ break
243
+
244
+ if return_clipScore:
245
+ return clip_score_list
246
+ else:
247
+ return None
248
+
249
+
250
+ def sample2dir_wPrompt(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, config=None):
251
+ os.makedirs(path, exist_ok=True)
252
+ idx = 0
253
+ batch_size = mini_batch_size * accelerator.num_processes
254
+
255
+ for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'):
256
+ samples, samples_caption = sample_fn(mini_batch_size, return_caption=True, config=config)
257
+ samples = unpreprocess_fn(samples)
258
+ samples = accelerator.gather(samples.contiguous())[:_batch_size]
259
+ if accelerator.is_main_process:
260
+ for sample, caption in zip(samples,samples_caption):
261
+ try:
262
+ save_image_with_caption(sample, caption, os.path.join(path, f"{idx}.png"))
263
+ except:
264
+ save_image(sample, os.path.join(path, f"{idx}.png"))
265
+ idx += 1
266
+
267
+
268
+ def grad_norm(model):
269
+ total_norm = 0.
270
+ for p in model.parameters():
271
+ param_norm = p.grad.data.norm(2)
272
+ total_norm += param_norm.item() ** 2
273
+ total_norm = total_norm ** (1. / 2)
274
+ return total_norm