Klayand commited on
Commit
65ccd88
·
1 Parent(s): bafc662

submit app.py

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import math
4
+ import csv
5
+ import random
6
+ import argparse
7
+ import torch
8
+ import os
9
+ import torch.distributed as dist
10
+ import gradio as gr
11
+ from PIL import Image
12
+ from torch.nn.parallel import DistributedDataParallel as DDP
13
+ import spaces
14
+ from accelerate.utils import set_seed
15
+
16
+ from diffusion_pipeline.sd35_pipeline import StableDiffusion3Pipeline, FlowMatchEulerInverseScheduler
17
+ from diffusion_pipeline.sdxl_pipeline import StableDiffusionXLPipeline
18
+ from diffusers import BitsAndBytesConfig, SD3Transformer2DModel
19
+ from diffusers import FlowMatchEulerDiscreteScheduler, DDIMInverseScheduler, DDIMScheduler
20
+
21
+ device = torch.device('cuda')
22
+
23
+
24
+ @spaces.GPU
25
+ def generate_image(
26
+ model_name,
27
+ seed,
28
+ num_steps,
29
+ guidance_scale,
30
+ inv_cfg,
31
+ w2s_guidance,
32
+ end_timesteps,
33
+ prompt,
34
+ method,
35
+ size,
36
+ ):
37
+ try:
38
+ # 根据传入的参数生成图像
39
+ torch.cuda.empty_cache()
40
+ dtype = torch.float16
41
+ set_seed(seed)
42
+ if model_name == 'sd35':
43
+ nf4_config = BitsAndBytesConfig(
44
+ load_in_4bit=True,
45
+ bnb_4bit_quant_type="nf4",
46
+ bnb_4bit_compute_dtype=torch.bfloat16
47
+ )
48
+ model_nf4 = SD3Transformer2DModel.from_pretrained(
49
+ "stabilityai/stable-diffusion-3.5-large",
50
+ subfolder="transformer",
51
+ quantization_config=nf4_config,
52
+ torch_dtype=torch.bfloat16
53
+ )
54
+
55
+ pipe = StableDiffusion3Pipeline.from_pretrained(
56
+ "stabilityai/stable-diffusion-3.5-large",
57
+ transformer=model_nf4,
58
+ torch_dtype=torch.bfloat16,
59
+ )
60
+
61
+ pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
62
+ inverse_scheduler = FlowMatchEulerInverseScheduler.from_pretrained("stabilityai/stable-diffusion-3.5-large",
63
+ subfolder='scheduler')
64
+ pipe.inv_scheduler = inverse_scheduler
65
+
66
+ elif model_name == "sdxl":
67
+ pipe = StableDiffusionXLPipeline.from_pretrained(
68
+ "stabilityai/stable-diffusion-xl-base-1.0",
69
+ torch_dtype=torch.float16,
70
+ variant="fp16",
71
+ use_safetensors=True
72
+ ).to("cuda")
73
+
74
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
75
+ inverse_scheduler = DDIMInverseScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
76
+ subfolder='scheduler')
77
+ pipe.inv_scheduler = inverse_scheduler
78
+
79
+ pipe.to(device)
80
+ pipe.enable_model_cpu_offload()
81
+
82
+ # TODO: load noise model
83
+ if method == 'core' or method == 'z-core':
84
+ from diffusion_pipeline.refine_model import PromptSD35Net, PromptSDXLNet
85
+ from diffusion_pipeline.lora import replace_linear_with_lora, lora_true
86
+
87
+ if model_name == 'sd35':
88
+ refine_model = PromptSD35Net()
89
+ replace_linear_with_lora(refine_model, rank=64, alpha=1.0, number_of_lora=28)
90
+ lora_true(refine_model, lora_idx=0)
91
+
92
+ os.makedirs('./weights', exist_ok=True)
93
+ if not os.path.exists('./weights/sd35_noise_model.pth'):
94
+ os.system('wget https://huggingface.co/sst12345/CoRe2/resolve/main/weights/sd35_noise_model.pth')
95
+ os.system('mv sd35_noise_model.pth ./weights/')
96
+ checkpoint = torch.load('./weights/sd35_noise_model.pth', map_location='cpu')
97
+ refine_model.load_state_dict(checkpoint)
98
+ elif model_name == 'sdxl':
99
+ refine_model = PromptSDXLNet()
100
+ replace_linear_with_lora(refine_model, rank=48, alpha=1.0, number_of_lora=50)
101
+ lora_true(refine_model, lora_idx=0)
102
+ os.makedirs('./weights', exist_ok=True)
103
+ if not os.path.exists('./weights/sdxl_noise_model.pth'):
104
+ os.system('wget https://huggingface.co/sst12345/CoRe2/resolve/main/weights/sdxl_noise_model.pth')
105
+ os.system('mv sdxl_noise_model.pth ./weights/')
106
+ checkpoint = torch.load('./weights/sdxl_noise_model.pth', map_location='cpu')
107
+ refine_model.load_state_dict(checkpoint)
108
+
109
+ print("Load Lora Success")
110
+ refine_model = refine_model.to(device)
111
+ refine_model = refine_model.to(torch.bfloat16)
112
+
113
+ # 根据模型类型设置形状
114
+ if model_name == 'sdxl':
115
+ shape = (1, 4, size // 8, size // 8)
116
+ else:
117
+ shape = (1, 16, size // 8, size // 8)
118
+
119
+ start_latents = torch.randn(shape, dtype=dtype).to(device)
120
+
121
+ # 根据方法选择生成图像
122
+ if model_name == 'sdxl':
123
+ if method == 'core':
124
+ output = pipe.core(
125
+ prompt=prompt,
126
+ guidance_scale=guidance_scale,
127
+ num_inference_steps=num_steps,
128
+ latents=start_latents,
129
+ return_dict=False,
130
+ refine_model=refine_model,
131
+ lora_true=lora_true,
132
+ end_timesteps=end_timesteps,
133
+ w2s_guidance=w2s_guidance)[0][0]
134
+ elif method == 'zigzag':
135
+ output = pipe.zigzag(
136
+ prompt=prompt,
137
+ guidance_scale=guidance_scale,
138
+ latents=start_latents,
139
+ return_dict=False,
140
+ num_inference_steps=num_steps,
141
+ inv_cfg=inv_cfg)[0][0]
142
+ elif method == 'z-core':
143
+ output = pipe.z_core(
144
+ prompt=prompt,
145
+ guidance_scale=guidance_scale,
146
+ num_inference_steps=num_steps,
147
+ latents=start_latents,
148
+ return_dict=False,
149
+ refine_model=refine_model,
150
+ lora_true=lora_true,
151
+ end_timesteps=end_timesteps,
152
+ w2s_guidance=w2s_guidance,
153
+ inv_cfg=inv_cfg)[0][0]
154
+ elif method == 'standard':
155
+ output = pipe(
156
+ prompt=prompt,
157
+ guidance_scale=guidance_scale,
158
+ latents=start_latents,
159
+ return_dict=False,
160
+ num_inference_steps=num_steps)[0][0]
161
+ else:
162
+ raise ValueError("Invalid method")
163
+ else:
164
+ if method == 'core':
165
+ output = pipe.core(
166
+ prompt=prompt,
167
+ guidance_scale=guidance_scale,
168
+ num_inference_steps=num_steps,
169
+ latents=start_latents,
170
+ max_sequence_length=512,
171
+ return_dict=False,
172
+ refine_model=refine_model,
173
+ lora_true=lora_true,
174
+ end_timesteps=end_timesteps,
175
+ w2s_guidance=w2s_guidance)[0][0]
176
+ elif method == 'zigzag':
177
+ output = pipe.zigzag(
178
+ prompt=prompt,
179
+ max_sequence_length=512,
180
+ guidance_scale=guidance_scale,
181
+ latents=start_latents,
182
+ return_dict=False,
183
+ num_inference_steps=num_steps,
184
+ inv_cfg=inv_cfg)[0][0]
185
+ elif method == 'z-core':
186
+ output = pipe.z_core(
187
+ prompt=prompt,
188
+ guidance_scale=guidance_scale,
189
+ num_inference_steps=num_steps,
190
+ latents=start_latents,
191
+ return_dict=False,
192
+ max_sequence_length=512,
193
+ refine_model=refine_model,
194
+ lora_true=lora_true,
195
+ end_timesteps=end_timesteps,
196
+ w2s_guidance=w2s_guidance)[0][0]
197
+ elif method == 'standard':
198
+ output = pipe(
199
+ prompt=prompt,
200
+ guidance_scale=guidance_scale,
201
+ latents=start_latents,
202
+ return_dict=False,
203
+ max_sequence_length=512,
204
+ num_inference_steps=num_steps)[0][0]
205
+ else:
206
+ raise ValueError("Invalid method")
207
+
208
+ # 将生成的图像保存为临时文件并返回
209
+ output_path = f'{model_name}_{method}.png'
210
+ output.save(output_path)
211
+ return output_path
212
+
213
+ except Exception as e:
214
+ print(f"An error occurred: {e}")
215
+ return None
216
+
217
+
218
+
219
+ if __name__ == '__main__':
220
+ # 创建Gradio接口
221
+ iface = gr.Interface(
222
+ fn=generate_image,
223
+ inputs=[
224
+ gr.Dropdown(choices=['sdxl', 'sd35'], value='sdxl', label="Model"), # 设置默认模型为 'sdxl'
225
+ gr.Slider(minimum=1, maximum=1000000, value=1, label="seed"), # 设置默认种子为 1
226
+ gr.Slider(minimum=1, maximum=100, value=50, label="Inference Steps"), # 设置默认推理步数为 50
227
+ gr.Slider(minimum=1, maximum=10, value=5.5, label="CFG"), # 设置默认CFG为 5.5
228
+ gr.Slider(minimum=-10, maximum=10, value=-1, label="Inverse CFG"), # 设置默认逆CFG为 -1
229
+ gr.Slider(minimum=1, maximum=3.5, value=2.5, label="W2S Guidance"), # 设置默认W2S指导为 2.5
230
+ gr.Slider(minimum=1, maximum=100, value=50, label="End Timesteps"), # 设置默认结束时间步为 50
231
+ gr.Textbox(label="Prompt"), # 文本框没有默认值
232
+ gr.Dropdown(choices=['standard', 'core', 'zigzag', 'z-core'], value='core', label="Method"), # 设置默认方法为 'core'
233
+ gr.Slider(minimum=1024, maximum=2048, value=1024, label="Size") # 设置默认大小为 1024
234
+ ],
235
+ outputs=gr.Image(type="filepath"), # 修改了type参数
236
+ title="Image Generation with CoRe^2"
237
+ )
238
+ iface.launch()
239
+
diffusion_pipeline/gemma.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, Gemma2ForTokenClassification, BitsAndBytesConfig
5
+
6
+ import os
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
+ torch.set_float32_matmul_precision("high")
9
+
10
+ def repeat_function(xs, max_length = 128):
11
+ new_xs = []
12
+ for x in xs:
13
+ if x.shape[1] >= max_length-1:
14
+ new_xs.append(x[:,:max_length-1,:])
15
+ else:
16
+ new_xs.append(x)
17
+ xs = new_xs
18
+ mean_xs = [x.mean(1,keepdim=True).expand(-1,max_length - x.shape[1],-1) for x in xs]
19
+ xs = [torch.cat([x,mean_x],1) for mean_x, x in zip(mean_xs, xs)]
20
+ return xs
21
+
22
+ class Gemma2Model(nn.Module):
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", )
26
+ self.tokenizer_max_length = 128
27
+ # quantization_config = BitsAndBytesConfig(load_in_8bit=True)
28
+
29
+ self.model = Gemma2ForTokenClassification.from_pretrained(
30
+ "google/gemma-2-2b",
31
+ # device_map="auto",
32
+ # quantization_config=quantization_config,
33
+ ).float()
34
+ self.model.score = nn.Identity()
35
+
36
+ @torch.no_grad()
37
+ def forward(self, input_prompt):
38
+ input_prompt = list(input_prompt)
39
+ outputs = []
40
+ for _input_prompt in input_prompt:
41
+ input_ids = self.tokenizer(_input_prompt, add_special_tokens=False, max_length=77, return_tensors="pt").to("cuda")
42
+ _outputs = self.model(**input_ids)["logits"]
43
+ outputs.append(_outputs)
44
+ outputs = repeat_function(outputs)
45
+ outputs = torch.cat(outputs,0)
46
+ return outputs
47
+
48
+ if __name__ == "__main__":
49
+ model = Gemma2Model().cuda()
50
+ input_text = ["Write me a poem about Machine Learning.", "Write me a poem about Deep Learning."]
51
+ print(model(input_text))
52
+ print(model(input_text)[0].shape)
53
+ print(model(input_text).shape)
diffusion_pipeline/lora.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class LoRALayer(torch.nn.Module):
5
+ def __init__(self, in_dim, out_dim, rank, alpha):
6
+ super().__init__()
7
+ std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
8
+ self.A = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)
9
+ self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
10
+ self.alpha = alpha
11
+
12
+ def forward(self, x):
13
+ x = self.alpha * (x @ self.A @ self.B)
14
+ return x
15
+
16
+ class LinearWithLoRA(torch.nn.Module):
17
+ def __init__(self, linear, rank, alpha,
18
+ weak_lora_alpha=0.1, number_of_lora=1):
19
+ super().__init__()
20
+ self.linear = linear
21
+ self.lora = nn.ModuleList([LoRALayer(
22
+ linear.in_features, linear.out_features, rank, alpha
23
+ ) for _ in range(number_of_lora)])
24
+ self.use_lora = True
25
+ self.lora_idx = 0
26
+
27
+ def forward(self, x):
28
+ if self.use_lora:
29
+ return self.linear(x) + self.lora[self.lora_idx](x)
30
+ else:
31
+ return self.linear(x)
32
+
33
+ def replace_linear_with_lora(module, rank=64, alpha=1., tag=0, weak_lora_alpha=0.1, number_of_lora=1):
34
+ for name, child in module.named_children():
35
+ if isinstance(child, nn.Linear):
36
+ setattr(module, name, LinearWithLoRA(child, rank, alpha, weak_lora_alpha=weak_lora_alpha, number_of_lora=number_of_lora))
37
+ else:
38
+ replace_linear_with_lora(child, rank, alpha, tag, weak_lora_alpha=weak_lora_alpha, number_of_lora=number_of_lora)
39
+
40
+
41
+ def lora_false(model, lora_idx=0):
42
+ for name, module in model.named_modules():
43
+ if isinstance(module, LinearWithLoRA):
44
+ module.use_lora = False
45
+ module.lora_idx = lora_idx
46
+
47
+ def lora_true(model, lora_idx=0):
48
+ for name, module in model.named_modules():
49
+ if isinstance(module, LinearWithLoRA):
50
+ module.use_lora = True
51
+ module.lora_idx = lora_idx
52
+ for i, lora in enumerate(module.lora):
53
+ if i != lora_idx:
54
+ lora.A.requires_grad = False
55
+ lora.B.requires_grad = False
56
+ if lora.A.grad is not None:
57
+ del lora.A.grad
58
+ if lora.B.grad is not None:
59
+ del lora.B.grad
60
+ else:
61
+ lora.A.requires_grad = True
62
+ lora.B.requires_grad = True
diffusion_pipeline/refine_model.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+ import json
6
+ import torch.nn.functional as F
7
+ import random
8
+ from torch.utils.data import Dataset
9
+ from transformers import AutoTokenizer
10
+ from glob import glob
11
+ import math
12
+ from PIL import Image
13
+ device = torch.device('cuda')
14
+ import numpy as np
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from diffusers.utils import logging
21
+ from diffusers.models.embeddings import PatchEmbed
22
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
23
+ from diffusers.models.attention import BasicTransformerBlock
24
+ from diffusers.models.normalization import AdaLayerNormContinuous
25
+ from torchvision import transforms
26
+
27
+ def add_hook_to_module(model, module_name):
28
+ outputs = []
29
+ def hook(module, input, output):
30
+ outputs.append(output)
31
+ module = dict(model.named_modules()).get(module_name)
32
+ if module is None:
33
+ raise ValueError(f"can't find module {module_name}")
34
+ hook_handle = module.register_forward_hook(hook)
35
+ return hook_handle, outputs
36
+
37
+ class PromptSD35Net(nn.Module):
38
+
39
+ def __init__(self,
40
+ sample_size: int = 128,
41
+ patch_size: int = 2,
42
+ in_channels: int = 16,
43
+ num_layers: int = 8,
44
+ attention_head_dim: int = 64,
45
+ num_attention_heads: int = 24,
46
+ out_channels: int = 16,
47
+ pos_embed_max_size: int = 192
48
+ ):
49
+ super().__init__()
50
+ self.sample_size = sample_size
51
+ self.patch_size = patch_size
52
+ self.in_channels = in_channels
53
+ self.num_layers = num_layers
54
+ self.attention_head_dim = attention_head_dim
55
+ self.num_attention_heads = num_attention_heads
56
+ self.out_channels = out_channels
57
+ self.pos_embed_max_size = pos_embed_max_size
58
+ self.inner_dim = self.num_attention_heads * self.attention_head_dim
59
+
60
+ self.pos_embed = PatchEmbed(
61
+ height=self.sample_size,
62
+ width=self.sample_size,
63
+ patch_size=self.patch_size,
64
+ in_channels=self.in_channels,
65
+ embed_dim=self.inner_dim,
66
+ pos_embed_max_size=pos_embed_max_size
67
+ )
68
+
69
+ self.transformer_blocks = nn.ModuleList(
70
+ [
71
+ BasicTransformerBlock(
72
+ dim=self.inner_dim,
73
+ num_attention_heads=self.num_attention_heads,
74
+ attention_head_dim=self.attention_head_dim,
75
+ ff_inner_dim=2*self.inner_dim # mult should be 4 by default
76
+ )
77
+ for i in range(self.num_layers)
78
+ ]
79
+ )
80
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
81
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
82
+
83
+ self.noise_shape = (1, 16, 128, 128) # (667, 4096)
84
+ self.pre8_linear = nn.Sequential(nn.Linear(4096, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
85
+ self.pre16_linear = nn.Sequential(nn.Linear(4096, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
86
+ self.pre24_linear = nn.Sequential(nn.Linear(4096, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
87
+
88
+ self.pre8_linear2 = nn.Sequential(nn.Linear(4096, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
89
+ self.pre16_linear2 = nn.Sequential(nn.Linear(4096, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
90
+ self.pre24_linear2 = nn.Sequential(nn.Linear(4096, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
91
+
92
+ self.last_linear = nn.Sequential(nn.Linear(4096, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
93
+ # self.last_linear2 = nn.Sequential(nn.Linear(667, 32))
94
+ self.skip_connection2 = nn.Linear(4096, 1, bias=False)
95
+ self.skip_connection = nn.Linear(667, 32, bias=False)
96
+ self.trans_linear = nn.Linear(666+1+4096, 1536, bias=False)
97
+ nn.init.constant_(self.skip_connection.weight.data, 0)
98
+ nn.init.constant_(self.trans_linear.weight.data, 0)
99
+ nn.init.constant_(self.trans_linear.weight.data, 0)
100
+ nn.init.constant_(self.pre8_linear[-1].weight.data, 0)
101
+ nn.init.constant_(self.pre16_linear[-1].weight.data, 0)
102
+ nn.init.constant_(self.pre24_linear[-1].weight.data, 0)
103
+ nn.init.constant_(self.pre8_linear2[-1].weight.data, 0)
104
+ nn.init.constant_(self.pre16_linear2[-1].weight.data, 0)
105
+ nn.init.constant_(self.pre24_linear2[-1].weight.data, 0)
106
+
107
+ def forward(self, noise: torch.Tensor, _s, _v, _d, _pool_embedding) -> torch.Tensor:
108
+
109
+ assert noise is not None
110
+ _ori_v = _v.clone()
111
+ _v = torch.stack([torch.diag(_v[jj]) for jj in range(_v.shape[0])], dim=0)
112
+ positive_embedding = _s.permute(0, 2, 1) @ _v @ _d # [2, 64, 666] [2, 64] [2, 64, 4096]
113
+ pool_embedding = _pool_embedding[:, None, :]
114
+ embedding = torch.cat([positive_embedding, pool_embedding], dim=1)
115
+ bs = noise.shape[0]
116
+ height, width = noise.shape[-2:]
117
+ embed_8 = embedding
118
+ embed_16 = embedding
119
+ embed_24 = embedding
120
+ scale_8 = self.pre8_linear2(embed_8).mean(1)
121
+ scale_16 = self.pre16_linear2(embed_16).mean(1)
122
+ scale_24 = self.pre24_linear2(embed_24).mean(1)
123
+ embed_8 = self.pre8_linear(embed_8).mean(1)
124
+ embed_16 = self.pre16_linear(embed_16).mean(1)
125
+ embed_24 = self.pre24_linear(embed_24).mean(1)
126
+ embed_last = self.last_linear(embedding).mean(1)
127
+ embed_trans = self.trans_linear(torch.cat([_s, _ori_v[...,None], _d], dim=2)).mean(1)
128
+ skip_embedding = self.skip_connection(self.skip_connection2(embedding).permute(0,2,1)).permute(0,2,1)
129
+ scale_skip, embed_skip = skip_embedding.chunk(2,dim=1)
130
+
131
+ ori_noise = noise * (scale_skip[...,None]) + embed_skip[...,None]
132
+ noise = self.pos_embed(noise)
133
+ noise = noise * (1 + scale_8[:, None, :] + embed_trans[:, None, :]) + embed_8[:, None, :]
134
+ scale_list = [scale_16, scale_24]
135
+ embed_list = [embed_16, embed_24]
136
+ for _ii, block in enumerate(self.transformer_blocks):
137
+ noise = block(noise)
138
+ if len(scale_list)!=0 and len(embed_list)!=0:
139
+ noise = noise * (1 + scale_list[int(_ii//4)][:, None, :] + embed_trans[:, None, :]) + embed_list[int(_ii//4)][:, None, :]
140
+
141
+ hidden_states = noise
142
+ hidden_states = self.norm_out(hidden_states, embed_last)
143
+ hidden_states = self.proj_out(hidden_states)
144
+
145
+ # unpatchify
146
+ patch_size = self.patch_size
147
+ height = height // patch_size
148
+ width = width // patch_size
149
+
150
+ hidden_states = hidden_states.reshape(
151
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
152
+ )
153
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
154
+ output = hidden_states.reshape(
155
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
156
+ )
157
+ return output + ori_noise
158
+
159
+ def weak_load_state_dict(self, state_dict: os.Mapping[str, torch.any], strict: bool = True, assign: bool = False):
160
+ return load_filtered_state_dict(self, state_dict)
161
+
162
+ class PromptSDXLNet(nn.Module):
163
+
164
+ def __init__(self,
165
+ sample_size: int = 128,
166
+ patch_size: int = 2,
167
+ in_channels: int = 4,
168
+ num_layers: int = 4,
169
+ attention_head_dim: int = 64,
170
+ num_attention_heads: int = 24,
171
+ out_channels: int = 4,
172
+ pos_embed_max_size: int = 192
173
+ ):
174
+ super().__init__()
175
+ self.sample_size = sample_size
176
+ self.patch_size = patch_size
177
+ self.in_channels = in_channels
178
+ self.num_layers = num_layers
179
+ self.attention_head_dim = attention_head_dim
180
+ self.num_attention_heads = num_attention_heads
181
+ self.out_channels = out_channels
182
+ self.pos_embed_max_size = pos_embed_max_size
183
+ self.inner_dim = self.num_attention_heads * self.attention_head_dim
184
+
185
+ self.pos_embed = PatchEmbed(
186
+ height=self.sample_size,
187
+ width=self.sample_size,
188
+ patch_size=self.patch_size,
189
+ in_channels=self.in_channels,
190
+ embed_dim=self.inner_dim,
191
+ pos_embed_max_size=pos_embed_max_size
192
+ )
193
+
194
+ self.transformer_blocks = nn.ModuleList(
195
+ [
196
+ BasicTransformerBlock(
197
+ dim=self.inner_dim,
198
+ num_attention_heads=self.num_attention_heads,
199
+ attention_head_dim=self.attention_head_dim,
200
+ ff_inner_dim=2*self.inner_dim # mult should be 4 by default
201
+ )
202
+ for i in range(self.num_layers)
203
+ ]
204
+ )
205
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
206
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
207
+
208
+ self.noise_shape = (1, 4, 128, 128)
209
+ self.pre8_linear = nn.Sequential(nn.Linear(2048, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
210
+ self.pre16_linear = nn.Sequential(nn.Linear(2048, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
211
+ self.pre24_linear = nn.Sequential(nn.Linear(2048, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
212
+
213
+ self.pre8_linear2 = nn.Sequential(nn.Linear(2048, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
214
+ self.pre16_linear2 = nn.Sequential(nn.Linear(2048, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
215
+ self.pre24_linear2 = nn.Sequential(nn.Linear(2048, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
216
+
217
+ self.last_linear = nn.Sequential(nn.Linear(2048, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
218
+ # self.last_linear2 = nn.Sequential(nn.Linear(667, 32))
219
+ self.skip_connection2 = nn.Linear(2048, 1, bias=False)
220
+ self.skip_connection = nn.Linear(154+1, 8, bias=False)
221
+ self.trans_linear = nn.Linear(154+1+2048, 1536, bias=False)
222
+ self.pool_prompt_linear = nn.Linear(2560, 2048, bias=False)
223
+ nn.init.constant_(self.skip_connection.weight.data, 0)
224
+ nn.init.constant_(self.trans_linear.weight.data, 0)
225
+ nn.init.constant_(self.trans_linear.weight.data, 0)
226
+ nn.init.constant_(self.pre8_linear[-1].weight.data, 0)
227
+ nn.init.constant_(self.pre16_linear[-1].weight.data, 0)
228
+ nn.init.constant_(self.pre24_linear[-1].weight.data, 0)
229
+ nn.init.constant_(self.pre8_linear2[-1].weight.data, 0)
230
+ nn.init.constant_(self.pre16_linear2[-1].weight.data, 0)
231
+ nn.init.constant_(self.pre24_linear2[-1].weight.data, 0)
232
+
233
+ def forward(self, noise: torch.Tensor, _s, _v, _d, _pool_embedding) -> torch.Tensor:
234
+
235
+ assert noise is not None
236
+ _ori_v = _v.clone()
237
+ _v = torch.stack([torch.diag(_v[jj]) for jj in range(_v.shape[0])], dim=0)
238
+ positive_embedding = _s.permute(0, 2, 1) @ _v @ _d # [2, 64, 154] [2, 64] [2, 64, 2048]
239
+ pool_embedding = self.pool_prompt_linear(_pool_embedding[:, None, :])
240
+ embedding = torch.cat([positive_embedding, pool_embedding], dim=1)
241
+ bs = noise.shape[0]
242
+ height, width = noise.shape[-2:]
243
+ embed_8 = embedding
244
+ embed_16 = embedding
245
+ embed_24 = embedding
246
+ scale_8 = self.pre8_linear2(embed_8).mean(1)
247
+ scale_16 = self.pre16_linear2(embed_16).mean(1)
248
+ scale_24 = self.pre24_linear2(embed_24).mean(1)
249
+ embed_8 = self.pre8_linear(embed_8).mean(1)
250
+ embed_16 = self.pre16_linear(embed_16).mean(1)
251
+ embed_24 = self.pre24_linear(embed_24).mean(1)
252
+ embed_last = self.last_linear(embedding).mean(1)
253
+ embed_trans = self.trans_linear(torch.cat([_s, _ori_v[...,None], _d], dim=2)).mean(1)
254
+ skip_embedding = self.skip_connection(self.skip_connection2(embedding).permute(0,2,1)).permute(0,2,1)
255
+ scale_skip, embed_skip = skip_embedding.chunk(2,dim=1)
256
+
257
+ ori_noise = noise * (scale_skip[...,None]) + embed_skip[...,None]
258
+ noise = self.pos_embed(noise)
259
+ noise = noise * (1 + scale_8[:, None, :] + embed_trans[:, None, :]) + embed_8[:, None, :]
260
+ scale_list = [scale_16, scale_24]
261
+ embed_list = [embed_16, embed_24]
262
+ for _ii, block in enumerate(self.transformer_blocks):
263
+ noise = block(noise)
264
+ if len(scale_list)!=0 and len(embed_list)!=0:
265
+ noise = noise * (1 + scale_list[int(_ii//4)][:, None, :] + embed_trans[:, None, :]) + embed_list[int(_ii//4)][:, None, :]
266
+
267
+ hidden_states = noise
268
+ hidden_states = self.norm_out(hidden_states, embed_last)
269
+ hidden_states = self.proj_out(hidden_states)
270
+
271
+ # unpatchify
272
+ patch_size = self.patch_size
273
+ height = height // patch_size
274
+ width = width // patch_size
275
+
276
+ hidden_states = hidden_states.reshape(
277
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
278
+ )
279
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
280
+ output = hidden_states.reshape(
281
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
282
+ )
283
+ return output + ori_noise
284
+
285
+ def weak_load_state_dict(self, state_dict: os.Mapping[str, torch.any], strict: bool = True, assign: bool = False):
286
+ return load_filtered_state_dict(self, state_dict)
287
+
288
+
289
+
290
+ def load_filtered_state_dict(model, state_dict):
291
+ model_state_dict = model.state_dict()
292
+ filtered_state_dict = {}
293
+ for k, v in state_dict.items():
294
+ if k in model_state_dict:
295
+ if model_state_dict[k].size() == v.size():
296
+ filtered_state_dict[k] = v
297
+ else:
298
+ print(f"Skipping {k}: shape mismatch ({model_state_dict[k].size()} vs {v.size()})")
299
+ else:
300
+ print(f"Skipping {k}: not found in model's state_dict.")
301
+ model.load_state_dict(filtered_state_dict, strict=False)
302
+ return model
303
+
304
+ def custom_collate_fn_2_0(batch):
305
+ noise_pred_texts, prompts, noise_preds, max_scores = zip(*batch)
306
+
307
+ noise_pred_texts = torch.stack(noise_pred_texts)
308
+ noise_preds = torch.stack(noise_preds)
309
+ max_scores = torch.stack(max_scores)
310
+
311
+ return noise_pred_texts, prompts, noise_preds, max_scores
312
+
diffusion_pipeline/sd35_pipeline.py ADDED
The diff for this file is too large to render. See raw diff
 
diffusion_pipeline/sdxl_pipeline.py ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ transformers
3
+ einops
4
+ wandb
5
+ accelerate
6
+ pandas
7
+ imageio
8
+ gradio
9
+ imageio-ffmpeg
10
+ omegaconf
11
+ spaces
12
+ torch==2.4 --index-url https://download.pytorch.org/whl/cu124
13
+ torchaudio==2.4 --index-url https://download.pytorch.org/whl/cu124
14
+ torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124