Narayana02 commited on
Commit
fdcddaa
·
verified ·
1 Parent(s): 519fa07

Upload 11 files

Browse files
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ import json
17
+ from pathlib import Path
18
+
19
+ import gradio as gr
20
+ import torch
21
+ import spaces
22
+
23
+ from uno.flux.pipeline import UNOPipeline
24
+
25
+ def get_examples(examples_dir: str = "assets/examples") -> list:
26
+ examples = Path(examples_dir)
27
+ ans = []
28
+ for example in examples.iterdir():
29
+ if not example.is_dir():
30
+ continue
31
+ with open(example / "config.json") as f:
32
+ example_dict = json.load(f)
33
+
34
+
35
+ example_list = []
36
+
37
+ example_list.append(example_dict["useage"]) # case for
38
+ example_list.append(example_dict["prompt"]) # prompt
39
+
40
+ for key in ["image_ref1", "image_ref2", "image_ref3", "image_ref4"]:
41
+ if key in example_dict:
42
+ example_list.append(str(example / example_dict[key]))
43
+ else:
44
+ example_list.append(None)
45
+
46
+ example_list.append(example_dict["seed"])
47
+
48
+ ans.append(example_list)
49
+ return ans
50
+
51
+
52
+ def create_demo(
53
+ model_type: str,
54
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
55
+ offload: bool = False,
56
+ ):
57
+ pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
58
+ pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
59
+
60
+
61
+ badges_text = r"""
62
+ <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
63
+ <a href="https://github.com/bytedance/UNO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UNO"></a>
64
+ <a href="https://bytedance.github.io/UNO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UNO-yellow"></a>
65
+ <a href="https://arxiv.org/abs/2504.02160"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UNO-b31b1b.svg"></a>
66
+ <a href="https://huggingface.co/bytedance-research/UNO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a>
67
+ <a href="https://huggingface.co/spaces/bytedance-research/UNO-FLUX"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=demo&color=orange"></a>
68
+ </div>
69
+ """.strip()
70
+
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown(f"# UNO by UNO team")
73
+ gr.Markdown(badges_text)
74
+ with gr.Row():
75
+ with gr.Column():
76
+ prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
77
+ with gr.Row():
78
+ image_prompt1 = gr.Image(label="Ref Img1", visible=True, interactive=True, type="pil")
79
+ image_prompt2 = gr.Image(label="Ref Img2", visible=True, interactive=True, type="pil")
80
+ image_prompt3 = gr.Image(label="Ref Img3", visible=True, interactive=True, type="pil")
81
+ image_prompt4 = gr.Image(label="Ref img4", visible=True, interactive=True, type="pil")
82
+
83
+ with gr.Row():
84
+ with gr.Column():
85
+ width = gr.Slider(512, 2048, 512, step=16, label="Gneration Width")
86
+ height = gr.Slider(512, 2048, 512, step=16, label="Gneration Height")
87
+ with gr.Column():
88
+ gr.Markdown("📌 The model trained on 512x512 resolution.\n")
89
+ gr.Markdown(
90
+ "The size closer to 512 is more stable,"
91
+ " and the higher size gives a better visual effect but is less stable"
92
+ )
93
+
94
+ with gr.Accordion("Advanced Options", open=False):
95
+ with gr.Row():
96
+ num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
97
+ guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
98
+ seed = gr.Number(-1, label="Seed (-1 for random)")
99
+
100
+ generate_btn = gr.Button("Generate")
101
+
102
+ with gr.Column():
103
+ output_image = gr.Image(label="Generated Image")
104
+ download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
105
+
106
+
107
+ inputs = [
108
+ prompt, width, height, guidance, num_steps,
109
+ seed, image_prompt1, image_prompt2, image_prompt3, image_prompt4
110
+ ]
111
+ generate_btn.click(
112
+ fn=pipeline.gradio_generate,
113
+ inputs=inputs,
114
+ outputs=[output_image, download_btn],
115
+ )
116
+
117
+ example_text = gr.Text("", visible=False, label="Case For:")
118
+ examples = get_examples("./assets/examples")
119
+
120
+ gr.Examples(
121
+ examples=examples,
122
+ inputs=[
123
+ example_text, prompt,
124
+ image_prompt1, image_prompt2, image_prompt3, image_prompt4,
125
+ seed, output_image
126
+ ],
127
+ )
128
+
129
+ return demo
130
+
131
+ if __name__ == "__main__":
132
+ from typing import Literal
133
+
134
+ from transformers import HfArgumentParser
135
+
136
+ @dataclasses.dataclass
137
+ class AppArgs:
138
+ name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
139
+ device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu"
140
+ offload: bool = dataclasses.field(
141
+ default=False,
142
+ metadata={"help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."}
143
+ )
144
+ port: int = 7860
145
+
146
+ parser = HfArgumentParser([AppArgs])
147
+ args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs]
148
+ args = args_tuple[0]
149
+
150
+ demo = create_demo(args.name, args.device, args.offload)
151
+ demo.launch(server_port=args.port, ssr_mode=False)
uno/dataset/uno.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import os
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torchvision.transforms.functional as TVF
22
+ from torch.utils.data import DataLoader, Dataset
23
+ from torchvision.transforms import Compose, Normalize, ToTensor
24
+
25
+ def bucket_images(images: list[torch.Tensor], resolution: int = 512):
26
+ bucket_override=[
27
+ # h w
28
+ (256, 768),
29
+ (320, 768),
30
+ (320, 704),
31
+ (384, 640),
32
+ (448, 576),
33
+ (512, 512),
34
+ (576, 448),
35
+ (640, 384),
36
+ (704, 320),
37
+ (768, 320),
38
+ (768, 256)
39
+ ]
40
+ bucket_override = [(int(h / 512 * resolution), int(w / 512 * resolution)) for h, w in bucket_override]
41
+ bucket_override = [(h // 16 * 16, w // 16 * 16) for h, w in bucket_override]
42
+
43
+ aspect_ratios = [image.shape[-2] / image.shape[-1] for image in images]
44
+ mean_aspect_ratio = np.mean(aspect_ratios)
45
+
46
+ new_h, new_w = bucket_override[0]
47
+ min_aspect_diff = np.abs(new_h / new_w - mean_aspect_ratio)
48
+ for h, w in bucket_override:
49
+ aspect_diff = np.abs(h / w - mean_aspect_ratio)
50
+ if aspect_diff < min_aspect_diff:
51
+ min_aspect_diff = aspect_diff
52
+ new_h, new_w = h, w
53
+
54
+ images = [TVF.resize(image, (new_h, new_w)) for image in images]
55
+ images = torch.stack(images, dim=0)
56
+ return images
57
+
58
+ class FluxPairedDatasetV2(Dataset):
59
+ def __init__(self, json_file: str, resolution: int, resolution_ref: int | None = None):
60
+ super().__init__()
61
+ self.json_file = json_file
62
+ self.resolution = resolution
63
+ self.resolution_ref = resolution_ref if resolution_ref is not None else resolution
64
+ self.image_root = os.path.dirname(json_file)
65
+
66
+ with open(self.json_file, "rt") as f:
67
+ self.data_dicts = json.load(f)
68
+
69
+ self.transform = Compose([
70
+ ToTensor(),
71
+ Normalize([0.5], [0.5]),
72
+ ])
73
+
74
+ def __getitem__(self, idx):
75
+ data_dict = self.data_dicts[idx]
76
+ image_paths = [data_dict["image_path"]] if "image_path" in data_dict else data_dict["image_paths"]
77
+ txt = data_dict["prompt"]
78
+ image_tgt_path = data_dict.get("image_tgt_path", None)
79
+ ref_imgs = [
80
+ Image.open(os.path.join(self.image_root, path)).convert("RGB")
81
+ for path in image_paths
82
+ ]
83
+ ref_imgs = [self.transform(img) for img in ref_imgs]
84
+ img = None
85
+ if image_tgt_path is not None:
86
+ img = Image.open(os.path.join(self.image_root, image_tgt_path)).convert("RGB")
87
+ img = self.transform(img)
88
+
89
+ return {
90
+ "img": img,
91
+ "txt": txt,
92
+ "ref_imgs": ref_imgs,
93
+ }
94
+
95
+ def __len__(self):
96
+ return len(self.data_dicts)
97
+
98
+ def collate_fn(self, batch):
99
+ img = [data["img"] for data in batch]
100
+ txt = [data["txt"] for data in batch]
101
+ ref_imgs = [data["ref_imgs"] for data in batch]
102
+ assert all([len(ref_imgs[0]) == len(ref_imgs[i]) for i in range(len(ref_imgs))])
103
+
104
+ n_ref = len(ref_imgs[0])
105
+
106
+ img = bucket_images(img, self.resolution)
107
+ ref_imgs_new = []
108
+ for i in range(n_ref):
109
+ ref_imgs_i = [refs[i] for refs in ref_imgs]
110
+ ref_imgs_i = bucket_images(ref_imgs_i, self.resolution_ref)
111
+ ref_imgs_new.append(ref_imgs_i)
112
+
113
+ return {
114
+ "txt": txt,
115
+ "img": img,
116
+ "ref_imgs": ref_imgs_new,
117
+ }
118
+
119
+ if __name__ == '__main__':
120
+ import argparse
121
+ from pprint import pprint
122
+ parser = argparse.ArgumentParser()
123
+ # parser.add_argument("--json_file", type=str, required=True)
124
+ parser.add_argument("--json_file", type=str, default="datasets/fake_train_data.json")
125
+ args = parser.parse_args()
126
+ dataset = FluxPairedDatasetV2(args.json_file, 512)
127
+ dataloder = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn)
128
+
129
+ for i, data_dict in enumerate(dataloder):
130
+ pprint(i)
131
+ pprint(data_dict)
132
+ breakpoint()
uno/flux/math.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from einops import rearrange
18
+ from torch import Tensor
19
+
20
+
21
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
22
+ q, k = apply_rope(q, k, pe)
23
+
24
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
25
+ x = rearrange(x, "B H L D -> B L (H D)")
26
+
27
+ return x
28
+
29
+
30
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
31
+ assert dim % 2 == 0
32
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
33
+ omega = 1.0 / (theta**scale)
34
+ out = torch.einsum("...n,d->...nd", pos, omega)
35
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
36
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
37
+ return out.float()
38
+
39
+
40
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
41
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
42
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
43
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
44
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
45
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
uno/flux/model.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ from torch import Tensor, nn
20
+
21
+ from .modules.layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding
22
+
23
+
24
+ @dataclass
25
+ class FluxParams:
26
+ in_channels: int
27
+ vec_in_dim: int
28
+ context_in_dim: int
29
+ hidden_size: int
30
+ mlp_ratio: float
31
+ num_heads: int
32
+ depth: int
33
+ depth_single_blocks: int
34
+ axes_dim: list[int]
35
+ theta: int
36
+ qkv_bias: bool
37
+ guidance_embed: bool
38
+
39
+
40
+ class Flux(nn.Module):
41
+ """
42
+ Transformer model for flow matching on sequences.
43
+ """
44
+ _supports_gradient_checkpointing = True
45
+
46
+ def __init__(self, params: FluxParams):
47
+ super().__init__()
48
+
49
+ self.params = params
50
+ self.in_channels = params.in_channels
51
+ self.out_channels = self.in_channels
52
+ if params.hidden_size % params.num_heads != 0:
53
+ raise ValueError(
54
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
55
+ )
56
+ pe_dim = params.hidden_size // params.num_heads
57
+ if sum(params.axes_dim) != pe_dim:
58
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
59
+ self.hidden_size = params.hidden_size
60
+ self.num_heads = params.num_heads
61
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
62
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
63
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
64
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
65
+ self.guidance_in = (
66
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
67
+ )
68
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
69
+
70
+ self.double_blocks = nn.ModuleList(
71
+ [
72
+ DoubleStreamBlock(
73
+ self.hidden_size,
74
+ self.num_heads,
75
+ mlp_ratio=params.mlp_ratio,
76
+ qkv_bias=params.qkv_bias,
77
+ )
78
+ for _ in range(params.depth)
79
+ ]
80
+ )
81
+
82
+ self.single_blocks = nn.ModuleList(
83
+ [
84
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
85
+ for _ in range(params.depth_single_blocks)
86
+ ]
87
+ )
88
+
89
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
90
+ self.gradient_checkpointing = False
91
+
92
+ def _set_gradient_checkpointing(self, module, value=False):
93
+ if hasattr(module, "gradient_checkpointing"):
94
+ module.gradient_checkpointing = value
95
+
96
+ @property
97
+ def attn_processors(self):
98
+ # set recursively
99
+ processors = {} # type: dict[str, nn.Module]
100
+
101
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
102
+ if hasattr(module, "set_processor"):
103
+ processors[f"{name}.processor"] = module.processor
104
+
105
+ for sub_name, child in module.named_children():
106
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
107
+
108
+ return processors
109
+
110
+ for name, module in self.named_children():
111
+ fn_recursive_add_processors(name, module, processors)
112
+
113
+ return processors
114
+
115
+ def set_attn_processor(self, processor):
116
+ r"""
117
+ Sets the attention processor to use to compute attention.
118
+
119
+ Parameters:
120
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
121
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
122
+ for **all** `Attention` layers.
123
+
124
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
125
+ processor. This is strongly recommended when setting trainable attention processors.
126
+
127
+ """
128
+ count = len(self.attn_processors.keys())
129
+
130
+ if isinstance(processor, dict) and len(processor) != count:
131
+ raise ValueError(
132
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
133
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
134
+ )
135
+
136
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
137
+ if hasattr(module, "set_processor"):
138
+ if not isinstance(processor, dict):
139
+ module.set_processor(processor)
140
+ else:
141
+ module.set_processor(processor.pop(f"{name}.processor"))
142
+
143
+ for sub_name, child in module.named_children():
144
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
145
+
146
+ for name, module in self.named_children():
147
+ fn_recursive_attn_processor(name, module, processor)
148
+
149
+ def forward(
150
+ self,
151
+ img: Tensor,
152
+ img_ids: Tensor,
153
+ txt: Tensor,
154
+ txt_ids: Tensor,
155
+ timesteps: Tensor,
156
+ y: Tensor,
157
+ guidance: Tensor | None = None,
158
+ ref_img: Tensor | None = None,
159
+ ref_img_ids: Tensor | None = None,
160
+ ) -> Tensor:
161
+ if img.ndim != 3 or txt.ndim != 3:
162
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
163
+
164
+ # running on sequences img
165
+ img = self.img_in(img)
166
+ vec = self.time_in(timestep_embedding(timesteps, 256))
167
+ if self.params.guidance_embed:
168
+ if guidance is None:
169
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
170
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
171
+ vec = vec + self.vector_in(y)
172
+ txt = self.txt_in(txt)
173
+
174
+ ids = torch.cat((txt_ids, img_ids), dim=1)
175
+
176
+ # concat ref_img/img
177
+ img_end = img.shape[1]
178
+ if ref_img is not None:
179
+ if isinstance(ref_img, tuple) or isinstance(ref_img, list):
180
+ img_in = [img] + [self.img_in(ref) for ref in ref_img]
181
+ img_ids = [ids] + [ref_ids for ref_ids in ref_img_ids]
182
+ img = torch.cat(img_in, dim=1)
183
+ ids = torch.cat(img_ids, dim=1)
184
+ else:
185
+ img = torch.cat((img, self.img_in(ref_img)), dim=1)
186
+ ids = torch.cat((ids, ref_img_ids), dim=1)
187
+ pe = self.pe_embedder(ids)
188
+
189
+ for index_block, block in enumerate(self.double_blocks):
190
+ if self.training and self.gradient_checkpointing:
191
+ img, txt = torch.utils.checkpoint.checkpoint(
192
+ block,
193
+ img=img,
194
+ txt=txt,
195
+ vec=vec,
196
+ pe=pe,
197
+ use_reentrant=False,
198
+ )
199
+ else:
200
+ img, txt = block(
201
+ img=img,
202
+ txt=txt,
203
+ vec=vec,
204
+ pe=pe
205
+ )
206
+
207
+ img = torch.cat((txt, img), 1)
208
+ for block in self.single_blocks:
209
+ if self.training and self.gradient_checkpointing:
210
+ img = torch.utils.checkpoint.checkpoint(
211
+ block,
212
+ img, vec=vec, pe=pe,
213
+ use_reentrant=False
214
+ )
215
+ else:
216
+ img = block(img, vec=vec, pe=pe)
217
+ img = img[:, txt.shape[1] :, ...]
218
+ # index img
219
+ img = img[:, :img_end, ...]
220
+
221
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
222
+ return img
uno/flux/modules/autoencoder.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ from einops import rearrange
20
+ from torch import Tensor, nn
21
+
22
+
23
+ @dataclass
24
+ class AutoEncoderParams:
25
+ resolution: int
26
+ in_channels: int
27
+ ch: int
28
+ out_ch: int
29
+ ch_mult: list[int]
30
+ num_res_blocks: int
31
+ z_channels: int
32
+ scale_factor: float
33
+ shift_factor: float
34
+
35
+
36
+ def swish(x: Tensor) -> Tensor:
37
+ return x * torch.sigmoid(x)
38
+
39
+
40
+ class AttnBlock(nn.Module):
41
+ def __init__(self, in_channels: int):
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+
45
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
46
+
47
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
48
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
49
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
50
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
51
+
52
+ def attention(self, h_: Tensor) -> Tensor:
53
+ h_ = self.norm(h_)
54
+ q = self.q(h_)
55
+ k = self.k(h_)
56
+ v = self.v(h_)
57
+
58
+ b, c, h, w = q.shape
59
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
60
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
61
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
62
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
63
+
64
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
65
+
66
+ def forward(self, x: Tensor) -> Tensor:
67
+ return x + self.proj_out(self.attention(x))
68
+
69
+
70
+ class ResnetBlock(nn.Module):
71
+ def __init__(self, in_channels: int, out_channels: int):
72
+ super().__init__()
73
+ self.in_channels = in_channels
74
+ out_channels = in_channels if out_channels is None else out_channels
75
+ self.out_channels = out_channels
76
+
77
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
79
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
80
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
81
+ if self.in_channels != self.out_channels:
82
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
83
+
84
+ def forward(self, x):
85
+ h = x
86
+ h = self.norm1(h)
87
+ h = swish(h)
88
+ h = self.conv1(h)
89
+
90
+ h = self.norm2(h)
91
+ h = swish(h)
92
+ h = self.conv2(h)
93
+
94
+ if self.in_channels != self.out_channels:
95
+ x = self.nin_shortcut(x)
96
+
97
+ return x + h
98
+
99
+
100
+ class Downsample(nn.Module):
101
+ def __init__(self, in_channels: int):
102
+ super().__init__()
103
+ # no asymmetric padding in torch conv, must do it ourselves
104
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
105
+
106
+ def forward(self, x: Tensor):
107
+ pad = (0, 1, 0, 1)
108
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
109
+ x = self.conv(x)
110
+ return x
111
+
112
+
113
+ class Upsample(nn.Module):
114
+ def __init__(self, in_channels: int):
115
+ super().__init__()
116
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
117
+
118
+ def forward(self, x: Tensor):
119
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
120
+ x = self.conv(x)
121
+ return x
122
+
123
+
124
+ class Encoder(nn.Module):
125
+ def __init__(
126
+ self,
127
+ resolution: int,
128
+ in_channels: int,
129
+ ch: int,
130
+ ch_mult: list[int],
131
+ num_res_blocks: int,
132
+ z_channels: int,
133
+ ):
134
+ super().__init__()
135
+ self.ch = ch
136
+ self.num_resolutions = len(ch_mult)
137
+ self.num_res_blocks = num_res_blocks
138
+ self.resolution = resolution
139
+ self.in_channels = in_channels
140
+ # downsampling
141
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
142
+
143
+ curr_res = resolution
144
+ in_ch_mult = (1,) + tuple(ch_mult)
145
+ self.in_ch_mult = in_ch_mult
146
+ self.down = nn.ModuleList()
147
+ block_in = self.ch
148
+ for i_level in range(self.num_resolutions):
149
+ block = nn.ModuleList()
150
+ attn = nn.ModuleList()
151
+ block_in = ch * in_ch_mult[i_level]
152
+ block_out = ch * ch_mult[i_level]
153
+ for _ in range(self.num_res_blocks):
154
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
155
+ block_in = block_out
156
+ down = nn.Module()
157
+ down.block = block
158
+ down.attn = attn
159
+ if i_level != self.num_resolutions - 1:
160
+ down.downsample = Downsample(block_in)
161
+ curr_res = curr_res // 2
162
+ self.down.append(down)
163
+
164
+ # middle
165
+ self.mid = nn.Module()
166
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
167
+ self.mid.attn_1 = AttnBlock(block_in)
168
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
169
+
170
+ # end
171
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
172
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
173
+
174
+ def forward(self, x: Tensor) -> Tensor:
175
+ # downsampling
176
+ hs = [self.conv_in(x)]
177
+ for i_level in range(self.num_resolutions):
178
+ for i_block in range(self.num_res_blocks):
179
+ h = self.down[i_level].block[i_block](hs[-1])
180
+ if len(self.down[i_level].attn) > 0:
181
+ h = self.down[i_level].attn[i_block](h)
182
+ hs.append(h)
183
+ if i_level != self.num_resolutions - 1:
184
+ hs.append(self.down[i_level].downsample(hs[-1]))
185
+
186
+ # middle
187
+ h = hs[-1]
188
+ h = self.mid.block_1(h)
189
+ h = self.mid.attn_1(h)
190
+ h = self.mid.block_2(h)
191
+ # end
192
+ h = self.norm_out(h)
193
+ h = swish(h)
194
+ h = self.conv_out(h)
195
+ return h
196
+
197
+
198
+ class Decoder(nn.Module):
199
+ def __init__(
200
+ self,
201
+ ch: int,
202
+ out_ch: int,
203
+ ch_mult: list[int],
204
+ num_res_blocks: int,
205
+ in_channels: int,
206
+ resolution: int,
207
+ z_channels: int,
208
+ ):
209
+ super().__init__()
210
+ self.ch = ch
211
+ self.num_resolutions = len(ch_mult)
212
+ self.num_res_blocks = num_res_blocks
213
+ self.resolution = resolution
214
+ self.in_channels = in_channels
215
+ self.ffactor = 2 ** (self.num_resolutions - 1)
216
+
217
+ # compute in_ch_mult, block_in and curr_res at lowest res
218
+ block_in = ch * ch_mult[self.num_resolutions - 1]
219
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
220
+ self.z_shape = (1, z_channels, curr_res, curr_res)
221
+
222
+ # z to block_in
223
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
224
+
225
+ # middle
226
+ self.mid = nn.Module()
227
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
228
+ self.mid.attn_1 = AttnBlock(block_in)
229
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
230
+
231
+ # upsampling
232
+ self.up = nn.ModuleList()
233
+ for i_level in reversed(range(self.num_resolutions)):
234
+ block = nn.ModuleList()
235
+ attn = nn.ModuleList()
236
+ block_out = ch * ch_mult[i_level]
237
+ for _ in range(self.num_res_blocks + 1):
238
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
239
+ block_in = block_out
240
+ up = nn.Module()
241
+ up.block = block
242
+ up.attn = attn
243
+ if i_level != 0:
244
+ up.upsample = Upsample(block_in)
245
+ curr_res = curr_res * 2
246
+ self.up.insert(0, up) # prepend to get consistent order
247
+
248
+ # end
249
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
250
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
251
+
252
+ def forward(self, z: Tensor) -> Tensor:
253
+ # z to block_in
254
+ h = self.conv_in(z)
255
+
256
+ # middle
257
+ h = self.mid.block_1(h)
258
+ h = self.mid.attn_1(h)
259
+ h = self.mid.block_2(h)
260
+
261
+ # upsampling
262
+ for i_level in reversed(range(self.num_resolutions)):
263
+ for i_block in range(self.num_res_blocks + 1):
264
+ h = self.up[i_level].block[i_block](h)
265
+ if len(self.up[i_level].attn) > 0:
266
+ h = self.up[i_level].attn[i_block](h)
267
+ if i_level != 0:
268
+ h = self.up[i_level].upsample(h)
269
+
270
+ # end
271
+ h = self.norm_out(h)
272
+ h = swish(h)
273
+ h = self.conv_out(h)
274
+ return h
275
+
276
+
277
+ class DiagonalGaussian(nn.Module):
278
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
279
+ super().__init__()
280
+ self.sample = sample
281
+ self.chunk_dim = chunk_dim
282
+
283
+ def forward(self, z: Tensor) -> Tensor:
284
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
285
+ if self.sample:
286
+ std = torch.exp(0.5 * logvar)
287
+ return mean + std * torch.randn_like(mean)
288
+ else:
289
+ return mean
290
+
291
+
292
+ class AutoEncoder(nn.Module):
293
+ def __init__(self, params: AutoEncoderParams):
294
+ super().__init__()
295
+ self.encoder = Encoder(
296
+ resolution=params.resolution,
297
+ in_channels=params.in_channels,
298
+ ch=params.ch,
299
+ ch_mult=params.ch_mult,
300
+ num_res_blocks=params.num_res_blocks,
301
+ z_channels=params.z_channels,
302
+ )
303
+ self.decoder = Decoder(
304
+ resolution=params.resolution,
305
+ in_channels=params.in_channels,
306
+ ch=params.ch,
307
+ out_ch=params.out_ch,
308
+ ch_mult=params.ch_mult,
309
+ num_res_blocks=params.num_res_blocks,
310
+ z_channels=params.z_channels,
311
+ )
312
+ self.reg = DiagonalGaussian()
313
+
314
+ self.scale_factor = params.scale_factor
315
+ self.shift_factor = params.shift_factor
316
+
317
+ def encode(self, x: Tensor) -> Tensor:
318
+ z = self.reg(self.encoder(x))
319
+ z = self.scale_factor * (z - self.shift_factor)
320
+ return z
321
+
322
+ def decode(self, z: Tensor) -> Tensor:
323
+ z = z / self.scale_factor + self.shift_factor
324
+ return self.decoder(z)
325
+
326
+ def forward(self, x: Tensor) -> Tensor:
327
+ return self.decode(self.encode(x))
uno/flux/modules/conditioner.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from torch import Tensor, nn
17
+ from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
18
+ T5Tokenizer)
19
+
20
+
21
+ class HFEmbedder(nn.Module):
22
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
23
+ super().__init__()
24
+ self.is_clip = version.startswith("openai")
25
+ self.max_length = max_length
26
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
27
+
28
+ if self.is_clip:
29
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
30
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
31
+ else:
32
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
33
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
34
+
35
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
36
+
37
+ def forward(self, text: list[str]) -> Tensor:
38
+ batch_encoding = self.tokenizer(
39
+ text,
40
+ truncation=True,
41
+ max_length=self.max_length,
42
+ return_length=False,
43
+ return_overflowing_tokens=False,
44
+ padding="max_length",
45
+ return_tensors="pt",
46
+ )
47
+
48
+ outputs = self.hf_module(
49
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
50
+ attention_mask=None,
51
+ output_hidden_states=False,
52
+ )
53
+ return outputs[self.output_key]
uno/flux/modules/layers.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from dataclasses import dataclass
18
+
19
+ import torch
20
+ from einops import rearrange
21
+ from torch import Tensor, nn
22
+
23
+ from ..math import attention, rope
24
+ import torch.nn.functional as F
25
+
26
+ class EmbedND(nn.Module):
27
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
28
+ super().__init__()
29
+ self.dim = dim
30
+ self.theta = theta
31
+ self.axes_dim = axes_dim
32
+
33
+ def forward(self, ids: Tensor) -> Tensor:
34
+ n_axes = ids.shape[-1]
35
+ emb = torch.cat(
36
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
37
+ dim=-3,
38
+ )
39
+
40
+ return emb.unsqueeze(1)
41
+
42
+
43
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
44
+ """
45
+ Create sinusoidal timestep embeddings.
46
+ :param t: a 1-D Tensor of N indices, one per batch element.
47
+ These may be fractional.
48
+ :param dim: the dimension of the output.
49
+ :param max_period: controls the minimum frequency of the embeddings.
50
+ :return: an (N, D) Tensor of positional embeddings.
51
+ """
52
+ t = time_factor * t
53
+ half = dim // 2
54
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
55
+ t.device
56
+ )
57
+
58
+ args = t[:, None].float() * freqs[None]
59
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
60
+ if dim % 2:
61
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
62
+ if torch.is_floating_point(t):
63
+ embedding = embedding.to(t)
64
+ return embedding
65
+
66
+
67
+ class MLPEmbedder(nn.Module):
68
+ def __init__(self, in_dim: int, hidden_dim: int):
69
+ super().__init__()
70
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
71
+ self.silu = nn.SiLU()
72
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
73
+
74
+ def forward(self, x: Tensor) -> Tensor:
75
+ return self.out_layer(self.silu(self.in_layer(x)))
76
+
77
+
78
+ class RMSNorm(torch.nn.Module):
79
+ def __init__(self, dim: int):
80
+ super().__init__()
81
+ self.scale = nn.Parameter(torch.ones(dim))
82
+
83
+ def forward(self, x: Tensor):
84
+ x_dtype = x.dtype
85
+ x = x.float()
86
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
87
+ return (x * rrms).to(dtype=x_dtype) * self.scale
88
+
89
+
90
+ class QKNorm(torch.nn.Module):
91
+ def __init__(self, dim: int):
92
+ super().__init__()
93
+ self.query_norm = RMSNorm(dim)
94
+ self.key_norm = RMSNorm(dim)
95
+
96
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
97
+ q = self.query_norm(q)
98
+ k = self.key_norm(k)
99
+ return q.to(v), k.to(v)
100
+
101
+ class LoRALinearLayer(nn.Module):
102
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
103
+ super().__init__()
104
+
105
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
106
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
107
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
108
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
109
+ self.network_alpha = network_alpha
110
+ self.rank = rank
111
+
112
+ nn.init.normal_(self.down.weight, std=1 / rank)
113
+ nn.init.zeros_(self.up.weight)
114
+
115
+ def forward(self, hidden_states):
116
+ orig_dtype = hidden_states.dtype
117
+ dtype = self.down.weight.dtype
118
+
119
+ down_hidden_states = self.down(hidden_states.to(dtype))
120
+ up_hidden_states = self.up(down_hidden_states)
121
+
122
+ if self.network_alpha is not None:
123
+ up_hidden_states *= self.network_alpha / self.rank
124
+
125
+ return up_hidden_states.to(orig_dtype)
126
+
127
+ class FLuxSelfAttnProcessor:
128
+ def __call__(self, attn, x, pe, **attention_kwargs):
129
+ qkv = attn.qkv(x)
130
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
131
+ q, k = attn.norm(q, k, v)
132
+ x = attention(q, k, v, pe=pe)
133
+ x = attn.proj(x)
134
+ return x
135
+
136
+ class LoraFluxAttnProcessor(nn.Module):
137
+
138
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
139
+ super().__init__()
140
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
141
+ self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
142
+ self.lora_weight = lora_weight
143
+
144
+
145
+ def __call__(self, attn, x, pe, **attention_kwargs):
146
+ qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
147
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
148
+ q, k = attn.norm(q, k, v)
149
+ x = attention(q, k, v, pe=pe)
150
+ x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
151
+ return x
152
+
153
+ class SelfAttention(nn.Module):
154
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
155
+ super().__init__()
156
+ self.num_heads = num_heads
157
+ head_dim = dim // num_heads
158
+
159
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
160
+ self.norm = QKNorm(head_dim)
161
+ self.proj = nn.Linear(dim, dim)
162
+ def forward():
163
+ pass
164
+
165
+
166
+ @dataclass
167
+ class ModulationOut:
168
+ shift: Tensor
169
+ scale: Tensor
170
+ gate: Tensor
171
+
172
+
173
+ class Modulation(nn.Module):
174
+ def __init__(self, dim: int, double: bool):
175
+ super().__init__()
176
+ self.is_double = double
177
+ self.multiplier = 6 if double else 3
178
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
179
+
180
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
181
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
182
+
183
+ return (
184
+ ModulationOut(*out[:3]),
185
+ ModulationOut(*out[3:]) if self.is_double else None,
186
+ )
187
+
188
+ class DoubleStreamBlockLoraProcessor(nn.Module):
189
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
190
+ super().__init__()
191
+ self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
192
+ self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
193
+ self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
194
+ self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
195
+ self.lora_weight = lora_weight
196
+
197
+ def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
198
+ img_mod1, img_mod2 = attn.img_mod(vec)
199
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
200
+
201
+ # prepare image for attention
202
+ img_modulated = attn.img_norm1(img)
203
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
204
+ img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
205
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
206
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
207
+
208
+ # prepare txt for attention
209
+ txt_modulated = attn.txt_norm1(txt)
210
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
211
+ txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
212
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
213
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
214
+
215
+ # run actual attention
216
+ q = torch.cat((txt_q, img_q), dim=2)
217
+ k = torch.cat((txt_k, img_k), dim=2)
218
+ v = torch.cat((txt_v, img_v), dim=2)
219
+
220
+ attn1 = attention(q, k, v, pe=pe)
221
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
222
+
223
+ # calculate the img bloks
224
+ img = img + img_mod1.gate * (attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight)
225
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
226
+
227
+ # calculate the txt bloks
228
+ txt = txt + txt_mod1.gate * (attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight)
229
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
230
+ return img, txt
231
+
232
+ class DoubleStreamBlockProcessor:
233
+ def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
234
+ img_mod1, img_mod2 = attn.img_mod(vec)
235
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
236
+
237
+ # prepare image for attention
238
+ img_modulated = attn.img_norm1(img)
239
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
240
+ img_qkv = attn.img_attn.qkv(img_modulated)
241
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
242
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
243
+
244
+ # prepare txt for attention
245
+ txt_modulated = attn.txt_norm1(txt)
246
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
247
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
248
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
249
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
250
+
251
+ # run actual attention
252
+ q = torch.cat((txt_q, img_q), dim=2)
253
+ k = torch.cat((txt_k, img_k), dim=2)
254
+ v = torch.cat((txt_v, img_v), dim=2)
255
+
256
+ attn1 = attention(q, k, v, pe=pe)
257
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
258
+
259
+ # calculate the img bloks
260
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
261
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
262
+
263
+ # calculate the txt bloks
264
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
265
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
266
+ return img, txt
267
+
268
+ class DoubleStreamBlock(nn.Module):
269
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
270
+ super().__init__()
271
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
272
+ self.num_heads = num_heads
273
+ self.hidden_size = hidden_size
274
+ self.head_dim = hidden_size // num_heads
275
+
276
+ self.img_mod = Modulation(hidden_size, double=True)
277
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
278
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
279
+
280
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
281
+ self.img_mlp = nn.Sequential(
282
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
283
+ nn.GELU(approximate="tanh"),
284
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
285
+ )
286
+
287
+ self.txt_mod = Modulation(hidden_size, double=True)
288
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
289
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
290
+
291
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
292
+ self.txt_mlp = nn.Sequential(
293
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
294
+ nn.GELU(approximate="tanh"),
295
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
296
+ )
297
+ processor = DoubleStreamBlockProcessor()
298
+ self.set_processor(processor)
299
+
300
+ def set_processor(self, processor) -> None:
301
+ self.processor = processor
302
+
303
+ def get_processor(self):
304
+ return self.processor
305
+
306
+ def forward(
307
+ self,
308
+ img: Tensor,
309
+ txt: Tensor,
310
+ vec: Tensor,
311
+ pe: Tensor,
312
+ image_proj: Tensor = None,
313
+ ip_scale: float =1.0,
314
+ ) -> tuple[Tensor, Tensor]:
315
+ if image_proj is None:
316
+ return self.processor(self, img, txt, vec, pe)
317
+ else:
318
+ return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
319
+
320
+
321
+ class SingleStreamBlockLoraProcessor(nn.Module):
322
+ def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1):
323
+ super().__init__()
324
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
325
+ self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha)
326
+ self.lora_weight = lora_weight
327
+
328
+ def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
329
+
330
+ mod, _ = attn.modulation(vec)
331
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
332
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
333
+ qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
334
+
335
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
336
+ q, k = attn.norm(q, k, v)
337
+
338
+ # compute attention
339
+ attn_1 = attention(q, k, v, pe=pe)
340
+
341
+ # compute activation in mlp stream, cat again and run second linear layer
342
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
343
+ output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight
344
+ output = x + mod.gate * output
345
+ return output
346
+
347
+
348
+ class SingleStreamBlockProcessor:
349
+ def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, **attention_kwargs) -> Tensor:
350
+
351
+ mod, _ = attn.modulation(vec)
352
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
353
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
354
+
355
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
356
+ q, k = attn.norm(q, k, v)
357
+
358
+ # compute attention
359
+ attn_1 = attention(q, k, v, pe=pe)
360
+
361
+ # compute activation in mlp stream, cat again and run second linear layer
362
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
363
+ output = x + mod.gate * output
364
+ return output
365
+
366
+ class SingleStreamBlock(nn.Module):
367
+ """
368
+ A DiT block with parallel linear layers as described in
369
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
370
+ """
371
+
372
+ def __init__(
373
+ self,
374
+ hidden_size: int,
375
+ num_heads: int,
376
+ mlp_ratio: float = 4.0,
377
+ qk_scale: float | None = None,
378
+ ):
379
+ super().__init__()
380
+ self.hidden_dim = hidden_size
381
+ self.num_heads = num_heads
382
+ self.head_dim = hidden_size // num_heads
383
+ self.scale = qk_scale or self.head_dim**-0.5
384
+
385
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
386
+ # qkv and mlp_in
387
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
388
+ # proj and mlp_out
389
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
390
+
391
+ self.norm = QKNorm(self.head_dim)
392
+
393
+ self.hidden_size = hidden_size
394
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
395
+
396
+ self.mlp_act = nn.GELU(approximate="tanh")
397
+ self.modulation = Modulation(hidden_size, double=False)
398
+
399
+ processor = SingleStreamBlockProcessor()
400
+ self.set_processor(processor)
401
+
402
+
403
+ def set_processor(self, processor) -> None:
404
+ self.processor = processor
405
+
406
+ def get_processor(self):
407
+ return self.processor
408
+
409
+ def forward(
410
+ self,
411
+ x: Tensor,
412
+ vec: Tensor,
413
+ pe: Tensor,
414
+ image_proj: Tensor | None = None,
415
+ ip_scale: float = 1.0,
416
+ ) -> Tensor:
417
+ if image_proj is None:
418
+ return self.processor(self, x, vec, pe)
419
+ else:
420
+ return self.processor(self, x, vec, pe, image_proj, ip_scale)
421
+
422
+
423
+
424
+ class LastLayer(nn.Module):
425
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
426
+ super().__init__()
427
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
428
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
429
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
430
+
431
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
432
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
433
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
434
+ x = self.linear(x)
435
+ return x
uno/flux/pipeline.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from typing import Literal
18
+
19
+ import torch
20
+ from einops import rearrange
21
+ from PIL import ExifTags, Image
22
+ import torchvision.transforms.functional as TVF
23
+
24
+ from uno.flux.modules.layers import (
25
+ DoubleStreamBlockLoraProcessor,
26
+ DoubleStreamBlockProcessor,
27
+ SingleStreamBlockLoraProcessor,
28
+ SingleStreamBlockProcessor,
29
+ )
30
+ from uno.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack
31
+ from uno.flux.util import (
32
+ get_lora_rank,
33
+ load_ae,
34
+ load_checkpoint,
35
+ load_clip,
36
+ load_flow_model,
37
+ load_flow_model_only_lora,
38
+ load_flow_model_quintized,
39
+ load_t5,
40
+ )
41
+
42
+
43
+ def find_nearest_scale(image_h, image_w, predefined_scales):
44
+ """
45
+ 根据图片的高度和宽度,找到最近的预定义尺度。
46
+
47
+ :param image_h: 图片的高度
48
+ :param image_w: 图片的宽度
49
+ :param predefined_scales: 预定义尺度列表 [(h1, w1), (h2, w2), ...]
50
+ :return: 最近的预定义尺度 (h, w)
51
+ """
52
+ # 计算输入图片的长宽比
53
+ image_ratio = image_h / image_w
54
+
55
+ # 初始化变量以存储最小差异和最近的尺度
56
+ min_diff = float('inf')
57
+ nearest_scale = None
58
+
59
+ # 遍历所有预定义尺度,找到与输入图片长宽比最接近的尺度
60
+ for scale_h, scale_w in predefined_scales:
61
+ predefined_ratio = scale_h / scale_w
62
+ diff = abs(predefined_ratio - image_ratio)
63
+
64
+ if diff < min_diff:
65
+ min_diff = diff
66
+ nearest_scale = (scale_h, scale_w)
67
+
68
+ return nearest_scale
69
+
70
+ def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
71
+ # 获取原始图像的宽度和高度
72
+ image_w, image_h = raw_image.size
73
+
74
+ # 计算长边和短边
75
+ if image_w >= image_h:
76
+ new_w = long_size
77
+ new_h = int((long_size / image_w) * image_h)
78
+ else:
79
+ new_h = long_size
80
+ new_w = int((long_size / image_h) * image_w)
81
+
82
+ # 按新的宽高进行等比例缩放
83
+ raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
84
+ target_w = new_w // 16 * 16
85
+ target_h = new_h // 16 * 16
86
+
87
+ # 计算裁剪的起始坐标以实现中心裁剪
88
+ left = (new_w - target_w) // 2
89
+ top = (new_h - target_h) // 2
90
+ right = left + target_w
91
+ bottom = top + target_h
92
+
93
+ # 进行中心裁剪
94
+ raw_image = raw_image.crop((left, top, right, bottom))
95
+
96
+ # 转换为 RGB 模式
97
+ raw_image = raw_image.convert("RGB")
98
+ return raw_image
99
+
100
+ class UNOPipeline:
101
+ def __init__(
102
+ self,
103
+ model_type: str,
104
+ device: torch.device,
105
+ offload: bool = False,
106
+ only_lora: bool = False,
107
+ lora_rank: int = 16
108
+ ):
109
+ self.device = device
110
+ self.offload = offload
111
+ self.model_type = model_type
112
+
113
+ self.clip = load_clip(self.device)
114
+ self.t5 = load_t5(self.device, max_length=512)
115
+ self.ae = load_ae(model_type, device="cpu" if offload else self.device)
116
+ if "fp8" in model_type:
117
+ self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device)
118
+ elif only_lora:
119
+ self.model = load_flow_model_only_lora(
120
+ model_type, device="cpu" if offload else self.device, lora_rank=lora_rank
121
+ )
122
+ else:
123
+ self.model = load_flow_model(model_type, device="cpu" if offload else self.device)
124
+
125
+
126
+ def load_ckpt(self, ckpt_path):
127
+ if ckpt_path is not None:
128
+ from safetensors.torch import load_file as load_sft
129
+ print("Loading checkpoint to replace old keys")
130
+ # load_sft doesn't support torch.device
131
+ if ckpt_path.endswith('safetensors'):
132
+ sd = load_sft(ckpt_path, device='cpu')
133
+ missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
134
+ else:
135
+ dit_state = torch.load(ckpt_path, map_location='cpu')
136
+ sd = {}
137
+ for k in dit_state.keys():
138
+ sd[k.replace('module.','')] = dit_state[k]
139
+ missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
140
+ self.model.to(str(self.device))
141
+ print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}")
142
+
143
+ def set_lora(self, local_path: str = None, repo_id: str = None,
144
+ name: str = None, lora_weight: int = 0.7):
145
+ checkpoint = load_checkpoint(local_path, repo_id, name)
146
+ self.update_model_with_lora(checkpoint, lora_weight)
147
+
148
+ def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7):
149
+ checkpoint = load_checkpoint(
150
+ None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
151
+ )
152
+ self.update_model_with_lora(checkpoint, lora_weight)
153
+
154
+ def update_model_with_lora(self, checkpoint, lora_weight):
155
+ rank = get_lora_rank(checkpoint)
156
+ lora_attn_procs = {}
157
+
158
+ for name, _ in self.model.attn_processors.items():
159
+ lora_state_dict = {}
160
+ for k in checkpoint.keys():
161
+ if name in k:
162
+ lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight
163
+
164
+ if len(lora_state_dict):
165
+ if name.startswith("single_blocks"):
166
+ lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank)
167
+ else:
168
+ lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
169
+ lora_attn_procs[name].load_state_dict(lora_state_dict)
170
+ lora_attn_procs[name].to(self.device)
171
+ else:
172
+ if name.startswith("single_blocks"):
173
+ lora_attn_procs[name] = SingleStreamBlockProcessor()
174
+ else:
175
+ lora_attn_procs[name] = DoubleStreamBlockProcessor()
176
+
177
+ self.model.set_attn_processor(lora_attn_procs)
178
+
179
+
180
+ def __call__(
181
+ self,
182
+ prompt: str,
183
+ width: int = 512,
184
+ height: int = 512,
185
+ guidance: float = 4,
186
+ num_steps: int = 50,
187
+ seed: int = 123456789,
188
+ **kwargs
189
+ ):
190
+ width = 16 * (width // 16)
191
+ height = 16 * (height // 16)
192
+
193
+ return self.forward(
194
+ prompt,
195
+ width,
196
+ height,
197
+ guidance,
198
+ num_steps,
199
+ seed,
200
+ **kwargs
201
+ )
202
+
203
+ @torch.inference_mode()
204
+ def gradio_generate(
205
+ self,
206
+ prompt: str,
207
+ width: int,
208
+ height: int,
209
+ guidance: float,
210
+ num_steps: int,
211
+ seed: int,
212
+ image_prompt1: Image.Image,
213
+ image_prompt2: Image.Image,
214
+ image_prompt3: Image.Image,
215
+ image_prompt4: Image.Image,
216
+ ):
217
+ ref_imgs = [image_prompt1, image_prompt2, image_prompt3, image_prompt4]
218
+ ref_imgs = [img for img in ref_imgs if isinstance(img, Image.Image)]
219
+ ref_long_side = 512 if len(ref_imgs) <= 1 else 320
220
+ ref_imgs = [preprocess_ref(img, ref_long_side) for img in ref_imgs]
221
+
222
+ seed = seed if seed != -1 else torch.randint(0, 10 ** 8, (1,)).item()
223
+
224
+ img = self(prompt=prompt, width=width, height=height, guidance=guidance,
225
+ num_steps=num_steps, seed=seed, ref_imgs=ref_imgs)
226
+
227
+ filename = f"output/gradio/{seed}_{prompt[:20]}.png"
228
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
229
+ exif_data = Image.Exif()
230
+ exif_data[ExifTags.Base.Make] = "UNO"
231
+ exif_data[ExifTags.Base.Model] = self.model_type
232
+ info = f"{prompt=}, {seed=}, {width=}, {height=}, {guidance=}, {num_steps=}"
233
+ exif_data[ExifTags.Base.ImageDescription] = info
234
+ img.save(filename, format="png", exif=exif_data)
235
+ return img, filename
236
+
237
+ @torch.inference_mode
238
+ def forward(
239
+ self,
240
+ prompt: str,
241
+ width: int,
242
+ height: int,
243
+ guidance: float,
244
+ num_steps: int,
245
+ seed: int,
246
+ ref_imgs: list[Image.Image] | None = None,
247
+ pe: Literal['d', 'h', 'w', 'o'] = 'd',
248
+ ):
249
+ x = get_noise(
250
+ 1, height, width, device=self.device,
251
+ dtype=torch.bfloat16, seed=seed
252
+ )
253
+ timesteps = get_schedule(
254
+ num_steps,
255
+ (width // 8) * (height // 8) // (16 * 16),
256
+ shift=True,
257
+ )
258
+ if self.offload:
259
+ self.ae.encoder = self.ae.encoder.to(self.device)
260
+ x_1_refs = [
261
+ self.ae.encode(
262
+ (TVF.to_tensor(ref_img) * 2.0 - 1.0)
263
+ .unsqueeze(0).to(self.device, torch.float32)
264
+ ).to(torch.bfloat16)
265
+ for ref_img in ref_imgs
266
+ ]
267
+
268
+ if self.offload:
269
+ self.ae.encoder = self.offload_model_to_cpu(self.ae.encoder)
270
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
271
+ inp_cond = prepare_multi_ip(
272
+ t5=self.t5, clip=self.clip,
273
+ img=x,
274
+ prompt=prompt, ref_imgs=x_1_refs, pe=pe
275
+ )
276
+
277
+ if self.offload:
278
+ self.offload_model_to_cpu(self.t5, self.clip)
279
+ self.model = self.model.to(self.device)
280
+
281
+ x = denoise(
282
+ self.model,
283
+ **inp_cond,
284
+ timesteps=timesteps,
285
+ guidance=guidance,
286
+ )
287
+
288
+ if self.offload:
289
+ self.offload_model_to_cpu(self.model)
290
+ self.ae.decoder.to(x.device)
291
+ x = unpack(x.float(), height, width)
292
+ x = self.ae.decode(x)
293
+ self.offload_model_to_cpu(self.ae.decoder)
294
+
295
+ x1 = x.clamp(-1, 1)
296
+ x1 = rearrange(x1[-1], "c h w -> h w c")
297
+ output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
298
+ return output_img
299
+
300
+ def offload_model_to_cpu(self, *models):
301
+ if not self.offload: return
302
+ for model in models:
303
+ model.cpu()
304
+ torch.cuda.empty_cache()
uno/flux/sampling.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Literal
18
+
19
+ import torch
20
+ from einops import rearrange, repeat
21
+ from torch import Tensor
22
+ from tqdm import tqdm
23
+
24
+ from .model import Flux
25
+ from .modules.conditioner import HFEmbedder
26
+
27
+
28
+ def get_noise(
29
+ num_samples: int,
30
+ height: int,
31
+ width: int,
32
+ device: torch.device,
33
+ dtype: torch.dtype,
34
+ seed: int,
35
+ ):
36
+ return torch.randn(
37
+ num_samples,
38
+ 16,
39
+ # allow for packing
40
+ 2 * math.ceil(height / 16),
41
+ 2 * math.ceil(width / 16),
42
+ device=device,
43
+ dtype=dtype,
44
+ generator=torch.Generator(device=device).manual_seed(seed),
45
+ )
46
+
47
+
48
+ def prepare(
49
+ t5: HFEmbedder,
50
+ clip: HFEmbedder,
51
+ img: Tensor,
52
+ prompt: str | list[str],
53
+ ref_img: None | Tensor=None,
54
+ pe: Literal['d', 'h', 'w', 'o'] ='d'
55
+ ) -> dict[str, Tensor]:
56
+ assert pe in ['d', 'h', 'w', 'o']
57
+ bs, c, h, w = img.shape
58
+ if bs == 1 and not isinstance(prompt, str):
59
+ bs = len(prompt)
60
+
61
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
62
+ if img.shape[0] == 1 and bs > 1:
63
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
64
+
65
+ img_ids = torch.zeros(h // 2, w // 2, 3)
66
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
67
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
68
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
69
+
70
+ if ref_img is not None:
71
+ _, _, ref_h, ref_w = ref_img.shape
72
+ ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
73
+ if ref_img.shape[0] == 1 and bs > 1:
74
+ ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
75
+ ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
76
+ # img id分别在宽高偏移各自最大值
77
+ h_offset = h // 2 if pe in {'d', 'h'} else 0
78
+ w_offset = w // 2 if pe in {'d', 'w'} else 0
79
+ ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + h_offset
80
+ ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + w_offset
81
+ ref_img_ids = repeat(ref_img_ids, "h w c -> b (h w) c", b=bs)
82
+
83
+ if isinstance(prompt, str):
84
+ prompt = [prompt]
85
+ txt = t5(prompt)
86
+ if txt.shape[0] == 1 and bs > 1:
87
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
88
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
89
+
90
+ vec = clip(prompt)
91
+ if vec.shape[0] == 1 and bs > 1:
92
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
93
+
94
+ if ref_img is not None:
95
+ return {
96
+ "img": img,
97
+ "img_ids": img_ids.to(img.device),
98
+ "ref_img": ref_img,
99
+ "ref_img_ids": ref_img_ids.to(img.device),
100
+ "txt": txt.to(img.device),
101
+ "txt_ids": txt_ids.to(img.device),
102
+ "vec": vec.to(img.device),
103
+ }
104
+ else:
105
+ return {
106
+ "img": img,
107
+ "img_ids": img_ids.to(img.device),
108
+ "txt": txt.to(img.device),
109
+ "txt_ids": txt_ids.to(img.device),
110
+ "vec": vec.to(img.device),
111
+ }
112
+
113
+ def prepare_multi_ip(
114
+ t5: HFEmbedder,
115
+ clip: HFEmbedder,
116
+ img: Tensor,
117
+ prompt: str | list[str],
118
+ ref_imgs: list[Tensor] | None = None,
119
+ pe: Literal['d', 'h', 'w', 'o'] = 'd'
120
+ ) -> dict[str, Tensor]:
121
+ assert pe in ['d', 'h', 'w', 'o']
122
+ bs, c, h, w = img.shape
123
+ if bs == 1 and not isinstance(prompt, str):
124
+ bs = len(prompt)
125
+
126
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
127
+ if img.shape[0] == 1 and bs > 1:
128
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
129
+
130
+ img_ids = torch.zeros(h // 2, w // 2, 3)
131
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
132
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
133
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
134
+
135
+ ref_img_ids = []
136
+ ref_imgs_list = []
137
+ pe_shift_w, pe_shift_h = w // 2, h // 2
138
+ for ref_img in ref_imgs:
139
+ _, _, ref_h1, ref_w1 = ref_img.shape
140
+ ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
141
+ if ref_img.shape[0] == 1 and bs > 1:
142
+ ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
143
+ ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3)
144
+ # img id分别���宽高偏移各自最大值
145
+ h_offset = pe_shift_h if pe in {'d', 'h'} else 0
146
+ w_offset = pe_shift_w if pe in {'d', 'w'} else 0
147
+ ref_img_ids1[..., 1] = ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset
148
+ ref_img_ids1[..., 2] = ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset
149
+ ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs)
150
+ ref_img_ids.append(ref_img_ids1)
151
+ ref_imgs_list.append(ref_img)
152
+
153
+ # 更新pe shift
154
+ pe_shift_h += ref_h1 // 2
155
+ pe_shift_w += ref_w1 // 2
156
+
157
+ if isinstance(prompt, str):
158
+ prompt = [prompt]
159
+ txt = t5(prompt)
160
+ if txt.shape[0] == 1 and bs > 1:
161
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
162
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
163
+
164
+ vec = clip(prompt)
165
+ if vec.shape[0] == 1 and bs > 1:
166
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
167
+
168
+ return {
169
+ "img": img,
170
+ "img_ids": img_ids.to(img.device),
171
+ "ref_img": tuple(ref_imgs_list),
172
+ "ref_img_ids": [ref_img_id.to(img.device) for ref_img_id in ref_img_ids],
173
+ "txt": txt.to(img.device),
174
+ "txt_ids": txt_ids.to(img.device),
175
+ "vec": vec.to(img.device),
176
+ }
177
+
178
+
179
+ def time_shift(mu: float, sigma: float, t: Tensor):
180
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
181
+
182
+
183
+ def get_lin_function(
184
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
185
+ ):
186
+ m = (y2 - y1) / (x2 - x1)
187
+ b = y1 - m * x1
188
+ return lambda x: m * x + b
189
+
190
+
191
+ def get_schedule(
192
+ num_steps: int,
193
+ image_seq_len: int,
194
+ base_shift: float = 0.5,
195
+ max_shift: float = 1.15,
196
+ shift: bool = True,
197
+ ) -> list[float]:
198
+ # extra step for zero
199
+ timesteps = torch.linspace(1, 0, num_steps + 1)
200
+
201
+ # shifting the schedule to favor high timesteps for higher signal images
202
+ if shift:
203
+ # eastimate mu based on linear estimation between two points
204
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
205
+ timesteps = time_shift(mu, 1.0, timesteps)
206
+
207
+ return timesteps.tolist()
208
+
209
+
210
+ def denoise(
211
+ model: Flux,
212
+ # model input
213
+ img: Tensor,
214
+ img_ids: Tensor,
215
+ txt: Tensor,
216
+ txt_ids: Tensor,
217
+ vec: Tensor,
218
+ # sampling parameters
219
+ timesteps: list[float],
220
+ guidance: float = 4.0,
221
+ ref_img: Tensor=None,
222
+ ref_img_ids: Tensor=None,
223
+ ):
224
+ i = 0
225
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
226
+ for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1):
227
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
228
+ pred = model(
229
+ img=img,
230
+ img_ids=img_ids,
231
+ ref_img=ref_img,
232
+ ref_img_ids=ref_img_ids,
233
+ txt=txt,
234
+ txt_ids=txt_ids,
235
+ y=vec,
236
+ timesteps=t_vec,
237
+ guidance=guidance_vec
238
+ )
239
+ img = img + (t_prev - t_curr) * pred
240
+ i += 1
241
+ return img
242
+
243
+
244
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
245
+ return rearrange(
246
+ x,
247
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
248
+ h=math.ceil(height / 16),
249
+ w=math.ceil(width / 16),
250
+ ph=2,
251
+ pw=2,
252
+ )
uno/flux/util.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from dataclasses import dataclass
18
+
19
+ import torch
20
+ import json
21
+ import numpy as np
22
+ from huggingface_hub import hf_hub_download
23
+ from safetensors import safe_open
24
+ from safetensors.torch import load_file as load_sft
25
+
26
+ from .model import Flux, FluxParams
27
+ from .modules.autoencoder import AutoEncoder, AutoEncoderParams
28
+ from .modules.conditioner import HFEmbedder
29
+
30
+ import re
31
+ from uno.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor
32
+ def load_model(ckpt, device='cpu'):
33
+ if ckpt.endswith('safetensors'):
34
+ from safetensors import safe_open
35
+ pl_sd = {}
36
+ with safe_open(ckpt, framework="pt", device=device) as f:
37
+ for k in f.keys():
38
+ pl_sd[k] = f.get_tensor(k)
39
+ else:
40
+ pl_sd = torch.load(ckpt, map_location=device)
41
+ return pl_sd
42
+
43
+ def load_safetensors(path):
44
+ tensors = {}
45
+ with safe_open(path, framework="pt", device="cpu") as f:
46
+ for key in f.keys():
47
+ tensors[key] = f.get_tensor(key)
48
+ return tensors
49
+
50
+ def get_lora_rank(checkpoint):
51
+ for k in checkpoint.keys():
52
+ if k.endswith(".down.weight"):
53
+ return checkpoint[k].shape[0]
54
+
55
+ def load_checkpoint(local_path, repo_id, name):
56
+ if local_path is not None:
57
+ if '.safetensors' in local_path:
58
+ print(f"Loading .safetensors checkpoint from {local_path}")
59
+ checkpoint = load_safetensors(local_path)
60
+ else:
61
+ print(f"Loading checkpoint from {local_path}")
62
+ checkpoint = torch.load(local_path, map_location='cpu')
63
+ elif repo_id is not None and name is not None:
64
+ print(f"Loading checkpoint {name} from repo id {repo_id}")
65
+ checkpoint = load_from_repo_id(repo_id, name)
66
+ else:
67
+ raise ValueError(
68
+ "LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
69
+ )
70
+ return checkpoint
71
+
72
+
73
+ def c_crop(image):
74
+ width, height = image.size
75
+ new_size = min(width, height)
76
+ left = (width - new_size) / 2
77
+ top = (height - new_size) / 2
78
+ right = (width + new_size) / 2
79
+ bottom = (height + new_size) / 2
80
+ return image.crop((left, top, right, bottom))
81
+
82
+ def pad64(x):
83
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
84
+
85
+ def HWC3(x):
86
+ assert x.dtype == np.uint8
87
+ if x.ndim == 2:
88
+ x = x[:, :, None]
89
+ assert x.ndim == 3
90
+ H, W, C = x.shape
91
+ assert C == 1 or C == 3 or C == 4
92
+ if C == 3:
93
+ return x
94
+ if C == 1:
95
+ return np.concatenate([x, x, x], axis=2)
96
+ if C == 4:
97
+ color = x[:, :, 0:3].astype(np.float32)
98
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
99
+ y = color * alpha + 255.0 * (1.0 - alpha)
100
+ y = y.clip(0, 255).astype(np.uint8)
101
+ return y
102
+
103
+ @dataclass
104
+ class ModelSpec:
105
+ params: FluxParams
106
+ ae_params: AutoEncoderParams
107
+ ckpt_path: str | None
108
+ ae_path: str | None
109
+ repo_id: str | None
110
+ repo_flow: str | None
111
+ repo_ae: str | None
112
+ repo_id_ae: str | None
113
+
114
+
115
+ configs = {
116
+ "flux-dev": ModelSpec(
117
+ repo_id="black-forest-labs/FLUX.1-dev",
118
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
119
+ repo_flow="flux1-dev.safetensors",
120
+ repo_ae="ae.safetensors",
121
+ ckpt_path=os.getenv("FLUX_DEV"),
122
+ params=FluxParams(
123
+ in_channels=64,
124
+ vec_in_dim=768,
125
+ context_in_dim=4096,
126
+ hidden_size=3072,
127
+ mlp_ratio=4.0,
128
+ num_heads=24,
129
+ depth=19,
130
+ depth_single_blocks=38,
131
+ axes_dim=[16, 56, 56],
132
+ theta=10_000,
133
+ qkv_bias=True,
134
+ guidance_embed=True,
135
+ ),
136
+ ae_path=os.getenv("AE"),
137
+ ae_params=AutoEncoderParams(
138
+ resolution=256,
139
+ in_channels=3,
140
+ ch=128,
141
+ out_ch=3,
142
+ ch_mult=[1, 2, 4, 4],
143
+ num_res_blocks=2,
144
+ z_channels=16,
145
+ scale_factor=0.3611,
146
+ shift_factor=0.1159,
147
+ ),
148
+ ),
149
+ "flux-dev-fp8": ModelSpec(
150
+ repo_id="XLabs-AI/flux-dev-fp8",
151
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
152
+ repo_flow="flux-dev-fp8.safetensors",
153
+ repo_ae="ae.safetensors",
154
+ ckpt_path=os.getenv("FLUX_DEV_FP8"),
155
+ params=FluxParams(
156
+ in_channels=64,
157
+ vec_in_dim=768,
158
+ context_in_dim=4096,
159
+ hidden_size=3072,
160
+ mlp_ratio=4.0,
161
+ num_heads=24,
162
+ depth=19,
163
+ depth_single_blocks=38,
164
+ axes_dim=[16, 56, 56],
165
+ theta=10_000,
166
+ qkv_bias=True,
167
+ guidance_embed=True,
168
+ ),
169
+ ae_path=os.getenv("AE"),
170
+ ae_params=AutoEncoderParams(
171
+ resolution=256,
172
+ in_channels=3,
173
+ ch=128,
174
+ out_ch=3,
175
+ ch_mult=[1, 2, 4, 4],
176
+ num_res_blocks=2,
177
+ z_channels=16,
178
+ scale_factor=0.3611,
179
+ shift_factor=0.1159,
180
+ ),
181
+ ),
182
+ "flux-schnell": ModelSpec(
183
+ repo_id="black-forest-labs/FLUX.1-schnell",
184
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
185
+ repo_flow="flux1-schnell.safetensors",
186
+ repo_ae="ae.safetensors",
187
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
188
+ params=FluxParams(
189
+ in_channels=64,
190
+ vec_in_dim=768,
191
+ context_in_dim=4096,
192
+ hidden_size=3072,
193
+ mlp_ratio=4.0,
194
+ num_heads=24,
195
+ depth=19,
196
+ depth_single_blocks=38,
197
+ axes_dim=[16, 56, 56],
198
+ theta=10_000,
199
+ qkv_bias=True,
200
+ guidance_embed=False,
201
+ ),
202
+ ae_path=os.getenv("AE"),
203
+ ae_params=AutoEncoderParams(
204
+ resolution=256,
205
+ in_channels=3,
206
+ ch=128,
207
+ out_ch=3,
208
+ ch_mult=[1, 2, 4, 4],
209
+ num_res_blocks=2,
210
+ z_channels=16,
211
+ scale_factor=0.3611,
212
+ shift_factor=0.1159,
213
+ ),
214
+ ),
215
+ }
216
+
217
+
218
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
219
+ if len(missing) > 0 and len(unexpected) > 0:
220
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
221
+ print("\n" + "-" * 79 + "\n")
222
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
223
+ elif len(missing) > 0:
224
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
225
+ elif len(unexpected) > 0:
226
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
227
+
228
+ def load_from_repo_id(repo_id, checkpoint_name):
229
+ ckpt_path = hf_hub_download(repo_id, checkpoint_name)
230
+ sd = load_sft(ckpt_path, device='cpu')
231
+ return sd
232
+
233
+ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
234
+ # Loading Flux
235
+ print("Init model")
236
+ ckpt_path = configs[name].ckpt_path
237
+ if (
238
+ ckpt_path is None
239
+ and configs[name].repo_id is not None
240
+ and configs[name].repo_flow is not None
241
+ and hf_download
242
+ ):
243
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
244
+
245
+ with torch.device("meta" if ckpt_path is not None else device):
246
+ model = Flux(configs[name].params).to(torch.bfloat16)
247
+
248
+ if ckpt_path is not None:
249
+ print("Loading checkpoint")
250
+ # load_sft doesn't support torch.device
251
+ sd = load_model(ckpt_path, device=str(device))
252
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
253
+ print_load_warning(missing, unexpected)
254
+ return model
255
+
256
+ def load_flow_model_only_lora(
257
+ name: str,
258
+ device: str | torch.device = "cuda",
259
+ hf_download: bool = True,
260
+ lora_rank: int = 16
261
+ ):
262
+ # Loading Flux
263
+ print("Init model")
264
+ ckpt_path = configs[name].ckpt_path
265
+ if (
266
+ ckpt_path is None
267
+ and configs[name].repo_id is not None
268
+ and configs[name].repo_flow is not None
269
+ and hf_download
270
+ ):
271
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
272
+
273
+ if hf_download:
274
+ # lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
275
+ try:
276
+ lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
277
+ except:
278
+ lora_ckpt_path = os.environ.get("LORA", None)
279
+ else:
280
+ lora_ckpt_path = os.environ.get("LORA", None)
281
+
282
+ with torch.device("meta" if ckpt_path is not None else device):
283
+ model = Flux(configs[name].params)
284
+
285
+
286
+ model = set_lora(model, lora_rank, device="meta" if lora_ckpt_path is not None else device)
287
+
288
+ if ckpt_path is not None:
289
+ print("Loading lora")
290
+ lora_sd = load_sft(lora_ckpt_path, device=str(device)) if lora_ckpt_path.endswith("safetensors")\
291
+ else torch.load(lora_ckpt_path, map_location='cpu')
292
+
293
+ print("Loading main checkpoint")
294
+ # load_sft doesn't support torch.device
295
+
296
+ if ckpt_path.endswith('safetensors'):
297
+ sd = load_sft(ckpt_path, device=str(device))
298
+ sd.update(lora_sd)
299
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
300
+ else:
301
+ dit_state = torch.load(ckpt_path, map_location='cpu')
302
+ sd = {}
303
+ for k in dit_state.keys():
304
+ sd[k.replace('module.','')] = dit_state[k]
305
+ sd.update(lora_sd)
306
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
307
+ model.to(str(device))
308
+ print_load_warning(missing, unexpected)
309
+ return model
310
+
311
+
312
+ def set_lora(
313
+ model: Flux,
314
+ lora_rank: int,
315
+ double_blocks_indices: list[int] | None = None,
316
+ single_blocks_indices: list[int] | None = None,
317
+ device: str | torch.device = "cpu",
318
+ ) -> Flux:
319
+ double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices
320
+ single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \
321
+ else single_blocks_indices
322
+
323
+ lora_attn_procs = {}
324
+ with torch.device(device):
325
+ for name, attn_processor in model.attn_processors.items():
326
+ match = re.search(r'\.(\d+)\.', name)
327
+ if match:
328
+ layer_index = int(match.group(1))
329
+
330
+ if name.startswith("double_blocks") and layer_index in double_blocks_indices:
331
+ lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
332
+ elif name.startswith("single_blocks") and layer_index in single_blocks_indices:
333
+ lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
334
+ else:
335
+ lora_attn_procs[name] = attn_processor
336
+ model.set_attn_processor(lora_attn_procs)
337
+ return model
338
+
339
+
340
+ def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
341
+ # Loading Flux
342
+ from optimum.quanto import requantize
343
+ print("Init model")
344
+ ckpt_path = configs[name].ckpt_path
345
+ if (
346
+ ckpt_path is None
347
+ and configs[name].repo_id is not None
348
+ and configs[name].repo_flow is not None
349
+ and hf_download
350
+ ):
351
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
352
+ json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
353
+
354
+
355
+ model = Flux(configs[name].params).to(torch.bfloat16)
356
+
357
+ print("Loading checkpoint")
358
+ # load_sft doesn't support torch.device
359
+ sd = load_sft(ckpt_path, device='cpu')
360
+ with open(json_path, "r") as f:
361
+ quantization_map = json.load(f)
362
+ print("Start a quantization process...")
363
+ requantize(model, sd, quantization_map, device=device)
364
+ print("Model is quantized!")
365
+ return model
366
+
367
+ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
368
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
369
+ version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders")
370
+ return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device)
371
+
372
+ def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
373
+ version = os.environ.get("CLIP", "openai/clip-vit-large-patch14")
374
+ return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device)
375
+
376
+
377
+ def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
378
+ ckpt_path = configs[name].ae_path
379
+ if (
380
+ ckpt_path is None
381
+ and configs[name].repo_id is not None
382
+ and configs[name].repo_ae is not None
383
+ and hf_download
384
+ ):
385
+ ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
386
+
387
+ # Loading the autoencoder
388
+ print("Init AE")
389
+ with torch.device("meta" if ckpt_path is not None else device):
390
+ ae = AutoEncoder(configs[name].ae_params)
391
+
392
+ if ckpt_path is not None:
393
+ sd = load_sft(ckpt_path, device=str(device))
394
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
395
+ print_load_warning(missing, unexpected)
396
+ return ae
uno/utils/convert_yaml_to_args_file.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import yaml
17
+
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("--yaml", type=str, required=True)
20
+ parser.add_argument("--arg", type=str, required=True)
21
+ args = parser.parse_args()
22
+
23
+
24
+ with open(args.yaml, "r") as f:
25
+ data = yaml.safe_load(f)
26
+
27
+ with open(args.arg, "w") as f:
28
+ for k, v in data.items():
29
+ if isinstance(v, list):
30
+ v = list(map(str, v))
31
+ v = " ".join(v)
32
+ if v is None:
33
+ continue
34
+ print(f"--{k} {v}", end=" ", file=f)