Spaces:
Running
on
Zero
Running
on
Zero
Upload 61 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +12 -0
- OminiControl/LICENSE +201 -0
- OminiControl/README.md +170 -0
- OminiControl/assets/book.jpg +0 -0
- OminiControl/assets/cartoon_boy.png +3 -0
- OminiControl/assets/clock.jpg +3 -0
- OminiControl/assets/coffee.png +0 -0
- OminiControl/assets/demo/book_omini.jpg +0 -0
- OminiControl/assets/demo/clock_omini.jpg +0 -0
- OminiControl/assets/demo/demo_this_is_omini_control.jpg +3 -0
- OminiControl/assets/demo/dreambooth_res.jpg +3 -0
- OminiControl/assets/demo/man_omini.jpg +0 -0
- OminiControl/assets/demo/monalisa_omini.jpg +3 -0
- OminiControl/assets/demo/oranges_omini.jpg +0 -0
- OminiControl/assets/demo/panda_omini.jpg +0 -0
- OminiControl/assets/demo/penguin_omini.jpg +0 -0
- OminiControl/assets/demo/rc_car_omini.jpg +0 -0
- OminiControl/assets/demo/room_corner_canny.jpg +0 -0
- OminiControl/assets/demo/room_corner_coloring.jpg +0 -0
- OminiControl/assets/demo/room_corner_deblurring.jpg +0 -0
- OminiControl/assets/demo/room_corner_depth.jpg +0 -0
- OminiControl/assets/demo/scene_variation.jpg +3 -0
- OminiControl/assets/demo/shirt_omini.jpg +0 -0
- OminiControl/assets/demo/try_on.jpg +3 -0
- OminiControl/assets/monalisa.jpg +3 -0
- OminiControl/assets/oranges.jpg +0 -0
- OminiControl/assets/penguin.jpg +0 -0
- OminiControl/assets/rc_car.jpg +3 -0
- OminiControl/assets/room_corner.jpg +3 -0
- OminiControl/assets/test_in.jpg +0 -0
- OminiControl/assets/test_out.jpg +0 -0
- OminiControl/assets/tshirt.jpg +3 -0
- OminiControl/assets/vase.jpg +0 -0
- OminiControl/assets/vase_hq.jpg +3 -0
- OminiControl/examples/inpainting.ipynb +143 -0
- OminiControl/examples/spatial.ipynb +184 -0
- OminiControl/examples/subject.ipynb +214 -0
- OminiControl/examples/subject_1024.ipynb +221 -0
- OminiControl/requirements.txt +9 -0
- OminiControl/src/flux/block.py +339 -0
- OminiControl/src/flux/condition.py +138 -0
- OminiControl/src/flux/generate.py +321 -0
- OminiControl/src/flux/lora_controller.py +75 -0
- OminiControl/src/flux/pipeline_tools.py +52 -0
- OminiControl/src/flux/transformer.py +252 -0
- OminiControl/src/gradio/gradio_app.py +115 -0
- OminiControl/src/train/callbacks.py +253 -0
- OminiControl/src/train/data.py +323 -0
- OminiControl/src/train/model.py +185 -0
- OminiControl/src/train/train.py +178 -0
.gitattributes
CHANGED
@@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
OminiControl/assets/cartoon_boy.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
OminiControl/assets/clock.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
OminiControl/assets/demo/demo_this_is_omini_control.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
OminiControl/assets/demo/dreambooth_res.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
OminiControl/assets/demo/monalisa_omini.jpg filter=lfs diff=lfs merge=lfs -text
|
41 |
+
OminiControl/assets/demo/scene_variation.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
OminiControl/assets/demo/try_on.jpg filter=lfs diff=lfs merge=lfs -text
|
43 |
+
OminiControl/assets/monalisa.jpg filter=lfs diff=lfs merge=lfs -text
|
44 |
+
OminiControl/assets/rc_car.jpg filter=lfs diff=lfs merge=lfs -text
|
45 |
+
OminiControl/assets/room_corner.jpg filter=lfs diff=lfs merge=lfs -text
|
46 |
+
OminiControl/assets/tshirt.jpg filter=lfs diff=lfs merge=lfs -text
|
47 |
+
OminiControl/assets/vase_hq.jpg filter=lfs diff=lfs merge=lfs -text
|
OminiControl/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [2024] [Zhenxiong Tan]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
OminiControl/README.md
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# OminiControl
|
2 |
+
|
3 |
+
|
4 |
+
<img src='./assets/demo/demo_this_is_omini_control.jpg' width='100%' />
|
5 |
+
<br>
|
6 |
+
|
7 |
+
<a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ariXv-2411.15098-A42C25.svg" alt="arXiv"></a>
|
8 |
+
<a href="https://huggingface.co/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a>
|
9 |
+
<a href="https://huggingface.co/spaces/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗_HuggingFace-Space-ffbd45.svg" alt="HuggingFace"></a>
|
10 |
+
<a href="https://github.com/Yuanshi9815/Subjects200K"><img src="https://img.shields.io/badge/GitHub-Dataset-blue.svg?logo=github&" alt="GitHub"></a>
|
11 |
+
<a href="https://huggingface.co/datasets/Yuanshi/Subjects200K"><img src="https://img.shields.io/badge/🤗_HuggingFace-Dataset-ffbd45.svg" alt="HuggingFace"></a>
|
12 |
+
|
13 |
+
> **OminiControl: Minimal and Universal Control for Diffusion Transformer**
|
14 |
+
> <br>
|
15 |
+
> Zhenxiong Tan,
|
16 |
+
> [Songhua Liu](http://121.37.94.87/),
|
17 |
+
> [Xingyi Yang](https://adamdad.github.io/),
|
18 |
+
> Qiaochu Xue,
|
19 |
+
> and
|
20 |
+
> [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
|
21 |
+
> <br>
|
22 |
+
> [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore
|
23 |
+
> <br>
|
24 |
+
|
25 |
+
|
26 |
+
## Features
|
27 |
+
|
28 |
+
OminiControl is a minimal yet powerful universal control framework for Diffusion Transformer models like [FLUX](https://github.com/black-forest-labs/flux).
|
29 |
+
|
30 |
+
* **Universal Control 🌐**: A unified control framework that supports both subject-driven control and spatial control (such as edge-guided and in-painting generation).
|
31 |
+
|
32 |
+
* **Minimal Design 🚀**: Injects control signals while preserving original model structure. Only introduces 0.1% additional parameters to the base model.
|
33 |
+
|
34 |
+
## News
|
35 |
+
- **2024-12-26**: ⭐️ Training code are released. Now you can create your own OminiControl model by customizing any control tasks (3D, multi-view, pose-guided, try-on, etc.) with the FLUX model. Check the [training folder](./train) for more details.
|
36 |
+
|
37 |
+
## Quick Start
|
38 |
+
### Setup (Optional)
|
39 |
+
1. **Environment setup**
|
40 |
+
```bash
|
41 |
+
conda create -n omini python=3.10
|
42 |
+
conda activate omini
|
43 |
+
```
|
44 |
+
2. **Requirements installation**
|
45 |
+
```bash
|
46 |
+
pip install -r requirements.txt
|
47 |
+
```
|
48 |
+
### Usage example
|
49 |
+
1. Subject-driven generation: `examples/subject.ipynb`
|
50 |
+
2. In-painting: `examples/inpainting.ipynb`
|
51 |
+
3. Canny edge to image, depth to image, colorization, deblurring: `examples/spatial.ipynb`
|
52 |
+
|
53 |
+
### Gradio app
|
54 |
+
To run the Gradio app for subject-driven generation:
|
55 |
+
```bash
|
56 |
+
python -m src.gradio.gradio_app
|
57 |
+
```
|
58 |
+
|
59 |
+
### Guidelines for subject-driven generation
|
60 |
+
1. Input images are automatically center-cropped and resized to 512x512 resolution.
|
61 |
+
2. When writing prompts, refer to the subject using phrases like `this item`, `the object`, or `it`. e.g.
|
62 |
+
1. *A close up view of this item. It is placed on a wooden table.*
|
63 |
+
2. *A young lady is wearing this shirt.*
|
64 |
+
3. The model primarily works with objects rather than human subjects currently, due to the absence of human data in training.
|
65 |
+
|
66 |
+
## Generated samples
|
67 |
+
### Subject-driven generation
|
68 |
+
<a href="https://huggingface.co/spaces/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗_HuggingFace-Space-ffbd45.svg" alt="HuggingFace"></a>
|
69 |
+
|
70 |
+
**Demos** (Left: condition image; Right: generated image)
|
71 |
+
|
72 |
+
<div float="left">
|
73 |
+
<img src='./assets/demo/oranges_omini.jpg' width='48%'/>
|
74 |
+
<img src='./assets/demo/rc_car_omini.jpg' width='48%' />
|
75 |
+
<img src='./assets/demo/clock_omini.jpg' width='48%' />
|
76 |
+
<img src='./assets/demo/shirt_omini.jpg' width='48%' />
|
77 |
+
</div>
|
78 |
+
|
79 |
+
<details>
|
80 |
+
<summary>Text Prompts</summary>
|
81 |
+
|
82 |
+
- Prompt1: *A close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!.'*
|
83 |
+
- Prompt2: *A film style shot. On the moon, this item drives across the moon surface. A flag on it reads 'Omini'. The background is that Earth looms large in the foreground.*
|
84 |
+
- Prompt3: *In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.*
|
85 |
+
- Prompt4: *"On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple."*
|
86 |
+
</details>
|
87 |
+
<details>
|
88 |
+
<summary>More results</summary>
|
89 |
+
|
90 |
+
* Try on:
|
91 |
+
<img src='./assets/demo/try_on.jpg'/>
|
92 |
+
* Scene variations:
|
93 |
+
<img src='./assets/demo/scene_variation.jpg'/>
|
94 |
+
* Dreambooth dataset:
|
95 |
+
<img src='./assets/demo/dreambooth_res.jpg'/>
|
96 |
+
* Oye-cartoon finetune:
|
97 |
+
<div float="left">
|
98 |
+
<img src='./assets/demo/man_omini.jpg' width='48%' />
|
99 |
+
<img src='./assets/demo/panda_omini.jpg' width='48%' />
|
100 |
+
</div>
|
101 |
+
</details>
|
102 |
+
|
103 |
+
### Spatially aligned control
|
104 |
+
1. **Image Inpainting** (Left: original image; Center: masked image; Right: filled image)
|
105 |
+
- Prompt: *The Mona Lisa is wearing a white VR headset with 'Omini' written on it.*
|
106 |
+
</br>
|
107 |
+
<img src='./assets/demo/monalisa_omini.jpg' width='700px' />
|
108 |
+
- Prompt: *A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.*
|
109 |
+
</br>
|
110 |
+
<img src='./assets/demo/book_omini.jpg' width='700px' />
|
111 |
+
2. **Other spatially aligned tasks** (Canny edge to image, depth to image, colorization, deblurring)
|
112 |
+
</br>
|
113 |
+
<details>
|
114 |
+
<summary>Click to show</summary>
|
115 |
+
<div float="left">
|
116 |
+
<img src='./assets/demo/room_corner_canny.jpg' width='48%'/>
|
117 |
+
<img src='./assets/demo/room_corner_depth.jpg' width='48%' />
|
118 |
+
<img src='./assets/demo/room_corner_coloring.jpg' width='48%' />
|
119 |
+
<img src='./assets/demo/room_corner_deblurring.jpg' width='48%' />
|
120 |
+
</div>
|
121 |
+
|
122 |
+
Prompt: *A light gray sofa stands against a white wall, featuring a black and white geometric patterned pillow. A white side table sits next to the sofa, topped with a white adjustable desk lamp and some books. Dark hardwood flooring contrasts with the pale walls and furniture.*
|
123 |
+
</details>
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
## Models
|
129 |
+
|
130 |
+
**Subject-driven control:**
|
131 |
+
| Model | Base model | Description | Resolution |
|
132 |
+
| ------------------------------------------------------------------------------------------------ | -------------- | -------------------------------------------------------------------------------------------------------- | ------------ |
|
133 |
+
| [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `subject` | FLUX.1-schnell | The model used in the paper. | (512, 512) |
|
134 |
+
| [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_512` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset. | (512, 512) |
|
135 |
+
| [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_1024` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset and accommodates higher resolution. (To be released) | (1024, 1024) |
|
136 |
+
| [`oye-cartoon`](https://huggingface.co/saquiboye/oye-cartoon) | FLUX.1-dev | The model has been fine-tuned on [oye-cartoon](https://huggingface.co/datasets/saquiboye/oye-cartoon) dataset by [@saquib764](https://github.com/Saquib764) | (512, 512) |
|
137 |
+
|
138 |
+
**Spatial aligned control:**
|
139 |
+
| Model | Base model | Description | Resolution |
|
140 |
+
| --------------------------------------------------------------------------------------------------------- | ---------- | -------------------------------------------------------------------------- | ------------ |
|
141 |
+
| [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `<task_name>` | FLUX.1 | Canny edge to image, depth to image, colorization, deblurring, in-painting | (512, 512) |
|
142 |
+
| [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `<task_name>_1024` | FLUX.1 | Supports higher resolution.(To be released) | (1024, 1024) |
|
143 |
+
|
144 |
+
## Community Extensions
|
145 |
+
- [ComfyUI-Diffusers-OminiControl](https://github.com/Macoron/ComfyUI-Diffusers-OminiControl) - ComfyUI integration by [@Macoron](https://github.com/Macoron)
|
146 |
+
- [ComfyUI_RH_OminiControl](https://github.com/HM-RunningHub/ComfyUI_RH_OminiControl) - ComfyUI integration by [@HM-RunningHub](https://github.com/HM-RunningHub)
|
147 |
+
|
148 |
+
## Limitations
|
149 |
+
1. The model's subject-driven generation primarily works with objects rather than human subjects due to the absence of human data in training.
|
150 |
+
2. The subject-driven generation model may not work well with `FLUX.1-dev`.
|
151 |
+
3. The released model currently only supports the resolution of 512x512.
|
152 |
+
|
153 |
+
## Training
|
154 |
+
Training instructions can be found in this [folder](./train).
|
155 |
+
|
156 |
+
|
157 |
+
## To-do
|
158 |
+
- [x] Release the training code.
|
159 |
+
- [ ] Release the model for higher resolution (1024x1024).
|
160 |
+
|
161 |
+
## Citation
|
162 |
+
```
|
163 |
+
@article{tan2024ominicontrol,
|
164 |
+
title={Ominicontrol: Minimal and universal control for diffusion transformer},
|
165 |
+
author={Tan, Zhenxiong and Liu, Songhua and Yang, Xingyi and Xue, Qiaochu and Wang, Xinchao},
|
166 |
+
journal={arXiv preprint arXiv:2411.15098},
|
167 |
+
volume={3},
|
168 |
+
year={2024}
|
169 |
+
}
|
170 |
+
```
|
OminiControl/assets/book.jpg
ADDED
![]() |
OminiControl/assets/cartoon_boy.png
ADDED
![]() |
Git LFS Details
|
OminiControl/assets/clock.jpg
ADDED
![]() |
Git LFS Details
|
OminiControl/assets/coffee.png
ADDED
![]() |
OminiControl/assets/demo/book_omini.jpg
ADDED
![]() |
OminiControl/assets/demo/clock_omini.jpg
ADDED
![]() |
OminiControl/assets/demo/demo_this_is_omini_control.jpg
ADDED
![]() |
Git LFS Details
|
OminiControl/assets/demo/dreambooth_res.jpg
ADDED
![]() |
Git LFS Details
|
OminiControl/assets/demo/man_omini.jpg
ADDED
![]() |
OminiControl/assets/demo/monalisa_omini.jpg
ADDED
![]() |
Git LFS Details
|
OminiControl/assets/demo/oranges_omini.jpg
ADDED
![]() |
OminiControl/assets/demo/panda_omini.jpg
ADDED
![]() |
OminiControl/assets/demo/penguin_omini.jpg
ADDED
![]() |
OminiControl/assets/demo/rc_car_omini.jpg
ADDED
![]() |
OminiControl/assets/demo/room_corner_canny.jpg
ADDED
![]() |
OminiControl/assets/demo/room_corner_coloring.jpg
ADDED
![]() |
OminiControl/assets/demo/room_corner_deblurring.jpg
ADDED
![]() |
OminiControl/assets/demo/room_corner_depth.jpg
ADDED
![]() |
OminiControl/assets/demo/scene_variation.jpg
ADDED
![]() |
Git LFS Details
|
OminiControl/assets/demo/shirt_omini.jpg
ADDED
![]() |
OminiControl/assets/demo/try_on.jpg
ADDED
![]() |
Git LFS Details
|
OminiControl/assets/monalisa.jpg
ADDED
![]() |
Git LFS Details
|
OminiControl/assets/oranges.jpg
ADDED
![]() |
OminiControl/assets/penguin.jpg
ADDED
![]() |
OminiControl/assets/rc_car.jpg
ADDED
![]() |
Git LFS Details
|
OminiControl/assets/room_corner.jpg
ADDED
![]() |
Git LFS Details
|
OminiControl/assets/test_in.jpg
ADDED
![]() |
OminiControl/assets/test_out.jpg
ADDED
![]() |
OminiControl/assets/tshirt.jpg
ADDED
![]() |
Git LFS Details
|
OminiControl/assets/vase.jpg
ADDED
![]() |
OminiControl/assets/vase_hq.jpg
ADDED
![]() |
Git LFS Details
|
OminiControl/examples/inpainting.ipynb
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"\n",
|
11 |
+
"os.chdir(\"..\")"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": null,
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"import torch\n",
|
21 |
+
"from diffusers.pipelines import FluxPipeline\n",
|
22 |
+
"from src.flux.condition import Condition\n",
|
23 |
+
"from PIL import Image\n",
|
24 |
+
"\n",
|
25 |
+
"from src.flux.generate import generate, seed_everything"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": null,
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"pipe = FluxPipeline.from_pretrained(\n",
|
35 |
+
" \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
|
36 |
+
")\n",
|
37 |
+
"pipe = pipe.to(\"cuda\")"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"execution_count": null,
|
43 |
+
"metadata": {},
|
44 |
+
"outputs": [],
|
45 |
+
"source": [
|
46 |
+
"pipe.load_lora_weights(\n",
|
47 |
+
" \"Yuanshi/OminiControl\",\n",
|
48 |
+
" weight_name=f\"experimental/fill.safetensors\",\n",
|
49 |
+
" adapter_name=\"fill\",\n",
|
50 |
+
")"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "code",
|
55 |
+
"execution_count": null,
|
56 |
+
"metadata": {},
|
57 |
+
"outputs": [],
|
58 |
+
"source": [
|
59 |
+
"image = Image.open(\"assets/monalisa.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
60 |
+
"\n",
|
61 |
+
"masked_image = image.copy()\n",
|
62 |
+
"masked_image.paste((0, 0, 0), (128, 100, 384, 220))\n",
|
63 |
+
"\n",
|
64 |
+
"condition = Condition(\"fill\", masked_image)\n",
|
65 |
+
"\n",
|
66 |
+
"seed_everything()\n",
|
67 |
+
"result_img = generate(\n",
|
68 |
+
" pipe,\n",
|
69 |
+
" prompt=\"The Mona Lisa is wearing a white VR headset with 'Omini' written on it.\",\n",
|
70 |
+
" conditions=[condition],\n",
|
71 |
+
").images[0]\n",
|
72 |
+
"\n",
|
73 |
+
"concat_image = Image.new(\"RGB\", (1536, 512))\n",
|
74 |
+
"concat_image.paste(image, (0, 0))\n",
|
75 |
+
"concat_image.paste(condition.condition, (512, 0))\n",
|
76 |
+
"concat_image.paste(result_img, (1024, 0))\n",
|
77 |
+
"concat_image"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": null,
|
83 |
+
"metadata": {},
|
84 |
+
"outputs": [],
|
85 |
+
"source": [
|
86 |
+
"image = Image.open(\"assets/book.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
87 |
+
"\n",
|
88 |
+
"w, h, min_dim = image.size + (min(image.size),)\n",
|
89 |
+
"image = image.crop(\n",
|
90 |
+
" ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n",
|
91 |
+
").resize((512, 512))\n",
|
92 |
+
"\n",
|
93 |
+
"\n",
|
94 |
+
"masked_image = image.copy()\n",
|
95 |
+
"masked_image.paste((0, 0, 0), (150, 150, 350, 250))\n",
|
96 |
+
"masked_image.paste((0, 0, 0), (200, 380, 320, 420))\n",
|
97 |
+
"\n",
|
98 |
+
"condition = Condition(\"fill\", masked_image)\n",
|
99 |
+
"\n",
|
100 |
+
"seed_everything()\n",
|
101 |
+
"result_img = generate(\n",
|
102 |
+
" pipe,\n",
|
103 |
+
" prompt=\"A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.\",\n",
|
104 |
+
" conditions=[condition],\n",
|
105 |
+
").images[0]\n",
|
106 |
+
"\n",
|
107 |
+
"concat_image = Image.new(\"RGB\", (1536, 512))\n",
|
108 |
+
"concat_image.paste(image, (0, 0))\n",
|
109 |
+
"concat_image.paste(condition.condition, (512, 0))\n",
|
110 |
+
"concat_image.paste(result_img, (1024, 0))\n",
|
111 |
+
"concat_image"
|
112 |
+
]
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"cell_type": "code",
|
116 |
+
"execution_count": null,
|
117 |
+
"metadata": {},
|
118 |
+
"outputs": [],
|
119 |
+
"source": []
|
120 |
+
}
|
121 |
+
],
|
122 |
+
"metadata": {
|
123 |
+
"kernelspec": {
|
124 |
+
"display_name": "base",
|
125 |
+
"language": "python",
|
126 |
+
"name": "python3"
|
127 |
+
},
|
128 |
+
"language_info": {
|
129 |
+
"codemirror_mode": {
|
130 |
+
"name": "ipython",
|
131 |
+
"version": 3
|
132 |
+
},
|
133 |
+
"file_extension": ".py",
|
134 |
+
"mimetype": "text/x-python",
|
135 |
+
"name": "python",
|
136 |
+
"nbconvert_exporter": "python",
|
137 |
+
"pygments_lexer": "ipython3",
|
138 |
+
"version": "3.12.7"
|
139 |
+
}
|
140 |
+
},
|
141 |
+
"nbformat": 4,
|
142 |
+
"nbformat_minor": 2
|
143 |
+
}
|
OminiControl/examples/spatial.ipynb
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"\n",
|
11 |
+
"os.chdir(\"..\")"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": null,
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"import torch\n",
|
21 |
+
"from diffusers.pipelines import FluxPipeline\n",
|
22 |
+
"from src.flux.condition import Condition\n",
|
23 |
+
"from PIL import Image\n",
|
24 |
+
"\n",
|
25 |
+
"from src.flux.generate import generate, seed_everything"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": null,
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"pipe = FluxPipeline.from_pretrained(\n",
|
35 |
+
" \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
|
36 |
+
")\n",
|
37 |
+
"pipe = pipe.to(\"cuda\")"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"execution_count": null,
|
43 |
+
"metadata": {},
|
44 |
+
"outputs": [],
|
45 |
+
"source": [
|
46 |
+
"for condition_type in [\"canny\", \"depth\", \"coloring\", \"deblurring\"]:\n",
|
47 |
+
" pipe.load_lora_weights(\n",
|
48 |
+
" \"Yuanshi/OminiControl\",\n",
|
49 |
+
" weight_name=f\"experimental/{condition_type}.safetensors\",\n",
|
50 |
+
" adapter_name=condition_type,\n",
|
51 |
+
" )"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "code",
|
56 |
+
"execution_count": null,
|
57 |
+
"metadata": {},
|
58 |
+
"outputs": [],
|
59 |
+
"source": [
|
60 |
+
"image = Image.open(\"assets/coffee.png\").convert(\"RGB\")\n",
|
61 |
+
"\n",
|
62 |
+
"w, h, min_dim = image.size + (min(image.size),)\n",
|
63 |
+
"image = image.crop(\n",
|
64 |
+
" ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n",
|
65 |
+
").resize((512, 512))\n",
|
66 |
+
"\n",
|
67 |
+
"prompt = \"In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table.\""
|
68 |
+
]
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"cell_type": "code",
|
72 |
+
"execution_count": null,
|
73 |
+
"metadata": {},
|
74 |
+
"outputs": [],
|
75 |
+
"source": [
|
76 |
+
"condition = Condition(\"canny\", image)\n",
|
77 |
+
"\n",
|
78 |
+
"seed_everything()\n",
|
79 |
+
"\n",
|
80 |
+
"result_img = generate(\n",
|
81 |
+
" pipe,\n",
|
82 |
+
" prompt=prompt,\n",
|
83 |
+
" conditions=[condition],\n",
|
84 |
+
").images[0]\n",
|
85 |
+
"\n",
|
86 |
+
"concat_image = Image.new(\"RGB\", (1536, 512))\n",
|
87 |
+
"concat_image.paste(image, (0, 0))\n",
|
88 |
+
"concat_image.paste(condition.condition, (512, 0))\n",
|
89 |
+
"concat_image.paste(result_img, (1024, 0))\n",
|
90 |
+
"concat_image"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"execution_count": null,
|
96 |
+
"metadata": {},
|
97 |
+
"outputs": [],
|
98 |
+
"source": [
|
99 |
+
"condition = Condition(\"depth\", image)\n",
|
100 |
+
"\n",
|
101 |
+
"seed_everything()\n",
|
102 |
+
"\n",
|
103 |
+
"result_img = generate(\n",
|
104 |
+
" pipe,\n",
|
105 |
+
" prompt=prompt,\n",
|
106 |
+
" conditions=[condition],\n",
|
107 |
+
").images[0]\n",
|
108 |
+
"\n",
|
109 |
+
"concat_image = Image.new(\"RGB\", (1536, 512))\n",
|
110 |
+
"concat_image.paste(image, (0, 0))\n",
|
111 |
+
"concat_image.paste(condition.condition, (512, 0))\n",
|
112 |
+
"concat_image.paste(result_img, (1024, 0))\n",
|
113 |
+
"concat_image"
|
114 |
+
]
|
115 |
+
},
|
116 |
+
{
|
117 |
+
"cell_type": "code",
|
118 |
+
"execution_count": null,
|
119 |
+
"metadata": {},
|
120 |
+
"outputs": [],
|
121 |
+
"source": [
|
122 |
+
"condition = Condition(\"deblurring\", image)\n",
|
123 |
+
"\n",
|
124 |
+
"seed_everything()\n",
|
125 |
+
"\n",
|
126 |
+
"result_img = generate(\n",
|
127 |
+
" pipe,\n",
|
128 |
+
" prompt=prompt,\n",
|
129 |
+
" conditions=[condition],\n",
|
130 |
+
").images[0]\n",
|
131 |
+
"\n",
|
132 |
+
"concat_image = Image.new(\"RGB\", (1536, 512))\n",
|
133 |
+
"concat_image.paste(image, (0, 0))\n",
|
134 |
+
"concat_image.paste(condition.condition, (512, 0))\n",
|
135 |
+
"concat_image.paste(result_img, (1024, 0))\n",
|
136 |
+
"concat_image"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "code",
|
141 |
+
"execution_count": null,
|
142 |
+
"metadata": {},
|
143 |
+
"outputs": [],
|
144 |
+
"source": [
|
145 |
+
"condition = Condition(\"coloring\", image)\n",
|
146 |
+
"\n",
|
147 |
+
"seed_everything()\n",
|
148 |
+
"\n",
|
149 |
+
"result_img = generate(\n",
|
150 |
+
" pipe,\n",
|
151 |
+
" prompt=prompt,\n",
|
152 |
+
" conditions=[condition],\n",
|
153 |
+
").images[0]\n",
|
154 |
+
"\n",
|
155 |
+
"concat_image = Image.new(\"RGB\", (1536, 512))\n",
|
156 |
+
"concat_image.paste(image, (0, 0))\n",
|
157 |
+
"concat_image.paste(condition.condition, (512, 0))\n",
|
158 |
+
"concat_image.paste(result_img, (1024, 0))\n",
|
159 |
+
"concat_image"
|
160 |
+
]
|
161 |
+
}
|
162 |
+
],
|
163 |
+
"metadata": {
|
164 |
+
"kernelspec": {
|
165 |
+
"display_name": "base",
|
166 |
+
"language": "python",
|
167 |
+
"name": "python3"
|
168 |
+
},
|
169 |
+
"language_info": {
|
170 |
+
"codemirror_mode": {
|
171 |
+
"name": "ipython",
|
172 |
+
"version": 3
|
173 |
+
},
|
174 |
+
"file_extension": ".py",
|
175 |
+
"mimetype": "text/x-python",
|
176 |
+
"name": "python",
|
177 |
+
"nbconvert_exporter": "python",
|
178 |
+
"pygments_lexer": "ipython3",
|
179 |
+
"version": "3.12.7"
|
180 |
+
}
|
181 |
+
},
|
182 |
+
"nbformat": 4,
|
183 |
+
"nbformat_minor": 2
|
184 |
+
}
|
OminiControl/examples/subject.ipynb
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"\n",
|
11 |
+
"os.chdir(\"..\")"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": null,
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"import torch\n",
|
21 |
+
"from diffusers.pipelines import FluxPipeline\n",
|
22 |
+
"from src.flux.condition import Condition\n",
|
23 |
+
"from PIL import Image\n",
|
24 |
+
"\n",
|
25 |
+
"from src.flux.generate import generate, seed_everything"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": null,
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"pipe = FluxPipeline.from_pretrained(\n",
|
35 |
+
" \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
|
36 |
+
")\n",
|
37 |
+
"pipe = pipe.to(\"cuda\")\n",
|
38 |
+
"pipe.load_lora_weights(\n",
|
39 |
+
" \"Yuanshi/OminiControl\",\n",
|
40 |
+
" weight_name=f\"omini/subject_512.safetensors\",\n",
|
41 |
+
" adapter_name=\"subject\",\n",
|
42 |
+
")"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": null,
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
52 |
+
"\n",
|
53 |
+
"condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
|
54 |
+
"\n",
|
55 |
+
"prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
|
56 |
+
"\n",
|
57 |
+
"\n",
|
58 |
+
"seed_everything(0)\n",
|
59 |
+
"\n",
|
60 |
+
"result_img = generate(\n",
|
61 |
+
" pipe,\n",
|
62 |
+
" prompt=prompt,\n",
|
63 |
+
" conditions=[condition],\n",
|
64 |
+
" num_inference_steps=8,\n",
|
65 |
+
" height=512,\n",
|
66 |
+
" width=512,\n",
|
67 |
+
").images[0]\n",
|
68 |
+
"\n",
|
69 |
+
"concat_image = Image.new(\"RGB\", (1024, 512))\n",
|
70 |
+
"concat_image.paste(image, (0, 0))\n",
|
71 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
72 |
+
"concat_image"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"metadata": {},
|
79 |
+
"outputs": [],
|
80 |
+
"source": [
|
81 |
+
"image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
82 |
+
"\n",
|
83 |
+
"condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
|
84 |
+
"\n",
|
85 |
+
"prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
|
86 |
+
"\n",
|
87 |
+
"\n",
|
88 |
+
"seed_everything()\n",
|
89 |
+
"\n",
|
90 |
+
"result_img = generate(\n",
|
91 |
+
" pipe,\n",
|
92 |
+
" prompt=prompt,\n",
|
93 |
+
" conditions=[condition],\n",
|
94 |
+
" num_inference_steps=8,\n",
|
95 |
+
" height=512,\n",
|
96 |
+
" width=512,\n",
|
97 |
+
").images[0]\n",
|
98 |
+
"\n",
|
99 |
+
"concat_image = Image.new(\"RGB\", (1024, 512))\n",
|
100 |
+
"concat_image.paste(condition.condition, (0, 0))\n",
|
101 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
102 |
+
"concat_image"
|
103 |
+
]
|
104 |
+
},
|
105 |
+
{
|
106 |
+
"cell_type": "code",
|
107 |
+
"execution_count": null,
|
108 |
+
"metadata": {},
|
109 |
+
"outputs": [],
|
110 |
+
"source": [
|
111 |
+
"image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
112 |
+
"\n",
|
113 |
+
"condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
|
114 |
+
"\n",
|
115 |
+
"prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
|
116 |
+
"\n",
|
117 |
+
"seed_everything()\n",
|
118 |
+
"\n",
|
119 |
+
"result_img = generate(\n",
|
120 |
+
" pipe,\n",
|
121 |
+
" prompt=prompt,\n",
|
122 |
+
" conditions=[condition],\n",
|
123 |
+
" num_inference_steps=8,\n",
|
124 |
+
" height=512,\n",
|
125 |
+
" width=512,\n",
|
126 |
+
").images[0]\n",
|
127 |
+
"\n",
|
128 |
+
"concat_image = Image.new(\"RGB\", (1024, 512))\n",
|
129 |
+
"concat_image.paste(condition.condition, (0, 0))\n",
|
130 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
131 |
+
"concat_image"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"execution_count": null,
|
137 |
+
"metadata": {},
|
138 |
+
"outputs": [],
|
139 |
+
"source": [
|
140 |
+
"image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
141 |
+
"\n",
|
142 |
+
"condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
|
143 |
+
"\n",
|
144 |
+
"prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
|
145 |
+
"\n",
|
146 |
+
"seed_everything()\n",
|
147 |
+
"\n",
|
148 |
+
"result_img = generate(\n",
|
149 |
+
" pipe,\n",
|
150 |
+
" prompt=prompt,\n",
|
151 |
+
" conditions=[condition],\n",
|
152 |
+
" num_inference_steps=8,\n",
|
153 |
+
" height=512,\n",
|
154 |
+
" width=512,\n",
|
155 |
+
").images[0]\n",
|
156 |
+
"\n",
|
157 |
+
"concat_image = Image.new(\"RGB\", (1024, 512))\n",
|
158 |
+
"concat_image.paste(condition.condition, (0, 0))\n",
|
159 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
160 |
+
"concat_image"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "code",
|
165 |
+
"execution_count": null,
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [],
|
168 |
+
"source": [
|
169 |
+
"image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
170 |
+
"\n",
|
171 |
+
"condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
|
172 |
+
"\n",
|
173 |
+
"prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
|
174 |
+
"\n",
|
175 |
+
"seed_everything()\n",
|
176 |
+
"\n",
|
177 |
+
"result_img = generate(\n",
|
178 |
+
" pipe,\n",
|
179 |
+
" prompt=prompt,\n",
|
180 |
+
" conditions=[condition],\n",
|
181 |
+
" num_inference_steps=8,\n",
|
182 |
+
" height=512,\n",
|
183 |
+
" width=512,\n",
|
184 |
+
").images[0]\n",
|
185 |
+
"\n",
|
186 |
+
"concat_image = Image.new(\"RGB\", (1024, 512))\n",
|
187 |
+
"concat_image.paste(condition.condition, (0, 0))\n",
|
188 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
189 |
+
"concat_image"
|
190 |
+
]
|
191 |
+
}
|
192 |
+
],
|
193 |
+
"metadata": {
|
194 |
+
"kernelspec": {
|
195 |
+
"display_name": "base",
|
196 |
+
"language": "python",
|
197 |
+
"name": "python3"
|
198 |
+
},
|
199 |
+
"language_info": {
|
200 |
+
"codemirror_mode": {
|
201 |
+
"name": "ipython",
|
202 |
+
"version": 3
|
203 |
+
},
|
204 |
+
"file_extension": ".py",
|
205 |
+
"mimetype": "text/x-python",
|
206 |
+
"name": "python",
|
207 |
+
"nbconvert_exporter": "python",
|
208 |
+
"pygments_lexer": "ipython3",
|
209 |
+
"version": "3.12.7"
|
210 |
+
}
|
211 |
+
},
|
212 |
+
"nbformat": 4,
|
213 |
+
"nbformat_minor": 2
|
214 |
+
}
|
OminiControl/examples/subject_1024.ipynb
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 4,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"\n",
|
11 |
+
"os.chdir(\"..\")"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": null,
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"import torch\n",
|
21 |
+
"from diffusers.pipelines import FluxPipeline\n",
|
22 |
+
"from src.flux.condition import Condition\n",
|
23 |
+
"from PIL import Image\n",
|
24 |
+
"\n",
|
25 |
+
"from src.flux.generate import generate, seed_everything"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": null,
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"pipe = FluxPipeline.from_pretrained(\n",
|
35 |
+
" \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
|
36 |
+
")\n",
|
37 |
+
"pipe = pipe.to(\"cuda\")\n",
|
38 |
+
"pipe.load_lora_weights(\n",
|
39 |
+
" \"Yuanshi/OminiControl\",\n",
|
40 |
+
" weight_name=f\"omini/subject_1024_beta.safetensors\",\n",
|
41 |
+
" adapter_name=\"subject\",\n",
|
42 |
+
")"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": null,
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
52 |
+
"\n",
|
53 |
+
"condition = Condition(\"subject\", image)\n",
|
54 |
+
"\n",
|
55 |
+
"prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
|
56 |
+
"\n",
|
57 |
+
"\n",
|
58 |
+
"seed_everything(0)\n",
|
59 |
+
"\n",
|
60 |
+
"result_img = generate(\n",
|
61 |
+
" pipe,\n",
|
62 |
+
" prompt=prompt,\n",
|
63 |
+
" conditions=[condition],\n",
|
64 |
+
" num_inference_steps=8,\n",
|
65 |
+
" height=1024,\n",
|
66 |
+
" width=1024,\n",
|
67 |
+
").images[0]\n",
|
68 |
+
"\n",
|
69 |
+
"concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
|
70 |
+
"concat_image.paste(image, (0, 0))\n",
|
71 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
72 |
+
"concat_image"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"metadata": {},
|
79 |
+
"outputs": [],
|
80 |
+
"source": [
|
81 |
+
"image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
82 |
+
"\n",
|
83 |
+
"condition = Condition(\"subject\", image)\n",
|
84 |
+
"\n",
|
85 |
+
"prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
|
86 |
+
"\n",
|
87 |
+
"\n",
|
88 |
+
"seed_everything(0)\n",
|
89 |
+
"\n",
|
90 |
+
"result_img = generate(\n",
|
91 |
+
" pipe,\n",
|
92 |
+
" prompt=prompt,\n",
|
93 |
+
" conditions=[condition],\n",
|
94 |
+
" num_inference_steps=8,\n",
|
95 |
+
" height=1024,\n",
|
96 |
+
" width=1024,\n",
|
97 |
+
").images[0]\n",
|
98 |
+
"\n",
|
99 |
+
"concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
|
100 |
+
"concat_image.paste(image, (0, 0))\n",
|
101 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
102 |
+
"concat_image"
|
103 |
+
]
|
104 |
+
},
|
105 |
+
{
|
106 |
+
"cell_type": "code",
|
107 |
+
"execution_count": null,
|
108 |
+
"metadata": {},
|
109 |
+
"outputs": [],
|
110 |
+
"source": [
|
111 |
+
"image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
112 |
+
"\n",
|
113 |
+
"condition = Condition(\"subject\", image)\n",
|
114 |
+
"\n",
|
115 |
+
"prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
|
116 |
+
"\n",
|
117 |
+
"seed_everything()\n",
|
118 |
+
"\n",
|
119 |
+
"result_img = generate(\n",
|
120 |
+
" pipe,\n",
|
121 |
+
" prompt=prompt,\n",
|
122 |
+
" conditions=[condition],\n",
|
123 |
+
" num_inference_steps=8,\n",
|
124 |
+
" height=1024,\n",
|
125 |
+
" width=1024,\n",
|
126 |
+
").images[0]\n",
|
127 |
+
"\n",
|
128 |
+
"concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
|
129 |
+
"concat_image.paste(image, (0, 0))\n",
|
130 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
131 |
+
"concat_image"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"execution_count": null,
|
137 |
+
"metadata": {},
|
138 |
+
"outputs": [],
|
139 |
+
"source": [
|
140 |
+
"image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
141 |
+
"\n",
|
142 |
+
"condition = Condition(\"subject\", image)\n",
|
143 |
+
"\n",
|
144 |
+
"prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
|
145 |
+
"\n",
|
146 |
+
"seed_everything(0)\n",
|
147 |
+
"\n",
|
148 |
+
"result_img = generate(\n",
|
149 |
+
" pipe,\n",
|
150 |
+
" prompt=prompt,\n",
|
151 |
+
" conditions=[condition],\n",
|
152 |
+
" num_inference_steps=8,\n",
|
153 |
+
" height=1024,\n",
|
154 |
+
" width=1024,\n",
|
155 |
+
").images[0]\n",
|
156 |
+
"\n",
|
157 |
+
"concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
|
158 |
+
"concat_image.paste(image, (0, 0))\n",
|
159 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
160 |
+
"concat_image"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "code",
|
165 |
+
"execution_count": null,
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [],
|
168 |
+
"source": [
|
169 |
+
"image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
|
170 |
+
"\n",
|
171 |
+
"condition = Condition(\"subject\", image)\n",
|
172 |
+
"\n",
|
173 |
+
"prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
|
174 |
+
"\n",
|
175 |
+
"seed_everything()\n",
|
176 |
+
"\n",
|
177 |
+
"result_img = generate(\n",
|
178 |
+
" pipe,\n",
|
179 |
+
" prompt=prompt,\n",
|
180 |
+
" conditions=[condition],\n",
|
181 |
+
" num_inference_steps=8,\n",
|
182 |
+
" height=1024,\n",
|
183 |
+
" width=1024,\n",
|
184 |
+
").images[0]\n",
|
185 |
+
"\n",
|
186 |
+
"concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
|
187 |
+
"concat_image.paste(image, (0, 0))\n",
|
188 |
+
"concat_image.paste(result_img, (512, 0))\n",
|
189 |
+
"concat_image"
|
190 |
+
]
|
191 |
+
},
|
192 |
+
{
|
193 |
+
"cell_type": "code",
|
194 |
+
"execution_count": null,
|
195 |
+
"metadata": {},
|
196 |
+
"outputs": [],
|
197 |
+
"source": []
|
198 |
+
}
|
199 |
+
],
|
200 |
+
"metadata": {
|
201 |
+
"kernelspec": {
|
202 |
+
"display_name": "Python 3 (ipykernel)",
|
203 |
+
"language": "python",
|
204 |
+
"name": "python3"
|
205 |
+
},
|
206 |
+
"language_info": {
|
207 |
+
"codemirror_mode": {
|
208 |
+
"name": "ipython",
|
209 |
+
"version": 3
|
210 |
+
},
|
211 |
+
"file_extension": ".py",
|
212 |
+
"mimetype": "text/x-python",
|
213 |
+
"name": "python",
|
214 |
+
"nbconvert_exporter": "python",
|
215 |
+
"pygments_lexer": "ipython3",
|
216 |
+
"version": "3.9.21"
|
217 |
+
}
|
218 |
+
},
|
219 |
+
"nbformat": 4,
|
220 |
+
"nbformat_minor": 2
|
221 |
+
}
|
OminiControl/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
diffusers
|
3 |
+
peft
|
4 |
+
opencv-python
|
5 |
+
protobuf
|
6 |
+
sentencepiece
|
7 |
+
gradio
|
8 |
+
jupyter
|
9 |
+
torchao
|
OminiControl/src/flux/block.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import List, Union, Optional, Dict, Any, Callable
|
3 |
+
from diffusers.models.attention_processor import Attention, F
|
4 |
+
from .lora_controller import enable_lora
|
5 |
+
|
6 |
+
|
7 |
+
def attn_forward(
|
8 |
+
attn: Attention,
|
9 |
+
hidden_states: torch.FloatTensor,
|
10 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
11 |
+
condition_latents: torch.FloatTensor = None,
|
12 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
13 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
14 |
+
cond_rotary_emb: Optional[torch.Tensor] = None,
|
15 |
+
model_config: Optional[Dict[str, Any]] = {},
|
16 |
+
) -> torch.FloatTensor:
|
17 |
+
batch_size, _, _ = (
|
18 |
+
hidden_states.shape
|
19 |
+
if encoder_hidden_states is None
|
20 |
+
else encoder_hidden_states.shape
|
21 |
+
)
|
22 |
+
|
23 |
+
with enable_lora(
|
24 |
+
(attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
|
25 |
+
):
|
26 |
+
# `sample` projections.
|
27 |
+
query = attn.to_q(hidden_states)
|
28 |
+
key = attn.to_k(hidden_states)
|
29 |
+
value = attn.to_v(hidden_states)
|
30 |
+
|
31 |
+
inner_dim = key.shape[-1]
|
32 |
+
head_dim = inner_dim // attn.heads
|
33 |
+
|
34 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
35 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
36 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
37 |
+
|
38 |
+
if attn.norm_q is not None:
|
39 |
+
query = attn.norm_q(query)
|
40 |
+
if attn.norm_k is not None:
|
41 |
+
key = attn.norm_k(key)
|
42 |
+
|
43 |
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
44 |
+
if encoder_hidden_states is not None:
|
45 |
+
# `context` projections.
|
46 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
47 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
48 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
49 |
+
|
50 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
51 |
+
batch_size, -1, attn.heads, head_dim
|
52 |
+
).transpose(1, 2)
|
53 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
54 |
+
batch_size, -1, attn.heads, head_dim
|
55 |
+
).transpose(1, 2)
|
56 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
57 |
+
batch_size, -1, attn.heads, head_dim
|
58 |
+
).transpose(1, 2)
|
59 |
+
|
60 |
+
if attn.norm_added_q is not None:
|
61 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(
|
62 |
+
encoder_hidden_states_query_proj
|
63 |
+
)
|
64 |
+
if attn.norm_added_k is not None:
|
65 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(
|
66 |
+
encoder_hidden_states_key_proj
|
67 |
+
)
|
68 |
+
|
69 |
+
# attention
|
70 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
71 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
72 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
73 |
+
|
74 |
+
if image_rotary_emb is not None:
|
75 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
76 |
+
|
77 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
78 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
79 |
+
|
80 |
+
if condition_latents is not None:
|
81 |
+
cond_query = attn.to_q(condition_latents)
|
82 |
+
cond_key = attn.to_k(condition_latents)
|
83 |
+
cond_value = attn.to_v(condition_latents)
|
84 |
+
|
85 |
+
cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
|
86 |
+
1, 2
|
87 |
+
)
|
88 |
+
cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
89 |
+
cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
|
90 |
+
1, 2
|
91 |
+
)
|
92 |
+
if attn.norm_q is not None:
|
93 |
+
cond_query = attn.norm_q(cond_query)
|
94 |
+
if attn.norm_k is not None:
|
95 |
+
cond_key = attn.norm_k(cond_key)
|
96 |
+
|
97 |
+
if cond_rotary_emb is not None:
|
98 |
+
cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
|
99 |
+
cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
|
100 |
+
|
101 |
+
if condition_latents is not None:
|
102 |
+
query = torch.cat([query, cond_query], dim=2)
|
103 |
+
key = torch.cat([key, cond_key], dim=2)
|
104 |
+
value = torch.cat([value, cond_value], dim=2)
|
105 |
+
|
106 |
+
if not model_config.get("union_cond_attn", True):
|
107 |
+
# If we don't want to use the union condition attention, we need to mask the attention
|
108 |
+
# between the hidden states and the condition latents
|
109 |
+
attention_mask = torch.ones(
|
110 |
+
query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
|
111 |
+
)
|
112 |
+
condition_n = cond_query.shape[2]
|
113 |
+
attention_mask[-condition_n:, :-condition_n] = False
|
114 |
+
attention_mask[:-condition_n, -condition_n:] = False
|
115 |
+
elif model_config.get("independent_condition", False):
|
116 |
+
attention_mask = torch.ones(
|
117 |
+
query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
|
118 |
+
)
|
119 |
+
condition_n = cond_query.shape[2]
|
120 |
+
attention_mask[-condition_n:, :-condition_n] = False
|
121 |
+
if hasattr(attn, "c_factor"):
|
122 |
+
attention_mask = torch.zeros(
|
123 |
+
query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
|
124 |
+
)
|
125 |
+
condition_n = cond_query.shape[2]
|
126 |
+
bias = torch.log(attn.c_factor[0])
|
127 |
+
attention_mask[-condition_n:, :-condition_n] = bias
|
128 |
+
attention_mask[:-condition_n, -condition_n:] = bias
|
129 |
+
hidden_states = F.scaled_dot_product_attention(
|
130 |
+
query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
|
131 |
+
)
|
132 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
133 |
+
batch_size, -1, attn.heads * head_dim
|
134 |
+
)
|
135 |
+
hidden_states = hidden_states.to(query.dtype)
|
136 |
+
|
137 |
+
if encoder_hidden_states is not None:
|
138 |
+
if condition_latents is not None:
|
139 |
+
encoder_hidden_states, hidden_states, condition_latents = (
|
140 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
141 |
+
hidden_states[
|
142 |
+
:, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
|
143 |
+
],
|
144 |
+
hidden_states[:, -condition_latents.shape[1] :],
|
145 |
+
)
|
146 |
+
else:
|
147 |
+
encoder_hidden_states, hidden_states = (
|
148 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
149 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
150 |
+
)
|
151 |
+
|
152 |
+
with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
|
153 |
+
# linear proj
|
154 |
+
hidden_states = attn.to_out[0](hidden_states)
|
155 |
+
# dropout
|
156 |
+
hidden_states = attn.to_out[1](hidden_states)
|
157 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
158 |
+
|
159 |
+
if condition_latents is not None:
|
160 |
+
condition_latents = attn.to_out[0](condition_latents)
|
161 |
+
condition_latents = attn.to_out[1](condition_latents)
|
162 |
+
|
163 |
+
return (
|
164 |
+
(hidden_states, encoder_hidden_states, condition_latents)
|
165 |
+
if condition_latents is not None
|
166 |
+
else (hidden_states, encoder_hidden_states)
|
167 |
+
)
|
168 |
+
elif condition_latents is not None:
|
169 |
+
# if there are condition_latents, we need to separate the hidden_states and the condition_latents
|
170 |
+
hidden_states, condition_latents = (
|
171 |
+
hidden_states[:, : -condition_latents.shape[1]],
|
172 |
+
hidden_states[:, -condition_latents.shape[1] :],
|
173 |
+
)
|
174 |
+
return hidden_states, condition_latents
|
175 |
+
else:
|
176 |
+
return hidden_states
|
177 |
+
|
178 |
+
|
179 |
+
def block_forward(
|
180 |
+
self,
|
181 |
+
hidden_states: torch.FloatTensor,
|
182 |
+
encoder_hidden_states: torch.FloatTensor,
|
183 |
+
condition_latents: torch.FloatTensor,
|
184 |
+
temb: torch.FloatTensor,
|
185 |
+
cond_temb: torch.FloatTensor,
|
186 |
+
cond_rotary_emb=None,
|
187 |
+
image_rotary_emb=None,
|
188 |
+
model_config: Optional[Dict[str, Any]] = {},
|
189 |
+
):
|
190 |
+
use_cond = condition_latents is not None
|
191 |
+
with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
|
192 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
193 |
+
hidden_states, emb=temb
|
194 |
+
)
|
195 |
+
|
196 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
|
197 |
+
self.norm1_context(encoder_hidden_states, emb=temb)
|
198 |
+
)
|
199 |
+
|
200 |
+
if use_cond:
|
201 |
+
(
|
202 |
+
norm_condition_latents,
|
203 |
+
cond_gate_msa,
|
204 |
+
cond_shift_mlp,
|
205 |
+
cond_scale_mlp,
|
206 |
+
cond_gate_mlp,
|
207 |
+
) = self.norm1(condition_latents, emb=cond_temb)
|
208 |
+
|
209 |
+
# Attention.
|
210 |
+
result = attn_forward(
|
211 |
+
self.attn,
|
212 |
+
model_config=model_config,
|
213 |
+
hidden_states=norm_hidden_states,
|
214 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
215 |
+
condition_latents=norm_condition_latents if use_cond else None,
|
216 |
+
image_rotary_emb=image_rotary_emb,
|
217 |
+
cond_rotary_emb=cond_rotary_emb if use_cond else None,
|
218 |
+
)
|
219 |
+
attn_output, context_attn_output = result[:2]
|
220 |
+
cond_attn_output = result[2] if use_cond else None
|
221 |
+
|
222 |
+
# Process attention outputs for the `hidden_states`.
|
223 |
+
# 1. hidden_states
|
224 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
225 |
+
hidden_states = hidden_states + attn_output
|
226 |
+
# 2. encoder_hidden_states
|
227 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
228 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
229 |
+
# 3. condition_latents
|
230 |
+
if use_cond:
|
231 |
+
cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
|
232 |
+
condition_latents = condition_latents + cond_attn_output
|
233 |
+
if model_config.get("add_cond_attn", False):
|
234 |
+
hidden_states += cond_attn_output
|
235 |
+
|
236 |
+
# LayerNorm + MLP.
|
237 |
+
# 1. hidden_states
|
238 |
+
norm_hidden_states = self.norm2(hidden_states)
|
239 |
+
norm_hidden_states = (
|
240 |
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
241 |
+
)
|
242 |
+
# 2. encoder_hidden_states
|
243 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
244 |
+
norm_encoder_hidden_states = (
|
245 |
+
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
246 |
+
)
|
247 |
+
# 3. condition_latents
|
248 |
+
if use_cond:
|
249 |
+
norm_condition_latents = self.norm2(condition_latents)
|
250 |
+
norm_condition_latents = (
|
251 |
+
norm_condition_latents * (1 + cond_scale_mlp[:, None])
|
252 |
+
+ cond_shift_mlp[:, None]
|
253 |
+
)
|
254 |
+
|
255 |
+
# Feed-forward.
|
256 |
+
with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
|
257 |
+
# 1. hidden_states
|
258 |
+
ff_output = self.ff(norm_hidden_states)
|
259 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
260 |
+
# 2. encoder_hidden_states
|
261 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
262 |
+
context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
|
263 |
+
# 3. condition_latents
|
264 |
+
if use_cond:
|
265 |
+
cond_ff_output = self.ff(norm_condition_latents)
|
266 |
+
cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
|
267 |
+
|
268 |
+
# Process feed-forward outputs.
|
269 |
+
hidden_states = hidden_states + ff_output
|
270 |
+
encoder_hidden_states = encoder_hidden_states + context_ff_output
|
271 |
+
if use_cond:
|
272 |
+
condition_latents = condition_latents + cond_ff_output
|
273 |
+
|
274 |
+
# Clip to avoid overflow.
|
275 |
+
if encoder_hidden_states.dtype == torch.float16:
|
276 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
277 |
+
|
278 |
+
return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
|
279 |
+
|
280 |
+
|
281 |
+
def single_block_forward(
|
282 |
+
self,
|
283 |
+
hidden_states: torch.FloatTensor,
|
284 |
+
temb: torch.FloatTensor,
|
285 |
+
image_rotary_emb=None,
|
286 |
+
condition_latents: torch.FloatTensor = None,
|
287 |
+
cond_temb: torch.FloatTensor = None,
|
288 |
+
cond_rotary_emb=None,
|
289 |
+
model_config: Optional[Dict[str, Any]] = {},
|
290 |
+
):
|
291 |
+
|
292 |
+
using_cond = condition_latents is not None
|
293 |
+
residual = hidden_states
|
294 |
+
with enable_lora(
|
295 |
+
(
|
296 |
+
self.norm.linear,
|
297 |
+
self.proj_mlp,
|
298 |
+
),
|
299 |
+
model_config.get("latent_lora", False),
|
300 |
+
):
|
301 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
302 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
303 |
+
if using_cond:
|
304 |
+
residual_cond = condition_latents
|
305 |
+
norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
|
306 |
+
mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
|
307 |
+
|
308 |
+
attn_output = attn_forward(
|
309 |
+
self.attn,
|
310 |
+
model_config=model_config,
|
311 |
+
hidden_states=norm_hidden_states,
|
312 |
+
image_rotary_emb=image_rotary_emb,
|
313 |
+
**(
|
314 |
+
{
|
315 |
+
"condition_latents": norm_condition_latents,
|
316 |
+
"cond_rotary_emb": cond_rotary_emb if using_cond else None,
|
317 |
+
}
|
318 |
+
if using_cond
|
319 |
+
else {}
|
320 |
+
),
|
321 |
+
)
|
322 |
+
if using_cond:
|
323 |
+
attn_output, cond_attn_output = attn_output
|
324 |
+
|
325 |
+
with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
|
326 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
327 |
+
gate = gate.unsqueeze(1)
|
328 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
329 |
+
hidden_states = residual + hidden_states
|
330 |
+
if using_cond:
|
331 |
+
condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
|
332 |
+
cond_gate = cond_gate.unsqueeze(1)
|
333 |
+
condition_latents = cond_gate * self.proj_out(condition_latents)
|
334 |
+
condition_latents = residual_cond + condition_latents
|
335 |
+
|
336 |
+
if hidden_states.dtype == torch.float16:
|
337 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
338 |
+
|
339 |
+
return hidden_states if not using_cond else (hidden_states, condition_latents)
|
OminiControl/src/flux/condition.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Optional, Union, List, Tuple
|
3 |
+
from diffusers.pipelines import FluxPipeline
|
4 |
+
from PIL import Image, ImageFilter
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
from .pipeline_tools import encode_images
|
9 |
+
|
10 |
+
condition_dict = {
|
11 |
+
"depth": 0,
|
12 |
+
"canny": 1,
|
13 |
+
"subject": 4,
|
14 |
+
"coloring": 6,
|
15 |
+
"deblurring": 7,
|
16 |
+
"depth_pred": 8,
|
17 |
+
"fill": 9,
|
18 |
+
"sr": 10,
|
19 |
+
"cartoon": 11,
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
class Condition(object):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
condition_type: str,
|
27 |
+
raw_img: Union[Image.Image, torch.Tensor] = None,
|
28 |
+
condition: Union[Image.Image, torch.Tensor] = None,
|
29 |
+
mask=None,
|
30 |
+
position_delta=None,
|
31 |
+
position_scale=1.0,
|
32 |
+
) -> None:
|
33 |
+
self.condition_type = condition_type
|
34 |
+
assert raw_img is not None or condition is not None
|
35 |
+
if raw_img is not None:
|
36 |
+
self.condition = self.get_condition(condition_type, raw_img)
|
37 |
+
else:
|
38 |
+
self.condition = condition
|
39 |
+
self.position_delta = position_delta
|
40 |
+
self.position_scale = position_scale
|
41 |
+
# TODO: Add mask support
|
42 |
+
assert mask is None, "Mask not supported yet"
|
43 |
+
|
44 |
+
def get_condition(
|
45 |
+
self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
|
46 |
+
) -> Union[Image.Image, torch.Tensor]:
|
47 |
+
"""
|
48 |
+
Returns the condition image.
|
49 |
+
"""
|
50 |
+
if condition_type == "depth":
|
51 |
+
from transformers import pipeline
|
52 |
+
|
53 |
+
depth_pipe = pipeline(
|
54 |
+
task="depth-estimation",
|
55 |
+
model="LiheYoung/depth-anything-small-hf",
|
56 |
+
device="cuda",
|
57 |
+
)
|
58 |
+
source_image = raw_img.convert("RGB")
|
59 |
+
condition_img = depth_pipe(source_image)["depth"].convert("RGB")
|
60 |
+
return condition_img
|
61 |
+
elif condition_type == "canny":
|
62 |
+
img = np.array(raw_img)
|
63 |
+
edges = cv2.Canny(img, 100, 200)
|
64 |
+
edges = Image.fromarray(edges).convert("RGB")
|
65 |
+
return edges
|
66 |
+
elif condition_type == "subject":
|
67 |
+
return raw_img
|
68 |
+
elif condition_type == "coloring":
|
69 |
+
return raw_img.convert("L").convert("RGB")
|
70 |
+
elif condition_type == "deblurring":
|
71 |
+
condition_image = (
|
72 |
+
raw_img.convert("RGB")
|
73 |
+
.filter(ImageFilter.GaussianBlur(10))
|
74 |
+
.convert("RGB")
|
75 |
+
)
|
76 |
+
return condition_image
|
77 |
+
elif condition_type == "fill":
|
78 |
+
return raw_img.convert("RGB")
|
79 |
+
elif condition_type == "cartoon":
|
80 |
+
return raw_img.convert("RGB")
|
81 |
+
return self.condition
|
82 |
+
|
83 |
+
@property
|
84 |
+
def type_id(self) -> int:
|
85 |
+
"""
|
86 |
+
Returns the type id of the condition.
|
87 |
+
"""
|
88 |
+
return condition_dict[self.condition_type]
|
89 |
+
|
90 |
+
@classmethod
|
91 |
+
def get_type_id(cls, condition_type: str) -> int:
|
92 |
+
"""
|
93 |
+
Returns the type id of the condition.
|
94 |
+
"""
|
95 |
+
return condition_dict[condition_type]
|
96 |
+
|
97 |
+
def encode(
|
98 |
+
self, pipe: FluxPipeline, empty: bool = False
|
99 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
100 |
+
"""
|
101 |
+
Encodes the condition into tokens, ids and type_id.
|
102 |
+
"""
|
103 |
+
if self.condition_type in [
|
104 |
+
"depth",
|
105 |
+
"canny",
|
106 |
+
"subject",
|
107 |
+
"coloring",
|
108 |
+
"deblurring",
|
109 |
+
"depth_pred",
|
110 |
+
"fill",
|
111 |
+
"sr",
|
112 |
+
"cartoon",
|
113 |
+
]:
|
114 |
+
if empty:
|
115 |
+
# make the condition black
|
116 |
+
e_condition = Image.new("RGB", self.condition.size, (0, 0, 0))
|
117 |
+
e_condition = e_condition.convert("RGB")
|
118 |
+
tokens, ids = encode_images(pipe, e_condition)
|
119 |
+
else:
|
120 |
+
tokens, ids = encode_images(pipe, self.condition)
|
121 |
+
tokens, ids = encode_images(pipe, self.condition)
|
122 |
+
else:
|
123 |
+
raise NotImplementedError(
|
124 |
+
f"Condition type {self.condition_type} not implemented"
|
125 |
+
)
|
126 |
+
if self.position_delta is None and self.condition_type == "subject":
|
127 |
+
self.position_delta = [0, -self.condition.size[0] // 16]
|
128 |
+
if self.position_delta is not None:
|
129 |
+
ids[:, 1] += self.position_delta[0]
|
130 |
+
ids[:, 2] += self.position_delta[1]
|
131 |
+
if self.position_scale != 1.0:
|
132 |
+
scale_bias = (self.position_scale - 1.0) / 2
|
133 |
+
ids[:, 1] *= self.position_scale
|
134 |
+
ids[:, 2] *= self.position_scale
|
135 |
+
ids[:, 1] += scale_bias
|
136 |
+
ids[:, 2] += scale_bias
|
137 |
+
type_id = torch.ones_like(ids[:, :1]) * self.type_id
|
138 |
+
return tokens, ids, type_id
|
OminiControl/src/flux/generate.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import yaml, os
|
3 |
+
from diffusers.pipelines import FluxPipeline
|
4 |
+
from typing import List, Union, Optional, Dict, Any, Callable
|
5 |
+
from .transformer import tranformer_forward
|
6 |
+
from .condition import Condition
|
7 |
+
|
8 |
+
from diffusers.pipelines.flux.pipeline_flux import (
|
9 |
+
FluxPipelineOutput,
|
10 |
+
calculate_shift,
|
11 |
+
retrieve_timesteps,
|
12 |
+
np,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
def get_config(config_path: str = None):
|
17 |
+
config_path = config_path or os.environ.get("XFL_CONFIG")
|
18 |
+
if not config_path:
|
19 |
+
return {}
|
20 |
+
with open(config_path, "r") as f:
|
21 |
+
config = yaml.safe_load(f)
|
22 |
+
return config
|
23 |
+
|
24 |
+
|
25 |
+
def prepare_params(
|
26 |
+
prompt: Union[str, List[str]] = None,
|
27 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
28 |
+
height: Optional[int] = 512,
|
29 |
+
width: Optional[int] = 512,
|
30 |
+
num_inference_steps: int = 28,
|
31 |
+
timesteps: List[int] = None,
|
32 |
+
guidance_scale: float = 3.5,
|
33 |
+
num_images_per_prompt: Optional[int] = 1,
|
34 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
35 |
+
latents: Optional[torch.FloatTensor] = None,
|
36 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
37 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
38 |
+
output_type: Optional[str] = "pil",
|
39 |
+
return_dict: bool = True,
|
40 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
41 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
42 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
43 |
+
max_sequence_length: int = 512,
|
44 |
+
**kwargs: dict,
|
45 |
+
):
|
46 |
+
return (
|
47 |
+
prompt,
|
48 |
+
prompt_2,
|
49 |
+
height,
|
50 |
+
width,
|
51 |
+
num_inference_steps,
|
52 |
+
timesteps,
|
53 |
+
guidance_scale,
|
54 |
+
num_images_per_prompt,
|
55 |
+
generator,
|
56 |
+
latents,
|
57 |
+
prompt_embeds,
|
58 |
+
pooled_prompt_embeds,
|
59 |
+
output_type,
|
60 |
+
return_dict,
|
61 |
+
joint_attention_kwargs,
|
62 |
+
callback_on_step_end,
|
63 |
+
callback_on_step_end_tensor_inputs,
|
64 |
+
max_sequence_length,
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
def seed_everything(seed: int = 42):
|
69 |
+
torch.backends.cudnn.deterministic = True
|
70 |
+
torch.manual_seed(seed)
|
71 |
+
np.random.seed(seed)
|
72 |
+
|
73 |
+
|
74 |
+
@torch.no_grad()
|
75 |
+
def generate(
|
76 |
+
pipeline: FluxPipeline,
|
77 |
+
conditions: List[Condition] = None,
|
78 |
+
config_path: str = None,
|
79 |
+
model_config: Optional[Dict[str, Any]] = {},
|
80 |
+
condition_scale: float = 1.0,
|
81 |
+
default_lora: bool = False,
|
82 |
+
image_guidance_scale: float = 1.0,
|
83 |
+
**params: dict,
|
84 |
+
):
|
85 |
+
model_config = model_config or get_config(config_path).get("model", {})
|
86 |
+
if condition_scale != 1:
|
87 |
+
for name, module in pipeline.transformer.named_modules():
|
88 |
+
if not name.endswith(".attn"):
|
89 |
+
continue
|
90 |
+
module.c_factor = torch.ones(1, 1) * condition_scale
|
91 |
+
|
92 |
+
self = pipeline
|
93 |
+
(
|
94 |
+
prompt,
|
95 |
+
prompt_2,
|
96 |
+
height,
|
97 |
+
width,
|
98 |
+
num_inference_steps,
|
99 |
+
timesteps,
|
100 |
+
guidance_scale,
|
101 |
+
num_images_per_prompt,
|
102 |
+
generator,
|
103 |
+
latents,
|
104 |
+
prompt_embeds,
|
105 |
+
pooled_prompt_embeds,
|
106 |
+
output_type,
|
107 |
+
return_dict,
|
108 |
+
joint_attention_kwargs,
|
109 |
+
callback_on_step_end,
|
110 |
+
callback_on_step_end_tensor_inputs,
|
111 |
+
max_sequence_length,
|
112 |
+
) = prepare_params(**params)
|
113 |
+
|
114 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
115 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
116 |
+
|
117 |
+
# 1. Check inputs. Raise error if not correct
|
118 |
+
self.check_inputs(
|
119 |
+
prompt,
|
120 |
+
prompt_2,
|
121 |
+
height,
|
122 |
+
width,
|
123 |
+
prompt_embeds=prompt_embeds,
|
124 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
125 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
126 |
+
max_sequence_length=max_sequence_length,
|
127 |
+
)
|
128 |
+
|
129 |
+
self._guidance_scale = guidance_scale
|
130 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
131 |
+
self._interrupt = False
|
132 |
+
|
133 |
+
# 2. Define call parameters
|
134 |
+
if prompt is not None and isinstance(prompt, str):
|
135 |
+
batch_size = 1
|
136 |
+
elif prompt is not None and isinstance(prompt, list):
|
137 |
+
batch_size = len(prompt)
|
138 |
+
else:
|
139 |
+
batch_size = prompt_embeds.shape[0]
|
140 |
+
|
141 |
+
device = self._execution_device
|
142 |
+
|
143 |
+
lora_scale = (
|
144 |
+
self.joint_attention_kwargs.get("scale", None)
|
145 |
+
if self.joint_attention_kwargs is not None
|
146 |
+
else None
|
147 |
+
)
|
148 |
+
(
|
149 |
+
prompt_embeds,
|
150 |
+
pooled_prompt_embeds,
|
151 |
+
text_ids,
|
152 |
+
) = self.encode_prompt(
|
153 |
+
prompt=prompt,
|
154 |
+
prompt_2=prompt_2,
|
155 |
+
prompt_embeds=prompt_embeds,
|
156 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
157 |
+
device=device,
|
158 |
+
num_images_per_prompt=num_images_per_prompt,
|
159 |
+
max_sequence_length=max_sequence_length,
|
160 |
+
lora_scale=lora_scale,
|
161 |
+
)
|
162 |
+
|
163 |
+
# 4. Prepare latent variables
|
164 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
165 |
+
latents, latent_image_ids = self.prepare_latents(
|
166 |
+
batch_size * num_images_per_prompt,
|
167 |
+
num_channels_latents,
|
168 |
+
height,
|
169 |
+
width,
|
170 |
+
prompt_embeds.dtype,
|
171 |
+
device,
|
172 |
+
generator,
|
173 |
+
latents,
|
174 |
+
)
|
175 |
+
|
176 |
+
# 4.1. Prepare conditions
|
177 |
+
condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
|
178 |
+
use_condition = conditions is not None or []
|
179 |
+
if use_condition:
|
180 |
+
assert len(conditions) <= 1, "Only one condition is supported for now."
|
181 |
+
if not default_lora:
|
182 |
+
pipeline.set_adapters(conditions[0].condition_type)
|
183 |
+
for condition in conditions:
|
184 |
+
tokens, ids, type_id = condition.encode(self)
|
185 |
+
condition_latents.append(tokens) # [batch_size, token_n, token_dim]
|
186 |
+
condition_ids.append(ids) # [token_n, id_dim(3)]
|
187 |
+
condition_type_ids.append(type_id) # [token_n, 1]
|
188 |
+
condition_latents = torch.cat(condition_latents, dim=1)
|
189 |
+
condition_ids = torch.cat(condition_ids, dim=0)
|
190 |
+
condition_type_ids = torch.cat(condition_type_ids, dim=0)
|
191 |
+
|
192 |
+
# 5. Prepare timesteps
|
193 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
194 |
+
image_seq_len = latents.shape[1]
|
195 |
+
mu = calculate_shift(
|
196 |
+
image_seq_len,
|
197 |
+
self.scheduler.config.base_image_seq_len,
|
198 |
+
self.scheduler.config.max_image_seq_len,
|
199 |
+
self.scheduler.config.base_shift,
|
200 |
+
self.scheduler.config.max_shift,
|
201 |
+
)
|
202 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
203 |
+
self.scheduler,
|
204 |
+
num_inference_steps,
|
205 |
+
device,
|
206 |
+
timesteps,
|
207 |
+
sigmas,
|
208 |
+
mu=mu,
|
209 |
+
)
|
210 |
+
num_warmup_steps = max(
|
211 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
212 |
+
)
|
213 |
+
self._num_timesteps = len(timesteps)
|
214 |
+
|
215 |
+
# 6. Denoising loop
|
216 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
217 |
+
for i, t in enumerate(timesteps):
|
218 |
+
if self.interrupt:
|
219 |
+
continue
|
220 |
+
|
221 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
222 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
223 |
+
|
224 |
+
# handle guidance
|
225 |
+
if self.transformer.config.guidance_embeds:
|
226 |
+
guidance = torch.tensor([guidance_scale], device=device)
|
227 |
+
guidance = guidance.expand(latents.shape[0])
|
228 |
+
else:
|
229 |
+
guidance = None
|
230 |
+
noise_pred = tranformer_forward(
|
231 |
+
self.transformer,
|
232 |
+
model_config=model_config,
|
233 |
+
# Inputs of the condition (new feature)
|
234 |
+
condition_latents=condition_latents if use_condition else None,
|
235 |
+
condition_ids=condition_ids if use_condition else None,
|
236 |
+
condition_type_ids=condition_type_ids if use_condition else None,
|
237 |
+
# Inputs to the original transformer
|
238 |
+
hidden_states=latents,
|
239 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
240 |
+
timestep=timestep / 1000,
|
241 |
+
guidance=guidance,
|
242 |
+
pooled_projections=pooled_prompt_embeds,
|
243 |
+
encoder_hidden_states=prompt_embeds,
|
244 |
+
txt_ids=text_ids,
|
245 |
+
img_ids=latent_image_ids,
|
246 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
247 |
+
return_dict=False,
|
248 |
+
)[0]
|
249 |
+
|
250 |
+
if image_guidance_scale != 1.0:
|
251 |
+
uncondition_latents = condition.encode(self, empty=True)[0]
|
252 |
+
unc_pred = tranformer_forward(
|
253 |
+
self.transformer,
|
254 |
+
model_config=model_config,
|
255 |
+
# Inputs of the condition (new feature)
|
256 |
+
condition_latents=uncondition_latents if use_condition else None,
|
257 |
+
condition_ids=condition_ids if use_condition else None,
|
258 |
+
condition_type_ids=condition_type_ids if use_condition else None,
|
259 |
+
# Inputs to the original transformer
|
260 |
+
hidden_states=latents,
|
261 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
262 |
+
timestep=timestep / 1000,
|
263 |
+
guidance=torch.ones_like(guidance),
|
264 |
+
pooled_projections=pooled_prompt_embeds,
|
265 |
+
encoder_hidden_states=prompt_embeds,
|
266 |
+
txt_ids=text_ids,
|
267 |
+
img_ids=latent_image_ids,
|
268 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
269 |
+
return_dict=False,
|
270 |
+
)[0]
|
271 |
+
|
272 |
+
noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
|
273 |
+
|
274 |
+
# compute the previous noisy sample x_t -> x_t-1
|
275 |
+
latents_dtype = latents.dtype
|
276 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
277 |
+
|
278 |
+
if latents.dtype != latents_dtype:
|
279 |
+
if torch.backends.mps.is_available():
|
280 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
281 |
+
latents = latents.to(latents_dtype)
|
282 |
+
|
283 |
+
if callback_on_step_end is not None:
|
284 |
+
callback_kwargs = {}
|
285 |
+
for k in callback_on_step_end_tensor_inputs:
|
286 |
+
callback_kwargs[k] = locals()[k]
|
287 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
288 |
+
|
289 |
+
latents = callback_outputs.pop("latents", latents)
|
290 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
291 |
+
|
292 |
+
# call the callback, if provided
|
293 |
+
if i == len(timesteps) - 1 or (
|
294 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
295 |
+
):
|
296 |
+
progress_bar.update()
|
297 |
+
|
298 |
+
if output_type == "latent":
|
299 |
+
image = latents
|
300 |
+
|
301 |
+
else:
|
302 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
303 |
+
latents = (
|
304 |
+
latents / self.vae.config.scaling_factor
|
305 |
+
) + self.vae.config.shift_factor
|
306 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
307 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
308 |
+
|
309 |
+
# Offload all models
|
310 |
+
self.maybe_free_model_hooks()
|
311 |
+
|
312 |
+
if condition_scale != 1:
|
313 |
+
for name, module in pipeline.transformer.named_modules():
|
314 |
+
if not name.endswith(".attn"):
|
315 |
+
continue
|
316 |
+
del module.c_factor
|
317 |
+
|
318 |
+
if not return_dict:
|
319 |
+
return (image,)
|
320 |
+
|
321 |
+
return FluxPipelineOutput(images=image)
|
OminiControl/src/flux/lora_controller.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
2 |
+
from typing import List, Any, Optional, Type
|
3 |
+
|
4 |
+
|
5 |
+
class enable_lora:
|
6 |
+
def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
|
7 |
+
self.activated: bool = activated
|
8 |
+
if activated:
|
9 |
+
return
|
10 |
+
self.lora_modules: List[BaseTunerLayer] = [
|
11 |
+
each for each in lora_modules if isinstance(each, BaseTunerLayer)
|
12 |
+
]
|
13 |
+
self.scales = [
|
14 |
+
{
|
15 |
+
active_adapter: lora_module.scaling[active_adapter]
|
16 |
+
for active_adapter in lora_module.active_adapters
|
17 |
+
}
|
18 |
+
for lora_module in self.lora_modules
|
19 |
+
]
|
20 |
+
|
21 |
+
def __enter__(self) -> None:
|
22 |
+
if self.activated:
|
23 |
+
return
|
24 |
+
|
25 |
+
for lora_module in self.lora_modules:
|
26 |
+
if not isinstance(lora_module, BaseTunerLayer):
|
27 |
+
continue
|
28 |
+
lora_module.scale_layer(0)
|
29 |
+
|
30 |
+
def __exit__(
|
31 |
+
self,
|
32 |
+
exc_type: Optional[Type[BaseException]],
|
33 |
+
exc_val: Optional[BaseException],
|
34 |
+
exc_tb: Optional[Any],
|
35 |
+
) -> None:
|
36 |
+
if self.activated:
|
37 |
+
return
|
38 |
+
for i, lora_module in enumerate(self.lora_modules):
|
39 |
+
if not isinstance(lora_module, BaseTunerLayer):
|
40 |
+
continue
|
41 |
+
for active_adapter in lora_module.active_adapters:
|
42 |
+
lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
|
43 |
+
|
44 |
+
|
45 |
+
class set_lora_scale:
|
46 |
+
def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
|
47 |
+
self.lora_modules: List[BaseTunerLayer] = [
|
48 |
+
each for each in lora_modules if isinstance(each, BaseTunerLayer)
|
49 |
+
]
|
50 |
+
self.scales = [
|
51 |
+
{
|
52 |
+
active_adapter: lora_module.scaling[active_adapter]
|
53 |
+
for active_adapter in lora_module.active_adapters
|
54 |
+
}
|
55 |
+
for lora_module in self.lora_modules
|
56 |
+
]
|
57 |
+
self.scale = scale
|
58 |
+
|
59 |
+
def __enter__(self) -> None:
|
60 |
+
for lora_module in self.lora_modules:
|
61 |
+
if not isinstance(lora_module, BaseTunerLayer):
|
62 |
+
continue
|
63 |
+
lora_module.scale_layer(self.scale)
|
64 |
+
|
65 |
+
def __exit__(
|
66 |
+
self,
|
67 |
+
exc_type: Optional[Type[BaseException]],
|
68 |
+
exc_val: Optional[BaseException],
|
69 |
+
exc_tb: Optional[Any],
|
70 |
+
) -> None:
|
71 |
+
for i, lora_module in enumerate(self.lora_modules):
|
72 |
+
if not isinstance(lora_module, BaseTunerLayer):
|
73 |
+
continue
|
74 |
+
for active_adapter in lora_module.active_adapters:
|
75 |
+
lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
|
OminiControl/src/flux/pipeline_tools.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers.pipelines import FluxPipeline
|
2 |
+
from diffusers.utils import logging
|
3 |
+
from diffusers.pipelines.flux.pipeline_flux import logger
|
4 |
+
from torch import Tensor
|
5 |
+
|
6 |
+
|
7 |
+
def encode_images(pipeline: FluxPipeline, images: Tensor):
|
8 |
+
images = pipeline.image_processor.preprocess(images)
|
9 |
+
images = images.to(pipeline.device).to(pipeline.dtype)
|
10 |
+
images = pipeline.vae.encode(images).latent_dist.sample()
|
11 |
+
images = (
|
12 |
+
images - pipeline.vae.config.shift_factor
|
13 |
+
) * pipeline.vae.config.scaling_factor
|
14 |
+
images_tokens = pipeline._pack_latents(images, *images.shape)
|
15 |
+
images_ids = pipeline._prepare_latent_image_ids(
|
16 |
+
images.shape[0],
|
17 |
+
images.shape[2],
|
18 |
+
images.shape[3],
|
19 |
+
pipeline.device,
|
20 |
+
pipeline.dtype,
|
21 |
+
)
|
22 |
+
if images_tokens.shape[1] != images_ids.shape[0]:
|
23 |
+
images_ids = pipeline._prepare_latent_image_ids(
|
24 |
+
images.shape[0],
|
25 |
+
images.shape[2] // 2,
|
26 |
+
images.shape[3] // 2,
|
27 |
+
pipeline.device,
|
28 |
+
pipeline.dtype,
|
29 |
+
)
|
30 |
+
return images_tokens, images_ids
|
31 |
+
|
32 |
+
|
33 |
+
def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512):
|
34 |
+
# Turn off warnings (CLIP overflow)
|
35 |
+
logger.setLevel(logging.ERROR)
|
36 |
+
(
|
37 |
+
prompt_embeds,
|
38 |
+
pooled_prompt_embeds,
|
39 |
+
text_ids,
|
40 |
+
) = pipeline.encode_prompt(
|
41 |
+
prompt=prompts,
|
42 |
+
prompt_2=None,
|
43 |
+
prompt_embeds=None,
|
44 |
+
pooled_prompt_embeds=None,
|
45 |
+
device=pipeline.device,
|
46 |
+
num_images_per_prompt=1,
|
47 |
+
max_sequence_length=max_sequence_length,
|
48 |
+
lora_scale=None,
|
49 |
+
)
|
50 |
+
# Turn on warnings
|
51 |
+
logger.setLevel(logging.WARNING)
|
52 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
OminiControl/src/flux/transformer.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.pipelines import FluxPipeline
|
3 |
+
from typing import List, Union, Optional, Dict, Any, Callable
|
4 |
+
from .block import block_forward, single_block_forward
|
5 |
+
from .lora_controller import enable_lora
|
6 |
+
from accelerate.utils import is_torch_version
|
7 |
+
from diffusers.models.transformers.transformer_flux import (
|
8 |
+
FluxTransformer2DModel,
|
9 |
+
Transformer2DModelOutput,
|
10 |
+
USE_PEFT_BACKEND,
|
11 |
+
scale_lora_layers,
|
12 |
+
unscale_lora_layers,
|
13 |
+
logger,
|
14 |
+
)
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
|
18 |
+
def prepare_params(
|
19 |
+
hidden_states: torch.Tensor,
|
20 |
+
encoder_hidden_states: torch.Tensor = None,
|
21 |
+
pooled_projections: torch.Tensor = None,
|
22 |
+
timestep: torch.LongTensor = None,
|
23 |
+
img_ids: torch.Tensor = None,
|
24 |
+
txt_ids: torch.Tensor = None,
|
25 |
+
guidance: torch.Tensor = None,
|
26 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
27 |
+
controlnet_block_samples=None,
|
28 |
+
controlnet_single_block_samples=None,
|
29 |
+
return_dict: bool = True,
|
30 |
+
**kwargs: dict,
|
31 |
+
):
|
32 |
+
return (
|
33 |
+
hidden_states,
|
34 |
+
encoder_hidden_states,
|
35 |
+
pooled_projections,
|
36 |
+
timestep,
|
37 |
+
img_ids,
|
38 |
+
txt_ids,
|
39 |
+
guidance,
|
40 |
+
joint_attention_kwargs,
|
41 |
+
controlnet_block_samples,
|
42 |
+
controlnet_single_block_samples,
|
43 |
+
return_dict,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
def tranformer_forward(
|
48 |
+
transformer: FluxTransformer2DModel,
|
49 |
+
condition_latents: torch.Tensor,
|
50 |
+
condition_ids: torch.Tensor,
|
51 |
+
condition_type_ids: torch.Tensor,
|
52 |
+
model_config: Optional[Dict[str, Any]] = {},
|
53 |
+
c_t=0,
|
54 |
+
**params: dict,
|
55 |
+
):
|
56 |
+
self = transformer
|
57 |
+
use_condition = condition_latents is not None
|
58 |
+
|
59 |
+
(
|
60 |
+
hidden_states,
|
61 |
+
encoder_hidden_states,
|
62 |
+
pooled_projections,
|
63 |
+
timestep,
|
64 |
+
img_ids,
|
65 |
+
txt_ids,
|
66 |
+
guidance,
|
67 |
+
joint_attention_kwargs,
|
68 |
+
controlnet_block_samples,
|
69 |
+
controlnet_single_block_samples,
|
70 |
+
return_dict,
|
71 |
+
) = prepare_params(**params)
|
72 |
+
|
73 |
+
if joint_attention_kwargs is not None:
|
74 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
75 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
76 |
+
else:
|
77 |
+
lora_scale = 1.0
|
78 |
+
|
79 |
+
if USE_PEFT_BACKEND:
|
80 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
81 |
+
scale_lora_layers(self, lora_scale)
|
82 |
+
else:
|
83 |
+
if (
|
84 |
+
joint_attention_kwargs is not None
|
85 |
+
and joint_attention_kwargs.get("scale", None) is not None
|
86 |
+
):
|
87 |
+
logger.warning(
|
88 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
89 |
+
)
|
90 |
+
|
91 |
+
with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
|
92 |
+
hidden_states = self.x_embedder(hidden_states)
|
93 |
+
condition_latents = self.x_embedder(condition_latents) if use_condition else None
|
94 |
+
|
95 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
96 |
+
|
97 |
+
if guidance is not None:
|
98 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
99 |
+
else:
|
100 |
+
guidance = None
|
101 |
+
|
102 |
+
temb = (
|
103 |
+
self.time_text_embed(timestep, pooled_projections)
|
104 |
+
if guidance is None
|
105 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
106 |
+
)
|
107 |
+
|
108 |
+
cond_temb = (
|
109 |
+
self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
|
110 |
+
if guidance is None
|
111 |
+
else self.time_text_embed(
|
112 |
+
torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
|
113 |
+
)
|
114 |
+
)
|
115 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
116 |
+
|
117 |
+
if txt_ids.ndim == 3:
|
118 |
+
logger.warning(
|
119 |
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
120 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
121 |
+
)
|
122 |
+
txt_ids = txt_ids[0]
|
123 |
+
if img_ids.ndim == 3:
|
124 |
+
logger.warning(
|
125 |
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
126 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
127 |
+
)
|
128 |
+
img_ids = img_ids[0]
|
129 |
+
|
130 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
131 |
+
image_rotary_emb = self.pos_embed(ids)
|
132 |
+
if use_condition:
|
133 |
+
# condition_ids[:, :1] = condition_type_ids
|
134 |
+
cond_rotary_emb = self.pos_embed(condition_ids)
|
135 |
+
|
136 |
+
# hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
|
137 |
+
|
138 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
139 |
+
if self.training and self.gradient_checkpointing:
|
140 |
+
ckpt_kwargs: Dict[str, Any] = (
|
141 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
142 |
+
)
|
143 |
+
encoder_hidden_states, hidden_states, condition_latents = (
|
144 |
+
torch.utils.checkpoint.checkpoint(
|
145 |
+
block_forward,
|
146 |
+
self=block,
|
147 |
+
model_config=model_config,
|
148 |
+
hidden_states=hidden_states,
|
149 |
+
encoder_hidden_states=encoder_hidden_states,
|
150 |
+
condition_latents=condition_latents if use_condition else None,
|
151 |
+
temb=temb,
|
152 |
+
cond_temb=cond_temb if use_condition else None,
|
153 |
+
cond_rotary_emb=cond_rotary_emb if use_condition else None,
|
154 |
+
image_rotary_emb=image_rotary_emb,
|
155 |
+
**ckpt_kwargs,
|
156 |
+
)
|
157 |
+
)
|
158 |
+
|
159 |
+
else:
|
160 |
+
encoder_hidden_states, hidden_states, condition_latents = block_forward(
|
161 |
+
block,
|
162 |
+
model_config=model_config,
|
163 |
+
hidden_states=hidden_states,
|
164 |
+
encoder_hidden_states=encoder_hidden_states,
|
165 |
+
condition_latents=condition_latents if use_condition else None,
|
166 |
+
temb=temb,
|
167 |
+
cond_temb=cond_temb if use_condition else None,
|
168 |
+
cond_rotary_emb=cond_rotary_emb if use_condition else None,
|
169 |
+
image_rotary_emb=image_rotary_emb,
|
170 |
+
)
|
171 |
+
|
172 |
+
# controlnet residual
|
173 |
+
if controlnet_block_samples is not None:
|
174 |
+
interval_control = len(self.transformer_blocks) / len(
|
175 |
+
controlnet_block_samples
|
176 |
+
)
|
177 |
+
interval_control = int(np.ceil(interval_control))
|
178 |
+
hidden_states = (
|
179 |
+
hidden_states
|
180 |
+
+ controlnet_block_samples[index_block // interval_control]
|
181 |
+
)
|
182 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
183 |
+
|
184 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
185 |
+
if self.training and self.gradient_checkpointing:
|
186 |
+
ckpt_kwargs: Dict[str, Any] = (
|
187 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
188 |
+
)
|
189 |
+
result = torch.utils.checkpoint.checkpoint(
|
190 |
+
single_block_forward,
|
191 |
+
self=block,
|
192 |
+
model_config=model_config,
|
193 |
+
hidden_states=hidden_states,
|
194 |
+
temb=temb,
|
195 |
+
image_rotary_emb=image_rotary_emb,
|
196 |
+
**(
|
197 |
+
{
|
198 |
+
"condition_latents": condition_latents,
|
199 |
+
"cond_temb": cond_temb,
|
200 |
+
"cond_rotary_emb": cond_rotary_emb,
|
201 |
+
}
|
202 |
+
if use_condition
|
203 |
+
else {}
|
204 |
+
),
|
205 |
+
**ckpt_kwargs,
|
206 |
+
)
|
207 |
+
|
208 |
+
else:
|
209 |
+
result = single_block_forward(
|
210 |
+
block,
|
211 |
+
model_config=model_config,
|
212 |
+
hidden_states=hidden_states,
|
213 |
+
temb=temb,
|
214 |
+
image_rotary_emb=image_rotary_emb,
|
215 |
+
**(
|
216 |
+
{
|
217 |
+
"condition_latents": condition_latents,
|
218 |
+
"cond_temb": cond_temb,
|
219 |
+
"cond_rotary_emb": cond_rotary_emb,
|
220 |
+
}
|
221 |
+
if use_condition
|
222 |
+
else {}
|
223 |
+
),
|
224 |
+
)
|
225 |
+
if use_condition:
|
226 |
+
hidden_states, condition_latents = result
|
227 |
+
else:
|
228 |
+
hidden_states = result
|
229 |
+
|
230 |
+
# controlnet residual
|
231 |
+
if controlnet_single_block_samples is not None:
|
232 |
+
interval_control = len(self.single_transformer_blocks) / len(
|
233 |
+
controlnet_single_block_samples
|
234 |
+
)
|
235 |
+
interval_control = int(np.ceil(interval_control))
|
236 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
237 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
238 |
+
+ controlnet_single_block_samples[index_block // interval_control]
|
239 |
+
)
|
240 |
+
|
241 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
242 |
+
|
243 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
244 |
+
output = self.proj_out(hidden_states)
|
245 |
+
|
246 |
+
if USE_PEFT_BACKEND:
|
247 |
+
# remove `lora_scale` from each PEFT layer
|
248 |
+
unscale_lora_layers(self, lora_scale)
|
249 |
+
|
250 |
+
if not return_dict:
|
251 |
+
return (output,)
|
252 |
+
return Transformer2DModelOutput(sample=output)
|
OminiControl/src/gradio/gradio_app.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from PIL import Image, ImageDraw, ImageFont
|
4 |
+
from diffusers.pipelines import FluxPipeline
|
5 |
+
from diffusers import FluxTransformer2DModel
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from ..flux.condition import Condition
|
9 |
+
from ..flux.generate import seed_everything, generate
|
10 |
+
|
11 |
+
pipe = None
|
12 |
+
use_int8 = False
|
13 |
+
|
14 |
+
|
15 |
+
def get_gpu_memory():
|
16 |
+
return torch.cuda.get_device_properties(0).total_memory / 1024**3
|
17 |
+
|
18 |
+
|
19 |
+
def init_pipeline():
|
20 |
+
global pipe
|
21 |
+
if use_int8 or get_gpu_memory() < 33:
|
22 |
+
transformer_model = FluxTransformer2DModel.from_pretrained(
|
23 |
+
"sayakpaul/flux.1-schell-int8wo-improved",
|
24 |
+
torch_dtype=torch.bfloat16,
|
25 |
+
use_safetensors=False,
|
26 |
+
)
|
27 |
+
pipe = FluxPipeline.from_pretrained(
|
28 |
+
"black-forest-labs/FLUX.1-schnell",
|
29 |
+
transformer=transformer_model,
|
30 |
+
torch_dtype=torch.bfloat16,
|
31 |
+
)
|
32 |
+
else:
|
33 |
+
pipe = FluxPipeline.from_pretrained(
|
34 |
+
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
35 |
+
)
|
36 |
+
pipe = pipe.to("cuda")
|
37 |
+
pipe.load_lora_weights(
|
38 |
+
"Yuanshi/OminiControl",
|
39 |
+
weight_name="omini/subject_512.safetensors",
|
40 |
+
adapter_name="subject",
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
def process_image_and_text(image, text):
|
45 |
+
# center crop image
|
46 |
+
w, h, min_size = image.size[0], image.size[1], min(image.size)
|
47 |
+
image = image.crop(
|
48 |
+
(
|
49 |
+
(w - min_size) // 2,
|
50 |
+
(h - min_size) // 2,
|
51 |
+
(w + min_size) // 2,
|
52 |
+
(h + min_size) // 2,
|
53 |
+
)
|
54 |
+
)
|
55 |
+
image = image.resize((512, 512))
|
56 |
+
|
57 |
+
condition = Condition("subject", image, position_delta=(0, 32))
|
58 |
+
|
59 |
+
if pipe is None:
|
60 |
+
init_pipeline()
|
61 |
+
|
62 |
+
result_img = generate(
|
63 |
+
pipe,
|
64 |
+
prompt=text.strip(),
|
65 |
+
conditions=[condition],
|
66 |
+
num_inference_steps=8,
|
67 |
+
height=512,
|
68 |
+
width=512,
|
69 |
+
).images[0]
|
70 |
+
|
71 |
+
return result_img
|
72 |
+
|
73 |
+
|
74 |
+
def get_samples():
|
75 |
+
sample_list = [
|
76 |
+
{
|
77 |
+
"image": "assets/oranges.jpg",
|
78 |
+
"text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'",
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"image": "assets/penguin.jpg",
|
82 |
+
"text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'",
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"image": "assets/rc_car.jpg",
|
86 |
+
"text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.",
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"image": "assets/clock.jpg",
|
90 |
+
"text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.",
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"image": "assets/tshirt.jpg",
|
94 |
+
"text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.",
|
95 |
+
},
|
96 |
+
]
|
97 |
+
return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list]
|
98 |
+
|
99 |
+
|
100 |
+
demo = gr.Interface(
|
101 |
+
fn=process_image_and_text,
|
102 |
+
inputs=[
|
103 |
+
gr.Image(type="pil"),
|
104 |
+
gr.Textbox(lines=2),
|
105 |
+
],
|
106 |
+
outputs=gr.Image(type="pil"),
|
107 |
+
title="OminiControl / Subject driven generation",
|
108 |
+
examples=get_samples(),
|
109 |
+
)
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
init_pipeline()
|
113 |
+
demo.launch(
|
114 |
+
debug=True,
|
115 |
+
)
|
OminiControl/src/train/callbacks.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lightning as L
|
2 |
+
from PIL import Image, ImageFilter, ImageDraw
|
3 |
+
import numpy as np
|
4 |
+
from transformers import pipeline
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
|
9 |
+
try:
|
10 |
+
import wandb
|
11 |
+
except ImportError:
|
12 |
+
wandb = None
|
13 |
+
|
14 |
+
from ..flux.condition import Condition
|
15 |
+
from ..flux.generate import generate
|
16 |
+
|
17 |
+
|
18 |
+
class TrainingCallback(L.Callback):
|
19 |
+
def __init__(self, run_name, training_config: dict = {}):
|
20 |
+
self.run_name, self.training_config = run_name, training_config
|
21 |
+
|
22 |
+
self.print_every_n_steps = training_config.get("print_every_n_steps", 10)
|
23 |
+
self.save_interval = training_config.get("save_interval", 1000)
|
24 |
+
self.sample_interval = training_config.get("sample_interval", 1000)
|
25 |
+
self.save_path = training_config.get("save_path", "./output")
|
26 |
+
|
27 |
+
self.wandb_config = training_config.get("wandb", None)
|
28 |
+
self.use_wandb = (
|
29 |
+
wandb is not None and os.environ.get("WANDB_API_KEY") is not None
|
30 |
+
)
|
31 |
+
|
32 |
+
self.total_steps = 0
|
33 |
+
|
34 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
35 |
+
gradient_size = 0
|
36 |
+
max_gradient_size = 0
|
37 |
+
count = 0
|
38 |
+
for _, param in pl_module.named_parameters():
|
39 |
+
if param.grad is not None:
|
40 |
+
gradient_size += param.grad.norm(2).item()
|
41 |
+
max_gradient_size = max(max_gradient_size, param.grad.norm(2).item())
|
42 |
+
count += 1
|
43 |
+
if count > 0:
|
44 |
+
gradient_size /= count
|
45 |
+
|
46 |
+
self.total_steps += 1
|
47 |
+
|
48 |
+
# Print training progress every n steps
|
49 |
+
if self.use_wandb:
|
50 |
+
report_dict = {
|
51 |
+
"steps": batch_idx,
|
52 |
+
"steps": self.total_steps,
|
53 |
+
"epoch": trainer.current_epoch,
|
54 |
+
"gradient_size": gradient_size,
|
55 |
+
}
|
56 |
+
loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches
|
57 |
+
report_dict["loss"] = loss_value
|
58 |
+
report_dict["t"] = pl_module.last_t
|
59 |
+
wandb.log(report_dict)
|
60 |
+
|
61 |
+
if self.total_steps % self.print_every_n_steps == 0:
|
62 |
+
print(
|
63 |
+
f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}"
|
64 |
+
)
|
65 |
+
|
66 |
+
# Save LoRA weights at specified intervals
|
67 |
+
if self.total_steps % self.save_interval == 0:
|
68 |
+
print(
|
69 |
+
f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights"
|
70 |
+
)
|
71 |
+
pl_module.save_lora(
|
72 |
+
f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}"
|
73 |
+
)
|
74 |
+
|
75 |
+
# Generate and save a sample image at specified intervals
|
76 |
+
if self.total_steps % self.sample_interval == 0:
|
77 |
+
print(
|
78 |
+
f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample"
|
79 |
+
)
|
80 |
+
self.generate_a_sample(
|
81 |
+
trainer,
|
82 |
+
pl_module,
|
83 |
+
f"{self.save_path}/{self.run_name}/output",
|
84 |
+
f"lora_{self.total_steps}",
|
85 |
+
batch["condition_type"][
|
86 |
+
0
|
87 |
+
], # Use the condition type from the current batch
|
88 |
+
)
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def generate_a_sample(
|
92 |
+
self,
|
93 |
+
trainer,
|
94 |
+
pl_module,
|
95 |
+
save_path,
|
96 |
+
file_name,
|
97 |
+
condition_type="super_resolution",
|
98 |
+
):
|
99 |
+
# TODO: change this two variables to parameters
|
100 |
+
condition_size = trainer.training_config["dataset"]["condition_size"]
|
101 |
+
target_size = trainer.training_config["dataset"]["target_size"]
|
102 |
+
position_scale = trainer.training_config["dataset"].get("position_scale", 1.0)
|
103 |
+
|
104 |
+
generator = torch.Generator(device=pl_module.device)
|
105 |
+
generator.manual_seed(42)
|
106 |
+
|
107 |
+
test_list = []
|
108 |
+
|
109 |
+
if condition_type == "subject":
|
110 |
+
test_list.extend(
|
111 |
+
[
|
112 |
+
(
|
113 |
+
Image.open("assets/test_in.jpg"),
|
114 |
+
[0, -32],
|
115 |
+
"Resting on the picnic table at a lakeside campsite, it's caught in the golden glow of early morning, with mist rising from the water and tall pines casting long shadows behind the scene.",
|
116 |
+
),
|
117 |
+
(
|
118 |
+
Image.open("assets/test_out.jpg"),
|
119 |
+
[0, -32],
|
120 |
+
"In a bright room. It is placed on a table.",
|
121 |
+
),
|
122 |
+
]
|
123 |
+
)
|
124 |
+
elif condition_type == "canny":
|
125 |
+
condition_img = Image.open("assets/vase_hq.jpg").resize(
|
126 |
+
(condition_size, condition_size)
|
127 |
+
)
|
128 |
+
condition_img = np.array(condition_img)
|
129 |
+
condition_img = cv2.Canny(condition_img, 100, 200)
|
130 |
+
condition_img = Image.fromarray(condition_img).convert("RGB")
|
131 |
+
test_list.append(
|
132 |
+
(
|
133 |
+
condition_img,
|
134 |
+
[0, 0],
|
135 |
+
"A beautiful vase on a table.",
|
136 |
+
{"position_scale": position_scale} if position_scale != 1.0 else {},
|
137 |
+
)
|
138 |
+
)
|
139 |
+
elif condition_type == "coloring":
|
140 |
+
condition_img = (
|
141 |
+
Image.open("assets/vase_hq.jpg")
|
142 |
+
.resize((condition_size, condition_size))
|
143 |
+
.convert("L")
|
144 |
+
.convert("RGB")
|
145 |
+
)
|
146 |
+
test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
|
147 |
+
elif condition_type == "depth":
|
148 |
+
if not hasattr(self, "deepth_pipe"):
|
149 |
+
self.deepth_pipe = pipeline(
|
150 |
+
task="depth-estimation",
|
151 |
+
model="LiheYoung/depth-anything-small-hf",
|
152 |
+
device="cpu",
|
153 |
+
)
|
154 |
+
condition_img = (
|
155 |
+
Image.open("assets/vase_hq.jpg")
|
156 |
+
.resize((condition_size, condition_size))
|
157 |
+
.convert("RGB")
|
158 |
+
)
|
159 |
+
condition_img = self.deepth_pipe(condition_img)["depth"].convert("RGB")
|
160 |
+
test_list.append(
|
161 |
+
(
|
162 |
+
condition_img,
|
163 |
+
[0, 0],
|
164 |
+
"A beautiful vase on a table.",
|
165 |
+
{"position_scale": position_scale} if position_scale != 1.0 else {},
|
166 |
+
)
|
167 |
+
)
|
168 |
+
elif condition_type == "depth_pred":
|
169 |
+
condition_img = (
|
170 |
+
Image.open("assets/vase_hq.jpg")
|
171 |
+
.resize((condition_size, condition_size))
|
172 |
+
.convert("RGB")
|
173 |
+
)
|
174 |
+
test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
|
175 |
+
elif condition_type == "deblurring":
|
176 |
+
blur_radius = 5
|
177 |
+
image = Image.open("./assets/vase_hq.jpg")
|
178 |
+
condition_img = (
|
179 |
+
image.convert("RGB")
|
180 |
+
.resize((condition_size, condition_size))
|
181 |
+
.filter(ImageFilter.GaussianBlur(blur_radius))
|
182 |
+
.convert("RGB")
|
183 |
+
)
|
184 |
+
test_list.append(
|
185 |
+
(
|
186 |
+
condition_img,
|
187 |
+
[0, 0],
|
188 |
+
"A beautiful vase on a table.",
|
189 |
+
{"position_scale": position_scale} if position_scale != 1.0 else {},
|
190 |
+
)
|
191 |
+
)
|
192 |
+
elif condition_type == "fill":
|
193 |
+
condition_img = (
|
194 |
+
Image.open("./assets/vase_hq.jpg")
|
195 |
+
.resize((condition_size, condition_size))
|
196 |
+
.convert("RGB")
|
197 |
+
)
|
198 |
+
mask = Image.new("L", condition_img.size, 0)
|
199 |
+
draw = ImageDraw.Draw(mask)
|
200 |
+
a = condition_img.size[0] // 4
|
201 |
+
b = a * 3
|
202 |
+
draw.rectangle([a, a, b, b], fill=255)
|
203 |
+
condition_img = Image.composite(
|
204 |
+
condition_img, Image.new("RGB", condition_img.size, (0, 0, 0)), mask
|
205 |
+
)
|
206 |
+
test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
|
207 |
+
elif condition_type == "sr":
|
208 |
+
condition_img = (
|
209 |
+
Image.open("assets/vase_hq.jpg")
|
210 |
+
.resize((condition_size, condition_size))
|
211 |
+
.convert("RGB")
|
212 |
+
)
|
213 |
+
test_list.append((condition_img, [0, -16], "A beautiful vase on a table."))
|
214 |
+
elif condition_type == "cartoon":
|
215 |
+
condition_img = (
|
216 |
+
Image.open("assets/cartoon_boy.png")
|
217 |
+
.resize((condition_size, condition_size))
|
218 |
+
.convert("RGB")
|
219 |
+
)
|
220 |
+
test_list.append(
|
221 |
+
(
|
222 |
+
condition_img,
|
223 |
+
[0, -16],
|
224 |
+
"A cartoon character in a white background. He is looking right, and running.",
|
225 |
+
)
|
226 |
+
)
|
227 |
+
else:
|
228 |
+
raise NotImplementedError
|
229 |
+
|
230 |
+
if not os.path.exists(save_path):
|
231 |
+
os.makedirs(save_path)
|
232 |
+
for i, (condition_img, position_delta, prompt, *others) in enumerate(test_list):
|
233 |
+
condition = Condition(
|
234 |
+
condition_type=condition_type,
|
235 |
+
condition=condition_img.resize(
|
236 |
+
(condition_size, condition_size)
|
237 |
+
).convert("RGB"),
|
238 |
+
position_delta=position_delta,
|
239 |
+
**(others[0] if others else {}),
|
240 |
+
)
|
241 |
+
res = generate(
|
242 |
+
pl_module.flux_pipe,
|
243 |
+
prompt=prompt,
|
244 |
+
conditions=[condition],
|
245 |
+
height=target_size,
|
246 |
+
width=target_size,
|
247 |
+
generator=generator,
|
248 |
+
model_config=pl_module.model_config,
|
249 |
+
default_lora=True,
|
250 |
+
)
|
251 |
+
res.images[0].save(
|
252 |
+
os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg")
|
253 |
+
)
|
OminiControl/src/train/data.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageFilter, ImageDraw
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import torchvision.transforms as T
|
6 |
+
import random
|
7 |
+
|
8 |
+
|
9 |
+
class Subject200KDataset(Dataset):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
base_dataset,
|
13 |
+
condition_size: int = 512,
|
14 |
+
target_size: int = 512,
|
15 |
+
image_size: int = 512,
|
16 |
+
padding: int = 0,
|
17 |
+
condition_type: str = "subject",
|
18 |
+
drop_text_prob: float = 0.1,
|
19 |
+
drop_image_prob: float = 0.1,
|
20 |
+
return_pil_image: bool = False,
|
21 |
+
):
|
22 |
+
self.base_dataset = base_dataset
|
23 |
+
self.condition_size = condition_size
|
24 |
+
self.target_size = target_size
|
25 |
+
self.image_size = image_size
|
26 |
+
self.padding = padding
|
27 |
+
self.condition_type = condition_type
|
28 |
+
self.drop_text_prob = drop_text_prob
|
29 |
+
self.drop_image_prob = drop_image_prob
|
30 |
+
self.return_pil_image = return_pil_image
|
31 |
+
|
32 |
+
self.to_tensor = T.ToTensor()
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return len(self.base_dataset) * 2
|
36 |
+
|
37 |
+
def __getitem__(self, idx):
|
38 |
+
# If target is 0, left image is target, right image is condition
|
39 |
+
target = idx % 2
|
40 |
+
item = self.base_dataset[idx // 2]
|
41 |
+
|
42 |
+
# Crop the image to target and condition
|
43 |
+
image = item["image"]
|
44 |
+
left_img = image.crop(
|
45 |
+
(
|
46 |
+
self.padding,
|
47 |
+
self.padding,
|
48 |
+
self.image_size + self.padding,
|
49 |
+
self.image_size + self.padding,
|
50 |
+
)
|
51 |
+
)
|
52 |
+
right_img = image.crop(
|
53 |
+
(
|
54 |
+
self.image_size + self.padding * 2,
|
55 |
+
self.padding,
|
56 |
+
self.image_size * 2 + self.padding * 2,
|
57 |
+
self.image_size + self.padding,
|
58 |
+
)
|
59 |
+
)
|
60 |
+
|
61 |
+
# Get the target and condition image
|
62 |
+
target_image, condition_img = (
|
63 |
+
(left_img, right_img) if target == 0 else (right_img, left_img)
|
64 |
+
)
|
65 |
+
|
66 |
+
# Resize the image
|
67 |
+
condition_img = condition_img.resize(
|
68 |
+
(self.condition_size, self.condition_size)
|
69 |
+
).convert("RGB")
|
70 |
+
target_image = target_image.resize(
|
71 |
+
(self.target_size, self.target_size)
|
72 |
+
).convert("RGB")
|
73 |
+
|
74 |
+
# Get the description
|
75 |
+
description = item["description"][
|
76 |
+
"description_0" if target == 0 else "description_1"
|
77 |
+
]
|
78 |
+
|
79 |
+
# Randomly drop text or image
|
80 |
+
drop_text = random.random() < self.drop_text_prob
|
81 |
+
drop_image = random.random() < self.drop_image_prob
|
82 |
+
if drop_text:
|
83 |
+
description = ""
|
84 |
+
if drop_image:
|
85 |
+
condition_img = Image.new(
|
86 |
+
"RGB", (self.condition_size, self.condition_size), (0, 0, 0)
|
87 |
+
)
|
88 |
+
|
89 |
+
return {
|
90 |
+
"image": self.to_tensor(target_image),
|
91 |
+
"condition": self.to_tensor(condition_img),
|
92 |
+
"condition_type": self.condition_type,
|
93 |
+
"description": description,
|
94 |
+
# 16 is the downscale factor of the image
|
95 |
+
"position_delta": np.array([0, -self.condition_size // 16]),
|
96 |
+
**({"pil_image": image} if self.return_pil_image else {}),
|
97 |
+
}
|
98 |
+
|
99 |
+
|
100 |
+
class ImageConditionDataset(Dataset):
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
base_dataset,
|
104 |
+
condition_size: int = 512,
|
105 |
+
target_size: int = 512,
|
106 |
+
condition_type: str = "canny",
|
107 |
+
drop_text_prob: float = 0.1,
|
108 |
+
drop_image_prob: float = 0.1,
|
109 |
+
return_pil_image: bool = False,
|
110 |
+
position_scale=1.0,
|
111 |
+
):
|
112 |
+
self.base_dataset = base_dataset
|
113 |
+
self.condition_size = condition_size
|
114 |
+
self.target_size = target_size
|
115 |
+
self.condition_type = condition_type
|
116 |
+
self.drop_text_prob = drop_text_prob
|
117 |
+
self.drop_image_prob = drop_image_prob
|
118 |
+
self.return_pil_image = return_pil_image
|
119 |
+
self.position_scale = position_scale
|
120 |
+
|
121 |
+
self.to_tensor = T.ToTensor()
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
return len(self.base_dataset)
|
125 |
+
|
126 |
+
@property
|
127 |
+
def depth_pipe(self):
|
128 |
+
if not hasattr(self, "_depth_pipe"):
|
129 |
+
from transformers import pipeline
|
130 |
+
|
131 |
+
self._depth_pipe = pipeline(
|
132 |
+
task="depth-estimation",
|
133 |
+
model="LiheYoung/depth-anything-small-hf",
|
134 |
+
device="cpu",
|
135 |
+
)
|
136 |
+
return self._depth_pipe
|
137 |
+
|
138 |
+
def _get_canny_edge(self, img):
|
139 |
+
resize_ratio = self.condition_size / max(img.size)
|
140 |
+
img = img.resize(
|
141 |
+
(int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio))
|
142 |
+
)
|
143 |
+
img_np = np.array(img)
|
144 |
+
img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
145 |
+
edges = cv2.Canny(img_gray, 100, 200)
|
146 |
+
return Image.fromarray(edges).convert("RGB")
|
147 |
+
|
148 |
+
def __getitem__(self, idx):
|
149 |
+
image = self.base_dataset[idx]["jpg"]
|
150 |
+
image = image.resize((self.target_size, self.target_size)).convert("RGB")
|
151 |
+
description = self.base_dataset[idx]["json"]["prompt"]
|
152 |
+
|
153 |
+
enable_scale = random.random() < 1
|
154 |
+
if not enable_scale:
|
155 |
+
condition_size = int(self.condition_size * self.position_scale)
|
156 |
+
position_scale = 1.0
|
157 |
+
else:
|
158 |
+
condition_size = self.condition_size
|
159 |
+
position_scale = self.position_scale
|
160 |
+
|
161 |
+
# Get the condition image
|
162 |
+
position_delta = np.array([0, 0])
|
163 |
+
if self.condition_type == "canny":
|
164 |
+
condition_img = self._get_canny_edge(image)
|
165 |
+
elif self.condition_type == "coloring":
|
166 |
+
condition_img = (
|
167 |
+
image.resize((condition_size, condition_size))
|
168 |
+
.convert("L")
|
169 |
+
.convert("RGB")
|
170 |
+
)
|
171 |
+
elif self.condition_type == "deblurring":
|
172 |
+
blur_radius = random.randint(1, 10)
|
173 |
+
condition_img = (
|
174 |
+
image.convert("RGB")
|
175 |
+
.filter(ImageFilter.GaussianBlur(blur_radius))
|
176 |
+
.resize((condition_size, condition_size))
|
177 |
+
.convert("RGB")
|
178 |
+
)
|
179 |
+
elif self.condition_type == "depth":
|
180 |
+
condition_img = self.depth_pipe(image)["depth"].convert("RGB")
|
181 |
+
condition_img = condition_img.resize((condition_size, condition_size))
|
182 |
+
elif self.condition_type == "depth_pred":
|
183 |
+
condition_img = image
|
184 |
+
image = self.depth_pipe(condition_img)["depth"].convert("RGB")
|
185 |
+
description = f"[depth] {description}"
|
186 |
+
elif self.condition_type == "fill":
|
187 |
+
condition_img = image.resize((condition_size, condition_size)).convert(
|
188 |
+
"RGB"
|
189 |
+
)
|
190 |
+
w, h = image.size
|
191 |
+
x1, x2 = sorted([random.randint(0, w), random.randint(0, w)])
|
192 |
+
y1, y2 = sorted([random.randint(0, h), random.randint(0, h)])
|
193 |
+
mask = Image.new("L", image.size, 0)
|
194 |
+
draw = ImageDraw.Draw(mask)
|
195 |
+
draw.rectangle([x1, y1, x2, y2], fill=255)
|
196 |
+
if random.random() > 0.5:
|
197 |
+
mask = Image.eval(mask, lambda a: 255 - a)
|
198 |
+
condition_img = Image.composite(
|
199 |
+
image, Image.new("RGB", image.size, (0, 0, 0)), mask
|
200 |
+
)
|
201 |
+
elif self.condition_type == "sr":
|
202 |
+
condition_img = image.resize((condition_size, condition_size)).convert(
|
203 |
+
"RGB"
|
204 |
+
)
|
205 |
+
position_delta = np.array([0, -condition_size // 16])
|
206 |
+
|
207 |
+
else:
|
208 |
+
raise ValueError(f"Condition type {self.condition_type} not implemented")
|
209 |
+
|
210 |
+
# Randomly drop text or image
|
211 |
+
drop_text = random.random() < self.drop_text_prob
|
212 |
+
drop_image = random.random() < self.drop_image_prob
|
213 |
+
if drop_text:
|
214 |
+
description = ""
|
215 |
+
if drop_image:
|
216 |
+
condition_img = Image.new(
|
217 |
+
"RGB", (condition_size, condition_size), (0, 0, 0)
|
218 |
+
)
|
219 |
+
|
220 |
+
return {
|
221 |
+
"image": self.to_tensor(image),
|
222 |
+
"condition": self.to_tensor(condition_img),
|
223 |
+
"condition_type": self.condition_type,
|
224 |
+
"description": description,
|
225 |
+
"position_delta": position_delta,
|
226 |
+
**({"pil_image": [image, condition_img]} if self.return_pil_image else {}),
|
227 |
+
**({"position_scale": position_scale} if position_scale != 1.0 else {}),
|
228 |
+
}
|
229 |
+
|
230 |
+
|
231 |
+
class CartoonDataset(Dataset):
|
232 |
+
def __init__(
|
233 |
+
self,
|
234 |
+
base_dataset,
|
235 |
+
condition_size: int = 1024,
|
236 |
+
target_size: int = 1024,
|
237 |
+
image_size: int = 1024,
|
238 |
+
padding: int = 0,
|
239 |
+
condition_type: str = "cartoon",
|
240 |
+
drop_text_prob: float = 0.1,
|
241 |
+
drop_image_prob: float = 0.1,
|
242 |
+
return_pil_image: bool = False,
|
243 |
+
):
|
244 |
+
self.base_dataset = base_dataset
|
245 |
+
self.condition_size = condition_size
|
246 |
+
self.target_size = target_size
|
247 |
+
self.image_size = image_size
|
248 |
+
self.padding = padding
|
249 |
+
self.condition_type = condition_type
|
250 |
+
self.drop_text_prob = drop_text_prob
|
251 |
+
self.drop_image_prob = drop_image_prob
|
252 |
+
self.return_pil_image = return_pil_image
|
253 |
+
|
254 |
+
self.to_tensor = T.ToTensor()
|
255 |
+
|
256 |
+
def __len__(self):
|
257 |
+
return len(self.base_dataset)
|
258 |
+
|
259 |
+
def __getitem__(self, idx):
|
260 |
+
data = self.base_dataset[idx]
|
261 |
+
condition_img = data["condition"]
|
262 |
+
target_image = data["target"]
|
263 |
+
|
264 |
+
# Tag
|
265 |
+
tag = data["tags"][0]
|
266 |
+
|
267 |
+
target_description = data["target_description"]
|
268 |
+
|
269 |
+
description = {
|
270 |
+
"lion": "lion like animal",
|
271 |
+
"bear": "bear like animal",
|
272 |
+
"gorilla": "gorilla like animal",
|
273 |
+
"dog": "dog like animal",
|
274 |
+
"elephant": "elephant like animal",
|
275 |
+
"eagle": "eagle like bird",
|
276 |
+
"tiger": "tiger like animal",
|
277 |
+
"owl": "owl like bird",
|
278 |
+
"woman": "woman",
|
279 |
+
"parrot": "parrot like bird",
|
280 |
+
"mouse": "mouse like animal",
|
281 |
+
"man": "man",
|
282 |
+
"pigeon": "pigeon like bird",
|
283 |
+
"girl": "girl",
|
284 |
+
"panda": "panda like animal",
|
285 |
+
"crocodile": "crocodile like animal",
|
286 |
+
"rabbit": "rabbit like animal",
|
287 |
+
"boy": "boy",
|
288 |
+
"monkey": "monkey like animal",
|
289 |
+
"cat": "cat like animal",
|
290 |
+
}
|
291 |
+
|
292 |
+
# Resize the image
|
293 |
+
condition_img = condition_img.resize(
|
294 |
+
(self.condition_size, self.condition_size)
|
295 |
+
).convert("RGB")
|
296 |
+
target_image = target_image.resize(
|
297 |
+
(self.target_size, self.target_size)
|
298 |
+
).convert("RGB")
|
299 |
+
|
300 |
+
# Process datum to create description
|
301 |
+
description = data.get(
|
302 |
+
"description",
|
303 |
+
f"Photo of a {description[tag]} cartoon character in a white background. Character is facing {target_description['facing_direction']}. Character pose is {target_description['pose']}.",
|
304 |
+
)
|
305 |
+
|
306 |
+
# Randomly drop text or image
|
307 |
+
drop_text = random.random() < self.drop_text_prob
|
308 |
+
drop_image = random.random() < self.drop_image_prob
|
309 |
+
if drop_text:
|
310 |
+
description = ""
|
311 |
+
if drop_image:
|
312 |
+
condition_img = Image.new(
|
313 |
+
"RGB", (self.condition_size, self.condition_size), (0, 0, 0)
|
314 |
+
)
|
315 |
+
|
316 |
+
return {
|
317 |
+
"image": self.to_tensor(target_image),
|
318 |
+
"condition": self.to_tensor(condition_img),
|
319 |
+
"condition_type": self.condition_type,
|
320 |
+
"description": description,
|
321 |
+
# 16 is the downscale factor of the image
|
322 |
+
"position_delta": np.array([0, -16]),
|
323 |
+
}
|
OminiControl/src/train/model.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lightning as L
|
2 |
+
from diffusers.pipelines import FluxPipeline
|
3 |
+
import torch
|
4 |
+
from peft import LoraConfig, get_peft_model_state_dict
|
5 |
+
|
6 |
+
import prodigyopt
|
7 |
+
|
8 |
+
from ..flux.transformer import tranformer_forward
|
9 |
+
from ..flux.condition import Condition
|
10 |
+
from ..flux.pipeline_tools import encode_images, prepare_text_input
|
11 |
+
|
12 |
+
|
13 |
+
class OminiModel(L.LightningModule):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
flux_pipe_id: str,
|
17 |
+
lora_path: str = None,
|
18 |
+
lora_config: dict = None,
|
19 |
+
device: str = "cuda",
|
20 |
+
dtype: torch.dtype = torch.bfloat16,
|
21 |
+
model_config: dict = {},
|
22 |
+
optimizer_config: dict = None,
|
23 |
+
gradient_checkpointing: bool = False,
|
24 |
+
):
|
25 |
+
# Initialize the LightningModule
|
26 |
+
super().__init__()
|
27 |
+
self.model_config = model_config
|
28 |
+
self.optimizer_config = optimizer_config
|
29 |
+
|
30 |
+
# Load the Flux pipeline
|
31 |
+
self.flux_pipe: FluxPipeline = (
|
32 |
+
FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device)
|
33 |
+
)
|
34 |
+
self.transformer = self.flux_pipe.transformer
|
35 |
+
self.transformer.gradient_checkpointing = gradient_checkpointing
|
36 |
+
self.transformer.train()
|
37 |
+
|
38 |
+
# Freeze the Flux pipeline
|
39 |
+
self.flux_pipe.text_encoder.requires_grad_(False).eval()
|
40 |
+
self.flux_pipe.text_encoder_2.requires_grad_(False).eval()
|
41 |
+
self.flux_pipe.vae.requires_grad_(False).eval()
|
42 |
+
|
43 |
+
# Initialize LoRA layers
|
44 |
+
self.lora_layers = self.init_lora(lora_path, lora_config)
|
45 |
+
|
46 |
+
self.to(device).to(dtype)
|
47 |
+
|
48 |
+
def init_lora(self, lora_path: str, lora_config: dict):
|
49 |
+
assert lora_path or lora_config
|
50 |
+
if lora_path:
|
51 |
+
# TODO: Implement this
|
52 |
+
raise NotImplementedError
|
53 |
+
else:
|
54 |
+
self.transformer.add_adapter(LoraConfig(**lora_config))
|
55 |
+
# TODO: Check if this is correct (p.requires_grad)
|
56 |
+
lora_layers = filter(
|
57 |
+
lambda p: p.requires_grad, self.transformer.parameters()
|
58 |
+
)
|
59 |
+
return list(lora_layers)
|
60 |
+
|
61 |
+
def save_lora(self, path: str):
|
62 |
+
FluxPipeline.save_lora_weights(
|
63 |
+
save_directory=path,
|
64 |
+
transformer_lora_layers=get_peft_model_state_dict(self.transformer),
|
65 |
+
safe_serialization=True,
|
66 |
+
)
|
67 |
+
|
68 |
+
def configure_optimizers(self):
|
69 |
+
# Freeze the transformer
|
70 |
+
self.transformer.requires_grad_(False)
|
71 |
+
opt_config = self.optimizer_config
|
72 |
+
|
73 |
+
# Set the trainable parameters
|
74 |
+
self.trainable_params = self.lora_layers
|
75 |
+
|
76 |
+
# Unfreeze trainable parameters
|
77 |
+
for p in self.trainable_params:
|
78 |
+
p.requires_grad_(True)
|
79 |
+
|
80 |
+
# Initialize the optimizer
|
81 |
+
if opt_config["type"] == "AdamW":
|
82 |
+
optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"])
|
83 |
+
elif opt_config["type"] == "Prodigy":
|
84 |
+
optimizer = prodigyopt.Prodigy(
|
85 |
+
self.trainable_params,
|
86 |
+
**opt_config["params"],
|
87 |
+
)
|
88 |
+
elif opt_config["type"] == "SGD":
|
89 |
+
optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"])
|
90 |
+
else:
|
91 |
+
raise NotImplementedError
|
92 |
+
|
93 |
+
return optimizer
|
94 |
+
|
95 |
+
def training_step(self, batch, batch_idx):
|
96 |
+
step_loss = self.step(batch)
|
97 |
+
self.log_loss = (
|
98 |
+
step_loss.item()
|
99 |
+
if not hasattr(self, "log_loss")
|
100 |
+
else self.log_loss * 0.95 + step_loss.item() * 0.05
|
101 |
+
)
|
102 |
+
return step_loss
|
103 |
+
|
104 |
+
def step(self, batch):
|
105 |
+
imgs = batch["image"]
|
106 |
+
conditions = batch["condition"]
|
107 |
+
condition_types = batch["condition_type"]
|
108 |
+
prompts = batch["description"]
|
109 |
+
position_delta = batch["position_delta"][0]
|
110 |
+
position_scale = float(batch.get("position_scale", [1.0])[0])
|
111 |
+
|
112 |
+
# Prepare inputs
|
113 |
+
with torch.no_grad():
|
114 |
+
# Prepare image input
|
115 |
+
x_0, img_ids = encode_images(self.flux_pipe, imgs)
|
116 |
+
|
117 |
+
# Prepare text input
|
118 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input(
|
119 |
+
self.flux_pipe, prompts
|
120 |
+
)
|
121 |
+
|
122 |
+
# Prepare t and x_t
|
123 |
+
t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
|
124 |
+
x_1 = torch.randn_like(x_0).to(self.device)
|
125 |
+
t_ = t.unsqueeze(1).unsqueeze(1)
|
126 |
+
x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
|
127 |
+
|
128 |
+
# Prepare conditions
|
129 |
+
condition_latents, condition_ids = encode_images(self.flux_pipe, conditions)
|
130 |
+
|
131 |
+
# Add position delta
|
132 |
+
condition_ids[:, 1] += position_delta[0]
|
133 |
+
condition_ids[:, 2] += position_delta[1]
|
134 |
+
|
135 |
+
if position_scale != 1.0:
|
136 |
+
scale_bias = (position_scale - 1.0) / 2
|
137 |
+
condition_ids[:, 1] *= position_scale
|
138 |
+
condition_ids[:, 2] *= position_scale
|
139 |
+
condition_ids[:, 1] += scale_bias
|
140 |
+
condition_ids[:, 2] += scale_bias
|
141 |
+
|
142 |
+
# Prepare condition type
|
143 |
+
condition_type_ids = torch.tensor(
|
144 |
+
[
|
145 |
+
Condition.get_type_id(condition_type)
|
146 |
+
for condition_type in condition_types
|
147 |
+
]
|
148 |
+
).to(self.device)
|
149 |
+
condition_type_ids = (
|
150 |
+
torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0]
|
151 |
+
).unsqueeze(1)
|
152 |
+
|
153 |
+
# Prepare guidance
|
154 |
+
guidance = (
|
155 |
+
torch.ones_like(t).to(self.device)
|
156 |
+
if self.transformer.config.guidance_embeds
|
157 |
+
else None
|
158 |
+
)
|
159 |
+
|
160 |
+
# Forward pass
|
161 |
+
transformer_out = tranformer_forward(
|
162 |
+
self.transformer,
|
163 |
+
# Model config
|
164 |
+
model_config=self.model_config,
|
165 |
+
# Inputs of the condition (new feature)
|
166 |
+
condition_latents=condition_latents,
|
167 |
+
condition_ids=condition_ids,
|
168 |
+
condition_type_ids=condition_type_ids,
|
169 |
+
# Inputs to the original transformer
|
170 |
+
hidden_states=x_t,
|
171 |
+
timestep=t,
|
172 |
+
guidance=guidance,
|
173 |
+
pooled_projections=pooled_prompt_embeds,
|
174 |
+
encoder_hidden_states=prompt_embeds,
|
175 |
+
txt_ids=text_ids,
|
176 |
+
img_ids=img_ids,
|
177 |
+
joint_attention_kwargs=None,
|
178 |
+
return_dict=False,
|
179 |
+
)
|
180 |
+
pred = transformer_out[0]
|
181 |
+
|
182 |
+
# Compute loss
|
183 |
+
loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean")
|
184 |
+
self.last_t = t.mean().item()
|
185 |
+
return loss
|
OminiControl/src/train/train.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
import torch
|
3 |
+
import lightning as L
|
4 |
+
import yaml
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
|
8 |
+
from datasets import load_dataset
|
9 |
+
|
10 |
+
from .data import ImageConditionDataset, Subject200KDataset, CartoonDataset
|
11 |
+
from .model import OminiModel
|
12 |
+
from .callbacks import TrainingCallback
|
13 |
+
|
14 |
+
|
15 |
+
def get_rank():
|
16 |
+
try:
|
17 |
+
rank = int(os.environ.get("LOCAL_RANK"))
|
18 |
+
except:
|
19 |
+
rank = 0
|
20 |
+
return rank
|
21 |
+
|
22 |
+
|
23 |
+
def get_config():
|
24 |
+
config_path = os.environ.get("XFL_CONFIG")
|
25 |
+
assert config_path is not None, "Please set the XFL_CONFIG environment variable"
|
26 |
+
with open(config_path, "r") as f:
|
27 |
+
config = yaml.safe_load(f)
|
28 |
+
return config
|
29 |
+
|
30 |
+
|
31 |
+
def init_wandb(wandb_config, run_name):
|
32 |
+
import wandb
|
33 |
+
|
34 |
+
try:
|
35 |
+
assert os.environ.get("WANDB_API_KEY") is not None
|
36 |
+
wandb.init(
|
37 |
+
project=wandb_config["project"],
|
38 |
+
name=run_name,
|
39 |
+
config={},
|
40 |
+
)
|
41 |
+
except Exception as e:
|
42 |
+
print("Failed to initialize WanDB:", e)
|
43 |
+
|
44 |
+
|
45 |
+
def main():
|
46 |
+
# Initialize
|
47 |
+
is_main_process, rank = get_rank() == 0, get_rank()
|
48 |
+
torch.cuda.set_device(rank)
|
49 |
+
config = get_config()
|
50 |
+
training_config = config["train"]
|
51 |
+
run_name = time.strftime("%Y%m%d-%H%M%S")
|
52 |
+
|
53 |
+
# Initialize WanDB
|
54 |
+
wandb_config = training_config.get("wandb", None)
|
55 |
+
if wandb_config is not None and is_main_process:
|
56 |
+
init_wandb(wandb_config, run_name)
|
57 |
+
|
58 |
+
print("Rank:", rank)
|
59 |
+
if is_main_process:
|
60 |
+
print("Config:", config)
|
61 |
+
|
62 |
+
# Initialize dataset and dataloader
|
63 |
+
if training_config["dataset"]["type"] == "subject":
|
64 |
+
dataset = load_dataset("Yuanshi/Subjects200K")
|
65 |
+
|
66 |
+
# Define filter function
|
67 |
+
def filter_func(item):
|
68 |
+
if not item.get("quality_assessment"):
|
69 |
+
return False
|
70 |
+
return all(
|
71 |
+
item["quality_assessment"].get(key, 0) >= 5
|
72 |
+
for key in ["compositeStructure", "objectConsistency", "imageQuality"]
|
73 |
+
)
|
74 |
+
|
75 |
+
# Filter dataset
|
76 |
+
if not os.path.exists("./cache/dataset"):
|
77 |
+
os.makedirs("./cache/dataset")
|
78 |
+
data_valid = dataset["train"].filter(
|
79 |
+
filter_func,
|
80 |
+
num_proc=16,
|
81 |
+
cache_file_name="./cache/dataset/data_valid.arrow",
|
82 |
+
)
|
83 |
+
dataset = Subject200KDataset(
|
84 |
+
data_valid,
|
85 |
+
condition_size=training_config["dataset"]["condition_size"],
|
86 |
+
target_size=training_config["dataset"]["target_size"],
|
87 |
+
image_size=training_config["dataset"]["image_size"],
|
88 |
+
padding=training_config["dataset"]["padding"],
|
89 |
+
condition_type=training_config["condition_type"],
|
90 |
+
drop_text_prob=training_config["dataset"]["drop_text_prob"],
|
91 |
+
drop_image_prob=training_config["dataset"]["drop_image_prob"],
|
92 |
+
)
|
93 |
+
elif training_config["dataset"]["type"] == "img":
|
94 |
+
# Load dataset text-to-image-2M
|
95 |
+
dataset = load_dataset(
|
96 |
+
"webdataset",
|
97 |
+
data_files={"train": training_config["dataset"]["urls"]},
|
98 |
+
split="train",
|
99 |
+
cache_dir="cache/t2i2m",
|
100 |
+
num_proc=32,
|
101 |
+
)
|
102 |
+
dataset = ImageConditionDataset(
|
103 |
+
dataset,
|
104 |
+
condition_size=training_config["dataset"]["condition_size"],
|
105 |
+
target_size=training_config["dataset"]["target_size"],
|
106 |
+
condition_type=training_config["condition_type"],
|
107 |
+
drop_text_prob=training_config["dataset"]["drop_text_prob"],
|
108 |
+
drop_image_prob=training_config["dataset"]["drop_image_prob"],
|
109 |
+
position_scale=training_config["dataset"].get("position_scale", 1.0),
|
110 |
+
)
|
111 |
+
elif training_config["dataset"]["type"] == "cartoon":
|
112 |
+
dataset = load_dataset("saquiboye/oye-cartoon", split="train")
|
113 |
+
dataset = CartoonDataset(
|
114 |
+
dataset,
|
115 |
+
condition_size=training_config["dataset"]["condition_size"],
|
116 |
+
target_size=training_config["dataset"]["target_size"],
|
117 |
+
image_size=training_config["dataset"]["image_size"],
|
118 |
+
padding=training_config["dataset"]["padding"],
|
119 |
+
condition_type=training_config["condition_type"],
|
120 |
+
drop_text_prob=training_config["dataset"]["drop_text_prob"],
|
121 |
+
drop_image_prob=training_config["dataset"]["drop_image_prob"],
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
raise NotImplementedError
|
125 |
+
|
126 |
+
print("Dataset length:", len(dataset))
|
127 |
+
train_loader = DataLoader(
|
128 |
+
dataset,
|
129 |
+
batch_size=training_config["batch_size"],
|
130 |
+
shuffle=True,
|
131 |
+
num_workers=training_config["dataloader_workers"],
|
132 |
+
)
|
133 |
+
|
134 |
+
# Initialize model
|
135 |
+
trainable_model = OminiModel(
|
136 |
+
flux_pipe_id=config["flux_path"],
|
137 |
+
lora_config=training_config["lora_config"],
|
138 |
+
device=f"cuda",
|
139 |
+
dtype=getattr(torch, config["dtype"]),
|
140 |
+
optimizer_config=training_config["optimizer"],
|
141 |
+
model_config=config.get("model", {}),
|
142 |
+
gradient_checkpointing=training_config.get("gradient_checkpointing", False),
|
143 |
+
)
|
144 |
+
|
145 |
+
# Callbacks for logging and saving checkpoints
|
146 |
+
training_callbacks = (
|
147 |
+
[TrainingCallback(run_name, training_config=training_config)]
|
148 |
+
if is_main_process
|
149 |
+
else []
|
150 |
+
)
|
151 |
+
|
152 |
+
# Initialize trainer
|
153 |
+
trainer = L.Trainer(
|
154 |
+
accumulate_grad_batches=training_config["accumulate_grad_batches"],
|
155 |
+
callbacks=training_callbacks,
|
156 |
+
enable_checkpointing=False,
|
157 |
+
enable_progress_bar=False,
|
158 |
+
logger=False,
|
159 |
+
max_steps=training_config.get("max_steps", -1),
|
160 |
+
max_epochs=training_config.get("max_epochs", -1),
|
161 |
+
gradient_clip_val=training_config.get("gradient_clip_val", 0.5),
|
162 |
+
)
|
163 |
+
|
164 |
+
setattr(trainer, "training_config", training_config)
|
165 |
+
|
166 |
+
# Save config
|
167 |
+
save_path = training_config.get("save_path", "./output")
|
168 |
+
if is_main_process:
|
169 |
+
os.makedirs(f"{save_path}/{run_name}")
|
170 |
+
with open(f"{save_path}/{run_name}/config.yaml", "w") as f:
|
171 |
+
yaml.dump(config, f)
|
172 |
+
|
173 |
+
# Start training
|
174 |
+
trainer.fit(trainable_model, train_loader)
|
175 |
+
|
176 |
+
|
177 |
+
if __name__ == "__main__":
|
178 |
+
main()
|