Manireddy1508 commited on
Commit
0bc84fc
·
verified ·
1 Parent(s): e73f637

Upload 10 files

Browse files
Files changed (10) hide show
  1. .env +1 -0
  2. .gitattributes +37 -35
  3. .gitignore +183 -0
  4. .gitmodules +3 -0
  5. LICENSE +201 -0
  6. README.md +144 -13
  7. app.py +186 -0
  8. inference.py +151 -0
  9. requirements.txt +16 -0
  10. train.py +482 -0
.env ADDED
@@ -0,0 +1 @@
 
 
1
+ OPENAI_API_KEY = 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
.gitattributes CHANGED
@@ -1,35 +1,37 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+
176
+ # User config files
177
+ .vscode/
178
+ output/
179
+
180
+ # ckpt
181
+ *.bin
182
+ *.pt
183
+ *.pth
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "datasets/dreambooth"]
2
+ path = datasets/dreambooth
3
+ url = https://github.com/google/dreambooth.git
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,13 +1,144 @@
1
- ---
2
- title: Uno Final
3
- emoji: 👁
4
- colorFrom: green
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.25.2
8
- app_file: app.py
9
- pinned: false
10
- short_description: uno-final
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h3 align="center">
2
+ <img src="assets/logo.png" alt="Logo" style="vertical-align: middle; width: 40px; height: 40px;">
3
+ Less-to-More Generalization: Unlocking More Controllability by In-Context Generation
4
+ </h3>
5
+
6
+ <p align="center">
7
+ <a href="https://github.com/bytedance/UNO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UNO"></a>
8
+ <a href="https://bytedance.github.io/UNO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UNO-yellow"></a>
9
+ <a href="https://arxiv.org/abs/2504.02160"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UNO-b31b1b.svg"></a>
10
+ <a href="https://huggingface.co/bytedance-research/UNO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a>
11
+ <a href="https://huggingface.co/spaces/bytedance-research/UNO-FLUX"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=demo&color=orange"></a>
12
+ </p>
13
+
14
+ ><p align="center"> <span style="color:#137cf3; font-family: Gill Sans">Shaojin Wu,</span><sup></sup></a> <span style="color:#137cf3; font-family: Gill Sans">Mengqi Huang</span><sup>*</sup>,</a> <span style="color:#137cf3; font-family: Gill Sans">Wenxu Wu,</span><sup></sup></a> <span style="color:#137cf3; font-family: Gill Sans">Yufeng Cheng,</span><sup></sup> </a> <span style="color:#137cf3; font-family: Gill Sans">Fei Ding</span><sup>+</sup>,</a> <span style="color:#137cf3; font-family: Gill Sans">Qian He</span></a> <br>
15
+ ><span style="font-size: 16px">Intelligent Creation Team, ByteDance</span></p>
16
+
17
+ <p align="center">
18
+ <img src="./assets/teaser.jpg" width=95% height=95%
19
+ class="center">
20
+ </p>
21
+
22
+ ## 🔥 News
23
+ - [04/2025] 🔥 Update fp8 mode as a primary low vmemory usage support. Gift for consumer-grade GPU users. The peak Vmemory usage is ~16GB now. We may try further inference optimization later.
24
+ - [04/2025] 🔥 The [demo](https://huggingface.co/spaces/bytedance-research/UNO-FLUX) of UNO is released.
25
+ - [04/2025] 🔥 The [training code](https://github.com/bytedance/UNO), [inference code](https://github.com/bytedance/UNO), and [model](https://huggingface.co/bytedance-research/UNO) of UNO are released.
26
+ - [04/2025] 🔥 The [project page](https://bytedance.github.io/UNO) of UNO is created.
27
+ - [04/2025] 🔥 The arXiv [paper](https://arxiv.org/abs/2504.02160) of UNO is released.
28
+
29
+ ## 📖 Introduction
30
+ In this study, we propose a highly-consistent data synthesis pipeline to tackle this challenge. This pipeline harnesses the intrinsic in-context generation capabilities of diffusion transformers and generates high-consistency multi-subject paired data. Additionally, we introduce UNO, which consists of progressive cross-modal alignment and universal rotary position embedding. It is a multi-image conditioned subject-to-image model iteratively trained from a text-to-image model. Extensive experiments show that our method can achieve high consistency while ensuring controllability in both single-subject and multi-subject driven generation.
31
+
32
+
33
+ ## ⚡️ Quick Start
34
+
35
+ ### 🔧 Requirements and Installation
36
+
37
+ Install the requirements
38
+ ```bash
39
+ ## create a virtual environment with python >= 3.10 <= 3.12, like
40
+ # python -m venv uno_env
41
+ # source uno_env/bin/activate
42
+ # then install
43
+ pip install -r requirements.txt
44
+ ```
45
+
46
+ then download checkpoints in one of the three ways:
47
+ 1. Directly run the inference scripts, the checkpoints will be downloaded automatically by the `hf_hub_download` function in the code to your `$HF_HOME`(the default value is `~/.cache/huggingface`).
48
+ 2. use `huggingface-cli download <repo name>` to download `black-forest-labs/FLUX.1-dev`, `xlabs-ai/xflux_text_encoders`, `openai/clip-vit-large-patch14`, `bytedance-research/UNO`, then run the inference scripts. You can just download the checkpoint in need only to speed up your set up and save your disk space. i.e. for `black-forest-labs/FLUX.1-dev` use `huggingface-cli download black-forest-labs/FLUX.1-dev flux1-dev.safetensors` and `huggingface-cli download black-forest-labs/FLUX.1-dev ae.safetensors`, ignoreing the text encoder in `black-forest-labes/FLUX.1-dev` model repo(They are here for `diffusers` call). All of the checkpoints will take 37 GB of disk space.
49
+ 3. use `huggingface-cli download <repo name> --local-dir <LOCAL_DIR>` to download all the checkpoints mentioned in 2. to the directories your want. Then set the environment variable `AE`, `FLUX_DEV`(or `FLUX_DEV_FP8` if you use fp8 mode), `T5`, `CLIP`, `LORA` to the corresponding paths. Finally, run the inference scripts.
50
+ 4. **If you already have some of the checkpoints**, you can set the environment variable `AE`, `FLUX_DEV`, `T5`, `CLIP`, `LORA` to the corresponding paths. Finally, run the inference scripts.
51
+
52
+ ### 🌟 Gradio Demo
53
+
54
+ ```bash
55
+ python app.py
56
+ ```
57
+
58
+ **For low vmemory usage**, please pass the `--offload` and `--name flux-dev-fp8` args. The peak memory usage will be 16GB. Just for reference, the end2end inference time is 40s to 1min on RTX 3090 in fp8 and offload mode.
59
+
60
+ ```bash
61
+ python app.py --offload --name flux-dev-fp8
62
+ ```
63
+
64
+
65
+ ### ✍️ Inference
66
+ Start from the examples below to explore and spark your creativity. ✨
67
+ ```bash
68
+ python inference.py --prompt "A clock on the beach is under a red sun umbrella" --image_paths "assets/clock.png" --width 704 --height 704
69
+ python inference.py --prompt "The figurine is in the crystal ball" --image_paths "assets/figurine.png" "assets/crystal_ball.png" --width 704 --height 704
70
+ python inference.py --prompt "The logo is printed on the cup" --image_paths "assets/cat_cafe.png" "assets/cup.png" --width 704 --height 704
71
+ ```
72
+
73
+ Optional prepreration: If you want to test the inference on dreambench at the first time, you should clone the submodule `dreambench` to download the dataset.
74
+
75
+ ```bash
76
+ git submodule update --init
77
+ ```
78
+ Then running the following scripts:
79
+ ```bash
80
+ # evaluated on dreambench
81
+ ## for single-subject
82
+ python inference.py --eval_json_path ./datasets/dreambench_singleip.json
83
+ ## for multi-subject
84
+ python inference.py --eval_json_path ./datasets/dreambench_multiip.json
85
+ ```
86
+
87
+
88
+
89
+ ### 🚄 Training
90
+
91
+ ```bash
92
+ accelerate launch train.py
93
+ ```
94
+
95
+
96
+ ### 📌 Tips and Notes
97
+ We integrate single-subject and multi-subject generation within a unified model. For single-subject scenarios, the longest side of the reference image is set to 512 by default, while for multi-subject scenarios, it is set to 320. UNO demonstrates remarkable flexibility across various aspect ratios, thanks to its training on a multi-scale dataset. Despite being trained within 512 buckets, it can handle higher resolutions, including 512, 568, and 704, among others.
98
+
99
+ UNO excels in subject-driven generation but has room for improvement in generalization due to dataset constraints. We are actively developing an enhanced model—stay tuned for updates. Your feedback is valuable, so please feel free to share any suggestions.
100
+
101
+ ## 🎨 Application Scenarios
102
+ <p align="center">
103
+ <img src="./assets/simplecase.jpg" width=95% height=95%
104
+ class="center">
105
+ </p>
106
+
107
+ ## 📄 Disclaimer
108
+ <p>
109
+ We open-source this project for academic research. The vast majority of images
110
+ used in this project are either generated or licensed. If you have any concerns,
111
+ please contact us, and we will promptly remove any inappropriate content.
112
+ Our code is released under the Apache 2.0 License,, while our models are under
113
+ the CC BY-NC 4.0 License. Any models related to <a href="https://huggingface.co/black-forest-labs/FLUX.1-dev" target="_blank">FLUX.1-dev</a>
114
+ base model must adhere to the original licensing terms.
115
+ <br><br>This research aims to advance the field of generative AI. Users are free to
116
+ create images using this tool, provided they comply with local laws and exercise
117
+ responsible usage. The developers are not liable for any misuse of the tool by users.</p>
118
+
119
+ ## 🚀 Updates
120
+ For the purpose of fostering research and the open-source community, we plan to open-source the entire project, encompassing training, inference, weights, etc. Thank you for your patience and support! 🌟
121
+ - [x] Release github repo.
122
+ - [x] Release inference code.
123
+ - [x] Release training code.
124
+ - [x] Release model checkpoints.
125
+ - [x] Release arXiv paper.
126
+ - [x] Release huggingface space demo.
127
+ - [ ] Release in-context data generation pipelines.
128
+
129
+ ## Related resources
130
+
131
+ - [https://github.com/jax-explorer/ComfyUI-UNO](https://github.com/jax-explorer/ComfyUI-UNO) a ComfyUI node implementation of UNO by jax-explorer.
132
+
133
+ ## Citation
134
+ If UNO is helpful, please help to ⭐ the repo.
135
+
136
+ If you find this project useful for your research, please consider citing our paper:
137
+ ```bibtex
138
+ @article{wu2025less,
139
+ title={Less-to-More Generalization: Unlocking More Controllability by In-Context Generation},
140
+ author={Wu, Shaojin and Huang, Mengqi and Wu, Wenxu and Cheng, Yufeng and Ding, Fei and He, Qian},
141
+ journal={arXiv preprint arXiv:2504.02160},
142
+ year={2025}
143
+ }
144
+ ```
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ import json
17
+ from pathlib import Path
18
+ import gradio as gr
19
+ import torch
20
+ import openai
21
+ import os
22
+
23
+ from uno.flux.pipeline import UNOPipeline
24
+ from uno.utils.prompt_enhancer import enhance_prompt_with_chatgpt
25
+
26
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
27
+
28
+ openai.api_key = os.getenv("OPENAI_API_KEY")
29
+
30
+ def get_examples(examples_dir: str = "assets/examples") -> list:
31
+ examples = Path(examples_dir)
32
+ ans = []
33
+ for example in examples.iterdir():
34
+ if not example.is_dir():
35
+ continue
36
+ with open(example / "config.json") as f:
37
+ example_dict = json.load(f)
38
+
39
+ example_list = [example_dict["useage"], example_dict["prompt"]]
40
+
41
+ for key in ["image_ref1", "image_ref2", "image_ref3", "image_ref4"]:
42
+ example_list.append(str(example / example_dict[key]) if key in example_dict else None)
43
+
44
+ example_list.append(example_dict["seed"])
45
+ ans.append(example_list)
46
+ return ans
47
+
48
+ def create_demo(model_type: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False):
49
+ pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
50
+
51
+ with gr.Blocks() as demo:
52
+ gr.Markdown("# UNO by UNO team")
53
+ gr.Markdown(
54
+ """
55
+ <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
56
+ <a href="https://github.com/bytedance/UNO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UNO"></a>
57
+ <a href="https://bytedance.github.io/UNO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UNO-yellow"></a>
58
+ <a href="https://arxiv.org/abs/2504.02160"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UNO-b31b1b.svg"></a>
59
+ <a href="https://huggingface.co/bytedance-research/UNO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a>
60
+ <a href="https://huggingface.co/spaces/bytedance-research/UNO-FLUX"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=demo&color=orange"></a>
61
+ </div>
62
+ """
63
+ )
64
+
65
+ with gr.Row():
66
+ with gr.Column():
67
+ prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
68
+ with gr.Row():
69
+ image_prompt1 = gr.Image(label="Ref Img1", type="pil")
70
+ image_prompt2 = gr.Image(label="Ref Img2", type="pil")
71
+ image_prompt3 = gr.Image(label="Ref Img3", type="pil")
72
+ image_prompt4 = gr.Image(label="Ref Img4", type="pil")
73
+
74
+ with gr.Row():
75
+ with gr.Column():
76
+ width = gr.Slider(512, 2048, 512, step=16, label="Generation Width")
77
+ height = gr.Slider(512, 2048, 512, step=16, label="Generation Height")
78
+ with gr.Column():
79
+ gr.Markdown("📌 Trained on 512x512. Larger size = better quality, but less stable.")
80
+
81
+ with gr.Accordion("Advanced Options", open=False):
82
+ with gr.Row():
83
+ num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
84
+ guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance")
85
+ seed = gr.Number(-1, label="Seed (-1 for random)")
86
+ num_outputs = gr.Slider(1, 5, 5, step=1, label="Number of Enhanced Prompts / Images")
87
+
88
+ generate_btn = gr.Button("Generate Enhanced Images")
89
+
90
+ with gr.Column():
91
+ outputs = []
92
+ for i in range(5):
93
+ outputs.append(gr.Image(label=f"Image {i+1}"))
94
+ outputs.append(gr.Textbox(label=f"Enhanced Prompt {i+1}"))
95
+
96
+ def run_generation(prompt, width, height, guidance, num_steps, seed,
97
+ img1, img2, img3, img4, num_outputs):
98
+ uploaded_images = [img for img in [img1, img2, img3, img4] if img is not None]
99
+
100
+ print(f"\n📥 [DEBUG] User prompt: {prompt}")
101
+ prompts = enhance_prompt_with_chatgpt(
102
+ user_prompt=prompt,
103
+ num_prompts=num_outputs,
104
+ reference_images=uploaded_images
105
+ )
106
+
107
+ print(f"\n🧠 [DEBUG] Final Prompt List (len={len(prompts)}):")
108
+ for idx, p in enumerate(prompts):
109
+ print(f" [{idx+1}] {p}")
110
+
111
+ while len(prompts) < num_outputs:
112
+ prompts.append(prompt)
113
+
114
+ results = []
115
+ for i in range(num_outputs):
116
+ try:
117
+ seed_val = int(seed) if seed != -1 else torch.randint(0, 10**8, (1,)).item()
118
+ print(f"🧪 [DEBUG] Using seed: {seed_val} for image {i+1}")
119
+ gen_image, _ = pipeline.gradio_generate(
120
+ prompt=prompts[i],
121
+ width=width,
122
+ height=height,
123
+ guidance=guidance,
124
+ num_steps=num_steps,
125
+ seed=seed_val,
126
+ image_prompt1=img1,
127
+ image_prompt2=img2,
128
+ image_prompt3=img3,
129
+ image_prompt4=img4,
130
+ )
131
+ print(f"✅ [DEBUG] Image {i+1} generated using prompt: {prompts[i]}")
132
+ results.append(gen_image)
133
+ results.append(prompts[i])
134
+ except Exception as e:
135
+ print(f"❌ [ERROR] Failed to generate image {i+1}: {e}")
136
+ results.append(None)
137
+ results.append(f"⚠️ Failed to generate: {e}")
138
+
139
+ # Pad to 10 outputs: 5 image + prompt pairs
140
+ while len(results) < 10:
141
+ results.append(None if len(results) % 2 == 0 else "")
142
+
143
+ return results
144
+
145
+ generate_btn.click(
146
+ fn=run_generation,
147
+ inputs=[
148
+ prompt, width, height, guidance, num_steps,
149
+ seed, image_prompt1, image_prompt2, image_prompt3, image_prompt4, num_outputs
150
+ ],
151
+ outputs=outputs
152
+ )
153
+
154
+ example_text = gr.Text("", visible=False, label="Case For:")
155
+ examples = get_examples("./assets/examples")
156
+
157
+ gr.Examples(
158
+ examples=examples,
159
+ inputs=[
160
+ example_text, prompt,
161
+ image_prompt1, image_prompt2, image_prompt3, image_prompt4,
162
+ seed, outputs[0]
163
+ ],
164
+ )
165
+
166
+ return demo
167
+
168
+ if __name__ == "__main__":
169
+ from typing import Literal
170
+ from transformers import HfArgumentParser
171
+
172
+ @dataclasses.dataclass
173
+ class AppArgs:
174
+ name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
175
+ device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu"
176
+ offload: bool = dataclasses.field(
177
+ default=False,
178
+ metadata={"help": "If True, sequentially offload unused models to CPU"}
179
+ )
180
+ port: int = 7860
181
+
182
+ parser = HfArgumentParser([AppArgs])
183
+ args = parser.parse_args_into_dataclasses()[0]
184
+
185
+ demo = create_demo(args.name, args.device, args.offload)
186
+ demo.launch(server_port=args.port)
inference.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import dataclasses
17
+ from typing import Literal
18
+ from accelerate import Accelerator
19
+ from transformers import HfArgumentParser
20
+ from PIL import Image
21
+ import json
22
+ import openai
23
+
24
+ from uno.flux.pipeline import UNOPipeline, preprocess_ref
25
+ from uno.utils.prompt_enhancer import enhance_prompt_with_chatgpt
26
+
27
+ openai.api_key = os.getenv("OPENAI_API_KEY")
28
+
29
+ def horizontal_concat(images):
30
+ widths, heights = zip(*(img.size for img in images))
31
+ total_width = sum(widths)
32
+ max_height = max(heights)
33
+ new_im = Image.new('RGB', (total_width, max_height))
34
+ x_offset = 0
35
+ for img in images:
36
+ new_im.paste(img, (x_offset, 0))
37
+ x_offset += img.size[0]
38
+ return new_im
39
+
40
+ @dataclasses.dataclass
41
+ class InferenceArgs:
42
+ prompt: str | None = None
43
+ image_paths: list[str] | None = None
44
+ eval_json_path: str | None = None
45
+ offload: bool = False
46
+ num_images_per_prompt: int = 1
47
+ model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
48
+ width: int = 512
49
+ height: int = 512
50
+ ref_size: int = -1
51
+ num_steps: int = 25
52
+ guidance: float = 4
53
+ seed: int = 3407
54
+ save_path: str = "output/inference"
55
+ only_lora: bool = True
56
+ concat_refs: bool = False
57
+ lora_rank: int = 512
58
+ data_resolution: int = 512
59
+ pe: Literal['d', 'h', 'w', 'o'] = 'd'
60
+
61
+ def main(args: InferenceArgs):
62
+ accelerator = Accelerator()
63
+ pipeline = UNOPipeline(
64
+ args.model_type,
65
+ accelerator.device,
66
+ args.offload,
67
+ only_lora=args.only_lora,
68
+ lora_rank=args.lora_rank
69
+ )
70
+
71
+ assert args.prompt is not None or args.eval_json_path is not None, \
72
+ "Please provide either prompt or eval_json_path"
73
+
74
+ if args.eval_json_path:
75
+ with open(args.eval_json_path, "rt") as f:
76
+ data_dicts = json.load(f)
77
+ data_root = os.path.dirname(args.eval_json_path)
78
+ else:
79
+ data_root = "./"
80
+ data_dicts = [{"prompt": args.prompt, "image_paths": args.image_paths}]
81
+
82
+ for i, data_dict in enumerate(data_dicts):
83
+ try:
84
+ ref_imgs = [
85
+ Image.open(os.path.join(data_root, img_path))
86
+ for img_path in data_dict["image_paths"]
87
+ ]
88
+ except Exception as e:
89
+ print(f"❌ [ERROR] Failed to load reference images: {e}")
90
+ continue
91
+
92
+ if args.ref_size == -1:
93
+ args.ref_size = 512 if len(ref_imgs) == 1 else 320
94
+ ref_imgs = [preprocess_ref(img, args.ref_size) for img in ref_imgs]
95
+
96
+ print(f"\n🔧 [DEBUG] Enhancing prompt: '{data_dict['prompt']}'")
97
+ enhanced_prompts = enhance_prompt_with_chatgpt(
98
+ user_prompt=data_dict["prompt"],
99
+ num_prompts=args.num_images_per_prompt,
100
+ reference_images=ref_imgs
101
+ )
102
+
103
+ # Pad if needed
104
+ while len(enhanced_prompts) < args.num_images_per_prompt:
105
+ print(f"⚠️ [DEBUG] Padding prompts: returning user prompt as fallback.")
106
+ enhanced_prompts.append(data_dict["prompt"])
107
+
108
+ for j in range(args.num_images_per_prompt):
109
+ if (i * args.num_images_per_prompt + j) % accelerator.num_processes != accelerator.process_index:
110
+ continue
111
+
112
+ prompt_j = enhanced_prompts[j]
113
+ print(f"\n--- Generating image [{i}_{j}] ---")
114
+ print(f"Enhanced Prompt: {prompt_j}")
115
+ print(f"Image paths: {data_dict['image_paths']}")
116
+ print(f"Seed: {args.seed + j}")
117
+ print(f"Resolution: {args.width}x{args.height}")
118
+ print("------------------------------")
119
+
120
+ try:
121
+ image_gen = pipeline(
122
+ prompt=prompt_j,
123
+ width=args.width,
124
+ height=args.height,
125
+ guidance=args.guidance,
126
+ num_steps=args.num_steps,
127
+ seed=args.seed + j,
128
+ ref_imgs=ref_imgs,
129
+ pe=args.pe,
130
+ )
131
+
132
+ if args.concat_refs:
133
+ image_gen = horizontal_concat([image_gen, *ref_imgs])
134
+
135
+ os.makedirs(args.save_path, exist_ok=True)
136
+ image_gen.save(os.path.join(args.save_path, f"{i}_{j}.png"))
137
+
138
+ # Save generation context
139
+ args_dict = vars(args)
140
+ args_dict['prompt'] = prompt_j
141
+ args_dict['image_paths'] = data_dict["image_paths"]
142
+ with open(os.path.join(args.save_path, f"{i}_{j}.json"), 'w') as f:
143
+ json.dump(args_dict, f, indent=4)
144
+
145
+ except Exception as e:
146
+ print(f"❌ [ERROR] Failed to generate or save image {i}_{j}: {e}")
147
+
148
+ if __name__ == "__main__":
149
+ parser = HfArgumentParser([InferenceArgs])
150
+ args = parser.parse_args_into_dataclasses()[0]
151
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+ torch==2.4.0
3
+ torchvision==0.19.0
4
+
5
+ accelerate==1.1.1
6
+ deepspeed==0.14.4
7
+ einops==0.8.0
8
+ transformers==4.43.3
9
+ huggingface-hub
10
+ diffusers==0.30.1
11
+ sentencepiece==0.2.0
12
+ gradio==5.22.0
13
+
14
+
15
+ openai>=1.14.0
16
+ python-dotenv>=1.0.1
train.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ import gc
17
+ import itertools
18
+ import logging
19
+ import os
20
+ import random
21
+ from copy import deepcopy
22
+ from typing import TYPE_CHECKING, Literal
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import transformers
27
+ from accelerate import Accelerator, DeepSpeedPlugin
28
+ from accelerate.logging import get_logger
29
+ from accelerate.utils import set_seed
30
+ from diffusers.optimization import get_scheduler
31
+ from einops import rearrange
32
+ from PIL import Image
33
+ from safetensors.torch import load_file
34
+ from torch.utils.data import DataLoader
35
+ from tqdm import tqdm
36
+
37
+ from uno.dataset.uno import FluxPairedDatasetV2
38
+ from uno.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack
39
+ from uno.flux.util import load_ae, load_clip, load_flow_model, load_t5, set_lora
40
+
41
+ if TYPE_CHECKING:
42
+ from uno.flux.model import Flux
43
+ from uno.flux.modules.autoencoder import AutoEncoder
44
+ from uno.flux.modules.conditioner import HFEmbedder
45
+
46
+ logger = get_logger(__name__)
47
+
48
+ def get_models(name: str, device, offload: bool=False):
49
+ t5 = load_t5(device, max_length=512)
50
+ clip = load_clip(device)
51
+ model = load_flow_model(name, device="cpu")
52
+ vae = load_ae(name, device="cpu" if offload else device)
53
+ return model, vae, t5, clip
54
+
55
+ def inference(
56
+ batch: dict,
57
+ model: "Flux", t5: "HFEmbedder", clip: "HFEmbedder", ae: "AutoEncoder",
58
+ accelerator: Accelerator,
59
+ seed: int = 0,
60
+ pe: Literal["d", "h", "w", "o"] = "d"
61
+ ) -> Image.Image:
62
+ ref_imgs = batch["ref_imgs"]
63
+ prompt = batch["txt"]
64
+ neg_prompt = ''
65
+ width, height = 512, 512
66
+ num_steps = 25
67
+ x = get_noise(
68
+ 1, height, width,
69
+ device=accelerator.device,
70
+ dtype=torch.bfloat16,
71
+ seed=seed + accelerator.process_index
72
+ )
73
+ timesteps = get_schedule(
74
+ num_steps,
75
+ (width // 8) * (height // 8) // (16 * 16),
76
+ shift=True,
77
+ )
78
+ with torch.no_grad():
79
+ ref_imgs = [
80
+ ae.encode(ref_img_.to(accelerator.device, torch.float32)).to(torch.bfloat16)
81
+ for ref_img_ in ref_imgs
82
+ ]
83
+ inp_cond = prepare_multi_ip(
84
+ t5=t5, clip=clip, img=x, prompt=prompt,
85
+ ref_imgs=ref_imgs,
86
+ pe=pe
87
+ )
88
+ neg_inp_cond = prepare_multi_ip(
89
+ t5=t5, clip=clip, img=x, prompt=neg_prompt,
90
+ ref_imgs=ref_imgs,
91
+ pe=pe
92
+ )
93
+
94
+ x = denoise(
95
+ model,
96
+ **inp_cond,
97
+ timesteps=timesteps,
98
+ guidance=4,
99
+ timestep_to_start_cfg=30,
100
+ neg_txt=neg_inp_cond['txt'],
101
+ neg_txt_ids=neg_inp_cond['txt_ids'],
102
+ neg_vec=neg_inp_cond['vec'],
103
+ true_gs=3.5,
104
+ image_proj=None,
105
+ neg_image_proj=None,
106
+ ip_scale=1,
107
+ neg_ip_scale=1
108
+ )
109
+
110
+ x = unpack(x.float(), height, width)
111
+ x = ae.decode(x)
112
+
113
+ x1 = x.clamp(-1, 1)
114
+ x1 = rearrange(x1[-1], "c h w -> h w c")
115
+ output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
116
+
117
+ return output_img
118
+
119
+
120
+ def resume_from_checkpoint(
121
+ resume_from_checkpoint: str | None | Literal["latest"],
122
+ project_dir: str,
123
+ accelerator: Accelerator,
124
+ dit: "Flux",
125
+ optimizer: torch.optim.Optimizer,
126
+ lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
127
+ dit_ema_dict: dict | None = None,
128
+ ) -> tuple["Flux", torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler, dict | None, int]:
129
+ # Potentially load in the weights and states from a previous save
130
+ if resume_from_checkpoint is None:
131
+ return dit, optimizer, lr_scheduler, dit_ema_dict, 0
132
+
133
+ if resume_from_checkpoint == "latest":
134
+ # Get the most recent checkpoint
135
+ dirs = os.listdir(project_dir)
136
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
137
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
138
+ if len(dirs) == 0:
139
+ accelerator.print(
140
+ f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
141
+ )
142
+ return dit, optimizer, lr_scheduler, dit_ema_dict, 0
143
+ path = dirs[-1]
144
+ else:
145
+ path = os.path.basename(resume_from_checkpoint)
146
+
147
+
148
+ accelerator.print(f"Resuming from checkpoint {path}")
149
+ lora_state = load_file(os.path.join(project_dir, path, 'dit_lora.safetensors'), device=accelerator.device)
150
+ unwarp_dit = accelerator.unwrap_model(dit)
151
+ unwarp_dit.load_state_dict(lora_state, strict=False)
152
+ if dit_ema_dict is not None:
153
+ dit_ema_dict = load_file(
154
+ os.path.join(project_dir, path, 'dit_lora_ema.safetensors'),
155
+ device=accelerator.device
156
+ )
157
+ if dit is not unwarp_dit:
158
+ dit_ema_dict = {f"module.{k}": v for k, v in dit_ema_dict.items() if k in unwarp_dit.state_dict()}
159
+
160
+ global_step = int(path.split("-")[1])
161
+
162
+ return dit, optimizer, lr_scheduler, dit_ema_dict, global_step
163
+
164
+ @dataclasses.dataclass
165
+ class TrainArgs:
166
+ ## accelerator
167
+ project_dir: str | None = None
168
+ mixed_precision: Literal["no", "fp16", "bf16"] = "bf16"
169
+ gradient_accumulation_steps: int = 1,
170
+ seed: int = 42
171
+ wandb_project_name: str | None = None
172
+ wandb_run_name: str | None = None
173
+
174
+ ## model
175
+ model_name: Literal["flux", "flux-schnell"] = "flux"
176
+ lora_rank: int = 512
177
+ double_blocks_indices: list[int] | None = dataclasses.field(
178
+ default=None,
179
+ metadata={"help": "Indices of double blocks to apply LoRA. None means all double blocks."}
180
+ )
181
+ single_blocks_indices: list[int] | None = dataclasses.field(
182
+ default=None,
183
+ metadata={"help": "Indices of double blocks to apply LoRA. None means all single blocks."}
184
+ )
185
+ pe: Literal["d", "h", "w", "o"] = "d",
186
+ gradient_checkpoint: bool = False
187
+
188
+ ## ema
189
+ ema: bool = False
190
+ ema_interval: int = 1
191
+ ema_decay: float = 0.99
192
+
193
+
194
+ ## optimizer
195
+ learning_rate: float = 1e-2
196
+ adam_betas: list[float] = dataclasses.field(default_factory=lambda: [0.9, 0.999])
197
+ adam_eps: float = 1e-8
198
+ adam_weight_decay: float = 0.01
199
+
200
+ ## lr_scheduler
201
+ lr_scheduler: str = "constant"
202
+ lr_warmup_steps: int = 100
203
+ max_train_steps: int = 100000
204
+
205
+ ## dataloader
206
+ train_data_json: str = "datasets/dreambench_singleip.json" # TODO: change to your own dataset, or use one data syenthsize pipeline comming in the future. stay tuned
207
+ batch_size: int = 1
208
+ text_dropout: float = 0.1
209
+ resolution: int = 512
210
+ resolution_ref: int | None = None
211
+
212
+ eval_data_json: str = "datasets/dreambench_singleip.json"
213
+ eval_batch_size: int = 1
214
+
215
+ ## misc
216
+ resume_from_checkpoint: str | None | Literal["latest"] = None
217
+ checkpointing_steps: int = 1000
218
+
219
+ def main(
220
+ args: TrainArgs,
221
+ ):
222
+ ## accelerator
223
+ deepspeed_plugins = {
224
+ "dit": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero2_config.json'),
225
+ "t5": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero3_config.json'),
226
+ "clip": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero3_config.json')
227
+ }
228
+ accelerator = Accelerator(
229
+ project_dir=args.project_dir,
230
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
231
+ mixed_precision=args.mixed_precision,
232
+ deepspeed_plugins=deepspeed_plugins,
233
+ log_with="wandb",
234
+ )
235
+ set_seed(args.seed, device_specific=True)
236
+ accelerator.init_trackers(
237
+ project_name=args.wandb_project_name,
238
+ config=args.__dict__,
239
+ init_kwargs={
240
+ "wandb": {
241
+ "name": args.wandb_run_name,
242
+ "dir": accelerator.project_dir,
243
+ },
244
+ },
245
+ )
246
+ weight_dtype = {
247
+ "fp16": torch.float16,
248
+ "bf16": torch.bfloat16,
249
+ "no": torch.float32,
250
+ }.get(accelerator.mixed_precision, torch.float32)
251
+
252
+ ## logger
253
+ logging.basicConfig(
254
+ format=f"[RANK {accelerator.process_index}] " + "%(asctime)s - %(levelname)s - %(name)s - %(message)s",
255
+ datefmt="%m/%d/%Y %H:%M:%S",
256
+ level=logging.INFO,
257
+ force=True
258
+ )
259
+ logger.info(accelerator.state)
260
+ logger.info("Training script launched", main_process_only=False)
261
+
262
+ ## model
263
+ dit, vae, t5, clip = get_models(
264
+ name=args.model_name,
265
+ device=accelerator.device,
266
+ )
267
+
268
+ vae.requires_grad_(False)
269
+ t5.requires_grad_(False)
270
+ clip.requires_grad_(False)
271
+
272
+ dit.requires_grad_(False)
273
+ dit = set_lora(dit, args.lora_rank, args.double_blocks_indices, args.single_blocks_indices, accelerator.device)
274
+ dit.train()
275
+ dit.gradient_checkpointing = args.gradient_checkpoint
276
+
277
+ ## optimizer and lr scheduler
278
+ optimizer = torch.optim.AdamW(
279
+ [p for p in dit.parameters() if p.requires_grad],
280
+ lr=args.learning_rate,
281
+ betas=args.adam_betas,
282
+ weight_decay=args.adam_weight_decay,
283
+ eps=args.adam_eps,
284
+ )
285
+
286
+ lr_scheduler = get_scheduler(
287
+ args.lr_scheduler,
288
+ optimizer=optimizer,
289
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
290
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
291
+ )
292
+
293
+ # dataloader
294
+ dataset = FluxPairedDatasetV2(
295
+ data_json=args.train_data_json,
296
+ resolution=args.resolution, resolution_ref=args.resolution_ref
297
+ )
298
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
299
+ eval_dataset = FluxPairedDatasetV2(
300
+ data_json=args.eval_data_json,
301
+ resolution=args.resolution, resolution_ref=args.resolution_ref
302
+ )
303
+ eval_dataloader = DataLoader(eval_dataset, batch_size=args.eval_batch_size, shuffle=False)
304
+
305
+ dataloader = accelerator.prepare_data_loader(dataloader)
306
+ eval_dataloader = accelerator.prepare_data_loader(eval_dataloader)
307
+ dataloader = itertools.cycle(dataloader) # as infinite fetch data loader
308
+
309
+ ## parallel
310
+ dit = accelerator.prepare_model(dit)
311
+ optimizer = accelerator.prepare_optimizer(optimizer)
312
+ lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
313
+ accelerator.state.select_deepspeed_plugin("t5")
314
+ t5 = accelerator.prepare_model(t5)
315
+ accelerator.state.select_deepspeed_plugin("clip")
316
+ clip = accelerator.prepare_model(clip)
317
+
318
+ ## ema
319
+ dit_ema_dict = {
320
+ k: deepcopy(v).requires_grad_(False) for k, v in dit.named_parameters() if v.requires_grad
321
+ } if args.ema else None
322
+
323
+ ## resume
324
+ (
325
+ dit,
326
+ optimizer,
327
+ lr_scheduler,
328
+ dit_ema_dict,
329
+ global_step
330
+ ) = resume_from_checkpoint(
331
+ args.resume_from_checkpoint,
332
+ project_dir=args.project_dir,
333
+ accelerator=accelerator,
334
+ dit=dit,
335
+ optimizer=optimizer,
336
+ lr_scheduler=lr_scheduler,
337
+ dit_ema_dict=dit_ema_dict
338
+ )
339
+
340
+ ## noise scheduler
341
+ timesteps = get_schedule(
342
+ 999,
343
+ (args.resolution // 8) * (args.resolution // 8) // 4,
344
+ shift=True,
345
+ )
346
+ timesteps = torch.tensor(timesteps, device=accelerator.device)
347
+ total_batch_size = args.batch_size * accelerator.num_processes * args.gradient_accumulation_steps
348
+
349
+ logger.info("***** Running training *****")
350
+ logger.info(f" Instantaneous batch size per device = {args.batch_size}")
351
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
352
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
353
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
354
+ logger.info(f" Total validation prompts = {len(eval_dataloader)}")
355
+
356
+ progress_bar = tqdm(
357
+ range(0, args.max_train_steps),
358
+ initial=global_step,
359
+ desc="Steps",
360
+ total=args.max_train_steps,
361
+ disable=not accelerator.is_local_main_process,
362
+ )
363
+
364
+ train_loss = 0.0
365
+ while global_step < (args.max_train_steps):
366
+ batch = next(dataloader)
367
+ prompts = [txt_ if random.random() > args.text_dropout else "" for txt_ in batch["txt"]]
368
+ img = batch["img"]
369
+ ref_imgs = batch["ref_imgs"]
370
+
371
+ with torch.no_grad():
372
+ x_1 = vae.encode(img.to(accelerator.device).to(torch.float32))
373
+ x_ref = [vae.encode(ref_img.to(accelerator.device).to(torch.float32)) for ref_img in ref_imgs]
374
+ inp = prepare_multi_ip(t5=t5, clip=clip, img=x_1, prompt=prompts, ref_imgs=tuple(x_ref), pe=args.pe)
375
+ x_1 = rearrange(x_1, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
376
+ x_ref = [rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for x in x_ref]
377
+
378
+ bs = img.shape[0]
379
+ t = torch.randint(0, 1000, (bs,), device=accelerator.device)
380
+ t = timesteps[t]
381
+ x_0 = torch.randn_like(x_1, device=accelerator.device)
382
+ x_t = (1 - t[:, None, None]) * x_1 + t[:, None, None] * x_0
383
+ guidance_vec = torch.full((x_t.shape[0],), 1, device=x_t.device, dtype=x_t.dtype)
384
+
385
+ with accelerator.accumulate(dit):
386
+ # Predict the noise residual and compute loss
387
+ model_pred = dit(
388
+ img=x_t.to(weight_dtype),
389
+ img_ids=inp['img_ids'].to(weight_dtype),
390
+ ref_img=[x.to(weight_dtype) for x in x_ref],
391
+ ref_img_ids=[ref_img_id.to(weight_dtype) for ref_img_id in inp['ref_img_ids']],
392
+ txt=inp['txt'].to(weight_dtype),
393
+ txt_ids=inp['txt_ids'].to(weight_dtype),
394
+ y=inp['vec'].to(weight_dtype),
395
+ timesteps=t.to(weight_dtype),
396
+ guidance=guidance_vec.to(weight_dtype)
397
+ )
398
+
399
+ loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean")
400
+
401
+ # Gather the losses across all processes for logging (if we use distributed training).
402
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
403
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
404
+
405
+ # Backpropagate
406
+ accelerator.backward(loss)
407
+ if accelerator.sync_gradients:
408
+ accelerator.clip_grad_norm_(dit.parameters(), args.max_grad_norm)
409
+ optimizer.step()
410
+ lr_scheduler.step()
411
+ optimizer.zero_grad()
412
+
413
+ # Checks if the accelerator has performed an optimization step behind the scenes
414
+ if accelerator.sync_gradients:
415
+ progress_bar.update(1)
416
+ global_step += 1
417
+ accelerator.log({"train_loss": train_loss}, step=global_step)
418
+ train_loss = 0.0
419
+
420
+ if accelerator.sync_gradients and dit_ema_dict is not None and global_step % args.ema_interval == 0:
421
+ src_dict = dit.state_dict()
422
+ for tgt_name in dit_ema_dict:
423
+ dit_ema_dict[tgt_name].data.lerp_(src_dict[tgt_name].to(dit_ema_dict[tgt_name]), 1 - args.ema_decay)
424
+
425
+ if accelerator.sync_gradients and accelerator.is_main_process and global_step % args.checkpointing_steps == 0:
426
+ logger.info(f"saving checkpoint in {global_step=}")
427
+ save_path = os.path.join(args.project_dir, f"checkpoint-{global_step}")
428
+ os.makedirs(save_path, exist_ok=True)
429
+
430
+ # save
431
+ accelerator.wait_for_everyone()
432
+ unwrapped_model = accelerator.unwrap_model(dit)
433
+ unwrapped_model_state = unwrapped_model.state_dict()
434
+ unwrapped_model_state = {k: v for k, v in unwrapped_model_state.items() if v.requires_grad}
435
+
436
+ accelerator.save(
437
+ unwrapped_model_state,
438
+ os.path.join(save_path, 'dit_lora.safetensors'),
439
+ safe_serialization=True
440
+ )
441
+ unwrapped_opt = accelerator.unwrap_model(optimizer)
442
+ accelerator.save(unwrapped_opt.state_dict(), os.path.join(save_path, 'optimizer.bin'))
443
+ logger.info(f"Saved state to {save_path}")
444
+
445
+ if args.ema:
446
+ accelerator.save(
447
+ {k.split("module.")[-1]: v for k, v in dit_ema_dict.items()},
448
+ os.path.join(save_path, 'dit_lora_ema.safetensors')
449
+ )
450
+
451
+ # validate
452
+ dit.eval()
453
+ torch.set_grad_enabled(False)
454
+ for i, batch in enumerate(eval_dataloader):
455
+ result = inference(batch, dit, t5, clip, vae, accelerator, seed=0)
456
+ accelerator.log({f"eval_gen_{i}": result}, step=global_step)
457
+
458
+
459
+ if args.ema:
460
+ original_state_dict = dit.state_dict()
461
+ dit.load_state_dict(dit_ema_dict, strict=False)
462
+ for batch in eval_dataloader:
463
+ result = inference(batch, dit, t5, clip, vae, accelerator, seed=0)
464
+ accelerator.log({f"eval_ema_gen_{i}": result}, step=global_step)
465
+ dit.load_state_dict(original_state_dict, strict=False)
466
+
467
+ torch.cuda.empty_cache()
468
+ gc.collect()
469
+ torch.set_grad_enabled(True)
470
+ dit.train()
471
+ accelerator.wait_for_everyone()
472
+
473
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
474
+ progress_bar.set_postfix(**logs)
475
+
476
+ accelerator.wait_for_everyone()
477
+ accelerator.end_training()
478
+
479
+ if __name__ == "__main__":
480
+ parser = transformers.HfArgumentParser([TrainArgs])
481
+ args_tuple = parser.parse_args_into_dataclasses(args_file_flag="--config")
482
+ main(*args_tuple)