wuwenxu.01 commited on
Commit
def2fd8
·
1 Parent(s): e8e76e7

feat: filter move app code from github

Browse files
.gitignore ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+
176
+ # User config files
177
+ .vscode/
178
+ output/
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: UNO FLUX
3
- emoji: 📊
4
  colorFrom: indigo
5
  colorTo: yellow
6
  sdk: gradio
@@ -9,6 +9,22 @@ app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-4.0
11
  short_description: Generate customized images using text and multiple images
 
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: UNO FLUX
3
+ emoji: ⚡️
4
  colorFrom: indigo
5
  colorTo: yellow
6
  sdk: gradio
 
9
  pinned: false
10
  license: cc-by-nc-4.0
11
  short_description: Generate customized images using text and multiple images
12
+ models:
13
+ - black-forest-labs/FLUX.1-dev
14
+ - bytedance-research/UNO
15
  ---
16
 
17
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
18
+
19
+ ## 📄 Disclaimer
20
+
21
+ We open-source this project for academic research. The vast majority of images
22
+ used in this project are either generated or licensed. If you have any concerns,
23
+ please contact us, and we will promptly remove any inappropriate content.
24
+ Our code is released under the Apache 2.0 License,, while our models are under
25
+ the CC BY-NC 4.0 License. Any models related to [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
26
+ base model must adhere to the original licensing terms.
27
+ This research aims to advance the field of generative AI. Users are free to
28
+ create images using this tool, provided they comply with local laws and exercise
29
+ responsible usage. The developers are not liable for any misuse of the tool by users.
30
+
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+
17
+ import gradio as gr
18
+ import torch
19
+
20
+ from uno.flux.pipeline import UNOPipeline
21
+
22
+
23
+ def create_demo(
24
+ model_type: str,
25
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
26
+ offload: bool = False,
27
+ ):
28
+ pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
29
+
30
+ with gr.Blocks() as demo:
31
+ gr.Markdown(f"# UNO by UNO team")
32
+ with gr.Row():
33
+ with gr.Column():
34
+ prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
35
+ with gr.Row():
36
+ image_prompt1 = gr.Image(label="ref img1", visible=True, interactive=True, type="pil")
37
+ image_prompt2 = gr.Image(label="ref img2", visible=True, interactive=True, type="pil")
38
+ image_prompt3 = gr.Image(label="ref img3", visible=True, interactive=True, type="pil")
39
+ image_prompt4 = gr.Image(label="ref img4", visible=True, interactive=True, type="pil")
40
+
41
+ with gr.Row():
42
+ with gr.Column():
43
+ ref_long_side = gr.Slider(128, 512, 512, step=16, label="Long side of Ref Images")
44
+ with gr.Column():
45
+ gr.Markdown("📌 **The recommended ref scale** is related to the ref img number.\n")
46
+ gr.Markdown(" 1->512 / 2->320 / 3...n->256")
47
+
48
+ with gr.Row():
49
+ with gr.Column():
50
+ width = gr.Slider(512, 2048, 512, step=16, label="Gneration Width")
51
+ height = gr.Slider(512, 2048, 512, step=16, label="Gneration Height")
52
+ with gr.Column():
53
+ gr.Markdown("📌 The model trained on 512x512 resolution.\n")
54
+ gr.Markdown(
55
+ "The size closer to 512 is more stable,"
56
+ " and the higher size gives a better visual effect but is less stable"
57
+ )
58
+
59
+ with gr.Accordion("Generation Options", open=False):
60
+ with gr.Row():
61
+ num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
62
+ guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
63
+ seed = gr.Number(-1, label="Seed (-1 for random)")
64
+
65
+ generate_btn = gr.Button("Generate")
66
+
67
+ with gr.Column():
68
+ output_image = gr.Image(label="Generated Image")
69
+ download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
70
+
71
+
72
+ inputs = [
73
+ prompt, width, height, guidance, num_steps,
74
+ seed, ref_long_side, image_prompt1, image_prompt2, image_prompt3, image_prompt4
75
+ ]
76
+ generate_btn.click(
77
+ fn=pipeline.gradio_generate,
78
+ inputs=inputs,
79
+ outputs=[output_image, download_btn],
80
+ )
81
+
82
+ return demo
83
+
84
+ if __name__ == "__main__":
85
+ from typing import Literal
86
+
87
+ from transformers import HfArgumentParser
88
+
89
+ @dataclasses.dataclass
90
+ class AppArgs:
91
+ name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
92
+ device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu"
93
+ offload: bool = dataclasses.field(
94
+ default=False,
95
+ metadata={"help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."}
96
+ )
97
+ port: int = 7860
98
+
99
+ parser = HfArgumentParser([AppArgs])
100
+ args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs]
101
+ args = args_tuple[0]
102
+
103
+ demo = create_demo(args.name, args.device, args.offload)
104
+ demo.launch(server_port=args.port)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.8.0
2
+ transformers==4.43.3
3
+ huggingface-hub
4
+ diffusers==0.30.1
5
+ sentencepiece==0.2.0
6
+ gradio==5.22.0
7
+
8
+ --extra-index-url https://download.pytorch.org/whl/cu124
9
+ torch==2.4.0
10
+ torchvision==0.19.0
uno/dataset/uno.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import os
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torchvision.transforms.functional as TVF
22
+ from torch.utils.data import DataLoader, Dataset
23
+ from torchvision.transforms import Compose, Normalize, ToTensor
24
+
25
+ def bucket_images(images: list[torch.Tensor], resolution: int = 512):
26
+ bucket_override=[
27
+ # h w
28
+ (256, 768),
29
+ (320, 768),
30
+ (320, 704),
31
+ (384, 640),
32
+ (448, 576),
33
+ (512, 512),
34
+ (576, 448),
35
+ (640, 384),
36
+ (704, 320),
37
+ (768, 320),
38
+ (768, 256)
39
+ ]
40
+ bucket_override = [(int(h / 512 * resolution), int(w / 512 * resolution)) for h, w in bucket_override]
41
+ bucket_override = [(h // 16 * 16, w // 16 * 16) for h, w in bucket_override]
42
+
43
+ aspect_ratios = [image.shape[-2] / image.shape[-1] for image in images]
44
+ mean_aspect_ratio = np.mean(aspect_ratios)
45
+
46
+ new_h, new_w = bucket_override[0]
47
+ min_aspect_diff = np.abs(new_h / new_w - mean_aspect_ratio)
48
+ for h, w in bucket_override:
49
+ aspect_diff = np.abs(h / w - mean_aspect_ratio)
50
+ if aspect_diff < min_aspect_diff:
51
+ min_aspect_diff = aspect_diff
52
+ new_h, new_w = h, w
53
+
54
+ images = [TVF.resize(image, (new_h, new_w)) for image in images]
55
+ images = torch.stack(images, dim=0)
56
+ return images
57
+
58
+ class FluxPairedDatasetV2(Dataset):
59
+ def __init__(self, json_file: str, resolution: int, resolution_ref: int | None = None):
60
+ super().__init__()
61
+ self.json_file = json_file
62
+ self.resolution = resolution
63
+ self.resolution_ref = resolution_ref if resolution_ref is not None else resolution
64
+ self.image_root = os.path.dirname(json_file)
65
+
66
+ with open(self.json_file, "rt") as f:
67
+ self.data_dicts = json.load(f)
68
+
69
+ self.transform = Compose([
70
+ ToTensor(),
71
+ Normalize([0.5], [0.5]),
72
+ ])
73
+
74
+ def __getitem__(self, idx):
75
+ data_dict = self.data_dicts[idx]
76
+ image_paths = [data_dict["image_path"]] if "image_path" in data_dict else data_dict["image_paths"]
77
+ txt = data_dict["prompt"]
78
+ image_tgt_path = data_dict.get("image_tgt_path", None)
79
+ ref_imgs = [
80
+ Image.open(os.path.join(self.image_root, path)).convert("RGB")
81
+ for path in image_paths
82
+ ]
83
+ ref_imgs = [self.transform(img) for img in ref_imgs]
84
+ img = None
85
+ if image_tgt_path is not None:
86
+ img = Image.open(os.path.join(self.image_root, image_tgt_path)).convert("RGB")
87
+ img = self.transform(img)
88
+
89
+ return {
90
+ "img": img,
91
+ "txt": txt,
92
+ "ref_imgs": ref_imgs,
93
+ }
94
+
95
+ def __len__(self):
96
+ return len(self.data_dicts)
97
+
98
+ def collate_fn(self, batch):
99
+ img = [data["img"] for data in batch]
100
+ txt = [data["txt"] for data in batch]
101
+ ref_imgs = [data["ref_imgs"] for data in batch]
102
+ assert all([len(ref_imgs[0]) == len(ref_imgs[i]) for i in range(len(ref_imgs))])
103
+
104
+ n_ref = len(ref_imgs[0])
105
+
106
+ img = bucket_images(img, self.resolution)
107
+ ref_imgs_new = []
108
+ for i in range(n_ref):
109
+ ref_imgs_i = [refs[i] for refs in ref_imgs]
110
+ ref_imgs_i = bucket_images(ref_imgs_i, self.resolution_ref)
111
+ ref_imgs_new.append(ref_imgs_i)
112
+
113
+ return {
114
+ "txt": txt,
115
+ "img": img,
116
+ "ref_imgs": ref_imgs_new,
117
+ }
118
+
119
+ if __name__ == '__main__':
120
+ import argparse
121
+ from pprint import pprint
122
+ parser = argparse.ArgumentParser()
123
+ # parser.add_argument("--json_file", type=str, required=True)
124
+ parser.add_argument("--json_file", type=str, default="datasets/fake_train_data.json")
125
+ args = parser.parse_args()
126
+ dataset = FluxPairedDatasetV2(args.json_file, 512)
127
+ dataloder = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn)
128
+
129
+ for i, data_dict in enumerate(dataloder):
130
+ pprint(i)
131
+ pprint(data_dict)
132
+ breakpoint()
uno/flux/math.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from einops import rearrange
18
+ from torch import Tensor
19
+
20
+
21
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
22
+ q, k = apply_rope(q, k, pe)
23
+
24
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
25
+ x = rearrange(x, "B H L D -> B L (H D)")
26
+
27
+ return x
28
+
29
+
30
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
31
+ assert dim % 2 == 0
32
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
33
+ omega = 1.0 / (theta**scale)
34
+ out = torch.einsum("...n,d->...nd", pos, omega)
35
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
36
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
37
+ return out.float()
38
+
39
+
40
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
41
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
42
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
43
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
44
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
45
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
uno/flux/model.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ from torch import Tensor, nn
20
+
21
+ from .modules.layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding
22
+
23
+
24
+ @dataclass
25
+ class FluxParams:
26
+ in_channels: int
27
+ vec_in_dim: int
28
+ context_in_dim: int
29
+ hidden_size: int
30
+ mlp_ratio: float
31
+ num_heads: int
32
+ depth: int
33
+ depth_single_blocks: int
34
+ axes_dim: list[int]
35
+ theta: int
36
+ qkv_bias: bool
37
+ guidance_embed: bool
38
+
39
+
40
+ class Flux(nn.Module):
41
+ """
42
+ Transformer model for flow matching on sequences.
43
+ """
44
+ _supports_gradient_checkpointing = True
45
+
46
+ def __init__(self, params: FluxParams):
47
+ super().__init__()
48
+
49
+ self.params = params
50
+ self.in_channels = params.in_channels
51
+ self.out_channels = self.in_channels
52
+ if params.hidden_size % params.num_heads != 0:
53
+ raise ValueError(
54
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
55
+ )
56
+ pe_dim = params.hidden_size // params.num_heads
57
+ if sum(params.axes_dim) != pe_dim:
58
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
59
+ self.hidden_size = params.hidden_size
60
+ self.num_heads = params.num_heads
61
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
62
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
63
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
64
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
65
+ self.guidance_in = (
66
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
67
+ )
68
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
69
+
70
+ self.double_blocks = nn.ModuleList(
71
+ [
72
+ DoubleStreamBlock(
73
+ self.hidden_size,
74
+ self.num_heads,
75
+ mlp_ratio=params.mlp_ratio,
76
+ qkv_bias=params.qkv_bias,
77
+ )
78
+ for _ in range(params.depth)
79
+ ]
80
+ )
81
+
82
+ self.single_blocks = nn.ModuleList(
83
+ [
84
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
85
+ for _ in range(params.depth_single_blocks)
86
+ ]
87
+ )
88
+
89
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
90
+ self.gradient_checkpointing = False
91
+
92
+ def _set_gradient_checkpointing(self, module, value=False):
93
+ if hasattr(module, "gradient_checkpointing"):
94
+ module.gradient_checkpointing = value
95
+
96
+ @property
97
+ def attn_processors(self):
98
+ # set recursively
99
+ processors = {} # type: dict[str, nn.Module]
100
+
101
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
102
+ if hasattr(module, "set_processor"):
103
+ processors[f"{name}.processor"] = module.processor
104
+
105
+ for sub_name, child in module.named_children():
106
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
107
+
108
+ return processors
109
+
110
+ for name, module in self.named_children():
111
+ fn_recursive_add_processors(name, module, processors)
112
+
113
+ return processors
114
+
115
+ def set_attn_processor(self, processor):
116
+ r"""
117
+ Sets the attention processor to use to compute attention.
118
+
119
+ Parameters:
120
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
121
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
122
+ for **all** `Attention` layers.
123
+
124
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
125
+ processor. This is strongly recommended when setting trainable attention processors.
126
+
127
+ """
128
+ count = len(self.attn_processors.keys())
129
+
130
+ if isinstance(processor, dict) and len(processor) != count:
131
+ raise ValueError(
132
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
133
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
134
+ )
135
+
136
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
137
+ if hasattr(module, "set_processor"):
138
+ if not isinstance(processor, dict):
139
+ module.set_processor(processor)
140
+ else:
141
+ module.set_processor(processor.pop(f"{name}.processor"))
142
+
143
+ for sub_name, child in module.named_children():
144
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
145
+
146
+ for name, module in self.named_children():
147
+ fn_recursive_attn_processor(name, module, processor)
148
+
149
+ def forward(
150
+ self,
151
+ img: Tensor,
152
+ img_ids: Tensor,
153
+ txt: Tensor,
154
+ txt_ids: Tensor,
155
+ timesteps: Tensor,
156
+ y: Tensor,
157
+ guidance: Tensor | None = None,
158
+ ref_img: Tensor | None = None,
159
+ ref_img_ids: Tensor | None = None,
160
+ ) -> Tensor:
161
+ if img.ndim != 3 or txt.ndim != 3:
162
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
163
+
164
+ # running on sequences img
165
+ img = self.img_in(img)
166
+ vec = self.time_in(timestep_embedding(timesteps, 256))
167
+ if self.params.guidance_embed:
168
+ if guidance is None:
169
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
170
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
171
+ vec = vec + self.vector_in(y)
172
+ txt = self.txt_in(txt)
173
+
174
+ ids = torch.cat((txt_ids, img_ids), dim=1)
175
+
176
+ # concat ref_img/img
177
+ img_end = img.shape[1]
178
+ if ref_img is not None:
179
+ if isinstance(ref_img, tuple) or isinstance(ref_img, list):
180
+ img_in = [img] + [self.img_in(ref) for ref in ref_img]
181
+ img_ids = [ids] + [ref_ids for ref_ids in ref_img_ids]
182
+ img = torch.cat(img_in, dim=1)
183
+ ids = torch.cat(img_ids, dim=1)
184
+ else:
185
+ img = torch.cat((img, self.img_in(ref_img)), dim=1)
186
+ ids = torch.cat((ids, ref_img_ids), dim=1)
187
+ pe = self.pe_embedder(ids)
188
+
189
+ for index_block, block in enumerate(self.double_blocks):
190
+ if self.training and self.gradient_checkpointing:
191
+ img, txt = torch.utils.checkpoint.checkpoint(
192
+ block,
193
+ img=img,
194
+ txt=txt,
195
+ vec=vec,
196
+ pe=pe,
197
+ use_reentrant=False,
198
+ )
199
+ else:
200
+ img, txt = block(
201
+ img=img,
202
+ txt=txt,
203
+ vec=vec,
204
+ pe=pe
205
+ )
206
+
207
+ img = torch.cat((txt, img), 1)
208
+ for block in self.single_blocks:
209
+ if self.training and self.gradient_checkpointing:
210
+ img = torch.utils.checkpoint.checkpoint(
211
+ block,
212
+ img, vec=vec, pe=pe,
213
+ use_reentrant=False
214
+ )
215
+ else:
216
+ img = block(img, vec=vec, pe=pe)
217
+ img = img[:, txt.shape[1] :, ...]
218
+ # index img
219
+ img = img[:, :img_end, ...]
220
+
221
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
222
+ return img
uno/flux/modules/autoencoder.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ from einops import rearrange
20
+ from torch import Tensor, nn
21
+
22
+
23
+ @dataclass
24
+ class AutoEncoderParams:
25
+ resolution: int
26
+ in_channels: int
27
+ ch: int
28
+ out_ch: int
29
+ ch_mult: list[int]
30
+ num_res_blocks: int
31
+ z_channels: int
32
+ scale_factor: float
33
+ shift_factor: float
34
+
35
+
36
+ def swish(x: Tensor) -> Tensor:
37
+ return x * torch.sigmoid(x)
38
+
39
+
40
+ class AttnBlock(nn.Module):
41
+ def __init__(self, in_channels: int):
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+
45
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
46
+
47
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
48
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
49
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
50
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
51
+
52
+ def attention(self, h_: Tensor) -> Tensor:
53
+ h_ = self.norm(h_)
54
+ q = self.q(h_)
55
+ k = self.k(h_)
56
+ v = self.v(h_)
57
+
58
+ b, c, h, w = q.shape
59
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
60
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
61
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
62
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
63
+
64
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
65
+
66
+ def forward(self, x: Tensor) -> Tensor:
67
+ return x + self.proj_out(self.attention(x))
68
+
69
+
70
+ class ResnetBlock(nn.Module):
71
+ def __init__(self, in_channels: int, out_channels: int):
72
+ super().__init__()
73
+ self.in_channels = in_channels
74
+ out_channels = in_channels if out_channels is None else out_channels
75
+ self.out_channels = out_channels
76
+
77
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
79
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
80
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
81
+ if self.in_channels != self.out_channels:
82
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
83
+
84
+ def forward(self, x):
85
+ h = x
86
+ h = self.norm1(h)
87
+ h = swish(h)
88
+ h = self.conv1(h)
89
+
90
+ h = self.norm2(h)
91
+ h = swish(h)
92
+ h = self.conv2(h)
93
+
94
+ if self.in_channels != self.out_channels:
95
+ x = self.nin_shortcut(x)
96
+
97
+ return x + h
98
+
99
+
100
+ class Downsample(nn.Module):
101
+ def __init__(self, in_channels: int):
102
+ super().__init__()
103
+ # no asymmetric padding in torch conv, must do it ourselves
104
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
105
+
106
+ def forward(self, x: Tensor):
107
+ pad = (0, 1, 0, 1)
108
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
109
+ x = self.conv(x)
110
+ return x
111
+
112
+
113
+ class Upsample(nn.Module):
114
+ def __init__(self, in_channels: int):
115
+ super().__init__()
116
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
117
+
118
+ def forward(self, x: Tensor):
119
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
120
+ x = self.conv(x)
121
+ return x
122
+
123
+
124
+ class Encoder(nn.Module):
125
+ def __init__(
126
+ self,
127
+ resolution: int,
128
+ in_channels: int,
129
+ ch: int,
130
+ ch_mult: list[int],
131
+ num_res_blocks: int,
132
+ z_channels: int,
133
+ ):
134
+ super().__init__()
135
+ self.ch = ch
136
+ self.num_resolutions = len(ch_mult)
137
+ self.num_res_blocks = num_res_blocks
138
+ self.resolution = resolution
139
+ self.in_channels = in_channels
140
+ # downsampling
141
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
142
+
143
+ curr_res = resolution
144
+ in_ch_mult = (1,) + tuple(ch_mult)
145
+ self.in_ch_mult = in_ch_mult
146
+ self.down = nn.ModuleList()
147
+ block_in = self.ch
148
+ for i_level in range(self.num_resolutions):
149
+ block = nn.ModuleList()
150
+ attn = nn.ModuleList()
151
+ block_in = ch * in_ch_mult[i_level]
152
+ block_out = ch * ch_mult[i_level]
153
+ for _ in range(self.num_res_blocks):
154
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
155
+ block_in = block_out
156
+ down = nn.Module()
157
+ down.block = block
158
+ down.attn = attn
159
+ if i_level != self.num_resolutions - 1:
160
+ down.downsample = Downsample(block_in)
161
+ curr_res = curr_res // 2
162
+ self.down.append(down)
163
+
164
+ # middle
165
+ self.mid = nn.Module()
166
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
167
+ self.mid.attn_1 = AttnBlock(block_in)
168
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
169
+
170
+ # end
171
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
172
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
173
+
174
+ def forward(self, x: Tensor) -> Tensor:
175
+ # downsampling
176
+ hs = [self.conv_in(x)]
177
+ for i_level in range(self.num_resolutions):
178
+ for i_block in range(self.num_res_blocks):
179
+ h = self.down[i_level].block[i_block](hs[-1])
180
+ if len(self.down[i_level].attn) > 0:
181
+ h = self.down[i_level].attn[i_block](h)
182
+ hs.append(h)
183
+ if i_level != self.num_resolutions - 1:
184
+ hs.append(self.down[i_level].downsample(hs[-1]))
185
+
186
+ # middle
187
+ h = hs[-1]
188
+ h = self.mid.block_1(h)
189
+ h = self.mid.attn_1(h)
190
+ h = self.mid.block_2(h)
191
+ # end
192
+ h = self.norm_out(h)
193
+ h = swish(h)
194
+ h = self.conv_out(h)
195
+ return h
196
+
197
+
198
+ class Decoder(nn.Module):
199
+ def __init__(
200
+ self,
201
+ ch: int,
202
+ out_ch: int,
203
+ ch_mult: list[int],
204
+ num_res_blocks: int,
205
+ in_channels: int,
206
+ resolution: int,
207
+ z_channels: int,
208
+ ):
209
+ super().__init__()
210
+ self.ch = ch
211
+ self.num_resolutions = len(ch_mult)
212
+ self.num_res_blocks = num_res_blocks
213
+ self.resolution = resolution
214
+ self.in_channels = in_channels
215
+ self.ffactor = 2 ** (self.num_resolutions - 1)
216
+
217
+ # compute in_ch_mult, block_in and curr_res at lowest res
218
+ block_in = ch * ch_mult[self.num_resolutions - 1]
219
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
220
+ self.z_shape = (1, z_channels, curr_res, curr_res)
221
+
222
+ # z to block_in
223
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
224
+
225
+ # middle
226
+ self.mid = nn.Module()
227
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
228
+ self.mid.attn_1 = AttnBlock(block_in)
229
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
230
+
231
+ # upsampling
232
+ self.up = nn.ModuleList()
233
+ for i_level in reversed(range(self.num_resolutions)):
234
+ block = nn.ModuleList()
235
+ attn = nn.ModuleList()
236
+ block_out = ch * ch_mult[i_level]
237
+ for _ in range(self.num_res_blocks + 1):
238
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
239
+ block_in = block_out
240
+ up = nn.Module()
241
+ up.block = block
242
+ up.attn = attn
243
+ if i_level != 0:
244
+ up.upsample = Upsample(block_in)
245
+ curr_res = curr_res * 2
246
+ self.up.insert(0, up) # prepend to get consistent order
247
+
248
+ # end
249
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
250
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
251
+
252
+ def forward(self, z: Tensor) -> Tensor:
253
+ # z to block_in
254
+ h = self.conv_in(z)
255
+
256
+ # middle
257
+ h = self.mid.block_1(h)
258
+ h = self.mid.attn_1(h)
259
+ h = self.mid.block_2(h)
260
+
261
+ # upsampling
262
+ for i_level in reversed(range(self.num_resolutions)):
263
+ for i_block in range(self.num_res_blocks + 1):
264
+ h = self.up[i_level].block[i_block](h)
265
+ if len(self.up[i_level].attn) > 0:
266
+ h = self.up[i_level].attn[i_block](h)
267
+ if i_level != 0:
268
+ h = self.up[i_level].upsample(h)
269
+
270
+ # end
271
+ h = self.norm_out(h)
272
+ h = swish(h)
273
+ h = self.conv_out(h)
274
+ return h
275
+
276
+
277
+ class DiagonalGaussian(nn.Module):
278
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
279
+ super().__init__()
280
+ self.sample = sample
281
+ self.chunk_dim = chunk_dim
282
+
283
+ def forward(self, z: Tensor) -> Tensor:
284
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
285
+ if self.sample:
286
+ std = torch.exp(0.5 * logvar)
287
+ return mean + std * torch.randn_like(mean)
288
+ else:
289
+ return mean
290
+
291
+
292
+ class AutoEncoder(nn.Module):
293
+ def __init__(self, params: AutoEncoderParams):
294
+ super().__init__()
295
+ self.encoder = Encoder(
296
+ resolution=params.resolution,
297
+ in_channels=params.in_channels,
298
+ ch=params.ch,
299
+ ch_mult=params.ch_mult,
300
+ num_res_blocks=params.num_res_blocks,
301
+ z_channels=params.z_channels,
302
+ )
303
+ self.decoder = Decoder(
304
+ resolution=params.resolution,
305
+ in_channels=params.in_channels,
306
+ ch=params.ch,
307
+ out_ch=params.out_ch,
308
+ ch_mult=params.ch_mult,
309
+ num_res_blocks=params.num_res_blocks,
310
+ z_channels=params.z_channels,
311
+ )
312
+ self.reg = DiagonalGaussian()
313
+
314
+ self.scale_factor = params.scale_factor
315
+ self.shift_factor = params.shift_factor
316
+
317
+ def encode(self, x: Tensor) -> Tensor:
318
+ z = self.reg(self.encoder(x))
319
+ z = self.scale_factor * (z - self.shift_factor)
320
+ return z
321
+
322
+ def decode(self, z: Tensor) -> Tensor:
323
+ z = z / self.scale_factor + self.shift_factor
324
+ return self.decoder(z)
325
+
326
+ def forward(self, x: Tensor) -> Tensor:
327
+ return self.decode(self.encode(x))
uno/flux/modules/conditioner.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from torch import Tensor, nn
17
+ from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
18
+ T5Tokenizer)
19
+
20
+
21
+ class HFEmbedder(nn.Module):
22
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
23
+ super().__init__()
24
+ self.is_clip = version.startswith("openai")
25
+ self.max_length = max_length
26
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
27
+
28
+ if self.is_clip:
29
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
30
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
31
+ else:
32
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
33
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
34
+
35
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
36
+
37
+ def forward(self, text: list[str]) -> Tensor:
38
+ batch_encoding = self.tokenizer(
39
+ text,
40
+ truncation=True,
41
+ max_length=self.max_length,
42
+ return_length=False,
43
+ return_overflowing_tokens=False,
44
+ padding="max_length",
45
+ return_tensors="pt",
46
+ )
47
+
48
+ outputs = self.hf_module(
49
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
50
+ attention_mask=None,
51
+ output_hidden_states=False,
52
+ )
53
+ return outputs[self.output_key]
uno/flux/modules/layers.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from dataclasses import dataclass
18
+
19
+ import torch
20
+ from einops import rearrange
21
+ from torch import Tensor, nn
22
+
23
+ from ..math import attention, rope
24
+ import torch.nn.functional as F
25
+
26
+ class EmbedND(nn.Module):
27
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
28
+ super().__init__()
29
+ self.dim = dim
30
+ self.theta = theta
31
+ self.axes_dim = axes_dim
32
+
33
+ def forward(self, ids: Tensor) -> Tensor:
34
+ n_axes = ids.shape[-1]
35
+ emb = torch.cat(
36
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
37
+ dim=-3,
38
+ )
39
+
40
+ return emb.unsqueeze(1)
41
+
42
+
43
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
44
+ """
45
+ Create sinusoidal timestep embeddings.
46
+ :param t: a 1-D Tensor of N indices, one per batch element.
47
+ These may be fractional.
48
+ :param dim: the dimension of the output.
49
+ :param max_period: controls the minimum frequency of the embeddings.
50
+ :return: an (N, D) Tensor of positional embeddings.
51
+ """
52
+ t = time_factor * t
53
+ half = dim // 2
54
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
55
+ t.device
56
+ )
57
+
58
+ args = t[:, None].float() * freqs[None]
59
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
60
+ if dim % 2:
61
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
62
+ if torch.is_floating_point(t):
63
+ embedding = embedding.to(t)
64
+ return embedding
65
+
66
+
67
+ class MLPEmbedder(nn.Module):
68
+ def __init__(self, in_dim: int, hidden_dim: int):
69
+ super().__init__()
70
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
71
+ self.silu = nn.SiLU()
72
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
73
+
74
+ def forward(self, x: Tensor) -> Tensor:
75
+ return self.out_layer(self.silu(self.in_layer(x)))
76
+
77
+
78
+ class RMSNorm(torch.nn.Module):
79
+ def __init__(self, dim: int):
80
+ super().__init__()
81
+ self.scale = nn.Parameter(torch.ones(dim))
82
+
83
+ def forward(self, x: Tensor):
84
+ x_dtype = x.dtype
85
+ x = x.float()
86
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
87
+ return (x * rrms).to(dtype=x_dtype) * self.scale
88
+
89
+
90
+ class QKNorm(torch.nn.Module):
91
+ def __init__(self, dim: int):
92
+ super().__init__()
93
+ self.query_norm = RMSNorm(dim)
94
+ self.key_norm = RMSNorm(dim)
95
+
96
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
97
+ q = self.query_norm(q)
98
+ k = self.key_norm(k)
99
+ return q.to(v), k.to(v)
100
+
101
+ class LoRALinearLayer(nn.Module):
102
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
103
+ super().__init__()
104
+
105
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
106
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
107
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
108
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
109
+ self.network_alpha = network_alpha
110
+ self.rank = rank
111
+
112
+ nn.init.normal_(self.down.weight, std=1 / rank)
113
+ nn.init.zeros_(self.up.weight)
114
+
115
+ def forward(self, hidden_states):
116
+ orig_dtype = hidden_states.dtype
117
+ dtype = self.down.weight.dtype
118
+
119
+ down_hidden_states = self.down(hidden_states.to(dtype))
120
+ up_hidden_states = self.up(down_hidden_states)
121
+
122
+ if self.network_alpha is not None:
123
+ up_hidden_states *= self.network_alpha / self.rank
124
+
125
+ return up_hidden_states.to(orig_dtype)
126
+
127
+ class FLuxSelfAttnProcessor:
128
+ def __call__(self, attn, x, pe, **attention_kwargs):
129
+ qkv = attn.qkv(x)
130
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
131
+ q, k = attn.norm(q, k, v)
132
+ x = attention(q, k, v, pe=pe)
133
+ x = attn.proj(x)
134
+ return x
135
+
136
+ class LoraFluxAttnProcessor(nn.Module):
137
+
138
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
139
+ super().__init__()
140
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
141
+ self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
142
+ self.lora_weight = lora_weight
143
+
144
+
145
+ def __call__(self, attn, x, pe, **attention_kwargs):
146
+ qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
147
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
148
+ q, k = attn.norm(q, k, v)
149
+ x = attention(q, k, v, pe=pe)
150
+ x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
151
+ return x
152
+
153
+ class SelfAttention(nn.Module):
154
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
155
+ super().__init__()
156
+ self.num_heads = num_heads
157
+ head_dim = dim // num_heads
158
+
159
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
160
+ self.norm = QKNorm(head_dim)
161
+ self.proj = nn.Linear(dim, dim)
162
+ def forward():
163
+ pass
164
+
165
+
166
+ @dataclass
167
+ class ModulationOut:
168
+ shift: Tensor
169
+ scale: Tensor
170
+ gate: Tensor
171
+
172
+
173
+ class Modulation(nn.Module):
174
+ def __init__(self, dim: int, double: bool):
175
+ super().__init__()
176
+ self.is_double = double
177
+ self.multiplier = 6 if double else 3
178
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
179
+
180
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
181
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
182
+
183
+ return (
184
+ ModulationOut(*out[:3]),
185
+ ModulationOut(*out[3:]) if self.is_double else None,
186
+ )
187
+
188
+ class DoubleStreamBlockLoraProcessor(nn.Module):
189
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
190
+ super().__init__()
191
+ self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
192
+ self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
193
+ self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
194
+ self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
195
+ self.lora_weight = lora_weight
196
+
197
+ def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
198
+ img_mod1, img_mod2 = attn.img_mod(vec)
199
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
200
+
201
+ # prepare image for attention
202
+ img_modulated = attn.img_norm1(img)
203
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
204
+ img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
205
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
206
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
207
+
208
+ # prepare txt for attention
209
+ txt_modulated = attn.txt_norm1(txt)
210
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
211
+ txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
212
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
213
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
214
+
215
+ # run actual attention
216
+ q = torch.cat((txt_q, img_q), dim=2)
217
+ k = torch.cat((txt_k, img_k), dim=2)
218
+ v = torch.cat((txt_v, img_v), dim=2)
219
+
220
+ attn1 = attention(q, k, v, pe=pe)
221
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
222
+
223
+ # calculate the img bloks
224
+ img = img + img_mod1.gate * (attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight)
225
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
226
+
227
+ # calculate the txt bloks
228
+ txt = txt + txt_mod1.gate * (attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight)
229
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
230
+ return img, txt
231
+
232
+ class DoubleStreamBlockProcessor:
233
+ def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
234
+ img_mod1, img_mod2 = attn.img_mod(vec)
235
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
236
+
237
+ # prepare image for attention
238
+ img_modulated = attn.img_norm1(img)
239
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
240
+ img_qkv = attn.img_attn.qkv(img_modulated)
241
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
242
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
243
+
244
+ # prepare txt for attention
245
+ txt_modulated = attn.txt_norm1(txt)
246
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
247
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
248
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
249
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
250
+
251
+ # run actual attention
252
+ q = torch.cat((txt_q, img_q), dim=2)
253
+ k = torch.cat((txt_k, img_k), dim=2)
254
+ v = torch.cat((txt_v, img_v), dim=2)
255
+
256
+ attn1 = attention(q, k, v, pe=pe)
257
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
258
+
259
+ # calculate the img bloks
260
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
261
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
262
+
263
+ # calculate the txt bloks
264
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
265
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
266
+ return img, txt
267
+
268
+ class DoubleStreamBlock(nn.Module):
269
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
270
+ super().__init__()
271
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
272
+ self.num_heads = num_heads
273
+ self.hidden_size = hidden_size
274
+ self.head_dim = hidden_size // num_heads
275
+
276
+ self.img_mod = Modulation(hidden_size, double=True)
277
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
278
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
279
+
280
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
281
+ self.img_mlp = nn.Sequential(
282
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
283
+ nn.GELU(approximate="tanh"),
284
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
285
+ )
286
+
287
+ self.txt_mod = Modulation(hidden_size, double=True)
288
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
289
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
290
+
291
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
292
+ self.txt_mlp = nn.Sequential(
293
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
294
+ nn.GELU(approximate="tanh"),
295
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
296
+ )
297
+ processor = DoubleStreamBlockProcessor()
298
+ self.set_processor(processor)
299
+
300
+ def set_processor(self, processor) -> None:
301
+ self.processor = processor
302
+
303
+ def get_processor(self):
304
+ return self.processor
305
+
306
+ def forward(
307
+ self,
308
+ img: Tensor,
309
+ txt: Tensor,
310
+ vec: Tensor,
311
+ pe: Tensor,
312
+ image_proj: Tensor = None,
313
+ ip_scale: float =1.0,
314
+ ) -> tuple[Tensor, Tensor]:
315
+ if image_proj is None:
316
+ return self.processor(self, img, txt, vec, pe)
317
+ else:
318
+ return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
319
+
320
+
321
+ class SingleStreamBlockLoraProcessor(nn.Module):
322
+ def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1):
323
+ super().__init__()
324
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
325
+ self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha)
326
+ self.lora_weight = lora_weight
327
+
328
+ def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
329
+
330
+ mod, _ = attn.modulation(vec)
331
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
332
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
333
+ qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
334
+
335
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
336
+ q, k = attn.norm(q, k, v)
337
+
338
+ # compute attention
339
+ attn_1 = attention(q, k, v, pe=pe)
340
+
341
+ # compute activation in mlp stream, cat again and run second linear layer
342
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
343
+ output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight
344
+ output = x + mod.gate * output
345
+ return output
346
+
347
+
348
+ class SingleStreamBlockProcessor:
349
+ def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, **attention_kwargs) -> Tensor:
350
+
351
+ mod, _ = attn.modulation(vec)
352
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
353
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
354
+
355
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
356
+ q, k = attn.norm(q, k, v)
357
+
358
+ # compute attention
359
+ attn_1 = attention(q, k, v, pe=pe)
360
+
361
+ # compute activation in mlp stream, cat again and run second linear layer
362
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
363
+ output = x + mod.gate * output
364
+ return output
365
+
366
+ class SingleStreamBlock(nn.Module):
367
+ """
368
+ A DiT block with parallel linear layers as described in
369
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
370
+ """
371
+
372
+ def __init__(
373
+ self,
374
+ hidden_size: int,
375
+ num_heads: int,
376
+ mlp_ratio: float = 4.0,
377
+ qk_scale: float | None = None,
378
+ ):
379
+ super().__init__()
380
+ self.hidden_dim = hidden_size
381
+ self.num_heads = num_heads
382
+ self.head_dim = hidden_size // num_heads
383
+ self.scale = qk_scale or self.head_dim**-0.5
384
+
385
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
386
+ # qkv and mlp_in
387
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
388
+ # proj and mlp_out
389
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
390
+
391
+ self.norm = QKNorm(self.head_dim)
392
+
393
+ self.hidden_size = hidden_size
394
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
395
+
396
+ self.mlp_act = nn.GELU(approximate="tanh")
397
+ self.modulation = Modulation(hidden_size, double=False)
398
+
399
+ processor = SingleStreamBlockProcessor()
400
+ self.set_processor(processor)
401
+
402
+
403
+ def set_processor(self, processor) -> None:
404
+ self.processor = processor
405
+
406
+ def get_processor(self):
407
+ return self.processor
408
+
409
+ def forward(
410
+ self,
411
+ x: Tensor,
412
+ vec: Tensor,
413
+ pe: Tensor,
414
+ image_proj: Tensor | None = None,
415
+ ip_scale: float = 1.0,
416
+ ) -> Tensor:
417
+ if image_proj is None:
418
+ return self.processor(self, x, vec, pe)
419
+ else:
420
+ return self.processor(self, x, vec, pe, image_proj, ip_scale)
421
+
422
+
423
+
424
+ class LastLayer(nn.Module):
425
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
426
+ super().__init__()
427
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
428
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
429
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
430
+
431
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
432
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
433
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
434
+ x = self.linear(x)
435
+ return x
uno/flux/pipeline.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from typing import Literal
18
+
19
+ import torch
20
+ from einops import rearrange
21
+ from PIL import ExifTags, Image
22
+ import torchvision.transforms.functional as TVF
23
+
24
+ from uno.flux.modules.layers import (
25
+ DoubleStreamBlockLoraProcessor,
26
+ DoubleStreamBlockProcessor,
27
+ SingleStreamBlockLoraProcessor,
28
+ SingleStreamBlockProcessor,
29
+ )
30
+ from uno.flux.sampling import denoise, get_noise, get_schedule, prepare, prepare_multi_ip, unpack
31
+ from uno.flux.util import (
32
+ get_lora_rank,
33
+ load_ae,
34
+ load_checkpoint,
35
+ load_clip,
36
+ load_flow_model,
37
+ load_flow_model_only_lora,
38
+ load_flow_model_quintized,
39
+ load_t5,
40
+ )
41
+
42
+
43
+ def find_nearest_scale(image_h, image_w, predefined_scales):
44
+ """
45
+ 根据图片的高度和宽度,找到最近的预定义尺度。
46
+
47
+ :param image_h: 图片的高度
48
+ :param image_w: 图片的宽度
49
+ :param predefined_scales: 预定义尺度列表 [(h1, w1), (h2, w2), ...]
50
+ :return: 最近的预定义尺度 (h, w)
51
+ """
52
+ # 计算输入图片的长宽比
53
+ image_ratio = image_h / image_w
54
+
55
+ # 初始化变量以存储最小差异和最近的尺度
56
+ min_diff = float('inf')
57
+ nearest_scale = None
58
+
59
+ # 遍历所有预定义尺度,找到与输入图片长宽比最接近的尺度
60
+ for scale_h, scale_w in predefined_scales:
61
+ predefined_ratio = scale_h / scale_w
62
+ diff = abs(predefined_ratio - image_ratio)
63
+
64
+ if diff < min_diff:
65
+ min_diff = diff
66
+ nearest_scale = (scale_h, scale_w)
67
+
68
+ return nearest_scale
69
+
70
+ def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
71
+ # 获取原始图像的宽度和高度
72
+ image_w, image_h = raw_image.size
73
+
74
+ # 计算长边和短边
75
+ if image_w >= image_h:
76
+ new_w = long_size
77
+ new_h = int((long_size / image_w) * image_h)
78
+ else:
79
+ new_h = long_size
80
+ new_w = int((long_size / image_h) * image_w)
81
+
82
+ # 按新的宽高进行等比例缩放
83
+ raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
84
+ target_w = new_w // 16 * 16
85
+ target_h = new_h // 16 * 16
86
+
87
+ # 计算裁剪的起始坐标以实现中心裁剪
88
+ left = (new_w - target_w) // 2
89
+ top = (new_h - target_h) // 2
90
+ right = left + target_w
91
+ bottom = top + target_h
92
+
93
+ # 进行中心裁剪
94
+ raw_image = raw_image.crop((left, top, right, bottom))
95
+
96
+ # 转换为 RGB 模式
97
+ raw_image = raw_image.convert("RGB")
98
+ return raw_image
99
+
100
+ class UNOPipeline:
101
+ def __init__(
102
+ self,
103
+ model_type: str,
104
+ device: torch.device,
105
+ offload: bool = False,
106
+ only_lora: bool = False,
107
+ lora_rank: int = 16
108
+ ):
109
+ self.device = device
110
+ self.offload = offload
111
+ self.model_type = model_type
112
+
113
+ self.clip = load_clip(self.device)
114
+ self.t5 = load_t5(self.device, max_length=512)
115
+ self.ae = load_ae(model_type, device="cpu" if offload else self.device)
116
+ if "fp8" in model_type:
117
+ self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device)
118
+ elif only_lora:
119
+ self.model = load_flow_model_only_lora(
120
+ model_type, device="cpu" if offload else self.device, lora_rank=lora_rank
121
+ )
122
+ else:
123
+ self.model = load_flow_model(model_type, device="cpu" if offload else self.device)
124
+
125
+
126
+ def load_ckpt(self, ckpt_path):
127
+ if ckpt_path is not None:
128
+ from safetensors.torch import load_file as load_sft
129
+ print("Loading checkpoint to replace old keys")
130
+ # load_sft doesn't support torch.device
131
+ if ckpt_path.endswith('safetensors'):
132
+ sd = load_sft(ckpt_path, device='cpu')
133
+ missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
134
+ else:
135
+ dit_state = torch.load(ckpt_path, map_location='cpu')
136
+ sd = {}
137
+ for k in dit_state.keys():
138
+ sd[k.replace('module.','')] = dit_state[k]
139
+ missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
140
+ self.model.to(str(self.device))
141
+ print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}")
142
+
143
+ def set_lora(self, local_path: str = None, repo_id: str = None,
144
+ name: str = None, lora_weight: int = 0.7):
145
+ checkpoint = load_checkpoint(local_path, repo_id, name)
146
+ self.update_model_with_lora(checkpoint, lora_weight)
147
+
148
+ def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7):
149
+ checkpoint = load_checkpoint(
150
+ None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
151
+ )
152
+ self.update_model_with_lora(checkpoint, lora_weight)
153
+
154
+ def update_model_with_lora(self, checkpoint, lora_weight):
155
+ rank = get_lora_rank(checkpoint)
156
+ lora_attn_procs = {}
157
+
158
+ for name, _ in self.model.attn_processors.items():
159
+ lora_state_dict = {}
160
+ for k in checkpoint.keys():
161
+ if name in k:
162
+ lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight
163
+
164
+ if len(lora_state_dict):
165
+ if name.startswith("single_blocks"):
166
+ lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank)
167
+ else:
168
+ lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
169
+ lora_attn_procs[name].load_state_dict(lora_state_dict)
170
+ lora_attn_procs[name].to(self.device)
171
+ else:
172
+ if name.startswith("single_blocks"):
173
+ lora_attn_procs[name] = SingleStreamBlockProcessor()
174
+ else:
175
+ lora_attn_procs[name] = DoubleStreamBlockProcessor()
176
+
177
+ self.model.set_attn_processor(lora_attn_procs)
178
+
179
+
180
+ def __call__(
181
+ self,
182
+ prompt: str,
183
+ width: int = 512,
184
+ height: int = 512,
185
+ guidance: float = 4,
186
+ num_steps: int = 50,
187
+ seed: int = 123456789,
188
+ true_gs: float = 3,
189
+ neg_prompt: str = '',
190
+ neg_image_prompt: Image = None,
191
+ timestep_to_start_cfg: int = 0,
192
+ **kwargs
193
+ ):
194
+ width = 16 * (width // 16)
195
+ height = 16 * (height // 16)
196
+
197
+ return self.forward(
198
+ prompt,
199
+ width,
200
+ height,
201
+ guidance,
202
+ num_steps,
203
+ seed,
204
+ timestep_to_start_cfg=timestep_to_start_cfg,
205
+ true_gs=true_gs,
206
+ neg_prompt=neg_prompt,
207
+ **kwargs
208
+ )
209
+
210
+ @torch.inference_mode()
211
+ def gradio_generate(
212
+ self,
213
+ prompt: str,
214
+ width: int,
215
+ height: int,
216
+ guidance: float,
217
+ num_steps: int,
218
+ seed: int,
219
+ ref_long_side: int,
220
+ image_prompt1: Image.Image,
221
+ image_prompt2: Image.Image,
222
+ image_prompt3: Image.Image,
223
+ image_prompt4: Image.Image,
224
+ ):
225
+ ref_imgs = [image_prompt1, image_prompt2, image_prompt3, image_prompt4]
226
+ ref_imgs = [img for img in ref_imgs if isinstance(img, Image.Image)]
227
+ ref_imgs = [preprocess_ref(img, ref_long_side) for img in ref_imgs]
228
+
229
+ seed = seed if seed != -1 else torch.randint(0, 10 ** 8, (1,)).item()
230
+
231
+ img = self(prompt=prompt, width=width, height=height, guidance=guidance,
232
+ num_steps=num_steps, seed=seed, ref_imgs=ref_imgs)
233
+
234
+ filename = f"output/gradio/{seed}_{prompt[:20]}.png"
235
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
236
+ exif_data = Image.Exif()
237
+ exif_data[ExifTags.Base.Make] = "UNO"
238
+ exif_data[ExifTags.Base.Model] = self.model_type
239
+ info = f"{prompt=}, {seed=}, {width=}, {height=}, {guidance=}, {num_steps=}"
240
+ exif_data[ExifTags.Base.ImageDescription] = info
241
+ img.save(filename, format="png", exif=exif_data)
242
+ return img, filename
243
+
244
+ @torch.inference_mode
245
+ def forward(
246
+ self,
247
+ prompt: str,
248
+ width: int,
249
+ height: int,
250
+ guidance: float,
251
+ num_steps: int,
252
+ seed: int,
253
+ timestep_to_start_cfg: int = 1e5, # TODO 没用,删除
254
+ true_gs: float = 3.5,
255
+ neg_prompt: str = "",
256
+ ref_imgs: list[Image.Image] | None = None,
257
+ pe: Literal['d', 'h', 'w', 'o'] = 'd',
258
+ ):
259
+ x = get_noise(
260
+ 1, height, width, device=self.device,
261
+ dtype=torch.bfloat16, seed=seed
262
+ )
263
+ timesteps = get_schedule(
264
+ num_steps,
265
+ (width // 8) * (height // 8) // (16 * 16),
266
+ shift=True,
267
+ )
268
+ if self.offload:
269
+ self.ae.encoder = self.ae.encoder.to(self.device)
270
+ x_1_refs = [
271
+ self.ae.encode(
272
+ (TVF.to_tensor(ref_img) * 2.0 - 1.0)
273
+ .unsqueeze(0).to(self.device, torch.float32)
274
+ ).to(torch.bfloat16)
275
+ for ref_img in ref_imgs
276
+ ]
277
+
278
+ if self.offload:
279
+ self.ae.encoder = self.offload_model_to_cpu(self.ae.encoder)
280
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
281
+ inp_cond = prepare_multi_ip(
282
+ t5=self.t5, clip=self.clip,
283
+ img=x,
284
+ prompt=prompt, ref_imgs=x_1_refs, pe=pe
285
+ )
286
+ neg_inp_cond = prepare_multi_ip(
287
+ t5=self.t5, clip=self.clip,
288
+ img=x,
289
+ prompt=neg_prompt, ref_imgs=x_1_refs, pe=pe
290
+ )
291
+
292
+ if self.offload:
293
+ self.offload_model_to_cpu(self.t5, self.clip)
294
+ self.model = self.model.to(self.device)
295
+
296
+ x = denoise(
297
+ self.model,
298
+ **inp_cond,
299
+ timesteps=timesteps,
300
+ guidance=guidance,
301
+ timestep_to_start_cfg=timestep_to_start_cfg,
302
+ neg_txt=neg_inp_cond['txt'],
303
+ neg_txt_ids=neg_inp_cond['txt_ids'],
304
+ neg_vec=neg_inp_cond['vec'],
305
+ true_gs=true_gs,
306
+ )
307
+
308
+ if self.offload:
309
+ self.offload_model_to_cpu(self.model)
310
+ self.ae.decoder.to(x.device)
311
+ x = unpack(x.float(), height, width)
312
+ x = self.ae.decode(x)
313
+ self.offload_model_to_cpu(self.ae.decoder)
314
+
315
+ x1 = x.clamp(-1, 1)
316
+ x1 = rearrange(x1[-1], "c h w -> h w c")
317
+ output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
318
+ return output_img
319
+
320
+ def offload_model_to_cpu(self, *models):
321
+ if not self.offload: return
322
+ for model in models:
323
+ model.cpu()
324
+ torch.cuda.empty_cache()
uno/flux/sampling.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Literal
18
+
19
+ import torch
20
+ from einops import rearrange, repeat
21
+ from torch import Tensor
22
+ from tqdm import tqdm
23
+
24
+ from .model import Flux
25
+ from .modules.conditioner import HFEmbedder
26
+
27
+
28
+ def get_noise(
29
+ num_samples: int,
30
+ height: int,
31
+ width: int,
32
+ device: torch.device,
33
+ dtype: torch.dtype,
34
+ seed: int,
35
+ ):
36
+ return torch.randn(
37
+ num_samples,
38
+ 16,
39
+ # allow for packing
40
+ 2 * math.ceil(height / 16),
41
+ 2 * math.ceil(width / 16),
42
+ device=device,
43
+ dtype=dtype,
44
+ generator=torch.Generator(device=device).manual_seed(seed),
45
+ )
46
+
47
+
48
+ def prepare(
49
+ t5: HFEmbedder,
50
+ clip: HFEmbedder,
51
+ img: Tensor,
52
+ prompt: str | list[str],
53
+ ref_img: None | Tensor=None,
54
+ pe: Literal['d', 'h', 'w', 'o'] ='d'
55
+ ) -> dict[str, Tensor]:
56
+ assert pe in ['d', 'h', 'w', 'o']
57
+ bs, c, h, w = img.shape
58
+ if bs == 1 and not isinstance(prompt, str):
59
+ bs = len(prompt)
60
+
61
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
62
+ if img.shape[0] == 1 and bs > 1:
63
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
64
+
65
+ img_ids = torch.zeros(h // 2, w // 2, 3)
66
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
67
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
68
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
69
+
70
+ if ref_img is not None:
71
+ _, _, ref_h, ref_w = ref_img.shape
72
+ ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
73
+ if ref_img.shape[0] == 1 and bs > 1:
74
+ ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
75
+ ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
76
+ # img id分别在宽高偏移各自最大值
77
+ h_offset = h // 2 if pe in {'d', 'h'} else 0
78
+ w_offset = w // 2 if pe in {'d', 'w'} else 0
79
+ ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + h_offset
80
+ ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + w_offset
81
+ ref_img_ids = repeat(ref_img_ids, "h w c -> b (h w) c", b=bs)
82
+
83
+ if isinstance(prompt, str):
84
+ prompt = [prompt]
85
+ txt = t5(prompt)
86
+ if txt.shape[0] == 1 and bs > 1:
87
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
88
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
89
+
90
+ vec = clip(prompt)
91
+ if vec.shape[0] == 1 and bs > 1:
92
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
93
+
94
+ if ref_img is not None:
95
+ return {
96
+ "img": img,
97
+ "img_ids": img_ids.to(img.device),
98
+ "ref_img": ref_img,
99
+ "ref_img_ids": ref_img_ids.to(img.device),
100
+ "txt": txt.to(img.device),
101
+ "txt_ids": txt_ids.to(img.device),
102
+ "vec": vec.to(img.device),
103
+ }
104
+ else:
105
+ return {
106
+ "img": img,
107
+ "img_ids": img_ids.to(img.device),
108
+ "txt": txt.to(img.device),
109
+ "txt_ids": txt_ids.to(img.device),
110
+ "vec": vec.to(img.device),
111
+ }
112
+
113
+ def prepare_multi_ip(
114
+ t5: HFEmbedder,
115
+ clip: HFEmbedder,
116
+ img: Tensor,
117
+ prompt: str | list[str],
118
+ ref_imgs: list[Tensor] | None = None,
119
+ pe: Literal['d', 'h', 'w', 'o'] = 'd'
120
+ ) -> dict[str, Tensor]:
121
+ assert pe in ['d', 'h', 'w', 'o']
122
+ bs, c, h, w = img.shape
123
+ if bs == 1 and not isinstance(prompt, str):
124
+ bs = len(prompt)
125
+
126
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
127
+ if img.shape[0] == 1 and bs > 1:
128
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
129
+
130
+ img_ids = torch.zeros(h // 2, w // 2, 3)
131
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
132
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
133
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
134
+
135
+ ref_img_ids = []
136
+ ref_imgs_list = []
137
+ pe_shift_w, pe_shift_h = w // 2, h // 2
138
+ for ref_img in ref_imgs:
139
+ _, _, ref_h1, ref_w1 = ref_img.shape
140
+ ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
141
+ if ref_img.shape[0] == 1 and bs > 1:
142
+ ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
143
+ ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3)
144
+ # img id分别���宽高偏移各自最大值
145
+ h_offset = pe_shift_h if pe in {'d', 'h'} else 0
146
+ w_offset = pe_shift_w if pe in {'d', 'w'} else 0
147
+ ref_img_ids1[..., 1] = ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset
148
+ ref_img_ids1[..., 2] = ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset
149
+ ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs)
150
+ ref_img_ids.append(ref_img_ids1)
151
+ ref_imgs_list.append(ref_img)
152
+
153
+ # 更新pe shift
154
+ pe_shift_h += ref_h1 // 2
155
+ pe_shift_w += ref_w1 // 2
156
+
157
+ if isinstance(prompt, str):
158
+ prompt = [prompt]
159
+ txt = t5(prompt)
160
+ if txt.shape[0] == 1 and bs > 1:
161
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
162
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
163
+
164
+ vec = clip(prompt)
165
+ if vec.shape[0] == 1 and bs > 1:
166
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
167
+
168
+ return {
169
+ "img": img,
170
+ "img_ids": img_ids.to(img.device),
171
+ "ref_img": tuple(ref_imgs_list),
172
+ "ref_img_ids": [ref_img_id.to(img.device) for ref_img_id in ref_img_ids],
173
+ "txt": txt.to(img.device),
174
+ "txt_ids": txt_ids.to(img.device),
175
+ "vec": vec.to(img.device),
176
+ }
177
+
178
+
179
+ def time_shift(mu: float, sigma: float, t: Tensor):
180
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
181
+
182
+
183
+ def get_lin_function(
184
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
185
+ ):
186
+ m = (y2 - y1) / (x2 - x1)
187
+ b = y1 - m * x1
188
+ return lambda x: m * x + b
189
+
190
+
191
+ def get_schedule(
192
+ num_steps: int,
193
+ image_seq_len: int,
194
+ base_shift: float = 0.5,
195
+ max_shift: float = 1.15,
196
+ shift: bool = True,
197
+ ) -> list[float]:
198
+ # extra step for zero
199
+ timesteps = torch.linspace(1, 0, num_steps + 1)
200
+
201
+ # shifting the schedule to favor high timesteps for higher signal images
202
+ if shift:
203
+ # eastimate mu based on linear estimation between two points
204
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
205
+ timesteps = time_shift(mu, 1.0, timesteps)
206
+
207
+ return timesteps.tolist()
208
+
209
+
210
+ def denoise(
211
+ model: Flux,
212
+ # model input
213
+ img: Tensor,
214
+ img_ids: Tensor,
215
+ txt: Tensor,
216
+ txt_ids: Tensor,
217
+ vec: Tensor,
218
+ neg_txt: Tensor,
219
+ neg_txt_ids: Tensor,
220
+ neg_vec: Tensor,
221
+ # sampling parameters
222
+ timesteps: list[float],
223
+ guidance: float = 4.0,
224
+ true_gs = 1,
225
+ timestep_to_start_cfg=0,
226
+ ref_img: Tensor=None,
227
+ ref_img_ids: Tensor=None,
228
+ ):
229
+ i = 0
230
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
231
+ for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1):
232
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
233
+ pred = model(
234
+ img=img,
235
+ img_ids=img_ids,
236
+ ref_img=ref_img,
237
+ ref_img_ids=ref_img_ids,
238
+ txt=txt,
239
+ txt_ids=txt_ids,
240
+ y=vec,
241
+ timesteps=t_vec,
242
+ guidance=guidance_vec
243
+ )
244
+ if i >= timestep_to_start_cfg:
245
+ # not test
246
+ neg_pred = model(
247
+ img=img,
248
+ img_ids=img_ids,
249
+ ref_img=ref_img, # TODO: neg img embedding
250
+ ref_img_ids=ref_img_ids,
251
+ txt=neg_txt,
252
+ txt_ids=neg_txt_ids,
253
+ y=neg_vec,
254
+ timesteps=t_vec,
255
+ guidance=guidance_vec,
256
+ )
257
+ pred = neg_pred + true_gs * (pred - neg_pred)
258
+ img = img + (t_prev - t_curr) * pred
259
+ i += 1
260
+ return img
261
+
262
+
263
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
264
+ return rearrange(
265
+ x,
266
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
267
+ h=math.ceil(height / 16),
268
+ w=math.ceil(width / 16),
269
+ ph=2,
270
+ pw=2,
271
+ )
uno/flux/util.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from dataclasses import dataclass
18
+
19
+ import torch
20
+ import json
21
+ import numpy as np
22
+ from huggingface_hub import hf_hub_download
23
+ from safetensors import safe_open
24
+ from safetensors.torch import load_file as load_sft
25
+
26
+ from .model import Flux, FluxParams
27
+ from .modules.autoencoder import AutoEncoder, AutoEncoderParams
28
+ from .modules.conditioner import HFEmbedder
29
+
30
+ import re
31
+ from uno.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor
32
+ def load_model(ckpt, device='cpu'):
33
+ if ckpt.endswith('safetensors'):
34
+ from safetensors import safe_open
35
+ pl_sd = {}
36
+ with safe_open(ckpt, framework="pt", device=device) as f:
37
+ for k in f.keys():
38
+ pl_sd[k] = f.get_tensor(k)
39
+ else:
40
+ pl_sd = torch.load(ckpt, map_location=device)
41
+ return pl_sd
42
+
43
+ def load_safetensors(path):
44
+ tensors = {}
45
+ with safe_open(path, framework="pt", device="cpu") as f:
46
+ for key in f.keys():
47
+ tensors[key] = f.get_tensor(key)
48
+ return tensors
49
+
50
+ def get_lora_rank(checkpoint):
51
+ for k in checkpoint.keys():
52
+ if k.endswith(".down.weight"):
53
+ return checkpoint[k].shape[0]
54
+
55
+ def load_checkpoint(local_path, repo_id, name):
56
+ if local_path is not None:
57
+ if '.safetensors' in local_path:
58
+ print(f"Loading .safetensors checkpoint from {local_path}")
59
+ checkpoint = load_safetensors(local_path)
60
+ else:
61
+ print(f"Loading checkpoint from {local_path}")
62
+ checkpoint = torch.load(local_path, map_location='cpu')
63
+ elif repo_id is not None and name is not None:
64
+ print(f"Loading checkpoint {name} from repo id {repo_id}")
65
+ checkpoint = load_from_repo_id(repo_id, name)
66
+ else:
67
+ raise ValueError(
68
+ "LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
69
+ )
70
+ return checkpoint
71
+
72
+
73
+ def c_crop(image):
74
+ width, height = image.size
75
+ new_size = min(width, height)
76
+ left = (width - new_size) / 2
77
+ top = (height - new_size) / 2
78
+ right = (width + new_size) / 2
79
+ bottom = (height + new_size) / 2
80
+ return image.crop((left, top, right, bottom))
81
+
82
+ def pad64(x):
83
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
84
+
85
+ def HWC3(x):
86
+ assert x.dtype == np.uint8
87
+ if x.ndim == 2:
88
+ x = x[:, :, None]
89
+ assert x.ndim == 3
90
+ H, W, C = x.shape
91
+ assert C == 1 or C == 3 or C == 4
92
+ if C == 3:
93
+ return x
94
+ if C == 1:
95
+ return np.concatenate([x, x, x], axis=2)
96
+ if C == 4:
97
+ color = x[:, :, 0:3].astype(np.float32)
98
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
99
+ y = color * alpha + 255.0 * (1.0 - alpha)
100
+ y = y.clip(0, 255).astype(np.uint8)
101
+ return y
102
+
103
+ @dataclass
104
+ class ModelSpec:
105
+ params: FluxParams
106
+ ae_params: AutoEncoderParams
107
+ ckpt_path: str | None
108
+ ae_path: str | None
109
+ repo_id: str | None
110
+ repo_flow: str | None
111
+ repo_ae: str | None
112
+ repo_id_ae: str | None
113
+
114
+
115
+ configs = {
116
+ "flux-dev": ModelSpec(
117
+ repo_id="black-forest-labs/FLUX.1-dev",
118
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
119
+ repo_flow="flux1-dev.safetensors",
120
+ repo_ae="ae.safetensors",
121
+ ckpt_path=os.getenv("FLUX_DEV"),
122
+ params=FluxParams(
123
+ in_channels=64,
124
+ vec_in_dim=768,
125
+ context_in_dim=4096,
126
+ hidden_size=3072,
127
+ mlp_ratio=4.0,
128
+ num_heads=24,
129
+ depth=19,
130
+ depth_single_blocks=38,
131
+ axes_dim=[16, 56, 56],
132
+ theta=10_000,
133
+ qkv_bias=True,
134
+ guidance_embed=True,
135
+ ),
136
+ ae_path=os.getenv("AE"),
137
+ ae_params=AutoEncoderParams(
138
+ resolution=256,
139
+ in_channels=3,
140
+ ch=128,
141
+ out_ch=3,
142
+ ch_mult=[1, 2, 4, 4],
143
+ num_res_blocks=2,
144
+ z_channels=16,
145
+ scale_factor=0.3611,
146
+ shift_factor=0.1159,
147
+ ),
148
+ ),
149
+ "flux-dev-fp8": ModelSpec(
150
+ repo_id="XLabs-AI/flux-dev-fp8",
151
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
152
+ repo_flow="flux-dev-fp8.safetensors",
153
+ repo_ae="ae.safetensors",
154
+ ckpt_path=os.getenv("FLUX_DEV_FP8"),
155
+ params=FluxParams(
156
+ in_channels=64,
157
+ vec_in_dim=768,
158
+ context_in_dim=4096,
159
+ hidden_size=3072,
160
+ mlp_ratio=4.0,
161
+ num_heads=24,
162
+ depth=19,
163
+ depth_single_blocks=38,
164
+ axes_dim=[16, 56, 56],
165
+ theta=10_000,
166
+ qkv_bias=True,
167
+ guidance_embed=True,
168
+ ),
169
+ ae_path=os.getenv("AE"),
170
+ ae_params=AutoEncoderParams(
171
+ resolution=256,
172
+ in_channels=3,
173
+ ch=128,
174
+ out_ch=3,
175
+ ch_mult=[1, 2, 4, 4],
176
+ num_res_blocks=2,
177
+ z_channels=16,
178
+ scale_factor=0.3611,
179
+ shift_factor=0.1159,
180
+ ),
181
+ ),
182
+ "flux-schnell": ModelSpec(
183
+ repo_id="black-forest-labs/FLUX.1-schnell",
184
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
185
+ repo_flow="flux1-schnell.safetensors",
186
+ repo_ae="ae.safetensors",
187
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
188
+ params=FluxParams(
189
+ in_channels=64,
190
+ vec_in_dim=768,
191
+ context_in_dim=4096,
192
+ hidden_size=3072,
193
+ mlp_ratio=4.0,
194
+ num_heads=24,
195
+ depth=19,
196
+ depth_single_blocks=38,
197
+ axes_dim=[16, 56, 56],
198
+ theta=10_000,
199
+ qkv_bias=True,
200
+ guidance_embed=False,
201
+ ),
202
+ ae_path=os.getenv("AE"),
203
+ ae_params=AutoEncoderParams(
204
+ resolution=256,
205
+ in_channels=3,
206
+ ch=128,
207
+ out_ch=3,
208
+ ch_mult=[1, 2, 4, 4],
209
+ num_res_blocks=2,
210
+ z_channels=16,
211
+ scale_factor=0.3611,
212
+ shift_factor=0.1159,
213
+ ),
214
+ ),
215
+ }
216
+
217
+
218
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
219
+ if len(missing) > 0 and len(unexpected) > 0:
220
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
221
+ print("\n" + "-" * 79 + "\n")
222
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
223
+ elif len(missing) > 0:
224
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
225
+ elif len(unexpected) > 0:
226
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
227
+
228
+ def load_from_repo_id(repo_id, checkpoint_name):
229
+ ckpt_path = hf_hub_download(repo_id, checkpoint_name)
230
+ sd = load_sft(ckpt_path, device='cpu')
231
+ return sd
232
+
233
+ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
234
+ # Loading Flux
235
+ print("Init model")
236
+ ckpt_path = configs[name].ckpt_path
237
+ if (
238
+ ckpt_path is None
239
+ and configs[name].repo_id is not None
240
+ and configs[name].repo_flow is not None
241
+ and hf_download
242
+ ):
243
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
244
+
245
+ with torch.device("meta" if ckpt_path is not None else device):
246
+ model = Flux(configs[name].params).to(torch.bfloat16)
247
+
248
+ if ckpt_path is not None:
249
+ print("Loading checkpoint")
250
+ # load_sft doesn't support torch.device
251
+ sd = load_model(ckpt_path, device=str(device))
252
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
253
+ print_load_warning(missing, unexpected)
254
+ return model
255
+
256
+ def load_flow_model_only_lora(
257
+ name: str,
258
+ device: str | torch.device = "cuda",
259
+ hf_download: bool = True,
260
+ lora_rank: int = 16
261
+ ):
262
+ # Loading Flux
263
+ print("Init model")
264
+ ckpt_path = configs[name].ckpt_path
265
+ if (
266
+ ckpt_path is None
267
+ and configs[name].repo_id is not None
268
+ and configs[name].repo_flow is not None
269
+ and hf_download
270
+ ):
271
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
272
+
273
+ if hf_download:
274
+ lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
275
+ else:
276
+ lora_ckpt_path = os.environ.get("LORA", None)
277
+
278
+ with torch.device("meta" if ckpt_path is not None else device):
279
+ model = Flux(configs[name].params)
280
+
281
+
282
+ model = set_lora(model, lora_rank, device="meta" if lora_ckpt_path is not None else device)
283
+
284
+ if ckpt_path is not None:
285
+ print("Loading lora")
286
+ lora_sd = load_sft(lora_ckpt_path, device=str(device)) if lora_ckpt_path.endswith("safetensors")\
287
+ else torch.load(lora_ckpt_path, map_location='cpu')
288
+
289
+ print("Loading main checkpoint")
290
+ # load_sft doesn't support torch.device
291
+
292
+ if ckpt_path.endswith('safetensors'):
293
+ sd = load_sft(ckpt_path, device=str(device))
294
+ sd.update(lora_sd)
295
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
296
+ else:
297
+ dit_state = torch.load(ckpt_path, map_location='cpu')
298
+ sd = {}
299
+ for k in dit_state.keys():
300
+ sd[k.replace('module.','')] = dit_state[k]
301
+ sd.update(lora_sd)
302
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
303
+ model.to(str(device))
304
+ print_load_warning(missing, unexpected)
305
+ return model
306
+
307
+
308
+ def set_lora(
309
+ model: Flux,
310
+ lora_rank: int,
311
+ double_blocks_indices: list[int] | None = None,
312
+ single_blocks_indices: list[int] | None = None,
313
+ device: str | torch.device = "cpu",
314
+ ) -> Flux:
315
+ double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices
316
+ single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \
317
+ else single_blocks_indices
318
+
319
+ lora_attn_procs = {}
320
+ with torch.device(device):
321
+ for name, attn_processor in model.attn_processors.items():
322
+ match = re.search(r'\.(\d+)\.', name)
323
+ if match:
324
+ layer_index = int(match.group(1))
325
+
326
+ if name.startswith("double_blocks") and layer_index in double_blocks_indices:
327
+ lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
328
+ elif name.startswith("single_blocks") and layer_index in single_blocks_indices:
329
+ lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
330
+ else:
331
+ lora_attn_procs[name] = attn_processor
332
+ model.set_attn_processor(lora_attn_procs)
333
+ return model
334
+
335
+
336
+ def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
337
+ # Loading Flux
338
+ from optimum.quanto import requantize
339
+ print("Init model")
340
+ ckpt_path = configs[name].ckpt_path
341
+ if (
342
+ ckpt_path is None
343
+ and configs[name].repo_id is not None
344
+ and configs[name].repo_flow is not None
345
+ and hf_download
346
+ ):
347
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
348
+ json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
349
+
350
+
351
+ model = Flux(configs[name].params).to(torch.bfloat16)
352
+
353
+ print("Loading checkpoint")
354
+ # load_sft doesn't support torch.device
355
+ sd = load_sft(ckpt_path, device='cpu')
356
+ with open(json_path, "r") as f:
357
+ quantization_map = json.load(f)
358
+ print("Start a quantization process...")
359
+ requantize(model, sd, quantization_map, device=device)
360
+ print("Model is quantized!")
361
+ return model
362
+
363
+ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
364
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
365
+ return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
366
+
367
+ def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
368
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
369
+
370
+
371
+ def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
372
+ ckpt_path = configs[name].ae_path
373
+ if (
374
+ ckpt_path is None
375
+ and configs[name].repo_id is not None
376
+ and configs[name].repo_ae is not None
377
+ and hf_download
378
+ ):
379
+ ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
380
+
381
+ # Loading the autoencoder
382
+ print("Init AE")
383
+ with torch.device("meta" if ckpt_path is not None else device):
384
+ ae = AutoEncoder(configs[name].ae_params)
385
+
386
+ if ckpt_path is not None:
387
+ sd = load_sft(ckpt_path, device=str(device))
388
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
389
+ print_load_warning(missing, unexpected)
390
+ return ae
uno/utils/convert_yaml_to_args_file.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import yaml
17
+
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("--yaml", type=str, required=True)
20
+ parser.add_argument("--arg", type=str, required=True)
21
+ args = parser.parse_args()
22
+
23
+
24
+ with open(args.yaml, "r") as f:
25
+ data = yaml.safe_load(f)
26
+
27
+ with open(args.arg, "w") as f:
28
+ for k, v in data.items():
29
+ if isinstance(v, list):
30
+ v = list(map(str, v))
31
+ v = " ".join(v)
32
+ if v is None:
33
+ continue
34
+ print(f"--{k} {v}", end=" ", file=f)