ramimu commited on
Commit
1c72248
·
verified ·
1 Parent(s): 8c226bf

Upload 586 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +28 -0
  2. ai-toolkit/.github/FUNDING.yml +2 -0
  3. ai-toolkit/.github/ISSUE_TEMPLATE/bug_report.md +19 -0
  4. ai-toolkit/.github/ISSUE_TEMPLATE/config.yml +5 -0
  5. ai-toolkit/.gitignore +183 -0
  6. ai-toolkit/.gitmodules +0 -0
  7. ai-toolkit/.vscode/launch.json +28 -0
  8. ai-toolkit/FAQ.md +10 -0
  9. ai-toolkit/LICENSE +21 -0
  10. ai-toolkit/README.md +342 -0
  11. ai-toolkit/__pycache__/info.cpython-312.pyc +0 -0
  12. ai-toolkit/assets/VAE_test1.jpg +3 -0
  13. ai-toolkit/assets/glif.svg +40 -0
  14. ai-toolkit/assets/lora_ease_ui.png +3 -0
  15. ai-toolkit/build_and_push_docker +29 -0
  16. ai-toolkit/config/examples/extract.example.yml +75 -0
  17. ai-toolkit/config/examples/generate.example.yaml +60 -0
  18. ai-toolkit/config/examples/mod_lora_scale.yaml +48 -0
  19. ai-toolkit/config/examples/modal/modal_train_lora_flux_24gb.yaml +96 -0
  20. ai-toolkit/config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml +98 -0
  21. ai-toolkit/config/examples/train_flex_redux.yaml +112 -0
  22. ai-toolkit/config/examples/train_full_fine_tune_flex.yaml +107 -0
  23. ai-toolkit/config/examples/train_full_fine_tune_lumina.yaml +99 -0
  24. ai-toolkit/config/examples/train_lora_chroma_24gb.yaml +97 -0
  25. ai-toolkit/config/examples/train_lora_flex_24gb.yaml +101 -0
  26. ai-toolkit/config/examples/train_lora_flux_24gb.yaml +96 -0
  27. ai-toolkit/config/examples/train_lora_flux_rami.yaml +78 -0
  28. ai-toolkit/config/examples/train_lora_flux_schnell_24gb.yaml +98 -0
  29. ai-toolkit/config/examples/train_lora_hidream_48.yaml +112 -0
  30. ai-toolkit/config/examples/train_lora_lumina.yaml +96 -0
  31. ai-toolkit/config/examples/train_lora_sd35_large_24gb.yaml +97 -0
  32. ai-toolkit/config/examples/train_lora_wan21_14b_24gb.yaml +101 -0
  33. ai-toolkit/config/examples/train_lora_wan21_1b_24gb.yaml +90 -0
  34. ai-toolkit/config/examples/train_slider.example.yml +230 -0
  35. ai-toolkit/docker-compose.yml +25 -0
  36. ai-toolkit/docker/Dockerfile +77 -0
  37. ai-toolkit/docker/start.sh +70 -0
  38. ai-toolkit/extensions/example/ExampleMergeModels.py +129 -0
  39. ai-toolkit/extensions/example/__init__.py +25 -0
  40. ai-toolkit/extensions/example/__pycache__/ExampleMergeModels.cpython-312.pyc +0 -0
  41. ai-toolkit/extensions/example/__pycache__/__init__.cpython-312.pyc +0 -0
  42. ai-toolkit/extensions/example/config/config.example.yaml +48 -0
  43. ai-toolkit/extensions_built_in/advanced_generator/Img2ImgGenerator.py +256 -0
  44. ai-toolkit/extensions_built_in/advanced_generator/PureLoraGenerator.py +102 -0
  45. ai-toolkit/extensions_built_in/advanced_generator/ReferenceGenerator.py +212 -0
  46. ai-toolkit/extensions_built_in/advanced_generator/__init__.py +59 -0
  47. ai-toolkit/extensions_built_in/advanced_generator/__pycache__/Img2ImgGenerator.cpython-312.pyc +0 -0
  48. ai-toolkit/extensions_built_in/advanced_generator/__pycache__/PureLoraGenerator.cpython-312.pyc +0 -0
  49. ai-toolkit/extensions_built_in/advanced_generator/__pycache__/ReferenceGenerator.cpython-312.pyc +0 -0
  50. ai-toolkit/extensions_built_in/advanced_generator/__pycache__/__init__.cpython-312.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,31 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ai-toolkit/assets/lora_ease_ui.png filter=lfs diff=lfs merge=lfs -text
37
+ ai-toolkit/assets/VAE_test1.jpg filter=lfs diff=lfs merge=lfs -text
38
+ ai-toolkit/images/image1.jpg filter=lfs diff=lfs merge=lfs -text
39
+ ai-toolkit/images/image10.jpg filter=lfs diff=lfs merge=lfs -text
40
+ ai-toolkit/images/image11.jpg filter=lfs diff=lfs merge=lfs -text
41
+ ai-toolkit/images/image12.jpg filter=lfs diff=lfs merge=lfs -text
42
+ ai-toolkit/images/image13.jpg filter=lfs diff=lfs merge=lfs -text
43
+ ai-toolkit/images/image14.jpg filter=lfs diff=lfs merge=lfs -text
44
+ ai-toolkit/images/image15.jpg filter=lfs diff=lfs merge=lfs -text
45
+ ai-toolkit/images/image16.jpg filter=lfs diff=lfs merge=lfs -text
46
+ ai-toolkit/images/image17.jpg filter=lfs diff=lfs merge=lfs -text
47
+ ai-toolkit/images/image18.jpg filter=lfs diff=lfs merge=lfs -text
48
+ ai-toolkit/images/image19.jpg filter=lfs diff=lfs merge=lfs -text
49
+ ai-toolkit/images/image2.jpg filter=lfs diff=lfs merge=lfs -text
50
+ ai-toolkit/images/image20.jpg filter=lfs diff=lfs merge=lfs -text
51
+ ai-toolkit/images/image21.jpg filter=lfs diff=lfs merge=lfs -text
52
+ ai-toolkit/images/image22.jpg filter=lfs diff=lfs merge=lfs -text
53
+ ai-toolkit/images/image26.jpg filter=lfs diff=lfs merge=lfs -text
54
+ ai-toolkit/images/image3.jpg filter=lfs diff=lfs merge=lfs -text
55
+ ai-toolkit/images/image4.jpg filter=lfs diff=lfs merge=lfs -text
56
+ ai-toolkit/images/image5.jpg filter=lfs diff=lfs merge=lfs -text
57
+ ai-toolkit/images/image6.jpg filter=lfs diff=lfs merge=lfs -text
58
+ ai-toolkit/images/image7.jpg filter=lfs diff=lfs merge=lfs -text
59
+ ai-toolkit/images/image8.jpg filter=lfs diff=lfs merge=lfs -text
60
+ ai-toolkit/images/image9.jpg filter=lfs diff=lfs merge=lfs -text
61
+ ai-toolkit/jobs/process/__pycache__/BaseSDTrainProcess.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
62
+ ai-toolkit/toolkit/__pycache__/dataloader_mixins.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
63
+ ai-toolkit/toolkit/__pycache__/stable_diffusion_model.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
ai-toolkit/.github/FUNDING.yml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ github: [ostris]
2
+ patreon: ostris
ai-toolkit/.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug Report
3
+ about: For bugs only. Not for feature requests or questions.
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+ ---
8
+
9
+ ## This is for bugs only
10
+
11
+ Did you already ask [in the discord](https://discord.gg/VXmU2f5WEU)?
12
+
13
+ Yes/No
14
+
15
+ You verified that this is a bug and not a feature request or question by asking [in the discord](https://discord.gg/VXmU2f5WEU)?
16
+
17
+ Yes/No
18
+
19
+ ## Describe the bug
ai-toolkit/.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ blank_issues_enabled: false
2
+ contact_links:
3
+ - name: Ask in the Discord BEFORE opening an issue
4
+ url: https://discord.gg/VXmU2f5WEU
5
+ about: Please ask in the discord before opening a github issue.
ai-toolkit/.gitignore ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ .idea/
161
+
162
+ /env.sh
163
+ /models
164
+ /datasets
165
+ /custom/*
166
+ !/custom/.gitkeep
167
+ /.tmp
168
+ /venv.bkp
169
+ /venv.*
170
+ /config/*
171
+ !/config/examples
172
+ !/config/_PUT_YOUR_CONFIGS_HERE).txt
173
+ /output/*
174
+ !/output/.gitkeep
175
+ /extensions/*
176
+ !/extensions/example
177
+ /temp
178
+ /wandb
179
+ .vscode/settings.json
180
+ .DS_Store
181
+ ._.DS_Store
182
+ aitk_db.db
183
+ /notes.md
ai-toolkit/.gitmodules ADDED
File without changes
ai-toolkit/.vscode/launch.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "Run current config",
6
+ "type": "python",
7
+ "request": "launch",
8
+ "program": "${workspaceFolder}/run.py",
9
+ "args": [
10
+ "${file}"
11
+ ],
12
+ "env": {
13
+ "CUDA_LAUNCH_BLOCKING": "1",
14
+ "DEBUG_TOOLKIT": "1"
15
+ },
16
+ "console": "integratedTerminal",
17
+ "justMyCode": false
18
+ },
19
+ {
20
+ "name": "Python: Debug Current File",
21
+ "type": "python",
22
+ "request": "launch",
23
+ "program": "${file}",
24
+ "console": "integratedTerminal",
25
+ "justMyCode": false
26
+ },
27
+ ]
28
+ }
ai-toolkit/FAQ.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # FAQ
2
+
3
+ WIP. Will continue to add things as they are needed.
4
+
5
+ ## FLUX.1 Training
6
+
7
+ #### How much VRAM is required to train a lora on FLUX.1?
8
+
9
+ 24GB minimum is required.
10
+
ai-toolkit/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Ostris, LLC
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ai-toolkit/README.md ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AI Toolkit by Ostris
2
+
3
+ ## Support My Work
4
+
5
+ If you enjoy my work, or use it for commercial purposes, please consider sponsoring me so I can continue to maintain it. Every bit helps!
6
+
7
+ [Become a sponsor on GitHub](https://github.com/orgs/ostris) or [support me on Patreon](https://www.patreon.com/ostris).
8
+
9
+ Thank you to all my current supporters!
10
+
11
+ _Last updated: 2025-04-04_
12
+
13
+ ### GitHub Sponsors
14
+
15
+ <a href="https://github.com/josephrocca" title="josephrocca"><img src="https://avatars.githubusercontent.com/u/1167575?u=92d92921b4cb5c8c7e225663fed53c4b41897736&v=4" width="50" height="50" alt="josephrocca" style="border-radius:50%;display:inline-block;"></a> <a href="https://github.com/replicate" title="Replicate"><img src="https://avatars.githubusercontent.com/u/60410876?v=4" width="50" height="50" alt="Replicate" style="border-radius:50%;display:inline-block;"></a> <a href="https://github.com/weights-ai" title="Weights"><img src="https://avatars.githubusercontent.com/u/185568492?v=4" width="50" height="50" alt="Weights" style="border-radius:50%;display:inline-block;"></a>
16
+
17
+ ### Patreon Supporters
18
+
19
+ <a href="None" title="Aaron Amortegui"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/81275465/1e4148fe9c47452b838949d02dd9a70f/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=uzJzkUq9rte3wx8wDLjGAgvSoxdtZcAnH7HctDhdYEo%3D" width="50" height="50" alt="Aaron Amortegui" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Abraham Irawan"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/60995694/92e0e8f336eb4a5bb8d99b940247d1d1/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=pj6Tm8XRdpGJcAEdnCakqYSNiSjoAYjvZescX7d0ic0%3D" width="50" height="50" alt="Abraham Irawan" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Al H"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/570742/4ceb33453a5a4745b430a216aba9280f/eyJ3IjoyMDB9/1.jpg?token-time=2145916800&token-hash=wUzsI5cO5Evp2ukIGdSgBbvKeYgv5LSOQMa6Br33Rrs%3D" width="50" height="50" alt="Al H" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Albert Bukoski"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/44568304/a9d83a0e786b41b4bdada150f7c9271c/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=SBphTD654nwr-OTrvIBIJBEQho7GE2PtRre8nyaG1Fk%3D" width="50" height="50" alt="Albert Bukoski" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Armin Behjati"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/93348210/5c650f32a0bc481d80900d2674528777/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=PpXK9B_iy288annlNdLOexhiQHbTftPEDeCh-sTQ2KA%3D" width="50" height="50" alt="Armin Behjati" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Arvin Flores"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/49304261/d0a730de1c3349e585c49288b9f419c6/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=C2BMZ3ci-Ty2nhnSwKZqsR-5hOGsUNDYcvXps0Geq9w%3D" width="50" height="50" alt="Arvin Flores" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Austin Robinson"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/164958178/4eb7a37baa0541bab7a091f2b14615b7/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=_aaum7fBJAGaJhMBhlR8vqYavDhExdVxmO9mwd3_XMw%3D" width="50" height="50" alt="Austin Robinson" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Ben Ward"><img src="https://c8.patreon.com/3/200/5048649" width="50" height="50" alt="Ben Ward" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Bharat Prabhakar"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/134129880/680c7e14cd1a4d1a9face921fb010f88/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=vNKojv67krNqx7gdpKBX1R_stX2TkMRYvRc0xZrbY6s%3D" width="50" height="50" alt="Bharat Prabhakar" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Bnp"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/130338124/f904a3bb76cd4588ac8d8f595c6cb486/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=k-inISRUtYDu9q7fNAKc3S2S7qcaw26fr1pj7PqU28Q%3D" width="50" height="50" alt="Bnp" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="clement Delangue"><img src="https://c8.patreon.com/3/200/33158543" width="50" height="50" alt="clement Delangue" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Cosmosis"><img src="https://c8.patreon.com/3/200/70218846" width="50" height="50" alt="Cosmosis" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="David Garrido"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/27288932/6c35d2d961ee4e14a7a368c990791315/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=dpFFssZXZM_KZMKQhl3uDwwusdFw1c_v9x_ChJU7_zc%3D" width="50" height="50" alt="David Garrido" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Doron Adler"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/82763/f99cc484361d4b9d94fe4f0814ada303/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=BpwC020pR3TRZ4r0RSCiSIOh-jmatkrpy1h2XU4sGa4%3D" width="50" height="50" alt="Doron Adler" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Eli Slugworth"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/54890369/45cea21d82974c78bf43956de7fb0e12/eyJ3IjoyMDB9/2.jpeg?token-time=2145916800&token-hash=IK6OT6UpusHgdaC4y8IhK5XxXiP5TuLy3vjvgL77Fho%3D" width="50" height="50" alt="Eli Slugworth" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="EmmanuelMr18"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/98811435/3a3632d1795b4c2b9f8f0270f2f6a650/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=93w8RMxwXlcM4X74t03u6P5_SrKvlm1IpjnD2SzVpJk%3D" width="50" height="50" alt="EmmanuelMr18" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Fagem X"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/113207022/d4a67cc113e84fb69032bef71d068720/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=mu-tIg88VwoQdgLEOmxuVkhVm9JT59DdnHXJstmkkLU%3D" width="50" height="50" alt="Fagem X" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="George Gostyshev"><img src="https://c8.patreon.com/3/200/2410522" width="50" height="50" alt="George Gostyshev" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Gili Ben Shahar"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/83054970/13de6cb103ad41a5841edf549e66cd51/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=wU_Eke9VYcfI40FAQvdEV84Xspqlo5VSiafLqhg_FOE%3D" width="50" height="50" alt="Gili Ben Shahar" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="HestoySeghuro ."><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/30931983/54ab4e4ceab946e79a6418d205f9ed51/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=LBmsSsMQZhO6yRZ_YyRwTgE6a7BVWrGNsAVveLXHXR0%3D" width="50" height="50" alt="HestoySeghuro ." style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Jack Blakely"><img src="https://c8.patreon.com/3/200/4105384" width="50" height="50" alt="Jack Blakely" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Jack English"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/45562978/0de33cf52ec642ae8a2f612cddec4ca6/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=hSAvaD4phiLcF0pvX7FP0juI5NQWCon-_TZSNpJzQJg%3D" width="50" height="50" alt="Jack English" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Jason"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/150257013/0e9e333d30294eef9f4d6821166966d8/eyJ3IjoyMDB9/2.png?token-time=2145916800&token-hash=hPH_rp5L5OJ9ZMS1wZfpVXDB4lRv2GHpV6r8Jmbmqww%3D" width="50" height="50" alt="Jason" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Jean-Tristan Marin"><img src="https://c8.patreon.com/3/200/27791680" width="50" height="50" alt="Jean-Tristan Marin" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Jodh Singh"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/131773947/eda3405aa582437db4582fce908c8739/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=S4Bh0sMqTNmJlo3uRr7co5d_kxvBjITemDTfi_1KrCA%3D" width="50" height="50" alt="Jodh Singh" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="John Dopamine"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/103077711/bb215761cc004e80bd9cec7d4bcd636d/eyJ3IjoyMDB9/2.jpeg?token-time=2145916800&token-hash=zvtBie29rRTKTXvAA2KhOI-l3mSMk9xxr-mg_CksLtc%3D" width="50" height="50" alt="John Dopamine" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Joseph Rocca"><img src="https://c8.patreon.com/3/200/93304" width="50" height="50" alt="Joseph Rocca" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Julian Tsependa"><img src="https://c8.patreon.com/3/200/494309" width="50" height="50" alt="Julian Tsependa" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Kasım Açıkbaş"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/31471379/0a887513ee314a1c86d0b6f8792e9795/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=DJMZs3rDlS0fCM_ahm95FAbjleM_L0gsO9qAPzqd0nA%3D" width="50" height="50" alt="Kasım Açıkbaş" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Kelevra"><img src="https://c8.patreon.com/3/200/5602036" width="50" height="50" alt="Kelevra" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Kristjan Retter"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/152118848/3b15a43d71714552b5ed1c9f84e66adf/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=IEKE18CBHVZ3k-08UD7Dkb7HbiFHb84W0FATdLMI0Dg%3D" width="50" height="50" alt="Kristjan Retter" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="L D"><img src="https://c8.patreon.com/3/200/358350" width="50" height="50" alt="L D" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Lukas"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/140599287/cff037fb93804af28bc3a4f1e91154f8/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=vkscmpmFoM5wq7GnsLmOEgNhvyXe-774kNGNqD0wurE%3D" width="50" height="50" alt="Lukas" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Marko jak"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/159203973/36c817f941ac4fa18103a4b8c0cb9cae/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=9toslDfsO14QyaOiu6vIf--d4marBsWCZWN3gdPqbIU%3D" width="50" height="50" alt="Marko jak" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Michael Levine"><img src="https://c8.patreon.com/3/200/22809690" width="50" height="50" alt="Michael Levine" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Miguel Lara"><img src="https://c8.patreon.com/3/200/83319230" width="50" height="50" alt="Miguel Lara" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Misch Strotz"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/8654302/b0f5ebedc62a47c4b56222693e1254e9/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=lpeicIh1_S-3Ji3W27gyiRB7iXurp8Bx8HAzDHftOuo%3D" width="50" height="50" alt="Misch Strotz" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Mohamed Oumoumad"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/2298192/1228b69bd7d7481baf3103315183250d/eyJ3IjoyMDB9/1.jpg?token-time=2145916800&token-hash=1B7dbXy_gAcPT9WXBesLhs7z_9APiz2k1Wx4Vml_-8Q%3D" width="50" height="50" alt="Mohamed Oumoumad" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="nitish PNR"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/120239481/49b1ce70d3d24704b8ec34de24ec8f55/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=Dv1NPKwdv9QT8fhYYwbGnQIvfiyqTUlh52bjDW1vYxY%3D" width="50" height="50" alt="nitish PNR" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Noctre"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/99036356/7ae9c4d80e604e739b68cca12ee2ed01/eyJ3IjoyMDB9/3.png?token-time=2145916800&token-hash=zK0dHe6A937WtNlrGdefoXFTPPzHUCfn__23HP8-Ui0%3D" width="50" height="50" alt="Noctre" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Patron"><img src="https://c8.patreon.com/3/200/8449560" width="50" height="50" alt="Patron" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Plaidam"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/338551/e8f257d8d3dd46c38272b391a5785948/eyJ3IjoyMDB9/1.jpg?token-time=2145916800&token-hash=GLom1rGgOZjBeO7I1OnjiIgWmjl6PO9ZjBB8YTvc7AM%3D" width="50" height="50" alt="Plaidam" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Prasanth Veerina"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/162524101/81a72689c3754ac5b9e38612ce5ce914/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=3XLSlLFCWAQ-0wd2_vZMikyotdQNSzKOjoyeoJiZEw0%3D" width="50" height="50" alt="Prasanth Veerina" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="RayHell"><img src="https://c8.patreon.com/3/200/24653779" width="50" height="50" alt="RayHell" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Steve Hanff"><img src="https://c8.patreon.com/3/200/548524" width="50" height="50" alt="Steve Hanff" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Steven Simmons"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/163426977/fc3941c79e894fef985d9f5440255313/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=TjwllfKCd_Ftt1C2wFYdcOdJZxyuPaRpEbKjrfzk0Zw%3D" width="50" height="50" alt="Steven Simmons" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Sören "><img src="https://c8.patreon.com/3/200/4541423" width="50" height="50" alt="Sören " style="border-radius:50%;display:inline-block;"></a> <a href="None" title="the biitz"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/76566911/6485eaf5ec6249a7b524ee0b979372f0/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=S1QK78ief5byQU7tB_reqnw4V2zhW_cpwTqHThk-tGc%3D" width="50" height="50" alt="the biitz" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="The Local Lab"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/141098579/1a9f0a1249d447a7a0df718a57343912/eyJ3IjoyMDB9/2.png?token-time=2145916800&token-hash=Rd_AjZGhMATVkZDf8E95ILc0n93gvvFWe1Ig0_dxwf4%3D" width="50" height="50" alt="The Local Lab" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Trent Hunter"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/31950857/c567dc648f6144be9f6234946df05da2/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=3Vx4R1eOfD4X_ZPPd40MsZ-3lyknLM35XmaHRELnWjM%3D" width="50" height="50" alt="Trent Hunter" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Ultimate Golf Archives"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/96561218/b0694642d13a49faa75aec9762ff2aeb/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=sLQXomYm1iMYpknvGwKQ49f30TKQ0B1R2W3EZfCJqr8%3D" width="50" height="50" alt="Ultimate Golf Archives" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Un Defined"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/155963250/6f8fd7075c3b4247bfeb054ba49172d6/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=twmKs4mADF_h7bKh5jBuigYVScMeaeHv2pEPin9K0Dg%3D" width="50" height="50" alt="Un Defined" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Vladimir Sotnikov"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/161471720/dd330b4036d44a5985ed5985c12a5def/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=qkRvrEc5gLPxaXxLvcvbYv1W1lcmOoTwhj4A9Cq5BxQ%3D" width="50" height="50" alt="Vladimir Sotnikov" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Wesley Reitzfeld"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/110407414/30f9e9d88ef945ddb0f47fd23a8cbac2/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=QQRWOkMyOfDBERHn4O8N2wMB32zeiIEsydVTbSNUw-I%3D" width="50" height="50" alt="Wesley Reitzfeld" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="william tatum"><img src="https://c8.patreon.com/3/200/83034" width="50" height="50" alt="william tatum" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Zack Abrams"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/32633822/1ab5612efe80417cbebfe91e871fc052/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=RHYMcjr0UGIYw5FBrUfJdKMGuoYWhBQlLIykccEFJvo%3D" width="50" height="50" alt="Zack Abrams" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Zoltán-Csaba Nyiró"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/162398691/89d78d89eecb4d6b981ce8c3c6a3d4b8/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=SWhI-0jGpY6Nc_bUQeXz4pa9DRURi9VnnnJ3Mxjg1po%3D" width="50" height="50" alt="Zoltán-Csaba Nyiró" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="Алексей Наумов"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/26019082/6ba968129e284c869069b261c875ae02/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=Jz-Kk9l8RIvGMNcaGXuN8_vaY3G435lFmtJtFZA3OCs%3D" width="50" height="50" alt="Алексей Наумов" style="border-radius:50%;display:inline-block;"></a> <a href="None" title="עומר מכלוף"><img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/97985240/3d1d0e6905d045aba713e8132cab4a30/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=pG3X2m-py2lRYI2aoJiXI47_4ArD78ZHdSm6jCAHA_w%3D" width="50" height="50" alt="עומר מכלוף" style="border-radius:50%;display:inline-block;"></a>
20
+
21
+
22
+ ---
23
+
24
+
25
+
26
+
27
+ ## Installation
28
+
29
+ Requirements:
30
+ - python >3.10
31
+ - Nvidia GPU with enough ram to do what you need
32
+ - python venv
33
+ - git
34
+
35
+
36
+ Linux:
37
+ ```bash
38
+ git clone https://github.com/ostris/ai-toolkit.git
39
+ cd ai-toolkit
40
+ python3 -m venv venv
41
+ source venv/bin/activate
42
+ # install torch first
43
+ pip3 install --no-cache-dir torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu126
44
+ pip3 install -r requirements.txt
45
+ ```
46
+
47
+ Windows:
48
+ ```bash
49
+ git clone https://github.com/ostris/ai-toolkit.git
50
+ cd ai-toolkit
51
+ python -m venv venv
52
+ .\venv\Scripts\activate
53
+ pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu126
54
+ pip install -r requirements.txt
55
+ ```
56
+
57
+
58
+ # AI Toolkit UI
59
+
60
+ <img src="https://ostris.com/wp-content/uploads/2025/02/toolkit-ui.jpg" alt="AI Toolkit UI" width="100%">
61
+
62
+ The AI Toolkit UI is a web interface for the AI Toolkit. It allows you to easily start, stop, and monitor jobs. It also allows you to easily train models with a few clicks. It also allows you to set a token for the UI to prevent unauthorized access so it is mostly safe to run on an exposed server.
63
+
64
+ ## Running the UI
65
+
66
+ Requirements:
67
+ - Node.js > 18
68
+
69
+ The UI does not need to be kept running for the jobs to run. It is only needed to start/stop/monitor jobs. The commands below
70
+ will install / update the UI and it's dependencies and start the UI.
71
+
72
+ ```bash
73
+ cd ui
74
+ npm run build_and_start
75
+ ```
76
+
77
+ You can now access the UI at `http://localhost:8675` or `http://<your-ip>:8675` if you are running it on a server.
78
+
79
+ ## Securing the UI
80
+
81
+ If you are hosting the UI on a cloud provider or any network that is not secure, I highly recommend securing it with an auth token.
82
+ You can do this by setting the environment variable `AI_TOOLKIT_AUTH` to super secure password. This token will be required to access
83
+ the UI. You can set this when starting the UI like so:
84
+
85
+ ```bash
86
+ # Linux
87
+ AI_TOOLKIT_AUTH=super_secure_password npm run build_and_start
88
+
89
+ # Windows
90
+ set AI_TOOLKIT_AUTH=super_secure_password && npm run build_and_start
91
+
92
+ # Windows Powershell
93
+ $env:AI_TOOLKIT_AUTH="super_secure_password"; npm run build_and_start
94
+ ```
95
+
96
+
97
+ ## FLUX.1 Training
98
+
99
+ ### Tutorial
100
+
101
+ To get started quickly, check out [@araminta_k](https://x.com/araminta_k) tutorial on [Finetuning Flux Dev on a 3090](https://www.youtube.com/watch?v=HzGW_Kyermg) with 24GB VRAM.
102
+
103
+
104
+ ### Requirements
105
+ You currently need a GPU with **at least 24GB of VRAM** to train FLUX.1. If you are using it as your GPU to control
106
+ your monitors, you probably need to set the flag `low_vram: true` in the config file under `model:`. This will quantize
107
+ the model on CPU and should allow it to train with monitors attached. Users have gotten it to work on Windows with WSL,
108
+ but there are some reports of a bug when running on windows natively.
109
+ I have only tested on linux for now. This is still extremely experimental
110
+ and a lot of quantizing and tricks had to happen to get it to fit on 24GB at all.
111
+
112
+ ### FLUX.1-dev
113
+
114
+ FLUX.1-dev has a non-commercial license. Which means anything you train will inherit the
115
+ non-commercial license. It is also a gated model, so you need to accept the license on HF before using it.
116
+ Otherwise, this will fail. Here are the required steps to setup a license.
117
+
118
+ 1. Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
119
+ 2. Make a file named `.env` in the root on this folder
120
+ 3. [Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and add it to the `.env` file like so `HF_TOKEN=your_key_here`
121
+
122
+ ### FLUX.1-schnell
123
+
124
+ FLUX.1-schnell is Apache 2.0. Anything trained on it can be licensed however you want and it does not require a HF_TOKEN to train.
125
+ However, it does require a special adapter to train with it, [ostris/FLUX.1-schnell-training-adapter](https://huggingface.co/ostris/FLUX.1-schnell-training-adapter).
126
+ It is also highly experimental. For best overall quality, training on FLUX.1-dev is recommended.
127
+
128
+ To use it, You just need to add the assistant to the `model` section of your config file like so:
129
+
130
+ ```yaml
131
+ model:
132
+ name_or_path: "black-forest-labs/FLUX.1-schnell"
133
+ assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter"
134
+ is_flux: true
135
+ quantize: true
136
+ ```
137
+
138
+ You also need to adjust your sample steps since schnell does not require as many
139
+
140
+ ```yaml
141
+ sample:
142
+ guidance_scale: 1 # schnell does not do guidance
143
+ sample_steps: 4 # 1 - 4 works well
144
+ ```
145
+
146
+ ### Training
147
+ 1. Copy the example config file located at `config/examples/train_lora_flux_24gb.yaml` (`config/examples/train_lora_flux_schnell_24gb.yaml` for schnell) to the `config` folder and rename it to `whatever_you_want.yml`
148
+ 2. Edit the file following the comments in the file
149
+ 3. Run the file like so `python run.py config/whatever_you_want.yml`
150
+
151
+ A folder with the name and the training folder from the config file will be created when you start. It will have all
152
+ checkpoints and images in it. You can stop the training at any time using ctrl+c and when you resume, it will pick back up
153
+ from the last checkpoint.
154
+
155
+ IMPORTANT. If you press crtl+c while it is saving, it will likely corrupt that checkpoint. So wait until it is done saving
156
+
157
+ ### Need help?
158
+
159
+ Please do not open a bug report unless it is a bug in the code. You are welcome to [Join my Discord](https://discord.gg/VXmU2f5WEU)
160
+ and ask for help there. However, please refrain from PMing me directly with general question or support. Ask in the discord
161
+ and I will answer when I can.
162
+
163
+ ## Gradio UI
164
+
165
+ To get started training locally with a with a custom UI, once you followed the steps above and `ai-toolkit` is installed:
166
+
167
+ ```bash
168
+ cd ai-toolkit #in case you are not yet in the ai-toolkit folder
169
+ huggingface-cli login #provide a `write` token to publish your LoRA at the end
170
+ python flux_train_ui.py
171
+ ```
172
+
173
+ You will instantiate a UI that will let you upload your images, caption them, train and publish your LoRA
174
+ ![image](assets/lora_ease_ui.png)
175
+
176
+
177
+ ## Training in RunPod
178
+ Example RunPod template: **runpod/pytorch:2.2.0-py3.10-cuda12.1.1-devel-ubuntu22.04**
179
+ > You need a minimum of 24GB VRAM, pick a GPU by your preference.
180
+
181
+ #### Example config ($0.5/hr):
182
+ - 1x A40 (48 GB VRAM)
183
+ - 19 vCPU 100 GB RAM
184
+
185
+ #### Custom overrides (you need some storage to clone FLUX.1, store datasets, store trained models and samples):
186
+ - ~120 GB Disk
187
+ - ~120 GB Pod Volume
188
+ - Start Jupyter Notebook
189
+
190
+ ### 1. Setup
191
+ ```
192
+ git clone https://github.com/ostris/ai-toolkit.git
193
+ cd ai-toolkit
194
+ git submodule update --init --recursive
195
+ python -m venv venv
196
+ source venv/bin/activate
197
+ pip install torch
198
+ pip install -r requirements.txt
199
+ pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues
200
+ ```
201
+ ### 2. Upload your dataset
202
+ - Create a new folder in the root, name it `dataset` or whatever you like.
203
+ - Drag and drop your .jpg, .jpeg, or .png images and .txt files inside the newly created dataset folder.
204
+
205
+ ### 3. Login into Hugging Face with an Access Token
206
+ - Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
207
+ - Run ```huggingface-cli login``` and paste your token.
208
+
209
+ ### 4. Training
210
+ - Copy an example config file located at ```config/examples``` to the config folder and rename it to ```whatever_you_want.yml```.
211
+ - Edit the config following the comments in the file.
212
+ - Change ```folder_path: "/path/to/images/folder"``` to your dataset path like ```folder_path: "/workspace/ai-toolkit/your-dataset"```.
213
+ - Run the file: ```python run.py config/whatever_you_want.yml```.
214
+
215
+ ### Screenshot from RunPod
216
+ <img width="1728" alt="RunPod Training Screenshot" src="https://github.com/user-attachments/assets/53a1b8ef-92fa-4481-81a7-bde45a14a7b5">
217
+
218
+ ## Training in Modal
219
+
220
+ ### 1. Setup
221
+ #### ai-toolkit:
222
+ ```
223
+ git clone https://github.com/ostris/ai-toolkit.git
224
+ cd ai-toolkit
225
+ git submodule update --init --recursive
226
+ python -m venv venv
227
+ source venv/bin/activate
228
+ pip install torch
229
+ pip install -r requirements.txt
230
+ pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues
231
+ ```
232
+ #### Modal:
233
+ - Run `pip install modal` to install the modal Python package.
234
+ - Run `modal setup` to authenticate (if this doesn’t work, try `python -m modal setup`).
235
+
236
+ #### Hugging Face:
237
+ - Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
238
+ - Run `huggingface-cli login` and paste your token.
239
+
240
+ ### 2. Upload your dataset
241
+ - Drag and drop your dataset folder containing the .jpg, .jpeg, or .png images and .txt files in `ai-toolkit`.
242
+
243
+ ### 3. Configs
244
+ - Copy an example config file located at ```config/examples/modal``` to the `config` folder and rename it to ```whatever_you_want.yml```.
245
+ - Edit the config following the comments in the file, **<ins>be careful and follow the example `/root/ai-toolkit` paths</ins>**.
246
+
247
+ ### 4. Edit run_modal.py
248
+ - Set your entire local `ai-toolkit` path at `code_mount = modal.Mount.from_local_dir` like:
249
+
250
+ ```
251
+ code_mount = modal.Mount.from_local_dir("/Users/username/ai-toolkit", remote_path="/root/ai-toolkit")
252
+ ```
253
+ - Choose a `GPU` and `Timeout` in `@app.function` _(default is A100 40GB and 2 hour timeout)_.
254
+
255
+ ### 5. Training
256
+ - Run the config file in your terminal: `modal run run_modal.py --config-file-list-str=/root/ai-toolkit/config/whatever_you_want.yml`.
257
+ - You can monitor your training in your local terminal, or on [modal.com](https://modal.com/).
258
+ - Models, samples and optimizer will be stored in `Storage > flux-lora-models`.
259
+
260
+ ### 6. Saving the model
261
+ - Check contents of the volume by running `modal volume ls flux-lora-models`.
262
+ - Download the content by running `modal volume get flux-lora-models your-model-name`.
263
+ - Example: `modal volume get flux-lora-models my_first_flux_lora_v1`.
264
+
265
+ ### Screenshot from Modal
266
+
267
+ <img width="1728" alt="Modal Traning Screenshot" src="https://github.com/user-attachments/assets/7497eb38-0090-49d6-8ad9-9c8ea7b5388b">
268
+
269
+ ---
270
+
271
+ ## Dataset Preparation
272
+
273
+ Datasets generally need to be a folder containing images and associated text files. Currently, the only supported
274
+ formats are jpg, jpeg, and png. Webp currently has issues. The text files should be named the same as the images
275
+ but with a `.txt` extension. For example `image2.jpg` and `image2.txt`. The text file should contain only the caption.
276
+ You can add the word `[trigger]` in the caption file and if you have `trigger_word` in your config, it will be automatically
277
+ replaced.
278
+
279
+ Images are never upscaled but they are downscaled and placed in buckets for batching. **You do not need to crop/resize your images**.
280
+ The loader will automatically resize them and can handle varying aspect ratios.
281
+
282
+
283
+ ## Training Specific Layers
284
+
285
+ To train specific layers with LoRA, you can use the `only_if_contains` network kwargs. For instance, if you want to train only the 2 layers
286
+ used by The Last Ben, [mentioned in this post](https://x.com/__TheBen/status/1829554120270987740), you can adjust your
287
+ network kwargs like so:
288
+
289
+ ```yaml
290
+ network:
291
+ type: "lora"
292
+ linear: 128
293
+ linear_alpha: 128
294
+ network_kwargs:
295
+ only_if_contains:
296
+ - "transformer.single_transformer_blocks.7.proj_out"
297
+ - "transformer.single_transformer_blocks.20.proj_out"
298
+ ```
299
+
300
+ The naming conventions of the layers are in diffusers format, so checking the state dict of a model will reveal
301
+ the suffix of the name of the layers you want to train. You can also use this method to only train specific groups of weights.
302
+ For instance to only train the `single_transformer` for FLUX.1, you can use the following:
303
+
304
+ ```yaml
305
+ network:
306
+ type: "lora"
307
+ linear: 128
308
+ linear_alpha: 128
309
+ network_kwargs:
310
+ only_if_contains:
311
+ - "transformer.single_transformer_blocks."
312
+ ```
313
+
314
+ You can also exclude layers by their names by using `ignore_if_contains` network kwarg. So to exclude all the single transformer blocks,
315
+
316
+
317
+ ```yaml
318
+ network:
319
+ type: "lora"
320
+ linear: 128
321
+ linear_alpha: 128
322
+ network_kwargs:
323
+ ignore_if_contains:
324
+ - "transformer.single_transformer_blocks."
325
+ ```
326
+
327
+ `ignore_if_contains` takes priority over `only_if_contains`. So if a weight is covered by both,
328
+ if will be ignored.
329
+
330
+ ## LoKr Training
331
+
332
+ To learn more about LoKr, read more about it at [KohakuBlueleaf/LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md). To train a LoKr model, you can adjust the network type in the config file like so:
333
+
334
+ ```yaml
335
+ network:
336
+ type: "lokr"
337
+ lokr_full_rank: true
338
+ lokr_factor: 8
339
+ ```
340
+
341
+ Everything else should work the same including layer targeting.
342
+
ai-toolkit/__pycache__/info.cpython-312.pyc ADDED
Binary file (374 Bytes). View file
 
ai-toolkit/assets/VAE_test1.jpg ADDED

Git LFS Details

  • SHA256: 879fcb537d039408d7aada297b7397420132684f0106edacc1205fb5cc839476
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
ai-toolkit/assets/glif.svg ADDED
ai-toolkit/assets/lora_ease_ui.png ADDED

Git LFS Details

  • SHA256: f647b9fe90cc96db2aa84d1cb25a73b60ffcc5394822f99e9dac27d373f89d79
  • Pointer size: 131 Bytes
  • Size of remote file: 349 kB
ai-toolkit/build_and_push_docker ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Extract version from version.py
4
+ if [ -f "version.py" ]; then
5
+ VERSION=$(python3 -c "from version import VERSION; print(VERSION)")
6
+ echo "Building version: $VERSION"
7
+ else
8
+ echo "Error: version.py not found. Please create a version.py file with VERSION defined."
9
+ exit 1
10
+ fi
11
+
12
+ echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
13
+ echo "Building version: $VERSION and latest"
14
+ # wait 2 seconds
15
+ sleep 2
16
+
17
+ # Build the image with cache busting
18
+ docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile .
19
+
20
+ # Tag with version and latest
21
+ docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION
22
+ docker tag aitoolkit:$VERSION ostris/aitoolkit:latest
23
+
24
+ # Push both tags
25
+ echo "Pushing images to Docker Hub..."
26
+ docker push ostris/aitoolkit:$VERSION
27
+ docker push ostris/aitoolkit:latest
28
+
29
+ echo "Successfully built and pushed ostris/aitoolkit:$VERSION and ostris/aitoolkit:latest"
ai-toolkit/config/examples/extract.example.yml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # this is in yaml format. You can use json if you prefer
3
+ # I like both but yaml is easier to read and write
4
+ # plus it has comments which is nice for documentation
5
+ job: extract # tells the runner what to do
6
+ config:
7
+ # the name will be used to create a folder in the output folder
8
+ # it will also replace any [name] token in the rest of this config
9
+ name: name_of_your_model
10
+ # can be hugging face model, a .ckpt, or a .safetensors
11
+ base_model: "/path/to/base/model.safetensors"
12
+ # can be hugging face model, a .ckpt, or a .safetensors
13
+ extract_model: "/path/to/model/to/extract/trained.safetensors"
14
+ # we will create folder here with name above so. This will create /path/to/output/folder/name_of_your_model
15
+ output_folder: "/path/to/output/folder"
16
+ is_v2: false
17
+ dtype: fp16 # saved dtype
18
+ device: cpu # cpu, cuda:0, etc
19
+
20
+ # processes can be chained like this to run multiple in a row
21
+ # they must all use same models above, but great for testing different
22
+ # sizes and typed of extractions. It is much faster as we already have the models loaded
23
+ process:
24
+ # process 1
25
+ - type: locon # locon or lora (locon is lycoris)
26
+ filename: "[name]_64_32.safetensors" # will be put in output folder
27
+ dtype: fp16
28
+ mode: fixed
29
+ linear: 64
30
+ conv: 32
31
+
32
+ # process 2
33
+ - type: locon
34
+ output_path: "/absolute/path/for/this/output.safetensors" # can be absolute
35
+ mode: ratio
36
+ linear: 0.2
37
+ conv: 0.2
38
+
39
+ # process 3
40
+ - type: locon
41
+ filename: "[name]_ratio_02.safetensors"
42
+ mode: quantile
43
+ linear: 0.5
44
+ conv: 0.5
45
+
46
+ # process 4
47
+ - type: lora # traditional lora extraction (lierla) with linear layers only
48
+ filename: "[name]_4.safetensors"
49
+ mode: fixed # fixed, ratio, quantile supported for lora as well
50
+ linear: 4 # lora dim or rank
51
+ # no conv for lora
52
+
53
+ # process 5
54
+ - type: lora
55
+ filename: "[name]_q05.safetensors"
56
+ mode: quantile
57
+ linear: 0.5
58
+
59
+ # you can put any information you want here, and it will be saved in the model
60
+ # the below is an example. I recommend doing trigger words at a minimum
61
+ # in the metadata. The software will include this plus some other information
62
+ meta:
63
+ name: "[name]" # [name] gets replaced with the name above
64
+ description: A short description of your model
65
+ trigger_words:
66
+ - put
67
+ - trigger
68
+ - words
69
+ - here
70
+ version: '0.1'
71
+ creator:
72
+ name: Your Name
73
74
+ website: https://yourwebsite.com
75
+ any: All meta data above is arbitrary, it can be whatever you want.
ai-toolkit/config/examples/generate.example.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+
3
+ job: generate # tells the runner what to do
4
+ config:
5
+ name: "generate" # this is not really used anywhere currently but required by runner
6
+ process:
7
+ # process 1
8
+ - type: to_folder # process images to a folder
9
+ output_folder: "output/gen"
10
+ device: cuda:0 # cpu, cuda:0, etc
11
+ generate:
12
+ # these are your defaults you can override most of them with flags
13
+ sampler: "ddpm" # ignored for now, will add later though ddpm is used regardless for now
14
+ width: 1024
15
+ height: 1024
16
+ neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
17
+ seed: -1 # -1 is random
18
+ guidance_scale: 7
19
+ sample_steps: 20
20
+ ext: ".png" # .png, .jpg, .jpeg, .webp
21
+
22
+ # here ate the flags you can use for prompts. Always start with
23
+ # your prompt first then add these flags after. You can use as many
24
+ # like
25
+ # photo of a baseball --n painting, ugly --w 1024 --h 1024 --seed 42 --cfg 7 --steps 20
26
+ # we will try to support all sd-scripts flags where we can
27
+
28
+ # FROM SD-SCRIPTS
29
+ # --n Treat everything until the next option as a negative prompt.
30
+ # --w Specify the width of the generated image.
31
+ # --h Specify the height of the generated image.
32
+ # --d Specify the seed for the generated image.
33
+ # --l Specify the CFG scale for the generated image.
34
+ # --s Specify the number of steps during generation.
35
+
36
+ # OURS and some QOL additions
37
+ # --p2 Prompt for the second text encoder (SDXL only)
38
+ # --n2 Negative prompt for the second text encoder (SDXL only)
39
+ # --gr Specify the guidance rescale for the generated image (SDXL only)
40
+ # --seed Specify the seed for the generated image same as --d
41
+ # --cfg Specify the CFG scale for the generated image same as --l
42
+ # --steps Specify the number of steps during generation same as --s
43
+
44
+ prompt_file: false # if true a txt file will be created next to images with prompt strings used
45
+ # prompts can also be a path to a text file with one prompt per line
46
+ # prompts: "/path/to/prompts.txt"
47
+ prompts:
48
+ - "photo of batman"
49
+ - "photo of superman"
50
+ - "photo of spiderman"
51
+ - "photo of a superhero --n batman superman spiderman"
52
+
53
+ model:
54
+ # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
55
+ # name_or_path: "runwayml/stable-diffusion-v1-5"
56
+ name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors"
57
+ is_v2: false # for v2 models
58
+ is_v_pred: false # for v-prediction models (most v2 models)
59
+ is_xl: false # for SDXL models
60
+ dtype: bf16
ai-toolkit/config/examples/mod_lora_scale.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: mod
3
+ config:
4
+ name: name_of_your_model_v1
5
+ process:
6
+ - type: rescale_lora
7
+ # path to your current lora model
8
+ input_path: "/path/to/lora/lora.safetensors"
9
+ # output path for your new lora model, can be the same as input_path to replace
10
+ output_path: "/path/to/lora/output_lora_v1.safetensors"
11
+ # replaces meta with the meta below (plus minimum meta fields)
12
+ # if false, we will leave the meta alone except for updating hashes (sd-script hashes)
13
+ replace_meta: true
14
+ # how to adjust, we can scale the up_down weights or the alpha
15
+ # up_down is the default and probably the best, they will both net the same outputs
16
+ # would only affect rare NaN cases and maybe merging with old merge tools
17
+ scale_target: 'up_down'
18
+ # precision to save, fp16 is the default and standard
19
+ save_dtype: fp16
20
+ # current_weight is the ideal weight you use as a multiplier when using the lora
21
+ # IE in automatic1111 <lora:my_lora:6.0> the 6.0 is the current_weight
22
+ # you can do negatives here too if you want to flip the lora
23
+ current_weight: 6.0
24
+ # target_weight is the ideal weight you use as a multiplier when using the lora
25
+ # instead of the one above. IE in automatic1111 instead of using <lora:my_lora:6.0>
26
+ # we want to use <lora:my_lora:1.0> so 1.0 is the target_weight
27
+ target_weight: 1.0
28
+
29
+ # base model for the lora
30
+ # this is just used to add meta so automatic111 knows which model it is for
31
+ # assume v1.5 if these are not set
32
+ is_xl: false
33
+ is_v2: false
34
+ meta:
35
+ # this is only used if you set replace_meta to true above
36
+ name: "[name]" # [name] gets replaced with the name above
37
+ description: A short description of your lora
38
+ trigger_words:
39
+ - put
40
+ - trigger
41
+ - words
42
+ - here
43
+ version: '0.1'
44
+ creator:
45
+ name: Your Name
46
47
+ website: https://yourwebsite.com
48
+ any: All meta data above is arbitrary, it can be whatever you want.
ai-toolkit/config/examples/modal/modal_train_lora_flux_24gb.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ datasets:
25
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
26
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
27
+ # images will automatically be resized and bucketed into the resolution specified
28
+ # on windows, escape back slashes with another backslash so
29
+ # "C:\\path\\to\\images\\folder"
30
+ # your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
31
+ - folder_path: "/root/ai-toolkit/your-dataset"
32
+ caption_ext: "txt"
33
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
34
+ shuffle_tokens: false # shuffle caption order, split by commas
35
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
36
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
37
+ train:
38
+ batch_size: 1
39
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
40
+ gradient_accumulation_steps: 1
41
+ train_unet: true
42
+ train_text_encoder: false # probably won't work with flux
43
+ gradient_checkpointing: true # need the on unless you have a ton of vram
44
+ noise_scheduler: "flowmatch" # for training only
45
+ optimizer: "adamw8bit"
46
+ lr: 1e-4
47
+ # uncomment this to skip the pre training sample
48
+ # skip_first_sample: true
49
+ # uncomment to completely disable sampling
50
+ # disable_sampling: true
51
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
52
+ # linear_timesteps: true
53
+
54
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
55
+ ema_config:
56
+ use_ema: true
57
+ ema_decay: 0.99
58
+
59
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
60
+ dtype: bf16
61
+ model:
62
+ # huggingface model name or path
63
+ # if you get an error, or get stuck while downloading,
64
+ # check https://github.com/ostris/ai-toolkit/issues/84, download the model locally and
65
+ # place it like "/root/ai-toolkit/FLUX.1-dev"
66
+ name_or_path: "black-forest-labs/FLUX.1-dev"
67
+ is_flux: true
68
+ quantize: true # run 8bit mixed precision
69
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ prompts:
76
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
77
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
79
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82
+ - "a bear building a log cabin in the snow covered mountains"
83
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84
+ - "hipster man with a beard, building a chair, in a wood shop"
85
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
86
+ - "a man holding a sign that says, 'this is a sign'"
87
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88
+ neg: "" # not used on flux
89
+ seed: 42
90
+ walk_seed: true
91
+ guidance_scale: 4
92
+ sample_steps: 20
93
+ # you can add any additional meta info here. [name] is replaced with config name at top
94
+ meta:
95
+ name: "[name]"
96
+ version: '1.0'
ai-toolkit/config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ datasets:
25
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
26
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
27
+ # images will automatically be resized and bucketed into the resolution specified
28
+ # on windows, escape back slashes with another backslash so
29
+ # "C:\\path\\to\\images\\folder"
30
+ # your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
31
+ - folder_path: "/root/ai-toolkit/your-dataset"
32
+ caption_ext: "txt"
33
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
34
+ shuffle_tokens: false # shuffle caption order, split by commas
35
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
36
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
37
+ train:
38
+ batch_size: 1
39
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
40
+ gradient_accumulation_steps: 1
41
+ train_unet: true
42
+ train_text_encoder: false # probably won't work with flux
43
+ gradient_checkpointing: true # need the on unless you have a ton of vram
44
+ noise_scheduler: "flowmatch" # for training only
45
+ optimizer: "adamw8bit"
46
+ lr: 1e-4
47
+ # uncomment this to skip the pre training sample
48
+ # skip_first_sample: true
49
+ # uncomment to completely disable sampling
50
+ # disable_sampling: true
51
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
52
+ # linear_timesteps: true
53
+
54
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
55
+ ema_config:
56
+ use_ema: true
57
+ ema_decay: 0.99
58
+
59
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
60
+ dtype: bf16
61
+ model:
62
+ # huggingface model name or path
63
+ # if you get an error, or get stuck while downloading,
64
+ # check https://github.com/ostris/ai-toolkit/issues/84, download the models locally and
65
+ # place them like "/root/ai-toolkit/FLUX.1-schnell" and "/root/ai-toolkit/FLUX.1-schnell-training-adapter"
66
+ name_or_path: "black-forest-labs/FLUX.1-schnell"
67
+ assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
68
+ is_flux: true
69
+ quantize: true # run 8bit mixed precision
70
+ # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
71
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
72
+ sample:
73
+ sampler: "flowmatch" # must match train.noise_scheduler
74
+ sample_every: 250 # sample every this many steps
75
+ width: 1024
76
+ height: 1024
77
+ prompts:
78
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
79
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
80
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
81
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
82
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
83
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
84
+ - "a bear building a log cabin in the snow covered mountains"
85
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
86
+ - "hipster man with a beard, building a chair, in a wood shop"
87
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
88
+ - "a man holding a sign that says, 'this is a sign'"
89
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
90
+ neg: "" # not used on flux
91
+ seed: 42
92
+ walk_seed: true
93
+ guidance_scale: 1 # schnell does not do guidance
94
+ sample_steps: 4 # 1 - 4 works well
95
+ # you can add any additional meta info here. [name] is replaced with config name at top
96
+ meta:
97
+ name: "[name]"
98
+ version: '1.0'
ai-toolkit/config/examples/train_flex_redux.yaml ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flex_redux_finetune_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ adapter:
14
+ type: "redux"
15
+ # you can finetune an existing adapter or start from scratch. Set to null to start from scratch
16
+ name_or_path: '/local/path/to/redux_adapter_to_finetune.safetensors'
17
+ # name_or_path: null
18
+ # image_encoder_path: 'google/siglip-so400m-patch14-384' # Flux.1 redux adapter
19
+ image_encoder_path: 'google/siglip2-so400m-patch16-512' # Flex.1 512 redux adapter
20
+ # image_encoder_arch: 'siglip' # for Flux.1
21
+ image_encoder_arch: 'siglip2'
22
+ # You need a control input for each sample. Best to do squares for both images
23
+ test_img_path:
24
+ - "/path/to/x_01.jpg"
25
+ - "/path/to/x_02.jpg"
26
+ - "/path/to/x_03.jpg"
27
+ - "/path/to/x_04.jpg"
28
+ - "/path/to/x_05.jpg"
29
+ - "/path/to/x_06.jpg"
30
+ - "/path/to/x_07.jpg"
31
+ - "/path/to/x_08.jpg"
32
+ - "/path/to/x_09.jpg"
33
+ - "/path/to/x_10.jpg"
34
+ clip_layer: 'last_hidden_state'
35
+ train: true
36
+ save:
37
+ dtype: bf16 # precision to save
38
+ save_every: 250 # save every this many steps
39
+ max_step_saves_to_keep: 4
40
+ datasets:
41
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
42
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
43
+ # images will automatically be resized and bucketed into the resolution specified
44
+ # on windows, escape back slashes with another backslash so
45
+ # "C:\\path\\to\\images\\folder"
46
+ - folder_path: "/path/to/images/folder"
47
+ # clip_image_path is directory containting your control images. They must have filename as their train image. (extension does not matter)
48
+ # for normal redux, we are just recreating the same image, so you can use the same folder path above
49
+ clip_image_path: "/path/to/control/images/folder"
50
+ caption_ext: "txt"
51
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
52
+ resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
53
+ train:
54
+ # this is what I used for the 24GB card, but feel free to adjust
55
+ # total batch size is 6 here
56
+ batch_size: 3
57
+ gradient_accumulation: 2
58
+
59
+ # captions are not needed for this training, we cache a blank proompt and rely on the vision encoder
60
+ unload_text_encoder: true
61
+
62
+ loss_type: "mse"
63
+ train_unet: true
64
+ train_text_encoder: false
65
+ steps: 4000000 # I set this very high and stop when I like the results
66
+ content_or_style: balanced # content, style, balanced
67
+ gradient_checkpointing: true
68
+ noise_scheduler: "flowmatch" # or "ddpm", "lms", "euler_a"
69
+ timestep_type: "flux_shift"
70
+ optimizer: "adamw8bit"
71
+ lr: 1e-4
72
+
73
+ # this is for Flex.1, comment this out for FLUX.1-dev
74
+ bypass_guidance_embedding: true
75
+
76
+ dtype: bf16
77
+ ema_config:
78
+ use_ema: true
79
+ ema_decay: 0.99
80
+ model:
81
+ name_or_path: "ostris/Flex.1-alpha"
82
+ is_flux: true
83
+ quantize: true
84
+ text_encoder_bits: 8
85
+ sample:
86
+ sampler: "flowmatch" # must match train.noise_scheduler
87
+ sample_every: 250 # sample every this many steps
88
+ width: 1024
89
+ height: 1024
90
+ # I leave half blank to test prompt and unprompted
91
+ prompts:
92
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
93
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
94
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
95
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
96
+ - "a bear building a log cabin in the snow covered mountains"
97
+ - ""
98
+ - ""
99
+ - ""
100
+ - ""
101
+ - ""
102
+ neg: ""
103
+ seed: 42
104
+ walk_seed: true
105
+ guidance_scale: 4
106
+ sample_steps: 25
107
+ network_multiplier: 1.0
108
+
109
+ # you can add any additional meta info here. [name] is replaced with config name at top
110
+ meta:
111
+ name: "[name]"
112
+ version: '1.0'
ai-toolkit/config/examples/train_full_fine_tune_flex.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This configuration requires 48GB of VRAM or more to operate
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_flex_finetune_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ save:
18
+ dtype: bf16 # precision to save
19
+ save_every: 250 # save every this many steps
20
+ max_step_saves_to_keep: 2 # how many intermittent saves to keep
21
+ save_format: 'diffusers' # 'diffusers'
22
+ datasets:
23
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
24
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
25
+ # images will automatically be resized and bucketed into the resolution specified
26
+ # on windows, escape back slashes with another backslash so
27
+ # "C:\\path\\to\\images\\folder"
28
+ - folder_path: "/path/to/images/folder"
29
+ caption_ext: "txt"
30
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
31
+ shuffle_tokens: false # shuffle caption order, split by commas
32
+ # cache_latents_to_disk: true # leave this true unless you know what you're doing
33
+ resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
34
+ train:
35
+ batch_size: 1
36
+ # IMPORTANT! For Flex, you must bypass the guidance embedder during training
37
+ bypass_guidance_embedding: true
38
+
39
+ # can be 'sigmoid', 'linear', or 'lognorm_blend'
40
+ timestep_type: 'sigmoid'
41
+
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with flex
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adafactor"
49
+ lr: 3e-5
50
+
51
+ # Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0.
52
+ # 0.1 is 10% of paramiters active at easc step. Only works with adafactor
53
+
54
+ # do_paramiter_swapping: true
55
+ # paramiter_swapping_factor: 0.9
56
+
57
+ # uncomment this to skip the pre training sample
58
+ # skip_first_sample: true
59
+ # uncomment to completely disable sampling
60
+ # disable_sampling: true
61
+
62
+ # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
63
+ ema_config:
64
+ use_ema: true
65
+ ema_decay: 0.99
66
+
67
+ # will probably need this if gpu supports it for flex, other dtypes may not work correctly
68
+ dtype: bf16
69
+ model:
70
+ # huggingface model name or path
71
+ name_or_path: "ostris/Flex.1-alpha"
72
+ is_flux: true # flex is flux architecture
73
+ # full finetuning quantized models is a crapshoot and results in subpar outputs
74
+ # quantize: true
75
+ # you can quantize just the T5 text encoder here to save vram
76
+ quantize_te: true
77
+ # only train the transformer blocks
78
+ only_if_contains:
79
+ - "transformer.transformer_blocks."
80
+ - "transformer.single_transformer_blocks."
81
+ sample:
82
+ sampler: "flowmatch" # must match train.noise_scheduler
83
+ sample_every: 250 # sample every this many steps
84
+ width: 1024
85
+ height: 1024
86
+ prompts:
87
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
88
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
89
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
90
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
91
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
92
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
93
+ - "a bear building a log cabin in the snow covered mountains"
94
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
95
+ - "hipster man with a beard, building a chair, in a wood shop"
96
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
97
+ - "a man holding a sign that says, 'this is a sign'"
98
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
99
+ neg: "" # not used on flex
100
+ seed: 42
101
+ walk_seed: true
102
+ guidance_scale: 4
103
+ sample_steps: 25
104
+ # you can add any additional meta info here. [name] is replaced with config name at top
105
+ meta:
106
+ name: "[name]"
107
+ version: '1.0'
ai-toolkit/config/examples/train_full_fine_tune_lumina.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This configuration requires 24GB of VRAM or more to operate
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_lumina_finetune_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ save:
18
+ dtype: bf16 # precision to save
19
+ save_every: 250 # save every this many steps
20
+ max_step_saves_to_keep: 2 # how many intermittent saves to keep
21
+ save_format: 'diffusers' # 'diffusers'
22
+ datasets:
23
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
24
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
25
+ # images will automatically be resized and bucketed into the resolution specified
26
+ # on windows, escape back slashes with another backslash so
27
+ # "C:\\path\\to\\images\\folder"
28
+ - folder_path: "/path/to/images/folder"
29
+ caption_ext: "txt"
30
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
31
+ shuffle_tokens: false # shuffle caption order, split by commas
32
+ # cache_latents_to_disk: true # leave this true unless you know what you're doing
33
+ resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions
34
+ train:
35
+ batch_size: 1
36
+
37
+ # can be 'sigmoid', 'linear', or 'lumina2_shift'
38
+ timestep_type: 'lumina2_shift'
39
+
40
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
41
+ gradient_accumulation: 1
42
+ train_unet: true
43
+ train_text_encoder: false # probably won't work with lumina2
44
+ gradient_checkpointing: true # need the on unless you have a ton of vram
45
+ noise_scheduler: "flowmatch" # for training only
46
+ optimizer: "adafactor"
47
+ lr: 3e-5
48
+
49
+ # Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0.
50
+ # 0.1 is 10% of paramiters active at easc step. Only works with adafactor
51
+
52
+ # do_paramiter_swapping: true
53
+ # paramiter_swapping_factor: 0.9
54
+
55
+ # uncomment this to skip the pre training sample
56
+ # skip_first_sample: true
57
+ # uncomment to completely disable sampling
58
+ # disable_sampling: true
59
+
60
+ # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
61
+ # ema_config:
62
+ # use_ema: true
63
+ # ema_decay: 0.99
64
+
65
+ # will probably need this if gpu supports it for lumina2, other dtypes may not work correctly
66
+ dtype: bf16
67
+ model:
68
+ # huggingface model name or path
69
+ name_or_path: "Alpha-VLLM/Lumina-Image-2.0"
70
+ is_lumina2: true # lumina2 architecture
71
+ # you can quantize just the Gemma2 text encoder here to save vram
72
+ quantize_te: true
73
+ sample:
74
+ sampler: "flowmatch" # must match train.noise_scheduler
75
+ sample_every: 250 # sample every this many steps
76
+ width: 1024
77
+ height: 1024
78
+ prompts:
79
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
80
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
81
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
82
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
83
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
84
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
85
+ - "a bear building a log cabin in the snow covered mountains"
86
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
87
+ - "hipster man with a beard, building a chair, in a wood shop"
88
+ - "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear."
89
+ - "a man holding a sign that says, 'this is a sign'"
90
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
91
+ neg: ""
92
+ seed: 42
93
+ walk_seed: true
94
+ guidance_scale: 4.0
95
+ sample_steps: 25
96
+ # you can add any additional meta info here. [name] is replaced with config name at top
97
+ meta:
98
+ name: "[name]"
99
+ version: '1.0'
ai-toolkit/config/examples/train_lora_chroma_24gb.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_chroma_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # chroma enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with chroma
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ # uncomment this to skip the pre training sample
51
+ # skip_first_sample: true
52
+ # uncomment to completely disable sampling
53
+ # disable_sampling: true
54
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
55
+ # linear_timesteps: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for chroma, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # Download the whichever model you prefer from the Chroma repo
66
+ # https://huggingface.co/lodestones/Chroma/tree/main
67
+ # point to it here.
68
+ name_or_path: "/path/to/chroma/chroma-unlocked-vVERSION.safetensors"
69
+ arch: "chroma"
70
+ quantize: true # run 8bit mixed precision
71
+ sample:
72
+ sampler: "flowmatch" # must match train.noise_scheduler
73
+ sample_every: 250 # sample every this many steps
74
+ width: 1024
75
+ height: 1024
76
+ prompts:
77
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
78
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
79
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
80
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
81
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
82
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
83
+ - "a bear building a log cabin in the snow covered mountains"
84
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
85
+ - "hipster man with a beard, building a chair, in a wood shop"
86
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
87
+ - "a man holding a sign that says, 'this is a sign'"
88
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
89
+ neg: "" # negative prompt, optional
90
+ seed: 42
91
+ walk_seed: true
92
+ guidance_scale: 4
93
+ sample_steps: 25
94
+ # you can add any additional meta info here. [name] is replaced with config name at top
95
+ meta:
96
+ name: "[name]"
97
+ version: '1.0'
ai-toolkit/config/examples/train_lora_flex_24gb.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flex_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ # IMPORTANT! For Flex, you must bypass the guidance embedder during training
43
+ bypass_guidance_embedding: true
44
+
45
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
46
+ gradient_accumulation: 1
47
+ train_unet: true
48
+ train_text_encoder: false # probably won't work with flex
49
+ gradient_checkpointing: true # need the on unless you have a ton of vram
50
+ noise_scheduler: "flowmatch" # for training only
51
+ optimizer: "adamw8bit"
52
+ lr: 1e-4
53
+ # uncomment this to skip the pre training sample
54
+ # skip_first_sample: true
55
+ # uncomment to completely disable sampling
56
+ # disable_sampling: true
57
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
58
+ # linear_timesteps: true
59
+
60
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
61
+ ema_config:
62
+ use_ema: true
63
+ ema_decay: 0.99
64
+
65
+ # will probably need this if gpu supports it for flex, other dtypes may not work correctly
66
+ dtype: bf16
67
+ model:
68
+ # huggingface model name or path
69
+ name_or_path: "ostris/Flex.1-alpha"
70
+ is_flux: true
71
+ quantize: true # run 8bit mixed precision
72
+ quantize_kwargs:
73
+ exclude:
74
+ - "*time_text_embed*" # exclude the time text embedder from quantization
75
+ sample:
76
+ sampler: "flowmatch" # must match train.noise_scheduler
77
+ sample_every: 250 # sample every this many steps
78
+ width: 1024
79
+ height: 1024
80
+ prompts:
81
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
82
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
83
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
84
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
85
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
86
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
87
+ - "a bear building a log cabin in the snow covered mountains"
88
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
89
+ - "hipster man with a beard, building a chair, in a wood shop"
90
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
91
+ - "a man holding a sign that says, 'this is a sign'"
92
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
93
+ neg: "" # not used on flex
94
+ seed: 42
95
+ walk_seed: true
96
+ guidance_scale: 4
97
+ sample_steps: 25
98
+ # you can add any additional meta info here. [name] is replaced with config name at top
99
+ meta:
100
+ name: "[name]"
101
+ version: '1.0'
ai-toolkit/config/examples/train_lora_flux_24gb.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation_steps: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with flux
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ # uncomment this to skip the pre training sample
51
+ # skip_first_sample: true
52
+ # uncomment to completely disable sampling
53
+ # disable_sampling: true
54
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
55
+ # linear_timesteps: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "black-forest-labs/FLUX.1-dev"
67
+ is_flux: true
68
+ quantize: true # run 8bit mixed precision
69
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ prompts:
76
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
77
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
79
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82
+ - "a bear building a log cabin in the snow covered mountains"
83
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84
+ - "hipster man with a beard, building a chair, in a wood shop"
85
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
86
+ - "a man holding a sign that says, 'this is a sign'"
87
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88
+ neg: "" # not used on flux
89
+ seed: 42
90
+ walk_seed: true
91
+ guidance_scale: 4
92
+ sample_steps: 20
93
+ # you can add any additional meta info here. [name] is replaced with config name at top
94
+ meta:
95
+ name: "[name]"
96
+ version: '1.0'
ai-toolkit/config/examples/train_lora_flux_rami.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ name: "flux_lora_rami_v1"
5
+ process:
6
+ - type: 'sd_trainer'
7
+ training_folder: "output_flux_lora_rami"
8
+ trigger_word: "rami murad"
9
+ device: cuda:0
10
+
11
+ network:
12
+ type: "lora"
13
+ linear: 16
14
+ linear_alpha: 16
15
+
16
+ save:
17
+ dtype: float16
18
+ save_every: 250
19
+ max_step_saves_to_keep: 4
20
+ push_to_hub: false
21
+
22
+ datasets:
23
+ - folder_path: "ai-toolkit/images"
24
+ caption_ext: "txt"
25
+ caption_dropout_rate: 0.05
26
+ shuffle_tokens: false
27
+ cache_latents_to_disk: true
28
+ resolution: [1024]
29
+
30
+ train:
31
+ batch_size: 1
32
+ bypass_guidance_embedding: true
33
+ steps: 3000
34
+ gradient_accumulation: 1
35
+ train_unet: true
36
+ train_text_encoder: false
37
+ gradient_checkpointing: true
38
+ noise_scheduler: "flowmatch"
39
+ optimizer: "adamw8bit"
40
+ lr: 1e-4
41
+ dtype: fp16
42
+ disable_sampling: true
43
+
44
+ ema_config:
45
+ use_ema: true
46
+ ema_decay: 0.99
47
+
48
+ model:
49
+ name_or_path: "black-forest-labs/FLUX.1-dev"
50
+ is_flux: true
51
+ load_in_8bit: true
52
+ quantize: true
53
+ quantize_kwargs:
54
+ exclude:
55
+ - "*time_text_embed*"
56
+
57
+ sample:
58
+ sampler: "flowmatch"
59
+ sample_every: 250
60
+ width: 1024
61
+ height: 1024
62
+ prompts:
63
+ - "[trigger] smiling in front of a white background, headshot, studio lighting"
64
+ - "[trigger] wearing a suit, standing in a futuristic city, cinematic lighting"
65
+ - "[trigger] in a medieval outfit, standing in front of a castle"
66
+ - "[trigger] sitting at a wooden desk, writing in a notebook"
67
+ - "[trigger] relaxing at the beach during sunset, soft light"
68
+ - "[trigger] on stage giving a TED talk, spotlight"
69
+ - "[trigger] in a forest with sunbeams shining through the trees"
70
+ neg: ""
71
+ seed: 42
72
+ walk_seed: true
73
+ guidance_scale: 4
74
+ sample_steps: 25
75
+
76
+ meta:
77
+ name: "[name]"
78
+ version: '1.0'
ai-toolkit/config/examples/train_lora_flux_schnell_24gb.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation_steps: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with flux
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ # uncomment this to skip the pre training sample
51
+ # skip_first_sample: true
52
+ # uncomment to completely disable sampling
53
+ # disable_sampling: true
54
+ # uncomment to use new bell curved weighting. Experimental but may produce better results
55
+ # linear_timesteps: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "black-forest-labs/FLUX.1-schnell"
67
+ assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
68
+ is_flux: true
69
+ quantize: true # run 8bit mixed precision
70
+ # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
71
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
72
+ sample:
73
+ sampler: "flowmatch" # must match train.noise_scheduler
74
+ sample_every: 250 # sample every this many steps
75
+ width: 1024
76
+ height: 1024
77
+ prompts:
78
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
79
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
80
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
81
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
82
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
83
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
84
+ - "a bear building a log cabin in the snow covered mountains"
85
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
86
+ - "hipster man with a beard, building a chair, in a wood shop"
87
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
88
+ - "a man holding a sign that says, 'this is a sign'"
89
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
90
+ neg: "" # not used on flux
91
+ seed: 42
92
+ walk_seed: true
93
+ guidance_scale: 1 # schnell does not do guidance
94
+ sample_steps: 4 # 1 - 4 works well
95
+ # you can add any additional meta info here. [name] is replaced with config name at top
96
+ meta:
97
+ name: "[name]"
98
+ version: '1.0'
ai-toolkit/config/examples/train_lora_hidream_48.yaml ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HiDream training is still highly experimental. The settings here will take ~35.2GB of vram to train.
2
+ # It is not possible to train on a single 24GB card yet, but I am working on it. If you have more VRAM
3
+ # I highly recommend first disabling quantization on the model itself if you can. You can leave the TEs quantized.
4
+ # HiDream has a mixture of experts that may take special training considerations that I do not
5
+ # have implemented properly. The current implementation seems to work well for LoRA training, but
6
+ # may not be effective for longer training runs. The implementation could change in future updates
7
+ # so your results may vary when this happens.
8
+
9
+ ---
10
+ job: extension
11
+ config:
12
+ # this name will be the folder and filename name
13
+ name: "my_first_hidream_lora_v1"
14
+ process:
15
+ - type: 'sd_trainer'
16
+ # root folder to save training sessions/samples/weights
17
+ training_folder: "output"
18
+ # uncomment to see performance stats in the terminal every N steps
19
+ # performance_log_every: 1000
20
+ device: cuda:0
21
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
22
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
23
+ # trigger_word: "p3r5on"
24
+ network:
25
+ type: "lora"
26
+ linear: 32
27
+ linear_alpha: 32
28
+ network_kwargs:
29
+ # it is probably best to ignore the mixture of experts since only 2 are active each block. It works activating it, but I wouldnt.
30
+ # proper training of it is not fully implemented
31
+ ignore_if_contains:
32
+ - "ff_i.experts"
33
+ - "ff_i.gate"
34
+ save:
35
+ dtype: bfloat16 # precision to save
36
+ save_every: 250 # save every this many steps
37
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
38
+ datasets:
39
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
40
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
41
+ # images will automatically be resized and bucketed into the resolution specified
42
+ # on windows, escape back slashes with another backslash so
43
+ # "C:\\path\\to\\images\\folder"
44
+ - folder_path: "/path/to/images/folder"
45
+ caption_ext: "txt"
46
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
47
+ resolution: [ 512, 768, 1024 ] # hidream enjoys multiple resolutions
48
+ train:
49
+ batch_size: 1
50
+ steps: 3000 # total number of steps to train 500 - 4000 is a good range
51
+ gradient_accumulation_steps: 1
52
+ train_unet: true
53
+ train_text_encoder: false # wont work with hidream
54
+ gradient_checkpointing: true # need the on unless you have a ton of vram
55
+ noise_scheduler: "flowmatch" # for training only
56
+ timestep_type: shift # sigmoid, shift, linear
57
+ optimizer: "adamw8bit"
58
+ lr: 2e-4
59
+ # uncomment this to skip the pre training sample
60
+ # skip_first_sample: true
61
+ # uncomment to completely disable sampling
62
+ # disable_sampling: true
63
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
64
+ # linear_timesteps: true
65
+
66
+ # ema will smooth out learning, but could slow it down. Defaults off
67
+ ema_config:
68
+ use_ema: false
69
+ ema_decay: 0.99
70
+
71
+ # will probably need this if gpu supports it for hidream, other dtypes may not work correctly
72
+ dtype: bf16
73
+ model:
74
+ # the transformer will get grabbed from this hf repo
75
+ # warning ONLY train on Full. The dev and fast models are distilled and will break
76
+ name_or_path: "HiDream-ai/HiDream-I1-Full"
77
+ # the extras will be grabbed from this hf repo. (text encoder, vae)
78
+ extras_name_or_path: "HiDream-ai/HiDream-I1-Full"
79
+ arch: "hidream"
80
+ # both need to be quantized to train on 48GB currently
81
+ quantize: true
82
+ quantize_te: true
83
+ model_kwargs:
84
+ # llama is a gated model, It defaults to unsloth version, but you can set the llama path here
85
+ llama_model_path: "unsloth/Meta-Llama-3.1-8B-Instruct"
86
+ sample:
87
+ sampler: "flowmatch" # must match train.noise_scheduler
88
+ sample_every: 250 # sample every this many steps
89
+ width: 1024
90
+ height: 1024
91
+ prompts:
92
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
93
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
94
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
95
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
96
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
97
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
98
+ - "a bear building a log cabin in the snow covered mountains"
99
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
100
+ - "hipster man with a beard, building a chair, in a wood shop"
101
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
102
+ - "a man holding a sign that says, 'this is a sign'"
103
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
104
+ neg: ""
105
+ seed: 42
106
+ walk_seed: true
107
+ guidance_scale: 4
108
+ sample_steps: 25
109
+ # you can add any additional meta info here. [name] is replaced with config name at top
110
+ meta:
111
+ name: "[name]"
112
+ version: '1.0'
ai-toolkit/config/examples/train_lora_lumina.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This configuration requires 20GB of VRAM or more to operate
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_lumina_lora_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ network:
18
+ type: "lora"
19
+ linear: 16
20
+ linear_alpha: 16
21
+ save:
22
+ dtype: bf16 # precision to save
23
+ save_every: 250 # save every this many steps
24
+ max_step_saves_to_keep: 2 # how many intermittent saves to keep
25
+ save_format: 'diffusers' # 'diffusers'
26
+ datasets:
27
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
28
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
29
+ # images will automatically be resized and bucketed into the resolution specified
30
+ # on windows, escape back slashes with another backslash so
31
+ # "C:\\path\\to\\images\\folder"
32
+ - folder_path: "/path/to/images/folder"
33
+ caption_ext: "txt"
34
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
35
+ shuffle_tokens: false # shuffle caption order, split by commas
36
+ # cache_latents_to_disk: true # leave this true unless you know what you're doing
37
+ resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions
38
+ train:
39
+ batch_size: 1
40
+
41
+ # can be 'sigmoid', 'linear', or 'lumina2_shift'
42
+ timestep_type: 'lumina2_shift'
43
+
44
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
45
+ gradient_accumulation: 1
46
+ train_unet: true
47
+ train_text_encoder: false # probably won't work with lumina2
48
+ gradient_checkpointing: true # need the on unless you have a ton of vram
49
+ noise_scheduler: "flowmatch" # for training only
50
+ optimizer: "adamw8bit"
51
+ lr: 1e-4
52
+ # uncomment this to skip the pre training sample
53
+ # skip_first_sample: true
54
+ # uncomment to completely disable sampling
55
+ # disable_sampling: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for lumina2, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "Alpha-VLLM/Lumina-Image-2.0"
67
+ is_lumina2: true # lumina2 architecture
68
+ # you can quantize just the Gemma2 text encoder here to save vram
69
+ quantize_te: true
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ prompts:
76
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
77
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
79
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82
+ - "a bear building a log cabin in the snow covered mountains"
83
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84
+ - "hipster man with a beard, building a chair, in a wood shop"
85
+ - "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear."
86
+ - "a man holding a sign that says, 'this is a sign'"
87
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88
+ neg: ""
89
+ seed: 42
90
+ walk_seed: true
91
+ guidance_scale: 4.0
92
+ sample_steps: 25
93
+ # you can add any additional meta info here. [name] is replaced with config name at top
94
+ meta:
95
+ name: "[name]"
96
+ version: '1.0'
ai-toolkit/config/examples/train_lora_sd35_large_24gb.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # NOTE!! THIS IS CURRENTLY EXPERIMENTAL AND UNDER DEVELOPMENT. SOME THINGS WILL CHANGE
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_sd3l_lora_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ network:
18
+ type: "lora"
19
+ linear: 16
20
+ linear_alpha: 16
21
+ save:
22
+ dtype: float16 # precision to save
23
+ save_every: 250 # save every this many steps
24
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
25
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
26
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
27
+ # hf_repo_id: your-username/your-model-slug
28
+ # hf_private: true #whether the repo is private or public
29
+ datasets:
30
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
31
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
32
+ # images will automatically be resized and bucketed into the resolution specified
33
+ # on windows, escape back slashes with another backslash so
34
+ # "C:\\path\\to\\images\\folder"
35
+ - folder_path: "/path/to/images/folder"
36
+ caption_ext: "txt"
37
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
38
+ shuffle_tokens: false # shuffle caption order, split by commas
39
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
40
+ resolution: [ 1024 ]
41
+ train:
42
+ batch_size: 1
43
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
44
+ gradient_accumulation_steps: 1
45
+ train_unet: true
46
+ train_text_encoder: false # May not fully work with SD3 yet
47
+ gradient_checkpointing: true # need the on unless you have a ton of vram
48
+ noise_scheduler: "flowmatch"
49
+ timestep_type: "linear" # linear or sigmoid
50
+ optimizer: "adamw8bit"
51
+ lr: 1e-4
52
+ # uncomment this to skip the pre training sample
53
+ # skip_first_sample: true
54
+ # uncomment to completely disable sampling
55
+ # disable_sampling: true
56
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
57
+ # linear_timesteps: true
58
+
59
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
60
+ ema_config:
61
+ use_ema: true
62
+ ema_decay: 0.99
63
+
64
+ # will probably need this if gpu supports it for sd3, other dtypes may not work correctly
65
+ dtype: bf16
66
+ model:
67
+ # huggingface model name or path
68
+ name_or_path: "stabilityai/stable-diffusion-3.5-large"
69
+ is_v3: true
70
+ quantize: true # run 8bit mixed precision
71
+ sample:
72
+ sampler: "flowmatch" # must match train.noise_scheduler
73
+ sample_every: 250 # sample every this many steps
74
+ width: 1024
75
+ height: 1024
76
+ prompts:
77
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
78
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
79
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
80
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
81
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
82
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
83
+ - "a bear building a log cabin in the snow covered mountains"
84
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
85
+ - "hipster man with a beard, building a chair, in a wood shop"
86
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
87
+ - "a man holding a sign that says, 'this is a sign'"
88
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
89
+ neg: ""
90
+ seed: 42
91
+ walk_seed: true
92
+ guidance_scale: 4
93
+ sample_steps: 25
94
+ # you can add any additional meta info here. [name] is replaced with config name at top
95
+ meta:
96
+ name: "[name]"
97
+ version: '1.0'
ai-toolkit/config/examples/train_lora_wan21_14b_24gb.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IMPORTANT: The Wan2.1 14B model is huge. This config should work on 24GB GPUs. It cannot
2
+ # support keeping the text encoder on GPU while training with 24GB, so it is only good
3
+ # for training on a single prompt, for example a person with a trigger word.
4
+ # to train on captions, you need more vran for now.
5
+ ---
6
+ job: extension
7
+ config:
8
+ # this name will be the folder and filename name
9
+ name: "my_first_wan21_14b_lora_v1"
10
+ process:
11
+ - type: 'sd_trainer'
12
+ # root folder to save training sessions/samples/weights
13
+ training_folder: "output"
14
+ # uncomment to see performance stats in the terminal every N steps
15
+ # performance_log_every: 1000
16
+ device: cuda:0
17
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
18
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
19
+ # this is probably needed for 24GB cards when offloading TE to CPU
20
+ trigger_word: "p3r5on"
21
+ network:
22
+ type: "lora"
23
+ linear: 32
24
+ linear_alpha: 32
25
+ save:
26
+ dtype: float16 # precision to save
27
+ save_every: 250 # save every this many steps
28
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
29
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
30
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
31
+ # hf_repo_id: your-username/your-model-slug
32
+ # hf_private: true #whether the repo is private or public
33
+ datasets:
34
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
35
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
36
+ # images will automatically be resized and bucketed into the resolution specified
37
+ # on windows, escape back slashes with another backslash so
38
+ # "C:\\path\\to\\images\\folder"
39
+ # AI-Toolkit does not currently support video datasets, we will train on 1 frame at a time
40
+ # it works well for characters, but not as well for "actions"
41
+ - folder_path: "/path/to/images/folder"
42
+ caption_ext: "txt"
43
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
44
+ shuffle_tokens: false # shuffle caption order, split by commas
45
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
46
+ resolution: [ 632 ] # will be around 480p
47
+ train:
48
+ batch_size: 1
49
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
50
+ gradient_accumulation: 1
51
+ train_unet: true
52
+ train_text_encoder: false # probably won't work with wan
53
+ gradient_checkpointing: true # need the on unless you have a ton of vram
54
+ noise_scheduler: "flowmatch" # for training only
55
+ timestep_type: 'sigmoid'
56
+ optimizer: "adamw8bit"
57
+ lr: 1e-4
58
+ optimizer_params:
59
+ weight_decay: 1e-4
60
+ # uncomment this to skip the pre training sample
61
+ # skip_first_sample: true
62
+ # uncomment to completely disable sampling
63
+ # disable_sampling: true
64
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
65
+ ema_config:
66
+ use_ema: true
67
+ ema_decay: 0.99
68
+ dtype: bf16
69
+ # required for 24GB cards
70
+ # this will encode your trigger word and use those embeddings for every image in the dataset
71
+ unload_text_encoder: true
72
+ model:
73
+ # huggingface model name or path
74
+ name_or_path: "Wan-AI/Wan2.1-T2V-14B-Diffusers"
75
+ arch: 'wan21'
76
+ # these settings will save as much vram as possible
77
+ quantize: true
78
+ quantize_te: true
79
+ low_vram: true
80
+ sample:
81
+ sampler: "flowmatch"
82
+ sample_every: 250 # sample every this many steps
83
+ width: 832
84
+ height: 480
85
+ num_frames: 40
86
+ fps: 15
87
+ # samples take a long time. so use them sparingly
88
+ # samples will be animated webp files, if you don't see them animated, open in a browser.
89
+ prompts:
90
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
91
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
92
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
93
+ neg: ""
94
+ seed: 42
95
+ walk_seed: true
96
+ guidance_scale: 5
97
+ sample_steps: 30
98
+ # you can add any additional meta info here. [name] is replaced with config name at top
99
+ meta:
100
+ name: "[name]"
101
+ version: '1.0'
ai-toolkit/config/examples/train_lora_wan21_1b_24gb.yaml ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_wan21_1b_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 32
19
+ linear_alpha: 32
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ # AI-Toolkit does not currently support video datasets, we will train on 1 frame at a time
35
+ # it works well for characters, but not as well for "actions"
36
+ - folder_path: "/path/to/images/folder"
37
+ caption_ext: "txt"
38
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
39
+ shuffle_tokens: false # shuffle caption order, split by commas
40
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
41
+ resolution: [ 632 ] # will be around 480p
42
+ train:
43
+ batch_size: 1
44
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
45
+ gradient_accumulation: 1
46
+ train_unet: true
47
+ train_text_encoder: false # probably won't work with wan
48
+ gradient_checkpointing: true # need the on unless you have a ton of vram
49
+ noise_scheduler: "flowmatch" # for training only
50
+ timestep_type: 'sigmoid'
51
+ optimizer: "adamw8bit"
52
+ lr: 1e-4
53
+ optimizer_params:
54
+ weight_decay: 1e-4
55
+ # uncomment this to skip the pre training sample
56
+ # skip_first_sample: true
57
+ # uncomment to completely disable sampling
58
+ # disable_sampling: true
59
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
60
+ ema_config:
61
+ use_ema: true
62
+ ema_decay: 0.99
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
67
+ arch: 'wan21'
68
+ quantize_te: true # saves vram
69
+ sample:
70
+ sampler: "flowmatch"
71
+ sample_every: 250 # sample every this many steps
72
+ width: 832
73
+ height: 480
74
+ num_frames: 40
75
+ fps: 15
76
+ # samples take a long time. so use them sparingly
77
+ # samples will be animated webp files, if you don't see them animated, open in a browser.
78
+ prompts:
79
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
80
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
81
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
82
+ neg: ""
83
+ seed: 42
84
+ walk_seed: true
85
+ guidance_scale: 5
86
+ sample_steps: 30
87
+ # you can add any additional meta info here. [name] is replaced with config name at top
88
+ meta:
89
+ name: "[name]"
90
+ version: '1.0'
ai-toolkit/config/examples/train_slider.example.yml ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This is in yaml format. You can use json if you prefer
3
+ # I like both but yaml is easier to write
4
+ # Plus it has comments which is nice for documentation
5
+ # This is the config I use on my sliders, It is solid and tested
6
+ job: train
7
+ config:
8
+ # the name will be used to create a folder in the output folder
9
+ # it will also replace any [name] token in the rest of this config
10
+ name: detail_slider_v1
11
+ # folder will be created with name above in folder below
12
+ # it can be relative to the project root or absolute
13
+ training_folder: "output/LoRA"
14
+ device: cuda:0 # cpu, cuda:0, etc
15
+ # for tensorboard logging, we will make a subfolder for this job
16
+ log_dir: "output/.tensorboard"
17
+ # you can stack processes for other jobs, It is not tested with sliders though
18
+ # just use one for now
19
+ process:
20
+ - type: slider # tells runner to run the slider process
21
+ # network is the LoRA network for a slider, I recommend to leave this be
22
+ network:
23
+ # network type lierla is traditional LoRA that works everywhere, only linear layers
24
+ type: "lierla"
25
+ # rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good
26
+ linear: 8
27
+ linear_alpha: 4 # Do about half of rank
28
+ # training config
29
+ train:
30
+ # this is also used in sampling. Stick with ddpm unless you know what you are doing
31
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
32
+ # how many steps to train. More is not always better. I rarely go over 1000
33
+ steps: 500
34
+ # I have had good results with 4e-4 to 1e-4 at 500 steps
35
+ lr: 2e-4
36
+ # enables gradient checkpoint, saves vram, leave it on
37
+ gradient_checkpointing: true
38
+ # train the unet. I recommend leaving this true
39
+ train_unet: true
40
+ # train the text encoder. I don't recommend this unless you have a special use case
41
+ # for sliders we are adjusting representation of the concept (unet),
42
+ # not the description of it (text encoder)
43
+ train_text_encoder: false
44
+ # same as from sd-scripts, not fully tested but should speed up training
45
+ min_snr_gamma: 5.0
46
+ # just leave unless you know what you are doing
47
+ # also supports "dadaptation" but set lr to 1 if you use that,
48
+ # but it learns too fast and I don't recommend it
49
+ optimizer: "adamw"
50
+ # only constant for now
51
+ lr_scheduler: "constant"
52
+ # we randomly denoise random num of steps form 1 to this number
53
+ # while training. Just leave it
54
+ max_denoising_steps: 40
55
+ # works great at 1. I do 1 even with my 4090.
56
+ # higher may not work right with newer single batch stacking code anyway
57
+ batch_size: 1
58
+ # bf16 works best if your GPU supports it (modern)
59
+ dtype: bf16 # fp32, bf16, fp16
60
+ # if you have it, use it. It is faster and better
61
+ # torch 2.0 doesnt need xformers anymore, only use if you have lower version
62
+ # xformers: true
63
+ # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
64
+ # although, the way we train sliders is comparative, so it probably won't work anyway
65
+ noise_offset: 0.0
66
+ # noise_offset: 0.0357 # SDXL was trained with offset of 0.0357. So use that when training on SDXL
67
+
68
+ # the model to train the LoRA network on
69
+ model:
70
+ # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
71
+ name_or_path: "runwayml/stable-diffusion-v1-5"
72
+ is_v2: false # for v2 models
73
+ is_v_pred: false # for v-prediction models (most v2 models)
74
+ # has some issues with the dual text encoder and the way we train sliders
75
+ # it works bit weights need to probably be higher to see it.
76
+ is_xl: false # for SDXL models
77
+
78
+ # saving config
79
+ save:
80
+ dtype: float16 # precision to save. I recommend float16
81
+ save_every: 50 # save every this many steps
82
+ # this will remove step counts more than this number
83
+ # allows you to save more often in case of a crash without filling up your drive
84
+ max_step_saves_to_keep: 2
85
+
86
+ # sampling config
87
+ sample:
88
+ # must match train.noise_scheduler, this is not used here
89
+ # but may be in future and in other processes
90
+ sampler: "ddpm"
91
+ # sample every this many steps
92
+ sample_every: 20
93
+ # image size
94
+ width: 512
95
+ height: 512
96
+ # prompts to use for sampling. Do as many as you want, but it slows down training
97
+ # pick ones that will best represent the concept you are trying to adjust
98
+ # allows some flags after the prompt
99
+ # --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive
100
+ # slide are good tests. will inherit sample.network_multiplier if not set
101
+ # --n [string] # negative prompt, will inherit sample.neg if not set
102
+ # Only 75 tokens allowed currently
103
+ # I like to do a wide positive and negative spread so I can see a good range and stop
104
+ # early if the network is braking down
105
+ prompts:
106
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5"
107
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3"
108
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3"
109
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5"
110
+ - "a golden retriever sitting on a leather couch, --m -5"
111
+ - "a golden retriever sitting on a leather couch --m -3"
112
+ - "a golden retriever sitting on a leather couch --m 3"
113
+ - "a golden retriever sitting on a leather couch --m 5"
114
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5"
115
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3"
116
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3"
117
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5"
118
+ # negative prompt used on all prompts above as default if they don't have one
119
+ neg: "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome"
120
+ # seed for sampling. 42 is the answer for everything
121
+ seed: 42
122
+ # walks the seed so s1 is 42, s2 is 43, s3 is 44, etc
123
+ # will start over on next sample_every so s1 is always seed
124
+ # works well if you use same prompt but want different results
125
+ walk_seed: false
126
+ # cfg scale (4 to 10 is good)
127
+ guidance_scale: 7
128
+ # sampler steps (20 to 30 is good)
129
+ sample_steps: 20
130
+ # default network multiplier for all prompts
131
+ # since we are training a slider, I recommend overriding this with --m [number]
132
+ # in the prompts above to get both sides of the slider
133
+ network_multiplier: 1.0
134
+
135
+ # logging information
136
+ logging:
137
+ log_every: 10 # log every this many steps
138
+ use_wandb: false # not supported yet
139
+ verbose: false # probably done need unless you are debugging
140
+
141
+ # slider training config, best for last
142
+ slider:
143
+ # resolutions to train on. [ width, height ]. This is less important for sliders
144
+ # as we are not teaching the model anything it doesn't already know
145
+ # but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1
146
+ # and [ 1024, 1024 ] for sd_xl
147
+ # you can do as many as you want here
148
+ resolutions:
149
+ - [ 512, 512 ]
150
+ # - [ 512, 768 ]
151
+ # - [ 768, 768 ]
152
+ # slider training uses 4 combined steps for a single round. This will do it in one gradient
153
+ # step. It is highly optimized and shouldn't take anymore vram than doing without it,
154
+ # since we break down batches for gradient accumulation now. so just leave it on.
155
+ batch_full_slide: true
156
+ # These are the concepts to train on. You can do as many as you want here,
157
+ # but they can conflict outweigh each other. Other than experimenting, I recommend
158
+ # just doing one for good results
159
+ targets:
160
+ # target_class is the base concept we are adjusting the representation of
161
+ # for example, if we are adjusting the representation of a person, we would use "person"
162
+ # if we are adjusting the representation of a cat, we would use "cat" It is not
163
+ # a keyword necessarily but what the model understands the concept to represent.
164
+ # "person" will affect men, women, children, etc but will not affect cats, dogs, etc
165
+ # it is the models base general understanding of the concept and everything it represents
166
+ # you can leave it blank to affect everything. In this example, we are adjusting
167
+ # detail, so we will leave it blank to affect everything
168
+ - target_class: ""
169
+ # positive is the prompt for the positive side of the slider.
170
+ # It is the concept that will be excited and amplified in the model when we slide the slider
171
+ # to the positive side and forgotten / inverted when we slide
172
+ # the slider to the negative side. It is generally best to include the target_class in
173
+ # the prompt. You want it to be the extreme of what you want to train on. For example,
174
+ # if you want to train on fat people, you would use "an extremely fat, morbidly obese person"
175
+ # as the prompt. Not just "fat person"
176
+ # max 75 tokens for now
177
+ positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality"
178
+ # negative is the prompt for the negative side of the slider and works the same as positive
179
+ # it does not necessarily work the same as a negative prompt when generating images
180
+ # these need to be polar opposites.
181
+ # max 76 tokens for now
182
+ negative: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality"
183
+ # the loss for this target is multiplied by this number.
184
+ # if you are doing more than one target it may be good to set less important ones
185
+ # to a lower number like 0.1 so they don't outweigh the primary target
186
+ weight: 1.0
187
+ # shuffle the prompts split by the comma. We will run every combination randomly
188
+ # this will make the LoRA more robust. You probably want this on unless prompt order
189
+ # is important for some reason
190
+ shuffle: true
191
+
192
+
193
+ # anchors are prompts that we will try to hold on to while training the slider
194
+ # these are NOT necessary and can prevent the slider from converging if not done right
195
+ # leave them off if you are having issues, but they can help lock the network
196
+ # on certain concepts to help prevent catastrophic forgetting
197
+ # you want these to generate an image that is not your target_class, but close to it
198
+ # is fine as long as it does not directly overlap it.
199
+ # For example, if you are training on a person smiling,
200
+ # you could use "a person with a face mask" as an anchor. It is a person, the image is the same
201
+ # regardless if they are smiling or not, however, the closer the concept is to the target_class
202
+ # the less the multiplier needs to be. Keep multipliers less than 1.0 for anchors usually
203
+ # for close concepts, you want to be closer to 0.1 or 0.2
204
+ # these will slow down training. I am leaving them off for the demo
205
+
206
+ # anchors:
207
+ # - prompt: "a woman"
208
+ # neg_prompt: "animal"
209
+ # # the multiplier applied to the LoRA when this is run.
210
+ # # higher will give it more weight but also help keep the lora from collapsing
211
+ # multiplier: 1.0
212
+ # - prompt: "a man"
213
+ # neg_prompt: "animal"
214
+ # multiplier: 1.0
215
+ # - prompt: "a person"
216
+ # neg_prompt: "animal"
217
+ # multiplier: 1.0
218
+
219
+ # You can put any information you want here, and it will be saved in the model.
220
+ # The below is an example, but you can put your grocery list in it if you want.
221
+ # It is saved in the model so be aware of that. The software will include this
222
+ # plus some other information for you automatically
223
+ meta:
224
+ # [name] gets replaced with the name above
225
+ name: "[name]"
226
+ # version: '1.0'
227
+ # creator:
228
+ # name: Your Name
229
+ # email: [email protected]
230
+ # website: https://your.website
ai-toolkit/docker-compose.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.8"
2
+
3
+ services:
4
+ ai-toolkit:
5
+ image: ostris/aitoolkit:latest
6
+ restart: unless-stopped
7
+ ports:
8
+ - "8675:8675"
9
+ volumes:
10
+ - ~/.cache/huggingface/hub:/root/.cache/huggingface/hub
11
+ - ./aitk_db.db:/app/ai-toolkit/aitk_db.db
12
+ - ./datasets:/app/ai-toolkit/datasets
13
+ - ./output:/app/ai-toolkit/output
14
+ - ./config:/app/ai-toolkit/config
15
+ environment:
16
+ - AI_TOOLKIT_AUTH=${AI_TOOLKIT_AUTH:-password}
17
+ - NODE_ENV=production
18
+ - TZ=UTC
19
+ deploy:
20
+ resources:
21
+ reservations:
22
+ devices:
23
+ - driver: nvidia
24
+ count: all
25
+ capabilities: [gpu]
ai-toolkit/docker/Dockerfile ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.6.3-devel-ubuntu22.04
2
+
3
+ LABEL authors="jaret"
4
+
5
+ # Set noninteractive to avoid timezone prompts
6
+ ENV DEBIAN_FRONTEND=noninteractive
7
+
8
+ # Install dependencies
9
+ RUN apt-get update && apt-get install --no-install-recommends -y \
10
+ git \
11
+ curl \
12
+ build-essential \
13
+ cmake \
14
+ wget \
15
+ python3.10 \
16
+ python3-pip \
17
+ python3-dev \
18
+ python3-setuptools \
19
+ python3-wheel \
20
+ python3-venv \
21
+ ffmpeg \
22
+ tmux \
23
+ htop \
24
+ nvtop \
25
+ python3-opencv \
26
+ openssh-client \
27
+ openssh-server \
28
+ openssl \
29
+ rsync \
30
+ unzip \
31
+ && apt-get clean \
32
+ && rm -rf /var/lib/apt/lists/*
33
+
34
+ # Install nodejs
35
+ WORKDIR /tmp
36
+ RUN curl -sL https://deb.nodesource.com/setup_23.x -o nodesource_setup.sh && \
37
+ bash nodesource_setup.sh && \
38
+ apt-get update && \
39
+ apt-get install -y nodejs && \
40
+ apt-get clean && \
41
+ rm -rf /var/lib/apt/lists/*
42
+
43
+ WORKDIR /app
44
+
45
+ # Set aliases for python and pip
46
+ RUN ln -s /usr/bin/python3 /usr/bin/python
47
+
48
+ # install pytorch before cache bust to avoid redownloading pytorch
49
+ RUN pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu126
50
+
51
+ # Fix cache busting by moving CACHEBUST to right before git clone
52
+ ARG CACHEBUST=1234
53
+ RUN echo "Cache bust: ${CACHEBUST}" && \
54
+ git clone https://github.com/ostris/ai-toolkit.git && \
55
+ cd ai-toolkit
56
+
57
+ WORKDIR /app/ai-toolkit
58
+
59
+ # Install Python dependencies
60
+ RUN pip install --no-cache-dir -r requirements.txt && \
61
+ pip install flash-attn --no-build-isolation --no-cache-dir
62
+
63
+ # Build UI
64
+ WORKDIR /app/ai-toolkit/ui
65
+ RUN npm install && \
66
+ npm run build && \
67
+ npm run update_db
68
+
69
+ # Expose port (assuming the application runs on port 3000)
70
+ EXPOSE 8675
71
+
72
+ WORKDIR /
73
+
74
+ COPY docker/start.sh /start.sh
75
+ RUN chmod +x /start.sh
76
+
77
+ CMD ["/start.sh"]
ai-toolkit/docker/start.sh ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e # Exit the script if any statement returns a non-true return value
3
+
4
+ # ref https://github.com/runpod/containers/blob/main/container-template/start.sh
5
+
6
+ # ---------------------------------------------------------------------------- #
7
+ # Function Definitions #
8
+ # ---------------------------------------------------------------------------- #
9
+
10
+
11
+ # Setup ssh
12
+ setup_ssh() {
13
+ if [[ $PUBLIC_KEY ]]; then
14
+ echo "Setting up SSH..."
15
+ mkdir -p ~/.ssh
16
+ echo "$PUBLIC_KEY" >> ~/.ssh/authorized_keys
17
+ chmod 700 -R ~/.ssh
18
+
19
+ if [ ! -f /etc/ssh/ssh_host_rsa_key ]; then
20
+ ssh-keygen -t rsa -f /etc/ssh/ssh_host_rsa_key -q -N ''
21
+ echo "RSA key fingerprint:"
22
+ ssh-keygen -lf /etc/ssh/ssh_host_rsa_key.pub
23
+ fi
24
+
25
+ if [ ! -f /etc/ssh/ssh_host_dsa_key ]; then
26
+ ssh-keygen -t dsa -f /etc/ssh/ssh_host_dsa_key -q -N ''
27
+ echo "DSA key fingerprint:"
28
+ ssh-keygen -lf /etc/ssh/ssh_host_dsa_key.pub
29
+ fi
30
+
31
+ if [ ! -f /etc/ssh/ssh_host_ecdsa_key ]; then
32
+ ssh-keygen -t ecdsa -f /etc/ssh/ssh_host_ecdsa_key -q -N ''
33
+ echo "ECDSA key fingerprint:"
34
+ ssh-keygen -lf /etc/ssh/ssh_host_ecdsa_key.pub
35
+ fi
36
+
37
+ if [ ! -f /etc/ssh/ssh_host_ed25519_key ]; then
38
+ ssh-keygen -t ed25519 -f /etc/ssh/ssh_host_ed25519_key -q -N ''
39
+ echo "ED25519 key fingerprint:"
40
+ ssh-keygen -lf /etc/ssh/ssh_host_ed25519_key.pub
41
+ fi
42
+
43
+ service ssh start
44
+
45
+ echo "SSH host keys:"
46
+ for key in /etc/ssh/*.pub; do
47
+ echo "Key: $key"
48
+ ssh-keygen -lf $key
49
+ done
50
+ fi
51
+ }
52
+
53
+ # Export env vars
54
+ export_env_vars() {
55
+ echo "Exporting environment variables..."
56
+ printenv | grep -E '^RUNPOD_|^PATH=|^_=' | awk -F = '{ print "export " $1 "=\"" $2 "\"" }' >> /etc/rp_environment
57
+ echo 'source /etc/rp_environment' >> ~/.bashrc
58
+ }
59
+
60
+ # ---------------------------------------------------------------------------- #
61
+ # Main Program #
62
+ # ---------------------------------------------------------------------------- #
63
+
64
+
65
+ echo "Pod Started"
66
+
67
+ setup_ssh
68
+ export_env_vars
69
+ echo "Starting AI Toolkit UI..."
70
+ cd /app/ai-toolkit/ui && npm run start
ai-toolkit/extensions/example/ExampleMergeModels.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+ from collections import OrderedDict
4
+ from typing import TYPE_CHECKING
5
+ from jobs.process import BaseExtensionProcess
6
+ from toolkit.config_modules import ModelConfig
7
+ from toolkit.stable_diffusion_model import StableDiffusion
8
+ from toolkit.train_tools import get_torch_dtype
9
+ from tqdm import tqdm
10
+
11
+ # Type check imports. Prevents circular imports
12
+ if TYPE_CHECKING:
13
+ from jobs import ExtensionJob
14
+
15
+
16
+ # extend standard config classes to add weight
17
+ class ModelInputConfig(ModelConfig):
18
+ def __init__(self, **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.weight = kwargs.get('weight', 1.0)
21
+ # overwrite default dtype unless user specifies otherwise
22
+ # float 32 will give up better precision on the merging functions
23
+ self.dtype: str = kwargs.get('dtype', 'float32')
24
+
25
+
26
+ def flush():
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+
31
+ # this is our main class process
32
+ class ExampleMergeModels(BaseExtensionProcess):
33
+ def __init__(
34
+ self,
35
+ process_id: int,
36
+ job: 'ExtensionJob',
37
+ config: OrderedDict
38
+ ):
39
+ super().__init__(process_id, job, config)
40
+ # this is the setup process, do not do process intensive stuff here, just variable setup and
41
+ # checking requirements. This is called before the run() function
42
+ # no loading models or anything like that, it is just for setting up the process
43
+ # all of your process intensive stuff should be done in the run() function
44
+ # config will have everything from the process item in the config file
45
+
46
+ # convince methods exist on BaseProcess to get config values
47
+ # if required is set to true and the value is not found it will throw an error
48
+ # you can pass a default value to get_conf() as well if it was not in the config file
49
+ # as well as a type to cast the value to
50
+ self.save_path = self.get_conf('save_path', required=True)
51
+ self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype)
52
+ self.device = self.get_conf('device', default='cpu', as_type=torch.device)
53
+
54
+ # build models to merge list
55
+ models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list)
56
+ # build list of ModelInputConfig objects. I find it is a good idea to make a class for each config
57
+ # this way you can add methods to it and it is easier to read and code. There are a lot of
58
+ # inbuilt config classes located in toolkit.config_modules as well
59
+ self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge]
60
+ # setup is complete. Don't load anything else here, just setup variables and stuff
61
+
62
+ # this is the entire run process be sure to call super().run() first
63
+ def run(self):
64
+ # always call first
65
+ super().run()
66
+ print(f"Running process: {self.__class__.__name__}")
67
+
68
+ # let's adjust our weights first to normalize them so the total is 1.0
69
+ total_weight = sum([model.weight for model in self.models_to_merge])
70
+ weight_adjust = 1.0 / total_weight
71
+ for model in self.models_to_merge:
72
+ model.weight *= weight_adjust
73
+
74
+ output_model: StableDiffusion = None
75
+ # let's do the merge, it is a good idea to use tqdm to show progress
76
+ for model_config in tqdm(self.models_to_merge, desc="Merging models"):
77
+ # setup model class with our helper class
78
+ sd_model = StableDiffusion(
79
+ device=self.device,
80
+ model_config=model_config,
81
+ dtype="float32"
82
+ )
83
+ # load the model
84
+ sd_model.load_model()
85
+
86
+ # adjust the weight of the text encoder
87
+ if isinstance(sd_model.text_encoder, list):
88
+ # sdxl model
89
+ for text_encoder in sd_model.text_encoder:
90
+ for key, value in text_encoder.state_dict().items():
91
+ value *= model_config.weight
92
+ else:
93
+ # normal model
94
+ for key, value in sd_model.text_encoder.state_dict().items():
95
+ value *= model_config.weight
96
+ # adjust the weights of the unet
97
+ for key, value in sd_model.unet.state_dict().items():
98
+ value *= model_config.weight
99
+
100
+ if output_model is None:
101
+ # use this one as the base
102
+ output_model = sd_model
103
+ else:
104
+ # merge the models
105
+ # text encoder
106
+ if isinstance(output_model.text_encoder, list):
107
+ # sdxl model
108
+ for i, text_encoder in enumerate(output_model.text_encoder):
109
+ for key, value in text_encoder.state_dict().items():
110
+ value += sd_model.text_encoder[i].state_dict()[key]
111
+ else:
112
+ # normal model
113
+ for key, value in output_model.text_encoder.state_dict().items():
114
+ value += sd_model.text_encoder.state_dict()[key]
115
+ # unet
116
+ for key, value in output_model.unet.state_dict().items():
117
+ value += sd_model.unet.state_dict()[key]
118
+
119
+ # remove the model to free memory
120
+ del sd_model
121
+ flush()
122
+
123
+ # merge loop is done, let's save the model
124
+ print(f"Saving merged model to {self.save_path}")
125
+ output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype)
126
+ print(f"Saved merged model to {self.save_path}")
127
+ # do cleanup here
128
+ del output_model
129
+ flush()
ai-toolkit/extensions/example/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # We make a subclass of Extension
6
+ class ExampleMergeExtension(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "example_merge_extension"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Example Merge Extension"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .ExampleMergeModels import ExampleMergeModels
19
+ return ExampleMergeModels
20
+
21
+
22
+ AI_TOOLKIT_EXTENSIONS = [
23
+ # you can put a list of extensions here
24
+ ExampleMergeExtension
25
+ ]
ai-toolkit/extensions/example/__pycache__/ExampleMergeModels.cpython-312.pyc ADDED
Binary file (5.61 kB). View file
 
ai-toolkit/extensions/example/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (760 Bytes). View file
 
ai-toolkit/extensions/example/config/config.example.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # Always include at least one example config file to show how to use your extension.
3
+ # use plenty of comments so users know how to use it and what everything does
4
+
5
+ # all extensions will use this job name
6
+ job: extension
7
+ config:
8
+ name: 'my_awesome_merge'
9
+ process:
10
+ # Put your example processes here. This will be passed
11
+ # to your extension process in the config argument.
12
+ # the type MUST match your extension uid
13
+ - type: "example_merge_extension"
14
+ # save path for the merged model
15
+ save_path: "output/merge/[name].safetensors"
16
+ # save type
17
+ dtype: fp16
18
+ # device to run it on
19
+ device: cuda:0
20
+ # input models can only be SD1.x and SD2.x models for this example (currently)
21
+ models_to_merge:
22
+ # weights are relative, total weights will be normalized
23
+ # for example. If you have 2 models with weight 1.0, they will
24
+ # both be weighted 0.5. If you have 1 model with weight 1.0 and
25
+ # another with weight 2.0, the first will be weighted 1/3 and the
26
+ # second will be weighted 2/3
27
+ - name_or_path: "input/model1.safetensors"
28
+ weight: 1.0
29
+ - name_or_path: "input/model2.safetensors"
30
+ weight: 1.0
31
+ - name_or_path: "input/model3.safetensors"
32
+ weight: 0.3
33
+ - name_or_path: "input/model4.safetensors"
34
+ weight: 1.0
35
+
36
+
37
+ # you can put any information you want here, and it will be saved in the model
38
+ # the below is an example. I recommend doing trigger words at a minimum
39
+ # in the metadata. The software will include this plus some other information
40
+ meta:
41
+ name: "[name]" # [name] gets replaced with the name above
42
+ description: A short description of your model
43
+ version: '0.1'
44
+ creator:
45
+ name: Your Name
46
47
+ website: https://yourwebsite.com
48
+ any: All meta data above is arbitrary, it can be whatever you want.
ai-toolkit/extensions_built_in/advanced_generator/Img2ImgGenerator.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ from collections import OrderedDict
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+ from diffusers import T2IAdapter
10
+ from diffusers.utils.torch_utils import randn_tensor
11
+ from torch.utils.data import DataLoader
12
+ from diffusers import StableDiffusionXLImg2ImgPipeline, PixArtSigmaPipeline
13
+ from tqdm import tqdm
14
+
15
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
16
+ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
17
+ from toolkit.sampler import get_sampler
18
+ from toolkit.stable_diffusion_model import StableDiffusion
19
+ import gc
20
+ import torch
21
+ from jobs.process import BaseExtensionProcess
22
+ from toolkit.data_loader import get_dataloader_from_datasets
23
+ from toolkit.train_tools import get_torch_dtype
24
+ from controlnet_aux.midas import MidasDetector
25
+ from diffusers.utils import load_image
26
+ from torchvision.transforms import ToTensor
27
+
28
+
29
+ def flush():
30
+ torch.cuda.empty_cache()
31
+ gc.collect()
32
+
33
+
34
+
35
+
36
+
37
+ class GenerateConfig:
38
+
39
+ def __init__(self, **kwargs):
40
+ self.prompts: List[str]
41
+ self.sampler = kwargs.get('sampler', 'ddpm')
42
+ self.neg = kwargs.get('neg', '')
43
+ self.seed = kwargs.get('seed', -1)
44
+ self.walk_seed = kwargs.get('walk_seed', False)
45
+ self.guidance_scale = kwargs.get('guidance_scale', 7)
46
+ self.sample_steps = kwargs.get('sample_steps', 20)
47
+ self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
48
+ self.ext = kwargs.get('ext', 'png')
49
+ self.denoise_strength = kwargs.get('denoise_strength', 0.5)
50
+ self.trigger_word = kwargs.get('trigger_word', None)
51
+
52
+
53
+ class Img2ImgGenerator(BaseExtensionProcess):
54
+
55
+ def __init__(self, process_id: int, job, config: OrderedDict):
56
+ super().__init__(process_id, job, config)
57
+ self.output_folder = self.get_conf('output_folder', required=True)
58
+ self.copy_inputs_to = self.get_conf('copy_inputs_to', None)
59
+ self.device = self.get_conf('device', 'cuda')
60
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
61
+ self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
62
+ self.is_latents_cached = True
63
+ raw_datasets = self.get_conf('datasets', None)
64
+ if raw_datasets is not None and len(raw_datasets) > 0:
65
+ raw_datasets = preprocess_dataset_raw_config(raw_datasets)
66
+ self.datasets = None
67
+ self.datasets_reg = None
68
+ self.dtype = self.get_conf('dtype', 'float16')
69
+ self.torch_dtype = get_torch_dtype(self.dtype)
70
+ self.params = []
71
+ if raw_datasets is not None and len(raw_datasets) > 0:
72
+ for raw_dataset in raw_datasets:
73
+ dataset = DatasetConfig(**raw_dataset)
74
+ is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
75
+ if not is_caching:
76
+ self.is_latents_cached = False
77
+ if dataset.is_reg:
78
+ if self.datasets_reg is None:
79
+ self.datasets_reg = []
80
+ self.datasets_reg.append(dataset)
81
+ else:
82
+ if self.datasets is None:
83
+ self.datasets = []
84
+ self.datasets.append(dataset)
85
+
86
+ self.progress_bar = None
87
+ self.sd = StableDiffusion(
88
+ device=self.device,
89
+ model_config=self.model_config,
90
+ dtype=self.dtype,
91
+ )
92
+ print(f"Using device {self.device}")
93
+ self.data_loader: DataLoader = None
94
+ self.adapter: T2IAdapter = None
95
+
96
+ def to_pil(self, img):
97
+ # image comes in -1 to 1. convert to a PIL RGB image
98
+ img = (img + 1) / 2
99
+ img = img.clamp(0, 1)
100
+ img = img[0].permute(1, 2, 0).cpu().numpy()
101
+ img = (img * 255).astype(np.uint8)
102
+ image = Image.fromarray(img)
103
+ return image
104
+
105
+ def run(self):
106
+ with torch.no_grad():
107
+ super().run()
108
+ print("Loading model...")
109
+ self.sd.load_model()
110
+ device = torch.device(self.device)
111
+
112
+ if self.model_config.is_xl:
113
+ pipe = StableDiffusionXLImg2ImgPipeline(
114
+ vae=self.sd.vae,
115
+ unet=self.sd.unet,
116
+ text_encoder=self.sd.text_encoder[0],
117
+ text_encoder_2=self.sd.text_encoder[1],
118
+ tokenizer=self.sd.tokenizer[0],
119
+ tokenizer_2=self.sd.tokenizer[1],
120
+ scheduler=get_sampler(self.generate_config.sampler),
121
+ ).to(device, dtype=self.torch_dtype)
122
+ elif self.model_config.is_pixart:
123
+ pipe = self.sd.pipeline.to(device, dtype=self.torch_dtype)
124
+ else:
125
+ raise NotImplementedError("Only XL models are supported")
126
+ pipe.set_progress_bar_config(disable=True)
127
+
128
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
129
+ # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
130
+
131
+ self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
132
+
133
+ num_batches = len(self.data_loader)
134
+ pbar = tqdm(total=num_batches, desc="Generating images")
135
+ seed = self.generate_config.seed
136
+ # load images from datasets, use tqdm
137
+ for i, batch in enumerate(self.data_loader):
138
+ batch: DataLoaderBatchDTO = batch
139
+
140
+ gen_seed = seed if seed > 0 else random.randint(0, 2 ** 32 - 1)
141
+ generator = torch.manual_seed(gen_seed)
142
+
143
+ file_item: FileItemDTO = batch.file_items[0]
144
+ img_path = file_item.path
145
+ img_filename = os.path.basename(img_path)
146
+ img_filename_no_ext = os.path.splitext(img_filename)[0]
147
+ img_filename = img_filename_no_ext + '.' + self.generate_config.ext
148
+ output_path = os.path.join(self.output_folder, img_filename)
149
+ output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
150
+
151
+ if self.copy_inputs_to is not None:
152
+ output_inputs_path = os.path.join(self.copy_inputs_to, img_filename)
153
+ output_inputs_caption_path = os.path.join(self.copy_inputs_to, img_filename_no_ext + '.txt')
154
+ else:
155
+ output_inputs_path = None
156
+ output_inputs_caption_path = None
157
+
158
+ caption = batch.get_caption_list()[0]
159
+ if self.generate_config.trigger_word is not None:
160
+ caption = caption.replace('[trigger]', self.generate_config.trigger_word)
161
+
162
+ img: torch.Tensor = batch.tensor.clone()
163
+ image = self.to_pil(img)
164
+
165
+ # image.save(output_depth_path)
166
+ if self.model_config.is_pixart:
167
+ pipe: PixArtSigmaPipeline = pipe
168
+
169
+ # Encode the full image once
170
+ encoded_image = pipe.vae.encode(
171
+ pipe.image_processor.preprocess(image).to(device=pipe.device, dtype=pipe.dtype))
172
+ if hasattr(encoded_image, "latent_dist"):
173
+ latents = encoded_image.latent_dist.sample(generator)
174
+ elif hasattr(encoded_image, "latents"):
175
+ latents = encoded_image.latents
176
+ else:
177
+ raise AttributeError("Could not access latents of provided encoder_output")
178
+ latents = pipe.vae.config.scaling_factor * latents
179
+
180
+ # latents = self.sd.encode_images(img)
181
+
182
+ # self.sd.noise_scheduler.set_timesteps(self.generate_config.sample_steps)
183
+ # start_step = math.floor(self.generate_config.sample_steps * self.generate_config.denoise_strength)
184
+ # timestep = self.sd.noise_scheduler.timesteps[start_step].unsqueeze(0)
185
+ # timestep = timestep.to(device, dtype=torch.int32)
186
+ # latent = latent.to(device, dtype=self.torch_dtype)
187
+ # noise = torch.randn_like(latent, device=device, dtype=self.torch_dtype)
188
+ # latent = self.sd.add_noise(latent, noise, timestep)
189
+ # timesteps_to_use = self.sd.noise_scheduler.timesteps[start_step + 1:]
190
+ batch_size = 1
191
+ num_images_per_prompt = 1
192
+
193
+ shape = (batch_size, pipe.transformer.config.in_channels, image.height // pipe.vae_scale_factor,
194
+ image.width // pipe.vae_scale_factor)
195
+ noise = randn_tensor(shape, generator=generator, device=pipe.device, dtype=pipe.dtype)
196
+
197
+ # noise = torch.randn_like(latents, device=device, dtype=self.torch_dtype)
198
+ num_inference_steps = self.generate_config.sample_steps
199
+ strength = self.generate_config.denoise_strength
200
+ # Get timesteps
201
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
202
+ t_start = max(num_inference_steps - init_timestep, 0)
203
+ pipe.scheduler.set_timesteps(num_inference_steps, device="cpu")
204
+ timesteps = pipe.scheduler.timesteps[t_start:]
205
+ timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
206
+ latents = pipe.scheduler.add_noise(latents, noise, timestep)
207
+
208
+ gen_images = pipe.__call__(
209
+ prompt=caption,
210
+ negative_prompt=self.generate_config.neg,
211
+ latents=latents,
212
+ timesteps=timesteps,
213
+ width=image.width,
214
+ height=image.height,
215
+ num_inference_steps=num_inference_steps,
216
+ num_images_per_prompt=num_images_per_prompt,
217
+ guidance_scale=self.generate_config.guidance_scale,
218
+ # strength=self.generate_config.denoise_strength,
219
+ use_resolution_binning=False,
220
+ output_type="np"
221
+ ).images[0]
222
+ gen_images = (gen_images * 255).clip(0, 255).astype(np.uint8)
223
+ gen_images = Image.fromarray(gen_images)
224
+ else:
225
+ pipe: StableDiffusionXLImg2ImgPipeline = pipe
226
+
227
+ gen_images = pipe.__call__(
228
+ prompt=caption,
229
+ negative_prompt=self.generate_config.neg,
230
+ image=image,
231
+ num_inference_steps=self.generate_config.sample_steps,
232
+ guidance_scale=self.generate_config.guidance_scale,
233
+ strength=self.generate_config.denoise_strength,
234
+ ).images[0]
235
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
236
+ gen_images.save(output_path)
237
+
238
+ # save caption
239
+ with open(output_caption_path, 'w') as f:
240
+ f.write(caption)
241
+
242
+ if output_inputs_path is not None:
243
+ os.makedirs(os.path.dirname(output_inputs_path), exist_ok=True)
244
+ image.save(output_inputs_path)
245
+ with open(output_inputs_caption_path, 'w') as f:
246
+ f.write(caption)
247
+
248
+ pbar.update(1)
249
+ batch.cleanup()
250
+
251
+ pbar.close()
252
+ print("Done generating images")
253
+ # cleanup
254
+ del self.sd
255
+ gc.collect()
256
+ torch.cuda.empty_cache()
ai-toolkit/extensions_built_in/advanced_generator/PureLoraGenerator.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+
4
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
5
+ from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
6
+ from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
7
+ from toolkit.stable_diffusion_model import StableDiffusion
8
+ import gc
9
+ import torch
10
+ from jobs.process import BaseExtensionProcess
11
+ from toolkit.train_tools import get_torch_dtype
12
+
13
+
14
+ def flush():
15
+ torch.cuda.empty_cache()
16
+ gc.collect()
17
+
18
+
19
+ class PureLoraGenerator(BaseExtensionProcess):
20
+
21
+ def __init__(self, process_id: int, job, config: OrderedDict):
22
+ super().__init__(process_id, job, config)
23
+ self.output_folder = self.get_conf('output_folder', required=True)
24
+ self.device = self.get_conf('device', 'cuda')
25
+ self.device_torch = torch.device(self.device)
26
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
27
+ self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
28
+ self.dtype = self.get_conf('dtype', 'float16')
29
+ self.torch_dtype = get_torch_dtype(self.dtype)
30
+ lorm_config = self.get_conf('lorm', None)
31
+ self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
32
+
33
+ self.device_state_preset = get_train_sd_device_state_preset(
34
+ device=torch.device(self.device),
35
+ )
36
+
37
+ self.progress_bar = None
38
+ self.sd = StableDiffusion(
39
+ device=self.device,
40
+ model_config=self.model_config,
41
+ dtype=self.dtype,
42
+ )
43
+
44
+ def run(self):
45
+ super().run()
46
+ print("Loading model...")
47
+ with torch.no_grad():
48
+ self.sd.load_model()
49
+ self.sd.unet.eval()
50
+ self.sd.unet.to(self.device_torch)
51
+ if isinstance(self.sd.text_encoder, list):
52
+ for te in self.sd.text_encoder:
53
+ te.eval()
54
+ te.to(self.device_torch)
55
+ else:
56
+ self.sd.text_encoder.eval()
57
+ self.sd.to(self.device_torch)
58
+
59
+ print(f"Converting to LoRM UNet")
60
+ # replace the unet with LoRMUnet
61
+ convert_diffusers_unet_to_lorm(
62
+ self.sd.unet,
63
+ config=self.lorm_config,
64
+ )
65
+
66
+ sample_folder = os.path.join(self.output_folder)
67
+ gen_img_config_list = []
68
+
69
+ sample_config = self.generate_config
70
+ start_seed = sample_config.seed
71
+ current_seed = start_seed
72
+ for i in range(len(sample_config.prompts)):
73
+ if sample_config.walk_seed:
74
+ current_seed = start_seed + i
75
+
76
+ filename = f"[time]_[count].{self.generate_config.ext}"
77
+ output_path = os.path.join(sample_folder, filename)
78
+ prompt = sample_config.prompts[i]
79
+ extra_args = {}
80
+ gen_img_config_list.append(GenerateImageConfig(
81
+ prompt=prompt, # it will autoparse the prompt
82
+ width=sample_config.width,
83
+ height=sample_config.height,
84
+ negative_prompt=sample_config.neg,
85
+ seed=current_seed,
86
+ guidance_scale=sample_config.guidance_scale,
87
+ guidance_rescale=sample_config.guidance_rescale,
88
+ num_inference_steps=sample_config.sample_steps,
89
+ network_multiplier=sample_config.network_multiplier,
90
+ output_path=output_path,
91
+ output_ext=sample_config.ext,
92
+ adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
93
+ **extra_args
94
+ ))
95
+
96
+ # send to be generated
97
+ self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
98
+ print("Done generating images")
99
+ # cleanup
100
+ del self.sd
101
+ gc.collect()
102
+ torch.cuda.empty_cache()
ai-toolkit/extensions_built_in/advanced_generator/ReferenceGenerator.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from collections import OrderedDict
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ from diffusers import T2IAdapter
9
+ from torch.utils.data import DataLoader
10
+ from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
11
+ from tqdm import tqdm
12
+
13
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
14
+ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
15
+ from toolkit.sampler import get_sampler
16
+ from toolkit.stable_diffusion_model import StableDiffusion
17
+ import gc
18
+ import torch
19
+ from jobs.process import BaseExtensionProcess
20
+ from toolkit.data_loader import get_dataloader_from_datasets
21
+ from toolkit.train_tools import get_torch_dtype
22
+ from controlnet_aux.midas import MidasDetector
23
+ from diffusers.utils import load_image
24
+
25
+
26
+ def flush():
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+
31
+ class GenerateConfig:
32
+
33
+ def __init__(self, **kwargs):
34
+ self.prompts: List[str]
35
+ self.sampler = kwargs.get('sampler', 'ddpm')
36
+ self.neg = kwargs.get('neg', '')
37
+ self.seed = kwargs.get('seed', -1)
38
+ self.walk_seed = kwargs.get('walk_seed', False)
39
+ self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None)
40
+ self.guidance_scale = kwargs.get('guidance_scale', 7)
41
+ self.sample_steps = kwargs.get('sample_steps', 20)
42
+ self.prompt_2 = kwargs.get('prompt_2', None)
43
+ self.neg_2 = kwargs.get('neg_2', None)
44
+ self.prompts = kwargs.get('prompts', None)
45
+ self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
46
+ self.ext = kwargs.get('ext', 'png')
47
+ self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
48
+ if kwargs.get('shuffle', False):
49
+ # shuffle the prompts
50
+ random.shuffle(self.prompts)
51
+
52
+
53
+ class ReferenceGenerator(BaseExtensionProcess):
54
+
55
+ def __init__(self, process_id: int, job, config: OrderedDict):
56
+ super().__init__(process_id, job, config)
57
+ self.output_folder = self.get_conf('output_folder', required=True)
58
+ self.device = self.get_conf('device', 'cuda')
59
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
60
+ self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
61
+ self.is_latents_cached = True
62
+ raw_datasets = self.get_conf('datasets', None)
63
+ if raw_datasets is not None and len(raw_datasets) > 0:
64
+ raw_datasets = preprocess_dataset_raw_config(raw_datasets)
65
+ self.datasets = None
66
+ self.datasets_reg = None
67
+ self.dtype = self.get_conf('dtype', 'float16')
68
+ self.torch_dtype = get_torch_dtype(self.dtype)
69
+ self.params = []
70
+ if raw_datasets is not None and len(raw_datasets) > 0:
71
+ for raw_dataset in raw_datasets:
72
+ dataset = DatasetConfig(**raw_dataset)
73
+ is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
74
+ if not is_caching:
75
+ self.is_latents_cached = False
76
+ if dataset.is_reg:
77
+ if self.datasets_reg is None:
78
+ self.datasets_reg = []
79
+ self.datasets_reg.append(dataset)
80
+ else:
81
+ if self.datasets is None:
82
+ self.datasets = []
83
+ self.datasets.append(dataset)
84
+
85
+ self.progress_bar = None
86
+ self.sd = StableDiffusion(
87
+ device=self.device,
88
+ model_config=self.model_config,
89
+ dtype=self.dtype,
90
+ )
91
+ print(f"Using device {self.device}")
92
+ self.data_loader: DataLoader = None
93
+ self.adapter: T2IAdapter = None
94
+
95
+ def run(self):
96
+ super().run()
97
+ print("Loading model...")
98
+ self.sd.load_model()
99
+ device = torch.device(self.device)
100
+
101
+ if self.generate_config.t2i_adapter_path is not None:
102
+ self.adapter = T2IAdapter.from_pretrained(
103
+ self.generate_config.t2i_adapter_path,
104
+ torch_dtype=self.torch_dtype,
105
+ varient="fp16"
106
+ ).to(device)
107
+
108
+ midas_depth = MidasDetector.from_pretrained(
109
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
110
+ ).to(device)
111
+
112
+ if self.model_config.is_xl:
113
+ pipe = StableDiffusionXLAdapterPipeline(
114
+ vae=self.sd.vae,
115
+ unet=self.sd.unet,
116
+ text_encoder=self.sd.text_encoder[0],
117
+ text_encoder_2=self.sd.text_encoder[1],
118
+ tokenizer=self.sd.tokenizer[0],
119
+ tokenizer_2=self.sd.tokenizer[1],
120
+ scheduler=get_sampler(self.generate_config.sampler),
121
+ adapter=self.adapter,
122
+ ).to(device, dtype=self.torch_dtype)
123
+ else:
124
+ pipe = StableDiffusionAdapterPipeline(
125
+ vae=self.sd.vae,
126
+ unet=self.sd.unet,
127
+ text_encoder=self.sd.text_encoder,
128
+ tokenizer=self.sd.tokenizer,
129
+ scheduler=get_sampler(self.generate_config.sampler),
130
+ safety_checker=None,
131
+ feature_extractor=None,
132
+ requires_safety_checker=False,
133
+ adapter=self.adapter,
134
+ ).to(device, dtype=self.torch_dtype)
135
+ pipe.set_progress_bar_config(disable=True)
136
+
137
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
138
+ # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
139
+
140
+ self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
141
+
142
+ num_batches = len(self.data_loader)
143
+ pbar = tqdm(total=num_batches, desc="Generating images")
144
+ seed = self.generate_config.seed
145
+ # load images from datasets, use tqdm
146
+ for i, batch in enumerate(self.data_loader):
147
+ batch: DataLoaderBatchDTO = batch
148
+
149
+ file_item: FileItemDTO = batch.file_items[0]
150
+ img_path = file_item.path
151
+ img_filename = os.path.basename(img_path)
152
+ img_filename_no_ext = os.path.splitext(img_filename)[0]
153
+ output_path = os.path.join(self.output_folder, img_filename)
154
+ output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
155
+ output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png')
156
+
157
+ caption = batch.get_caption_list()[0]
158
+
159
+ img: torch.Tensor = batch.tensor.clone()
160
+ # image comes in -1 to 1. convert to a PIL RGB image
161
+ img = (img + 1) / 2
162
+ img = img.clamp(0, 1)
163
+ img = img[0].permute(1, 2, 0).cpu().numpy()
164
+ img = (img * 255).astype(np.uint8)
165
+ image = Image.fromarray(img)
166
+
167
+ width, height = image.size
168
+ min_res = min(width, height)
169
+
170
+ if self.generate_config.walk_seed:
171
+ seed = seed + 1
172
+
173
+ if self.generate_config.seed == -1:
174
+ # random
175
+ seed = random.randint(0, 1000000)
176
+
177
+ torch.manual_seed(seed)
178
+ torch.cuda.manual_seed(seed)
179
+
180
+ # generate depth map
181
+ image = midas_depth(
182
+ image,
183
+ detect_resolution=min_res, # do 512 ?
184
+ image_resolution=min_res
185
+ )
186
+
187
+ # image.save(output_depth_path)
188
+
189
+ gen_images = pipe(
190
+ prompt=caption,
191
+ negative_prompt=self.generate_config.neg,
192
+ image=image,
193
+ num_inference_steps=self.generate_config.sample_steps,
194
+ adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
195
+ guidance_scale=self.generate_config.guidance_scale,
196
+ ).images[0]
197
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
198
+ gen_images.save(output_path)
199
+
200
+ # save caption
201
+ with open(output_caption_path, 'w') as f:
202
+ f.write(caption)
203
+
204
+ pbar.update(1)
205
+ batch.cleanup()
206
+
207
+ pbar.close()
208
+ print("Done generating images")
209
+ # cleanup
210
+ del self.sd
211
+ gc.collect()
212
+ torch.cuda.empty_cache()
ai-toolkit/extensions_built_in/advanced_generator/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
6
+ class AdvancedReferenceGeneratorExtension(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "reference_generator"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Reference Generator"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .ReferenceGenerator import ReferenceGenerator
19
+ return ReferenceGenerator
20
+
21
+
22
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
23
+ class PureLoraGenerator(Extension):
24
+ # uid must be unique, it is how the extension is identified
25
+ uid = "pure_lora_generator"
26
+
27
+ # name is the name of the extension for printing
28
+ name = "Pure LoRA Generator"
29
+
30
+ # This is where your process class is loaded
31
+ # keep your imports in here so they don't slow down the rest of the program
32
+ @classmethod
33
+ def get_process(cls):
34
+ # import your process class here so it is only loaded when needed and return it
35
+ from .PureLoraGenerator import PureLoraGenerator
36
+ return PureLoraGenerator
37
+
38
+
39
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
40
+ class Img2ImgGeneratorExtension(Extension):
41
+ # uid must be unique, it is how the extension is identified
42
+ uid = "batch_img2img"
43
+
44
+ # name is the name of the extension for printing
45
+ name = "Img2ImgGeneratorExtension"
46
+
47
+ # This is where your process class is loaded
48
+ # keep your imports in here so they don't slow down the rest of the program
49
+ @classmethod
50
+ def get_process(cls):
51
+ # import your process class here so it is only loaded when needed and return it
52
+ from .Img2ImgGenerator import Img2ImgGenerator
53
+ return Img2ImgGenerator
54
+
55
+
56
+ AI_TOOLKIT_EXTENSIONS = [
57
+ # you can put a list of extensions here
58
+ AdvancedReferenceGeneratorExtension, PureLoraGenerator, Img2ImgGeneratorExtension
59
+ ]
ai-toolkit/extensions_built_in/advanced_generator/__pycache__/Img2ImgGenerator.cpython-312.pyc ADDED
Binary file (14.1 kB). View file
 
ai-toolkit/extensions_built_in/advanced_generator/__pycache__/PureLoraGenerator.cpython-312.pyc ADDED
Binary file (5.91 kB). View file
 
ai-toolkit/extensions_built_in/advanced_generator/__pycache__/ReferenceGenerator.cpython-312.pyc ADDED
Binary file (11.7 kB). View file
 
ai-toolkit/extensions_built_in/advanced_generator/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.61 kB). View file