Yuanshi commited on
Commit
fb6a167
·
verified ·
1 Parent(s): f623d12

Upload 61 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +12 -0
  2. OminiControl/LICENSE +201 -0
  3. OminiControl/README.md +170 -0
  4. OminiControl/assets/book.jpg +0 -0
  5. OminiControl/assets/cartoon_boy.png +3 -0
  6. OminiControl/assets/clock.jpg +3 -0
  7. OminiControl/assets/coffee.png +0 -0
  8. OminiControl/assets/demo/book_omini.jpg +0 -0
  9. OminiControl/assets/demo/clock_omini.jpg +0 -0
  10. OminiControl/assets/demo/demo_this_is_omini_control.jpg +3 -0
  11. OminiControl/assets/demo/dreambooth_res.jpg +3 -0
  12. OminiControl/assets/demo/man_omini.jpg +0 -0
  13. OminiControl/assets/demo/monalisa_omini.jpg +3 -0
  14. OminiControl/assets/demo/oranges_omini.jpg +0 -0
  15. OminiControl/assets/demo/panda_omini.jpg +0 -0
  16. OminiControl/assets/demo/penguin_omini.jpg +0 -0
  17. OminiControl/assets/demo/rc_car_omini.jpg +0 -0
  18. OminiControl/assets/demo/room_corner_canny.jpg +0 -0
  19. OminiControl/assets/demo/room_corner_coloring.jpg +0 -0
  20. OminiControl/assets/demo/room_corner_deblurring.jpg +0 -0
  21. OminiControl/assets/demo/room_corner_depth.jpg +0 -0
  22. OminiControl/assets/demo/scene_variation.jpg +3 -0
  23. OminiControl/assets/demo/shirt_omini.jpg +0 -0
  24. OminiControl/assets/demo/try_on.jpg +3 -0
  25. OminiControl/assets/monalisa.jpg +3 -0
  26. OminiControl/assets/oranges.jpg +0 -0
  27. OminiControl/assets/penguin.jpg +0 -0
  28. OminiControl/assets/rc_car.jpg +3 -0
  29. OminiControl/assets/room_corner.jpg +3 -0
  30. OminiControl/assets/test_in.jpg +0 -0
  31. OminiControl/assets/test_out.jpg +0 -0
  32. OminiControl/assets/tshirt.jpg +3 -0
  33. OminiControl/assets/vase.jpg +0 -0
  34. OminiControl/assets/vase_hq.jpg +3 -0
  35. OminiControl/examples/inpainting.ipynb +143 -0
  36. OminiControl/examples/spatial.ipynb +184 -0
  37. OminiControl/examples/subject.ipynb +214 -0
  38. OminiControl/examples/subject_1024.ipynb +221 -0
  39. OminiControl/requirements.txt +9 -0
  40. OminiControl/src/flux/block.py +339 -0
  41. OminiControl/src/flux/condition.py +138 -0
  42. OminiControl/src/flux/generate.py +321 -0
  43. OminiControl/src/flux/lora_controller.py +75 -0
  44. OminiControl/src/flux/pipeline_tools.py +52 -0
  45. OminiControl/src/flux/transformer.py +252 -0
  46. OminiControl/src/gradio/gradio_app.py +115 -0
  47. OminiControl/src/train/callbacks.py +253 -0
  48. OminiControl/src/train/data.py +323 -0
  49. OminiControl/src/train/model.py +185 -0
  50. OminiControl/src/train/train.py +178 -0
.gitattributes CHANGED
@@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ OminiControl/assets/cartoon_boy.png filter=lfs diff=lfs merge=lfs -text
37
+ OminiControl/assets/clock.jpg filter=lfs diff=lfs merge=lfs -text
38
+ OminiControl/assets/demo/demo_this_is_omini_control.jpg filter=lfs diff=lfs merge=lfs -text
39
+ OminiControl/assets/demo/dreambooth_res.jpg filter=lfs diff=lfs merge=lfs -text
40
+ OminiControl/assets/demo/monalisa_omini.jpg filter=lfs diff=lfs merge=lfs -text
41
+ OminiControl/assets/demo/scene_variation.jpg filter=lfs diff=lfs merge=lfs -text
42
+ OminiControl/assets/demo/try_on.jpg filter=lfs diff=lfs merge=lfs -text
43
+ OminiControl/assets/monalisa.jpg filter=lfs diff=lfs merge=lfs -text
44
+ OminiControl/assets/rc_car.jpg filter=lfs diff=lfs merge=lfs -text
45
+ OminiControl/assets/room_corner.jpg filter=lfs diff=lfs merge=lfs -text
46
+ OminiControl/assets/tshirt.jpg filter=lfs diff=lfs merge=lfs -text
47
+ OminiControl/assets/vase_hq.jpg filter=lfs diff=lfs merge=lfs -text
OminiControl/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2024] [Zhenxiong Tan]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
OminiControl/README.md ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OminiControl
2
+
3
+
4
+ <img src='./assets/demo/demo_this_is_omini_control.jpg' width='100%' />
5
+ <br>
6
+
7
+ <a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ariXv-2411.15098-A42C25.svg" alt="arXiv"></a>
8
+ <a href="https://huggingface.co/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a>
9
+ <a href="https://huggingface.co/spaces/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗_HuggingFace-Space-ffbd45.svg" alt="HuggingFace"></a>
10
+ <a href="https://github.com/Yuanshi9815/Subjects200K"><img src="https://img.shields.io/badge/GitHub-Dataset-blue.svg?logo=github&" alt="GitHub"></a>
11
+ <a href="https://huggingface.co/datasets/Yuanshi/Subjects200K"><img src="https://img.shields.io/badge/🤗_HuggingFace-Dataset-ffbd45.svg" alt="HuggingFace"></a>
12
+
13
+ > **OminiControl: Minimal and Universal Control for Diffusion Transformer**
14
+ > <br>
15
+ > Zhenxiong Tan,
16
+ > [Songhua Liu](http://121.37.94.87/),
17
+ > [Xingyi Yang](https://adamdad.github.io/),
18
+ > Qiaochu Xue,
19
+ > and
20
+ > [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
21
+ > <br>
22
+ > [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore
23
+ > <br>
24
+
25
+
26
+ ## Features
27
+
28
+ OminiControl is a minimal yet powerful universal control framework for Diffusion Transformer models like [FLUX](https://github.com/black-forest-labs/flux).
29
+
30
+ * **Universal Control 🌐**: A unified control framework that supports both subject-driven control and spatial control (such as edge-guided and in-painting generation).
31
+
32
+ * **Minimal Design 🚀**: Injects control signals while preserving original model structure. Only introduces 0.1% additional parameters to the base model.
33
+
34
+ ## News
35
+ - **2024-12-26**: ⭐️ Training code are released. Now you can create your own OminiControl model by customizing any control tasks (3D, multi-view, pose-guided, try-on, etc.) with the FLUX model. Check the [training folder](./train) for more details.
36
+
37
+ ## Quick Start
38
+ ### Setup (Optional)
39
+ 1. **Environment setup**
40
+ ```bash
41
+ conda create -n omini python=3.10
42
+ conda activate omini
43
+ ```
44
+ 2. **Requirements installation**
45
+ ```bash
46
+ pip install -r requirements.txt
47
+ ```
48
+ ### Usage example
49
+ 1. Subject-driven generation: `examples/subject.ipynb`
50
+ 2. In-painting: `examples/inpainting.ipynb`
51
+ 3. Canny edge to image, depth to image, colorization, deblurring: `examples/spatial.ipynb`
52
+
53
+ ### Gradio app
54
+ To run the Gradio app for subject-driven generation:
55
+ ```bash
56
+ python -m src.gradio.gradio_app
57
+ ```
58
+
59
+ ### Guidelines for subject-driven generation
60
+ 1. Input images are automatically center-cropped and resized to 512x512 resolution.
61
+ 2. When writing prompts, refer to the subject using phrases like `this item`, `the object`, or `it`. e.g.
62
+ 1. *A close up view of this item. It is placed on a wooden table.*
63
+ 2. *A young lady is wearing this shirt.*
64
+ 3. The model primarily works with objects rather than human subjects currently, due to the absence of human data in training.
65
+
66
+ ## Generated samples
67
+ ### Subject-driven generation
68
+ <a href="https://huggingface.co/spaces/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗_HuggingFace-Space-ffbd45.svg" alt="HuggingFace"></a>
69
+
70
+ **Demos** (Left: condition image; Right: generated image)
71
+
72
+ <div float="left">
73
+ <img src='./assets/demo/oranges_omini.jpg' width='48%'/>
74
+ <img src='./assets/demo/rc_car_omini.jpg' width='48%' />
75
+ <img src='./assets/demo/clock_omini.jpg' width='48%' />
76
+ <img src='./assets/demo/shirt_omini.jpg' width='48%' />
77
+ </div>
78
+
79
+ <details>
80
+ <summary>Text Prompts</summary>
81
+
82
+ - Prompt1: *A close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!.'*
83
+ - Prompt2: *A film style shot. On the moon, this item drives across the moon surface. A flag on it reads 'Omini'. The background is that Earth looms large in the foreground.*
84
+ - Prompt3: *In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.*
85
+ - Prompt4: *"On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple."*
86
+ </details>
87
+ <details>
88
+ <summary>More results</summary>
89
+
90
+ * Try on:
91
+ <img src='./assets/demo/try_on.jpg'/>
92
+ * Scene variations:
93
+ <img src='./assets/demo/scene_variation.jpg'/>
94
+ * Dreambooth dataset:
95
+ <img src='./assets/demo/dreambooth_res.jpg'/>
96
+ * Oye-cartoon finetune:
97
+ <div float="left">
98
+ <img src='./assets/demo/man_omini.jpg' width='48%' />
99
+ <img src='./assets/demo/panda_omini.jpg' width='48%' />
100
+ </div>
101
+ </details>
102
+
103
+ ### Spatially aligned control
104
+ 1. **Image Inpainting** (Left: original image; Center: masked image; Right: filled image)
105
+ - Prompt: *The Mona Lisa is wearing a white VR headset with 'Omini' written on it.*
106
+ </br>
107
+ <img src='./assets/demo/monalisa_omini.jpg' width='700px' />
108
+ - Prompt: *A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.*
109
+ </br>
110
+ <img src='./assets/demo/book_omini.jpg' width='700px' />
111
+ 2. **Other spatially aligned tasks** (Canny edge to image, depth to image, colorization, deblurring)
112
+ </br>
113
+ <details>
114
+ <summary>Click to show</summary>
115
+ <div float="left">
116
+ <img src='./assets/demo/room_corner_canny.jpg' width='48%'/>
117
+ <img src='./assets/demo/room_corner_depth.jpg' width='48%' />
118
+ <img src='./assets/demo/room_corner_coloring.jpg' width='48%' />
119
+ <img src='./assets/demo/room_corner_deblurring.jpg' width='48%' />
120
+ </div>
121
+
122
+ Prompt: *A light gray sofa stands against a white wall, featuring a black and white geometric patterned pillow. A white side table sits next to the sofa, topped with a white adjustable desk lamp and some books. Dark hardwood flooring contrasts with the pale walls and furniture.*
123
+ </details>
124
+
125
+
126
+
127
+
128
+ ## Models
129
+
130
+ **Subject-driven control:**
131
+ | Model | Base model | Description | Resolution |
132
+ | ------------------------------------------------------------------------------------------------ | -------------- | -------------------------------------------------------------------------------------------------------- | ------------ |
133
+ | [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `subject` | FLUX.1-schnell | The model used in the paper. | (512, 512) |
134
+ | [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_512` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset. | (512, 512) |
135
+ | [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_1024` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset and accommodates higher resolution. (To be released) | (1024, 1024) |
136
+ | [`oye-cartoon`](https://huggingface.co/saquiboye/oye-cartoon) | FLUX.1-dev | The model has been fine-tuned on [oye-cartoon](https://huggingface.co/datasets/saquiboye/oye-cartoon) dataset by [@saquib764](https://github.com/Saquib764) | (512, 512) |
137
+
138
+ **Spatial aligned control:**
139
+ | Model | Base model | Description | Resolution |
140
+ | --------------------------------------------------------------------------------------------------------- | ---------- | -------------------------------------------------------------------------- | ------------ |
141
+ | [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `<task_name>` | FLUX.1 | Canny edge to image, depth to image, colorization, deblurring, in-painting | (512, 512) |
142
+ | [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `<task_name>_1024` | FLUX.1 | Supports higher resolution.(To be released) | (1024, 1024) |
143
+
144
+ ## Community Extensions
145
+ - [ComfyUI-Diffusers-OminiControl](https://github.com/Macoron/ComfyUI-Diffusers-OminiControl) - ComfyUI integration by [@Macoron](https://github.com/Macoron)
146
+ - [ComfyUI_RH_OminiControl](https://github.com/HM-RunningHub/ComfyUI_RH_OminiControl) - ComfyUI integration by [@HM-RunningHub](https://github.com/HM-RunningHub)
147
+
148
+ ## Limitations
149
+ 1. The model's subject-driven generation primarily works with objects rather than human subjects due to the absence of human data in training.
150
+ 2. The subject-driven generation model may not work well with `FLUX.1-dev`.
151
+ 3. The released model currently only supports the resolution of 512x512.
152
+
153
+ ## Training
154
+ Training instructions can be found in this [folder](./train).
155
+
156
+
157
+ ## To-do
158
+ - [x] Release the training code.
159
+ - [ ] Release the model for higher resolution (1024x1024).
160
+
161
+ ## Citation
162
+ ```
163
+ @article{tan2024ominicontrol,
164
+ title={Ominicontrol: Minimal and universal control for diffusion transformer},
165
+ author={Tan, Zhenxiong and Liu, Songhua and Yang, Xingyi and Xue, Qiaochu and Wang, Xinchao},
166
+ journal={arXiv preprint arXiv:2411.15098},
167
+ volume={3},
168
+ year={2024}
169
+ }
170
+ ```
OminiControl/assets/book.jpg ADDED
OminiControl/assets/cartoon_boy.png ADDED

Git LFS Details

  • SHA256: d4a82c0f9ed09b9468bded7d901beffaf29addc30ed5f72ad72451e1b6344b1c
  • Pointer size: 131 Bytes
  • Size of remote file: 429 kB
OminiControl/assets/clock.jpg ADDED

Git LFS Details

  • SHA256: 41235973f26152ac92d32bfc166fb5f9f1e352c5e16807920238473316ec462b
  • Pointer size: 131 Bytes
  • Size of remote file: 289 kB
OminiControl/assets/coffee.png ADDED
OminiControl/assets/demo/book_omini.jpg ADDED
OminiControl/assets/demo/clock_omini.jpg ADDED
OminiControl/assets/demo/demo_this_is_omini_control.jpg ADDED

Git LFS Details

  • SHA256: 798b7c25be6be118dc0de97c444c840869afca633a0d48f99d940aec040a7518
  • Pointer size: 131 Bytes
  • Size of remote file: 129 kB
OminiControl/assets/demo/dreambooth_res.jpg ADDED

Git LFS Details

  • SHA256: ba36bd861989564dc679acf3b5e56f382f1a11b1596e6f611ea0bd7d81b89680
  • Pointer size: 132 Bytes
  • Size of remote file: 1.94 MB
OminiControl/assets/demo/man_omini.jpg ADDED
OminiControl/assets/demo/monalisa_omini.jpg ADDED

Git LFS Details

  • SHA256: e5ca6c2bf44f19d216b2eb16dcc67d19f11d87220d3ee80f5e5e1ad98a5536dc
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
OminiControl/assets/demo/oranges_omini.jpg ADDED
OminiControl/assets/demo/panda_omini.jpg ADDED
OminiControl/assets/demo/penguin_omini.jpg ADDED
OminiControl/assets/demo/rc_car_omini.jpg ADDED
OminiControl/assets/demo/room_corner_canny.jpg ADDED
OminiControl/assets/demo/room_corner_coloring.jpg ADDED
OminiControl/assets/demo/room_corner_deblurring.jpg ADDED
OminiControl/assets/demo/room_corner_depth.jpg ADDED
OminiControl/assets/demo/scene_variation.jpg ADDED

Git LFS Details

  • SHA256: 39e4e16d2eeb58b3775b6d34c8b3e125d0d19cc36fa90b07c6c8d57624ad4333
  • Pointer size: 131 Bytes
  • Size of remote file: 958 kB
OminiControl/assets/demo/shirt_omini.jpg ADDED
OminiControl/assets/demo/try_on.jpg ADDED

Git LFS Details

  • SHA256: 6adce5194329a83f0109b4375e00667c341879e64fb55831c70ea3f3b2f99f7e
  • Pointer size: 131 Bytes
  • Size of remote file: 774 kB
OminiControl/assets/monalisa.jpg ADDED

Git LFS Details

  • SHA256: 188b8b6499e4541f9dfef2a9daf6f1eb920079c9208f587fd97566d6aa4a9719
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
OminiControl/assets/oranges.jpg ADDED
OminiControl/assets/penguin.jpg ADDED
OminiControl/assets/rc_car.jpg ADDED

Git LFS Details

  • SHA256: ae8aed11029fa3b084deb286c07a8cab5056840c9c123816fe2b504e94233e95
  • Pointer size: 131 Bytes
  • Size of remote file: 254 kB
OminiControl/assets/room_corner.jpg ADDED

Git LFS Details

  • SHA256: f97bd63df05f5f15ad5dd1a2ccef803e74e12caadd8fe145493fd6d5219045e7
  • Pointer size: 131 Bytes
  • Size of remote file: 236 kB
OminiControl/assets/test_in.jpg ADDED
OminiControl/assets/test_out.jpg ADDED
OminiControl/assets/tshirt.jpg ADDED

Git LFS Details

  • SHA256: cb1803315765302113a9e7a64dedd4ecba2672028cf093cbc33ef2edd2247c39
  • Pointer size: 131 Bytes
  • Size of remote file: 301 kB
OminiControl/assets/vase.jpg ADDED
OminiControl/assets/vase_hq.jpg ADDED

Git LFS Details

  • SHA256: 279905e32116792f118802d23b0d96629d98ccbdac9e704e65eaf2e98c752679
  • Pointer size: 132 Bytes
  • Size of remote file: 2.9 MB
OminiControl/examples/inpainting.ipynb ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "\n",
11
+ "os.chdir(\"..\")"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import torch\n",
21
+ "from diffusers.pipelines import FluxPipeline\n",
22
+ "from src.flux.condition import Condition\n",
23
+ "from PIL import Image\n",
24
+ "\n",
25
+ "from src.flux.generate import generate, seed_everything"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "pipe = FluxPipeline.from_pretrained(\n",
35
+ " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
36
+ ")\n",
37
+ "pipe = pipe.to(\"cuda\")"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "pipe.load_lora_weights(\n",
47
+ " \"Yuanshi/OminiControl\",\n",
48
+ " weight_name=f\"experimental/fill.safetensors\",\n",
49
+ " adapter_name=\"fill\",\n",
50
+ ")"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "image = Image.open(\"assets/monalisa.jpg\").convert(\"RGB\").resize((512, 512))\n",
60
+ "\n",
61
+ "masked_image = image.copy()\n",
62
+ "masked_image.paste((0, 0, 0), (128, 100, 384, 220))\n",
63
+ "\n",
64
+ "condition = Condition(\"fill\", masked_image)\n",
65
+ "\n",
66
+ "seed_everything()\n",
67
+ "result_img = generate(\n",
68
+ " pipe,\n",
69
+ " prompt=\"The Mona Lisa is wearing a white VR headset with 'Omini' written on it.\",\n",
70
+ " conditions=[condition],\n",
71
+ ").images[0]\n",
72
+ "\n",
73
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
74
+ "concat_image.paste(image, (0, 0))\n",
75
+ "concat_image.paste(condition.condition, (512, 0))\n",
76
+ "concat_image.paste(result_img, (1024, 0))\n",
77
+ "concat_image"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "image = Image.open(\"assets/book.jpg\").convert(\"RGB\").resize((512, 512))\n",
87
+ "\n",
88
+ "w, h, min_dim = image.size + (min(image.size),)\n",
89
+ "image = image.crop(\n",
90
+ " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n",
91
+ ").resize((512, 512))\n",
92
+ "\n",
93
+ "\n",
94
+ "masked_image = image.copy()\n",
95
+ "masked_image.paste((0, 0, 0), (150, 150, 350, 250))\n",
96
+ "masked_image.paste((0, 0, 0), (200, 380, 320, 420))\n",
97
+ "\n",
98
+ "condition = Condition(\"fill\", masked_image)\n",
99
+ "\n",
100
+ "seed_everything()\n",
101
+ "result_img = generate(\n",
102
+ " pipe,\n",
103
+ " prompt=\"A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.\",\n",
104
+ " conditions=[condition],\n",
105
+ ").images[0]\n",
106
+ "\n",
107
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
108
+ "concat_image.paste(image, (0, 0))\n",
109
+ "concat_image.paste(condition.condition, (512, 0))\n",
110
+ "concat_image.paste(result_img, (1024, 0))\n",
111
+ "concat_image"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": []
120
+ }
121
+ ],
122
+ "metadata": {
123
+ "kernelspec": {
124
+ "display_name": "base",
125
+ "language": "python",
126
+ "name": "python3"
127
+ },
128
+ "language_info": {
129
+ "codemirror_mode": {
130
+ "name": "ipython",
131
+ "version": 3
132
+ },
133
+ "file_extension": ".py",
134
+ "mimetype": "text/x-python",
135
+ "name": "python",
136
+ "nbconvert_exporter": "python",
137
+ "pygments_lexer": "ipython3",
138
+ "version": "3.12.7"
139
+ }
140
+ },
141
+ "nbformat": 4,
142
+ "nbformat_minor": 2
143
+ }
OminiControl/examples/spatial.ipynb ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "\n",
11
+ "os.chdir(\"..\")"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import torch\n",
21
+ "from diffusers.pipelines import FluxPipeline\n",
22
+ "from src.flux.condition import Condition\n",
23
+ "from PIL import Image\n",
24
+ "\n",
25
+ "from src.flux.generate import generate, seed_everything"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "pipe = FluxPipeline.from_pretrained(\n",
35
+ " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
36
+ ")\n",
37
+ "pipe = pipe.to(\"cuda\")"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "for condition_type in [\"canny\", \"depth\", \"coloring\", \"deblurring\"]:\n",
47
+ " pipe.load_lora_weights(\n",
48
+ " \"Yuanshi/OminiControl\",\n",
49
+ " weight_name=f\"experimental/{condition_type}.safetensors\",\n",
50
+ " adapter_name=condition_type,\n",
51
+ " )"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "image = Image.open(\"assets/coffee.png\").convert(\"RGB\")\n",
61
+ "\n",
62
+ "w, h, min_dim = image.size + (min(image.size),)\n",
63
+ "image = image.crop(\n",
64
+ " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n",
65
+ ").resize((512, 512))\n",
66
+ "\n",
67
+ "prompt = \"In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table.\""
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "condition = Condition(\"canny\", image)\n",
77
+ "\n",
78
+ "seed_everything()\n",
79
+ "\n",
80
+ "result_img = generate(\n",
81
+ " pipe,\n",
82
+ " prompt=prompt,\n",
83
+ " conditions=[condition],\n",
84
+ ").images[0]\n",
85
+ "\n",
86
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
87
+ "concat_image.paste(image, (0, 0))\n",
88
+ "concat_image.paste(condition.condition, (512, 0))\n",
89
+ "concat_image.paste(result_img, (1024, 0))\n",
90
+ "concat_image"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "condition = Condition(\"depth\", image)\n",
100
+ "\n",
101
+ "seed_everything()\n",
102
+ "\n",
103
+ "result_img = generate(\n",
104
+ " pipe,\n",
105
+ " prompt=prompt,\n",
106
+ " conditions=[condition],\n",
107
+ ").images[0]\n",
108
+ "\n",
109
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
110
+ "concat_image.paste(image, (0, 0))\n",
111
+ "concat_image.paste(condition.condition, (512, 0))\n",
112
+ "concat_image.paste(result_img, (1024, 0))\n",
113
+ "concat_image"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "condition = Condition(\"deblurring\", image)\n",
123
+ "\n",
124
+ "seed_everything()\n",
125
+ "\n",
126
+ "result_img = generate(\n",
127
+ " pipe,\n",
128
+ " prompt=prompt,\n",
129
+ " conditions=[condition],\n",
130
+ ").images[0]\n",
131
+ "\n",
132
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
133
+ "concat_image.paste(image, (0, 0))\n",
134
+ "concat_image.paste(condition.condition, (512, 0))\n",
135
+ "concat_image.paste(result_img, (1024, 0))\n",
136
+ "concat_image"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "condition = Condition(\"coloring\", image)\n",
146
+ "\n",
147
+ "seed_everything()\n",
148
+ "\n",
149
+ "result_img = generate(\n",
150
+ " pipe,\n",
151
+ " prompt=prompt,\n",
152
+ " conditions=[condition],\n",
153
+ ").images[0]\n",
154
+ "\n",
155
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
156
+ "concat_image.paste(image, (0, 0))\n",
157
+ "concat_image.paste(condition.condition, (512, 0))\n",
158
+ "concat_image.paste(result_img, (1024, 0))\n",
159
+ "concat_image"
160
+ ]
161
+ }
162
+ ],
163
+ "metadata": {
164
+ "kernelspec": {
165
+ "display_name": "base",
166
+ "language": "python",
167
+ "name": "python3"
168
+ },
169
+ "language_info": {
170
+ "codemirror_mode": {
171
+ "name": "ipython",
172
+ "version": 3
173
+ },
174
+ "file_extension": ".py",
175
+ "mimetype": "text/x-python",
176
+ "name": "python",
177
+ "nbconvert_exporter": "python",
178
+ "pygments_lexer": "ipython3",
179
+ "version": "3.12.7"
180
+ }
181
+ },
182
+ "nbformat": 4,
183
+ "nbformat_minor": 2
184
+ }
OminiControl/examples/subject.ipynb ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "\n",
11
+ "os.chdir(\"..\")"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import torch\n",
21
+ "from diffusers.pipelines import FluxPipeline\n",
22
+ "from src.flux.condition import Condition\n",
23
+ "from PIL import Image\n",
24
+ "\n",
25
+ "from src.flux.generate import generate, seed_everything"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "pipe = FluxPipeline.from_pretrained(\n",
35
+ " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
36
+ ")\n",
37
+ "pipe = pipe.to(\"cuda\")\n",
38
+ "pipe.load_lora_weights(\n",
39
+ " \"Yuanshi/OminiControl\",\n",
40
+ " weight_name=f\"omini/subject_512.safetensors\",\n",
41
+ " adapter_name=\"subject\",\n",
42
+ ")"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
52
+ "\n",
53
+ "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
54
+ "\n",
55
+ "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
56
+ "\n",
57
+ "\n",
58
+ "seed_everything(0)\n",
59
+ "\n",
60
+ "result_img = generate(\n",
61
+ " pipe,\n",
62
+ " prompt=prompt,\n",
63
+ " conditions=[condition],\n",
64
+ " num_inference_steps=8,\n",
65
+ " height=512,\n",
66
+ " width=512,\n",
67
+ ").images[0]\n",
68
+ "\n",
69
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
70
+ "concat_image.paste(image, (0, 0))\n",
71
+ "concat_image.paste(result_img, (512, 0))\n",
72
+ "concat_image"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
82
+ "\n",
83
+ "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
84
+ "\n",
85
+ "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
86
+ "\n",
87
+ "\n",
88
+ "seed_everything()\n",
89
+ "\n",
90
+ "result_img = generate(\n",
91
+ " pipe,\n",
92
+ " prompt=prompt,\n",
93
+ " conditions=[condition],\n",
94
+ " num_inference_steps=8,\n",
95
+ " height=512,\n",
96
+ " width=512,\n",
97
+ ").images[0]\n",
98
+ "\n",
99
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
100
+ "concat_image.paste(condition.condition, (0, 0))\n",
101
+ "concat_image.paste(result_img, (512, 0))\n",
102
+ "concat_image"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
112
+ "\n",
113
+ "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
114
+ "\n",
115
+ "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
116
+ "\n",
117
+ "seed_everything()\n",
118
+ "\n",
119
+ "result_img = generate(\n",
120
+ " pipe,\n",
121
+ " prompt=prompt,\n",
122
+ " conditions=[condition],\n",
123
+ " num_inference_steps=8,\n",
124
+ " height=512,\n",
125
+ " width=512,\n",
126
+ ").images[0]\n",
127
+ "\n",
128
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
129
+ "concat_image.paste(condition.condition, (0, 0))\n",
130
+ "concat_image.paste(result_img, (512, 0))\n",
131
+ "concat_image"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
141
+ "\n",
142
+ "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
143
+ "\n",
144
+ "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
145
+ "\n",
146
+ "seed_everything()\n",
147
+ "\n",
148
+ "result_img = generate(\n",
149
+ " pipe,\n",
150
+ " prompt=prompt,\n",
151
+ " conditions=[condition],\n",
152
+ " num_inference_steps=8,\n",
153
+ " height=512,\n",
154
+ " width=512,\n",
155
+ ").images[0]\n",
156
+ "\n",
157
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
158
+ "concat_image.paste(condition.condition, (0, 0))\n",
159
+ "concat_image.paste(result_img, (512, 0))\n",
160
+ "concat_image"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
170
+ "\n",
171
+ "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
172
+ "\n",
173
+ "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
174
+ "\n",
175
+ "seed_everything()\n",
176
+ "\n",
177
+ "result_img = generate(\n",
178
+ " pipe,\n",
179
+ " prompt=prompt,\n",
180
+ " conditions=[condition],\n",
181
+ " num_inference_steps=8,\n",
182
+ " height=512,\n",
183
+ " width=512,\n",
184
+ ").images[0]\n",
185
+ "\n",
186
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
187
+ "concat_image.paste(condition.condition, (0, 0))\n",
188
+ "concat_image.paste(result_img, (512, 0))\n",
189
+ "concat_image"
190
+ ]
191
+ }
192
+ ],
193
+ "metadata": {
194
+ "kernelspec": {
195
+ "display_name": "base",
196
+ "language": "python",
197
+ "name": "python3"
198
+ },
199
+ "language_info": {
200
+ "codemirror_mode": {
201
+ "name": "ipython",
202
+ "version": 3
203
+ },
204
+ "file_extension": ".py",
205
+ "mimetype": "text/x-python",
206
+ "name": "python",
207
+ "nbconvert_exporter": "python",
208
+ "pygments_lexer": "ipython3",
209
+ "version": "3.12.7"
210
+ }
211
+ },
212
+ "nbformat": 4,
213
+ "nbformat_minor": 2
214
+ }
OminiControl/examples/subject_1024.ipynb ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "\n",
11
+ "os.chdir(\"..\")"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import torch\n",
21
+ "from diffusers.pipelines import FluxPipeline\n",
22
+ "from src.flux.condition import Condition\n",
23
+ "from PIL import Image\n",
24
+ "\n",
25
+ "from src.flux.generate import generate, seed_everything"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "pipe = FluxPipeline.from_pretrained(\n",
35
+ " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
36
+ ")\n",
37
+ "pipe = pipe.to(\"cuda\")\n",
38
+ "pipe.load_lora_weights(\n",
39
+ " \"Yuanshi/OminiControl\",\n",
40
+ " weight_name=f\"omini/subject_1024_beta.safetensors\",\n",
41
+ " adapter_name=\"subject\",\n",
42
+ ")"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
52
+ "\n",
53
+ "condition = Condition(\"subject\", image)\n",
54
+ "\n",
55
+ "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
56
+ "\n",
57
+ "\n",
58
+ "seed_everything(0)\n",
59
+ "\n",
60
+ "result_img = generate(\n",
61
+ " pipe,\n",
62
+ " prompt=prompt,\n",
63
+ " conditions=[condition],\n",
64
+ " num_inference_steps=8,\n",
65
+ " height=1024,\n",
66
+ " width=1024,\n",
67
+ ").images[0]\n",
68
+ "\n",
69
+ "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
70
+ "concat_image.paste(image, (0, 0))\n",
71
+ "concat_image.paste(result_img, (512, 0))\n",
72
+ "concat_image"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
82
+ "\n",
83
+ "condition = Condition(\"subject\", image)\n",
84
+ "\n",
85
+ "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
86
+ "\n",
87
+ "\n",
88
+ "seed_everything(0)\n",
89
+ "\n",
90
+ "result_img = generate(\n",
91
+ " pipe,\n",
92
+ " prompt=prompt,\n",
93
+ " conditions=[condition],\n",
94
+ " num_inference_steps=8,\n",
95
+ " height=1024,\n",
96
+ " width=1024,\n",
97
+ ").images[0]\n",
98
+ "\n",
99
+ "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
100
+ "concat_image.paste(image, (0, 0))\n",
101
+ "concat_image.paste(result_img, (512, 0))\n",
102
+ "concat_image"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
112
+ "\n",
113
+ "condition = Condition(\"subject\", image)\n",
114
+ "\n",
115
+ "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
116
+ "\n",
117
+ "seed_everything()\n",
118
+ "\n",
119
+ "result_img = generate(\n",
120
+ " pipe,\n",
121
+ " prompt=prompt,\n",
122
+ " conditions=[condition],\n",
123
+ " num_inference_steps=8,\n",
124
+ " height=1024,\n",
125
+ " width=1024,\n",
126
+ ").images[0]\n",
127
+ "\n",
128
+ "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
129
+ "concat_image.paste(image, (0, 0))\n",
130
+ "concat_image.paste(result_img, (512, 0))\n",
131
+ "concat_image"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
141
+ "\n",
142
+ "condition = Condition(\"subject\", image)\n",
143
+ "\n",
144
+ "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
145
+ "\n",
146
+ "seed_everything(0)\n",
147
+ "\n",
148
+ "result_img = generate(\n",
149
+ " pipe,\n",
150
+ " prompt=prompt,\n",
151
+ " conditions=[condition],\n",
152
+ " num_inference_steps=8,\n",
153
+ " height=1024,\n",
154
+ " width=1024,\n",
155
+ ").images[0]\n",
156
+ "\n",
157
+ "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
158
+ "concat_image.paste(image, (0, 0))\n",
159
+ "concat_image.paste(result_img, (512, 0))\n",
160
+ "concat_image"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
170
+ "\n",
171
+ "condition = Condition(\"subject\", image)\n",
172
+ "\n",
173
+ "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
174
+ "\n",
175
+ "seed_everything()\n",
176
+ "\n",
177
+ "result_img = generate(\n",
178
+ " pipe,\n",
179
+ " prompt=prompt,\n",
180
+ " conditions=[condition],\n",
181
+ " num_inference_steps=8,\n",
182
+ " height=1024,\n",
183
+ " width=1024,\n",
184
+ ").images[0]\n",
185
+ "\n",
186
+ "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
187
+ "concat_image.paste(image, (0, 0))\n",
188
+ "concat_image.paste(result_img, (512, 0))\n",
189
+ "concat_image"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": []
198
+ }
199
+ ],
200
+ "metadata": {
201
+ "kernelspec": {
202
+ "display_name": "Python 3 (ipykernel)",
203
+ "language": "python",
204
+ "name": "python3"
205
+ },
206
+ "language_info": {
207
+ "codemirror_mode": {
208
+ "name": "ipython",
209
+ "version": 3
210
+ },
211
+ "file_extension": ".py",
212
+ "mimetype": "text/x-python",
213
+ "name": "python",
214
+ "nbconvert_exporter": "python",
215
+ "pygments_lexer": "ipython3",
216
+ "version": "3.9.21"
217
+ }
218
+ },
219
+ "nbformat": 4,
220
+ "nbformat_minor": 2
221
+ }
OminiControl/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ diffusers
3
+ peft
4
+ opencv-python
5
+ protobuf
6
+ sentencepiece
7
+ gradio
8
+ jupyter
9
+ torchao
OminiControl/src/flux/block.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Union, Optional, Dict, Any, Callable
3
+ from diffusers.models.attention_processor import Attention, F
4
+ from .lora_controller import enable_lora
5
+
6
+
7
+ def attn_forward(
8
+ attn: Attention,
9
+ hidden_states: torch.FloatTensor,
10
+ encoder_hidden_states: torch.FloatTensor = None,
11
+ condition_latents: torch.FloatTensor = None,
12
+ attention_mask: Optional[torch.FloatTensor] = None,
13
+ image_rotary_emb: Optional[torch.Tensor] = None,
14
+ cond_rotary_emb: Optional[torch.Tensor] = None,
15
+ model_config: Optional[Dict[str, Any]] = {},
16
+ ) -> torch.FloatTensor:
17
+ batch_size, _, _ = (
18
+ hidden_states.shape
19
+ if encoder_hidden_states is None
20
+ else encoder_hidden_states.shape
21
+ )
22
+
23
+ with enable_lora(
24
+ (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
25
+ ):
26
+ # `sample` projections.
27
+ query = attn.to_q(hidden_states)
28
+ key = attn.to_k(hidden_states)
29
+ value = attn.to_v(hidden_states)
30
+
31
+ inner_dim = key.shape[-1]
32
+ head_dim = inner_dim // attn.heads
33
+
34
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
35
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
36
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
37
+
38
+ if attn.norm_q is not None:
39
+ query = attn.norm_q(query)
40
+ if attn.norm_k is not None:
41
+ key = attn.norm_k(key)
42
+
43
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
44
+ if encoder_hidden_states is not None:
45
+ # `context` projections.
46
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
47
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
48
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
49
+
50
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
51
+ batch_size, -1, attn.heads, head_dim
52
+ ).transpose(1, 2)
53
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
54
+ batch_size, -1, attn.heads, head_dim
55
+ ).transpose(1, 2)
56
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
57
+ batch_size, -1, attn.heads, head_dim
58
+ ).transpose(1, 2)
59
+
60
+ if attn.norm_added_q is not None:
61
+ encoder_hidden_states_query_proj = attn.norm_added_q(
62
+ encoder_hidden_states_query_proj
63
+ )
64
+ if attn.norm_added_k is not None:
65
+ encoder_hidden_states_key_proj = attn.norm_added_k(
66
+ encoder_hidden_states_key_proj
67
+ )
68
+
69
+ # attention
70
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
71
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
72
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
73
+
74
+ if image_rotary_emb is not None:
75
+ from diffusers.models.embeddings import apply_rotary_emb
76
+
77
+ query = apply_rotary_emb(query, image_rotary_emb)
78
+ key = apply_rotary_emb(key, image_rotary_emb)
79
+
80
+ if condition_latents is not None:
81
+ cond_query = attn.to_q(condition_latents)
82
+ cond_key = attn.to_k(condition_latents)
83
+ cond_value = attn.to_v(condition_latents)
84
+
85
+ cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
86
+ 1, 2
87
+ )
88
+ cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
89
+ cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
90
+ 1, 2
91
+ )
92
+ if attn.norm_q is not None:
93
+ cond_query = attn.norm_q(cond_query)
94
+ if attn.norm_k is not None:
95
+ cond_key = attn.norm_k(cond_key)
96
+
97
+ if cond_rotary_emb is not None:
98
+ cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
99
+ cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
100
+
101
+ if condition_latents is not None:
102
+ query = torch.cat([query, cond_query], dim=2)
103
+ key = torch.cat([key, cond_key], dim=2)
104
+ value = torch.cat([value, cond_value], dim=2)
105
+
106
+ if not model_config.get("union_cond_attn", True):
107
+ # If we don't want to use the union condition attention, we need to mask the attention
108
+ # between the hidden states and the condition latents
109
+ attention_mask = torch.ones(
110
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
111
+ )
112
+ condition_n = cond_query.shape[2]
113
+ attention_mask[-condition_n:, :-condition_n] = False
114
+ attention_mask[:-condition_n, -condition_n:] = False
115
+ elif model_config.get("independent_condition", False):
116
+ attention_mask = torch.ones(
117
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
118
+ )
119
+ condition_n = cond_query.shape[2]
120
+ attention_mask[-condition_n:, :-condition_n] = False
121
+ if hasattr(attn, "c_factor"):
122
+ attention_mask = torch.zeros(
123
+ query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
124
+ )
125
+ condition_n = cond_query.shape[2]
126
+ bias = torch.log(attn.c_factor[0])
127
+ attention_mask[-condition_n:, :-condition_n] = bias
128
+ attention_mask[:-condition_n, -condition_n:] = bias
129
+ hidden_states = F.scaled_dot_product_attention(
130
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
131
+ )
132
+ hidden_states = hidden_states.transpose(1, 2).reshape(
133
+ batch_size, -1, attn.heads * head_dim
134
+ )
135
+ hidden_states = hidden_states.to(query.dtype)
136
+
137
+ if encoder_hidden_states is not None:
138
+ if condition_latents is not None:
139
+ encoder_hidden_states, hidden_states, condition_latents = (
140
+ hidden_states[:, : encoder_hidden_states.shape[1]],
141
+ hidden_states[
142
+ :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
143
+ ],
144
+ hidden_states[:, -condition_latents.shape[1] :],
145
+ )
146
+ else:
147
+ encoder_hidden_states, hidden_states = (
148
+ hidden_states[:, : encoder_hidden_states.shape[1]],
149
+ hidden_states[:, encoder_hidden_states.shape[1] :],
150
+ )
151
+
152
+ with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
153
+ # linear proj
154
+ hidden_states = attn.to_out[0](hidden_states)
155
+ # dropout
156
+ hidden_states = attn.to_out[1](hidden_states)
157
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
158
+
159
+ if condition_latents is not None:
160
+ condition_latents = attn.to_out[0](condition_latents)
161
+ condition_latents = attn.to_out[1](condition_latents)
162
+
163
+ return (
164
+ (hidden_states, encoder_hidden_states, condition_latents)
165
+ if condition_latents is not None
166
+ else (hidden_states, encoder_hidden_states)
167
+ )
168
+ elif condition_latents is not None:
169
+ # if there are condition_latents, we need to separate the hidden_states and the condition_latents
170
+ hidden_states, condition_latents = (
171
+ hidden_states[:, : -condition_latents.shape[1]],
172
+ hidden_states[:, -condition_latents.shape[1] :],
173
+ )
174
+ return hidden_states, condition_latents
175
+ else:
176
+ return hidden_states
177
+
178
+
179
+ def block_forward(
180
+ self,
181
+ hidden_states: torch.FloatTensor,
182
+ encoder_hidden_states: torch.FloatTensor,
183
+ condition_latents: torch.FloatTensor,
184
+ temb: torch.FloatTensor,
185
+ cond_temb: torch.FloatTensor,
186
+ cond_rotary_emb=None,
187
+ image_rotary_emb=None,
188
+ model_config: Optional[Dict[str, Any]] = {},
189
+ ):
190
+ use_cond = condition_latents is not None
191
+ with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
192
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
193
+ hidden_states, emb=temb
194
+ )
195
+
196
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
197
+ self.norm1_context(encoder_hidden_states, emb=temb)
198
+ )
199
+
200
+ if use_cond:
201
+ (
202
+ norm_condition_latents,
203
+ cond_gate_msa,
204
+ cond_shift_mlp,
205
+ cond_scale_mlp,
206
+ cond_gate_mlp,
207
+ ) = self.norm1(condition_latents, emb=cond_temb)
208
+
209
+ # Attention.
210
+ result = attn_forward(
211
+ self.attn,
212
+ model_config=model_config,
213
+ hidden_states=norm_hidden_states,
214
+ encoder_hidden_states=norm_encoder_hidden_states,
215
+ condition_latents=norm_condition_latents if use_cond else None,
216
+ image_rotary_emb=image_rotary_emb,
217
+ cond_rotary_emb=cond_rotary_emb if use_cond else None,
218
+ )
219
+ attn_output, context_attn_output = result[:2]
220
+ cond_attn_output = result[2] if use_cond else None
221
+
222
+ # Process attention outputs for the `hidden_states`.
223
+ # 1. hidden_states
224
+ attn_output = gate_msa.unsqueeze(1) * attn_output
225
+ hidden_states = hidden_states + attn_output
226
+ # 2. encoder_hidden_states
227
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
228
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
229
+ # 3. condition_latents
230
+ if use_cond:
231
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
232
+ condition_latents = condition_latents + cond_attn_output
233
+ if model_config.get("add_cond_attn", False):
234
+ hidden_states += cond_attn_output
235
+
236
+ # LayerNorm + MLP.
237
+ # 1. hidden_states
238
+ norm_hidden_states = self.norm2(hidden_states)
239
+ norm_hidden_states = (
240
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
241
+ )
242
+ # 2. encoder_hidden_states
243
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
244
+ norm_encoder_hidden_states = (
245
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
246
+ )
247
+ # 3. condition_latents
248
+ if use_cond:
249
+ norm_condition_latents = self.norm2(condition_latents)
250
+ norm_condition_latents = (
251
+ norm_condition_latents * (1 + cond_scale_mlp[:, None])
252
+ + cond_shift_mlp[:, None]
253
+ )
254
+
255
+ # Feed-forward.
256
+ with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
257
+ # 1. hidden_states
258
+ ff_output = self.ff(norm_hidden_states)
259
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
260
+ # 2. encoder_hidden_states
261
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
262
+ context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
263
+ # 3. condition_latents
264
+ if use_cond:
265
+ cond_ff_output = self.ff(norm_condition_latents)
266
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
267
+
268
+ # Process feed-forward outputs.
269
+ hidden_states = hidden_states + ff_output
270
+ encoder_hidden_states = encoder_hidden_states + context_ff_output
271
+ if use_cond:
272
+ condition_latents = condition_latents + cond_ff_output
273
+
274
+ # Clip to avoid overflow.
275
+ if encoder_hidden_states.dtype == torch.float16:
276
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
277
+
278
+ return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
279
+
280
+
281
+ def single_block_forward(
282
+ self,
283
+ hidden_states: torch.FloatTensor,
284
+ temb: torch.FloatTensor,
285
+ image_rotary_emb=None,
286
+ condition_latents: torch.FloatTensor = None,
287
+ cond_temb: torch.FloatTensor = None,
288
+ cond_rotary_emb=None,
289
+ model_config: Optional[Dict[str, Any]] = {},
290
+ ):
291
+
292
+ using_cond = condition_latents is not None
293
+ residual = hidden_states
294
+ with enable_lora(
295
+ (
296
+ self.norm.linear,
297
+ self.proj_mlp,
298
+ ),
299
+ model_config.get("latent_lora", False),
300
+ ):
301
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
302
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
303
+ if using_cond:
304
+ residual_cond = condition_latents
305
+ norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
306
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
307
+
308
+ attn_output = attn_forward(
309
+ self.attn,
310
+ model_config=model_config,
311
+ hidden_states=norm_hidden_states,
312
+ image_rotary_emb=image_rotary_emb,
313
+ **(
314
+ {
315
+ "condition_latents": norm_condition_latents,
316
+ "cond_rotary_emb": cond_rotary_emb if using_cond else None,
317
+ }
318
+ if using_cond
319
+ else {}
320
+ ),
321
+ )
322
+ if using_cond:
323
+ attn_output, cond_attn_output = attn_output
324
+
325
+ with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
326
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
327
+ gate = gate.unsqueeze(1)
328
+ hidden_states = gate * self.proj_out(hidden_states)
329
+ hidden_states = residual + hidden_states
330
+ if using_cond:
331
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
332
+ cond_gate = cond_gate.unsqueeze(1)
333
+ condition_latents = cond_gate * self.proj_out(condition_latents)
334
+ condition_latents = residual_cond + condition_latents
335
+
336
+ if hidden_states.dtype == torch.float16:
337
+ hidden_states = hidden_states.clip(-65504, 65504)
338
+
339
+ return hidden_states if not using_cond else (hidden_states, condition_latents)
OminiControl/src/flux/condition.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Union, List, Tuple
3
+ from diffusers.pipelines import FluxPipeline
4
+ from PIL import Image, ImageFilter
5
+ import numpy as np
6
+ import cv2
7
+
8
+ from .pipeline_tools import encode_images
9
+
10
+ condition_dict = {
11
+ "depth": 0,
12
+ "canny": 1,
13
+ "subject": 4,
14
+ "coloring": 6,
15
+ "deblurring": 7,
16
+ "depth_pred": 8,
17
+ "fill": 9,
18
+ "sr": 10,
19
+ "cartoon": 11,
20
+ }
21
+
22
+
23
+ class Condition(object):
24
+ def __init__(
25
+ self,
26
+ condition_type: str,
27
+ raw_img: Union[Image.Image, torch.Tensor] = None,
28
+ condition: Union[Image.Image, torch.Tensor] = None,
29
+ mask=None,
30
+ position_delta=None,
31
+ position_scale=1.0,
32
+ ) -> None:
33
+ self.condition_type = condition_type
34
+ assert raw_img is not None or condition is not None
35
+ if raw_img is not None:
36
+ self.condition = self.get_condition(condition_type, raw_img)
37
+ else:
38
+ self.condition = condition
39
+ self.position_delta = position_delta
40
+ self.position_scale = position_scale
41
+ # TODO: Add mask support
42
+ assert mask is None, "Mask not supported yet"
43
+
44
+ def get_condition(
45
+ self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
46
+ ) -> Union[Image.Image, torch.Tensor]:
47
+ """
48
+ Returns the condition image.
49
+ """
50
+ if condition_type == "depth":
51
+ from transformers import pipeline
52
+
53
+ depth_pipe = pipeline(
54
+ task="depth-estimation",
55
+ model="LiheYoung/depth-anything-small-hf",
56
+ device="cuda",
57
+ )
58
+ source_image = raw_img.convert("RGB")
59
+ condition_img = depth_pipe(source_image)["depth"].convert("RGB")
60
+ return condition_img
61
+ elif condition_type == "canny":
62
+ img = np.array(raw_img)
63
+ edges = cv2.Canny(img, 100, 200)
64
+ edges = Image.fromarray(edges).convert("RGB")
65
+ return edges
66
+ elif condition_type == "subject":
67
+ return raw_img
68
+ elif condition_type == "coloring":
69
+ return raw_img.convert("L").convert("RGB")
70
+ elif condition_type == "deblurring":
71
+ condition_image = (
72
+ raw_img.convert("RGB")
73
+ .filter(ImageFilter.GaussianBlur(10))
74
+ .convert("RGB")
75
+ )
76
+ return condition_image
77
+ elif condition_type == "fill":
78
+ return raw_img.convert("RGB")
79
+ elif condition_type == "cartoon":
80
+ return raw_img.convert("RGB")
81
+ return self.condition
82
+
83
+ @property
84
+ def type_id(self) -> int:
85
+ """
86
+ Returns the type id of the condition.
87
+ """
88
+ return condition_dict[self.condition_type]
89
+
90
+ @classmethod
91
+ def get_type_id(cls, condition_type: str) -> int:
92
+ """
93
+ Returns the type id of the condition.
94
+ """
95
+ return condition_dict[condition_type]
96
+
97
+ def encode(
98
+ self, pipe: FluxPipeline, empty: bool = False
99
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
100
+ """
101
+ Encodes the condition into tokens, ids and type_id.
102
+ """
103
+ if self.condition_type in [
104
+ "depth",
105
+ "canny",
106
+ "subject",
107
+ "coloring",
108
+ "deblurring",
109
+ "depth_pred",
110
+ "fill",
111
+ "sr",
112
+ "cartoon",
113
+ ]:
114
+ if empty:
115
+ # make the condition black
116
+ e_condition = Image.new("RGB", self.condition.size, (0, 0, 0))
117
+ e_condition = e_condition.convert("RGB")
118
+ tokens, ids = encode_images(pipe, e_condition)
119
+ else:
120
+ tokens, ids = encode_images(pipe, self.condition)
121
+ tokens, ids = encode_images(pipe, self.condition)
122
+ else:
123
+ raise NotImplementedError(
124
+ f"Condition type {self.condition_type} not implemented"
125
+ )
126
+ if self.position_delta is None and self.condition_type == "subject":
127
+ self.position_delta = [0, -self.condition.size[0] // 16]
128
+ if self.position_delta is not None:
129
+ ids[:, 1] += self.position_delta[0]
130
+ ids[:, 2] += self.position_delta[1]
131
+ if self.position_scale != 1.0:
132
+ scale_bias = (self.position_scale - 1.0) / 2
133
+ ids[:, 1] *= self.position_scale
134
+ ids[:, 2] *= self.position_scale
135
+ ids[:, 1] += scale_bias
136
+ ids[:, 2] += scale_bias
137
+ type_id = torch.ones_like(ids[:, :1]) * self.type_id
138
+ return tokens, ids, type_id
OminiControl/src/flux/generate.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import yaml, os
3
+ from diffusers.pipelines import FluxPipeline
4
+ from typing import List, Union, Optional, Dict, Any, Callable
5
+ from .transformer import tranformer_forward
6
+ from .condition import Condition
7
+
8
+ from diffusers.pipelines.flux.pipeline_flux import (
9
+ FluxPipelineOutput,
10
+ calculate_shift,
11
+ retrieve_timesteps,
12
+ np,
13
+ )
14
+
15
+
16
+ def get_config(config_path: str = None):
17
+ config_path = config_path or os.environ.get("XFL_CONFIG")
18
+ if not config_path:
19
+ return {}
20
+ with open(config_path, "r") as f:
21
+ config = yaml.safe_load(f)
22
+ return config
23
+
24
+
25
+ def prepare_params(
26
+ prompt: Union[str, List[str]] = None,
27
+ prompt_2: Optional[Union[str, List[str]]] = None,
28
+ height: Optional[int] = 512,
29
+ width: Optional[int] = 512,
30
+ num_inference_steps: int = 28,
31
+ timesteps: List[int] = None,
32
+ guidance_scale: float = 3.5,
33
+ num_images_per_prompt: Optional[int] = 1,
34
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
35
+ latents: Optional[torch.FloatTensor] = None,
36
+ prompt_embeds: Optional[torch.FloatTensor] = None,
37
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
38
+ output_type: Optional[str] = "pil",
39
+ return_dict: bool = True,
40
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
41
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
42
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
43
+ max_sequence_length: int = 512,
44
+ **kwargs: dict,
45
+ ):
46
+ return (
47
+ prompt,
48
+ prompt_2,
49
+ height,
50
+ width,
51
+ num_inference_steps,
52
+ timesteps,
53
+ guidance_scale,
54
+ num_images_per_prompt,
55
+ generator,
56
+ latents,
57
+ prompt_embeds,
58
+ pooled_prompt_embeds,
59
+ output_type,
60
+ return_dict,
61
+ joint_attention_kwargs,
62
+ callback_on_step_end,
63
+ callback_on_step_end_tensor_inputs,
64
+ max_sequence_length,
65
+ )
66
+
67
+
68
+ def seed_everything(seed: int = 42):
69
+ torch.backends.cudnn.deterministic = True
70
+ torch.manual_seed(seed)
71
+ np.random.seed(seed)
72
+
73
+
74
+ @torch.no_grad()
75
+ def generate(
76
+ pipeline: FluxPipeline,
77
+ conditions: List[Condition] = None,
78
+ config_path: str = None,
79
+ model_config: Optional[Dict[str, Any]] = {},
80
+ condition_scale: float = 1.0,
81
+ default_lora: bool = False,
82
+ image_guidance_scale: float = 1.0,
83
+ **params: dict,
84
+ ):
85
+ model_config = model_config or get_config(config_path).get("model", {})
86
+ if condition_scale != 1:
87
+ for name, module in pipeline.transformer.named_modules():
88
+ if not name.endswith(".attn"):
89
+ continue
90
+ module.c_factor = torch.ones(1, 1) * condition_scale
91
+
92
+ self = pipeline
93
+ (
94
+ prompt,
95
+ prompt_2,
96
+ height,
97
+ width,
98
+ num_inference_steps,
99
+ timesteps,
100
+ guidance_scale,
101
+ num_images_per_prompt,
102
+ generator,
103
+ latents,
104
+ prompt_embeds,
105
+ pooled_prompt_embeds,
106
+ output_type,
107
+ return_dict,
108
+ joint_attention_kwargs,
109
+ callback_on_step_end,
110
+ callback_on_step_end_tensor_inputs,
111
+ max_sequence_length,
112
+ ) = prepare_params(**params)
113
+
114
+ height = height or self.default_sample_size * self.vae_scale_factor
115
+ width = width or self.default_sample_size * self.vae_scale_factor
116
+
117
+ # 1. Check inputs. Raise error if not correct
118
+ self.check_inputs(
119
+ prompt,
120
+ prompt_2,
121
+ height,
122
+ width,
123
+ prompt_embeds=prompt_embeds,
124
+ pooled_prompt_embeds=pooled_prompt_embeds,
125
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
126
+ max_sequence_length=max_sequence_length,
127
+ )
128
+
129
+ self._guidance_scale = guidance_scale
130
+ self._joint_attention_kwargs = joint_attention_kwargs
131
+ self._interrupt = False
132
+
133
+ # 2. Define call parameters
134
+ if prompt is not None and isinstance(prompt, str):
135
+ batch_size = 1
136
+ elif prompt is not None and isinstance(prompt, list):
137
+ batch_size = len(prompt)
138
+ else:
139
+ batch_size = prompt_embeds.shape[0]
140
+
141
+ device = self._execution_device
142
+
143
+ lora_scale = (
144
+ self.joint_attention_kwargs.get("scale", None)
145
+ if self.joint_attention_kwargs is not None
146
+ else None
147
+ )
148
+ (
149
+ prompt_embeds,
150
+ pooled_prompt_embeds,
151
+ text_ids,
152
+ ) = self.encode_prompt(
153
+ prompt=prompt,
154
+ prompt_2=prompt_2,
155
+ prompt_embeds=prompt_embeds,
156
+ pooled_prompt_embeds=pooled_prompt_embeds,
157
+ device=device,
158
+ num_images_per_prompt=num_images_per_prompt,
159
+ max_sequence_length=max_sequence_length,
160
+ lora_scale=lora_scale,
161
+ )
162
+
163
+ # 4. Prepare latent variables
164
+ num_channels_latents = self.transformer.config.in_channels // 4
165
+ latents, latent_image_ids = self.prepare_latents(
166
+ batch_size * num_images_per_prompt,
167
+ num_channels_latents,
168
+ height,
169
+ width,
170
+ prompt_embeds.dtype,
171
+ device,
172
+ generator,
173
+ latents,
174
+ )
175
+
176
+ # 4.1. Prepare conditions
177
+ condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
178
+ use_condition = conditions is not None or []
179
+ if use_condition:
180
+ assert len(conditions) <= 1, "Only one condition is supported for now."
181
+ if not default_lora:
182
+ pipeline.set_adapters(conditions[0].condition_type)
183
+ for condition in conditions:
184
+ tokens, ids, type_id = condition.encode(self)
185
+ condition_latents.append(tokens) # [batch_size, token_n, token_dim]
186
+ condition_ids.append(ids) # [token_n, id_dim(3)]
187
+ condition_type_ids.append(type_id) # [token_n, 1]
188
+ condition_latents = torch.cat(condition_latents, dim=1)
189
+ condition_ids = torch.cat(condition_ids, dim=0)
190
+ condition_type_ids = torch.cat(condition_type_ids, dim=0)
191
+
192
+ # 5. Prepare timesteps
193
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
194
+ image_seq_len = latents.shape[1]
195
+ mu = calculate_shift(
196
+ image_seq_len,
197
+ self.scheduler.config.base_image_seq_len,
198
+ self.scheduler.config.max_image_seq_len,
199
+ self.scheduler.config.base_shift,
200
+ self.scheduler.config.max_shift,
201
+ )
202
+ timesteps, num_inference_steps = retrieve_timesteps(
203
+ self.scheduler,
204
+ num_inference_steps,
205
+ device,
206
+ timesteps,
207
+ sigmas,
208
+ mu=mu,
209
+ )
210
+ num_warmup_steps = max(
211
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
212
+ )
213
+ self._num_timesteps = len(timesteps)
214
+
215
+ # 6. Denoising loop
216
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
217
+ for i, t in enumerate(timesteps):
218
+ if self.interrupt:
219
+ continue
220
+
221
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
222
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
223
+
224
+ # handle guidance
225
+ if self.transformer.config.guidance_embeds:
226
+ guidance = torch.tensor([guidance_scale], device=device)
227
+ guidance = guidance.expand(latents.shape[0])
228
+ else:
229
+ guidance = None
230
+ noise_pred = tranformer_forward(
231
+ self.transformer,
232
+ model_config=model_config,
233
+ # Inputs of the condition (new feature)
234
+ condition_latents=condition_latents if use_condition else None,
235
+ condition_ids=condition_ids if use_condition else None,
236
+ condition_type_ids=condition_type_ids if use_condition else None,
237
+ # Inputs to the original transformer
238
+ hidden_states=latents,
239
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
240
+ timestep=timestep / 1000,
241
+ guidance=guidance,
242
+ pooled_projections=pooled_prompt_embeds,
243
+ encoder_hidden_states=prompt_embeds,
244
+ txt_ids=text_ids,
245
+ img_ids=latent_image_ids,
246
+ joint_attention_kwargs=self.joint_attention_kwargs,
247
+ return_dict=False,
248
+ )[0]
249
+
250
+ if image_guidance_scale != 1.0:
251
+ uncondition_latents = condition.encode(self, empty=True)[0]
252
+ unc_pred = tranformer_forward(
253
+ self.transformer,
254
+ model_config=model_config,
255
+ # Inputs of the condition (new feature)
256
+ condition_latents=uncondition_latents if use_condition else None,
257
+ condition_ids=condition_ids if use_condition else None,
258
+ condition_type_ids=condition_type_ids if use_condition else None,
259
+ # Inputs to the original transformer
260
+ hidden_states=latents,
261
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
262
+ timestep=timestep / 1000,
263
+ guidance=torch.ones_like(guidance),
264
+ pooled_projections=pooled_prompt_embeds,
265
+ encoder_hidden_states=prompt_embeds,
266
+ txt_ids=text_ids,
267
+ img_ids=latent_image_ids,
268
+ joint_attention_kwargs=self.joint_attention_kwargs,
269
+ return_dict=False,
270
+ )[0]
271
+
272
+ noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
273
+
274
+ # compute the previous noisy sample x_t -> x_t-1
275
+ latents_dtype = latents.dtype
276
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
277
+
278
+ if latents.dtype != latents_dtype:
279
+ if torch.backends.mps.is_available():
280
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
281
+ latents = latents.to(latents_dtype)
282
+
283
+ if callback_on_step_end is not None:
284
+ callback_kwargs = {}
285
+ for k in callback_on_step_end_tensor_inputs:
286
+ callback_kwargs[k] = locals()[k]
287
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
288
+
289
+ latents = callback_outputs.pop("latents", latents)
290
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
291
+
292
+ # call the callback, if provided
293
+ if i == len(timesteps) - 1 or (
294
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
295
+ ):
296
+ progress_bar.update()
297
+
298
+ if output_type == "latent":
299
+ image = latents
300
+
301
+ else:
302
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
303
+ latents = (
304
+ latents / self.vae.config.scaling_factor
305
+ ) + self.vae.config.shift_factor
306
+ image = self.vae.decode(latents, return_dict=False)[0]
307
+ image = self.image_processor.postprocess(image, output_type=output_type)
308
+
309
+ # Offload all models
310
+ self.maybe_free_model_hooks()
311
+
312
+ if condition_scale != 1:
313
+ for name, module in pipeline.transformer.named_modules():
314
+ if not name.endswith(".attn"):
315
+ continue
316
+ del module.c_factor
317
+
318
+ if not return_dict:
319
+ return (image,)
320
+
321
+ return FluxPipelineOutput(images=image)
OminiControl/src/flux/lora_controller.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from peft.tuners.tuners_utils import BaseTunerLayer
2
+ from typing import List, Any, Optional, Type
3
+
4
+
5
+ class enable_lora:
6
+ def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
7
+ self.activated: bool = activated
8
+ if activated:
9
+ return
10
+ self.lora_modules: List[BaseTunerLayer] = [
11
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
12
+ ]
13
+ self.scales = [
14
+ {
15
+ active_adapter: lora_module.scaling[active_adapter]
16
+ for active_adapter in lora_module.active_adapters
17
+ }
18
+ for lora_module in self.lora_modules
19
+ ]
20
+
21
+ def __enter__(self) -> None:
22
+ if self.activated:
23
+ return
24
+
25
+ for lora_module in self.lora_modules:
26
+ if not isinstance(lora_module, BaseTunerLayer):
27
+ continue
28
+ lora_module.scale_layer(0)
29
+
30
+ def __exit__(
31
+ self,
32
+ exc_type: Optional[Type[BaseException]],
33
+ exc_val: Optional[BaseException],
34
+ exc_tb: Optional[Any],
35
+ ) -> None:
36
+ if self.activated:
37
+ return
38
+ for i, lora_module in enumerate(self.lora_modules):
39
+ if not isinstance(lora_module, BaseTunerLayer):
40
+ continue
41
+ for active_adapter in lora_module.active_adapters:
42
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
43
+
44
+
45
+ class set_lora_scale:
46
+ def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
47
+ self.lora_modules: List[BaseTunerLayer] = [
48
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
49
+ ]
50
+ self.scales = [
51
+ {
52
+ active_adapter: lora_module.scaling[active_adapter]
53
+ for active_adapter in lora_module.active_adapters
54
+ }
55
+ for lora_module in self.lora_modules
56
+ ]
57
+ self.scale = scale
58
+
59
+ def __enter__(self) -> None:
60
+ for lora_module in self.lora_modules:
61
+ if not isinstance(lora_module, BaseTunerLayer):
62
+ continue
63
+ lora_module.scale_layer(self.scale)
64
+
65
+ def __exit__(
66
+ self,
67
+ exc_type: Optional[Type[BaseException]],
68
+ exc_val: Optional[BaseException],
69
+ exc_tb: Optional[Any],
70
+ ) -> None:
71
+ for i, lora_module in enumerate(self.lora_modules):
72
+ if not isinstance(lora_module, BaseTunerLayer):
73
+ continue
74
+ for active_adapter in lora_module.active_adapters:
75
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
OminiControl/src/flux/pipeline_tools.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.pipelines import FluxPipeline
2
+ from diffusers.utils import logging
3
+ from diffusers.pipelines.flux.pipeline_flux import logger
4
+ from torch import Tensor
5
+
6
+
7
+ def encode_images(pipeline: FluxPipeline, images: Tensor):
8
+ images = pipeline.image_processor.preprocess(images)
9
+ images = images.to(pipeline.device).to(pipeline.dtype)
10
+ images = pipeline.vae.encode(images).latent_dist.sample()
11
+ images = (
12
+ images - pipeline.vae.config.shift_factor
13
+ ) * pipeline.vae.config.scaling_factor
14
+ images_tokens = pipeline._pack_latents(images, *images.shape)
15
+ images_ids = pipeline._prepare_latent_image_ids(
16
+ images.shape[0],
17
+ images.shape[2],
18
+ images.shape[3],
19
+ pipeline.device,
20
+ pipeline.dtype,
21
+ )
22
+ if images_tokens.shape[1] != images_ids.shape[0]:
23
+ images_ids = pipeline._prepare_latent_image_ids(
24
+ images.shape[0],
25
+ images.shape[2] // 2,
26
+ images.shape[3] // 2,
27
+ pipeline.device,
28
+ pipeline.dtype,
29
+ )
30
+ return images_tokens, images_ids
31
+
32
+
33
+ def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512):
34
+ # Turn off warnings (CLIP overflow)
35
+ logger.setLevel(logging.ERROR)
36
+ (
37
+ prompt_embeds,
38
+ pooled_prompt_embeds,
39
+ text_ids,
40
+ ) = pipeline.encode_prompt(
41
+ prompt=prompts,
42
+ prompt_2=None,
43
+ prompt_embeds=None,
44
+ pooled_prompt_embeds=None,
45
+ device=pipeline.device,
46
+ num_images_per_prompt=1,
47
+ max_sequence_length=max_sequence_length,
48
+ lora_scale=None,
49
+ )
50
+ # Turn on warnings
51
+ logger.setLevel(logging.WARNING)
52
+ return prompt_embeds, pooled_prompt_embeds, text_ids
OminiControl/src/flux/transformer.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.pipelines import FluxPipeline
3
+ from typing import List, Union, Optional, Dict, Any, Callable
4
+ from .block import block_forward, single_block_forward
5
+ from .lora_controller import enable_lora
6
+ from accelerate.utils import is_torch_version
7
+ from diffusers.models.transformers.transformer_flux import (
8
+ FluxTransformer2DModel,
9
+ Transformer2DModelOutput,
10
+ USE_PEFT_BACKEND,
11
+ scale_lora_layers,
12
+ unscale_lora_layers,
13
+ logger,
14
+ )
15
+ import numpy as np
16
+
17
+
18
+ def prepare_params(
19
+ hidden_states: torch.Tensor,
20
+ encoder_hidden_states: torch.Tensor = None,
21
+ pooled_projections: torch.Tensor = None,
22
+ timestep: torch.LongTensor = None,
23
+ img_ids: torch.Tensor = None,
24
+ txt_ids: torch.Tensor = None,
25
+ guidance: torch.Tensor = None,
26
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
27
+ controlnet_block_samples=None,
28
+ controlnet_single_block_samples=None,
29
+ return_dict: bool = True,
30
+ **kwargs: dict,
31
+ ):
32
+ return (
33
+ hidden_states,
34
+ encoder_hidden_states,
35
+ pooled_projections,
36
+ timestep,
37
+ img_ids,
38
+ txt_ids,
39
+ guidance,
40
+ joint_attention_kwargs,
41
+ controlnet_block_samples,
42
+ controlnet_single_block_samples,
43
+ return_dict,
44
+ )
45
+
46
+
47
+ def tranformer_forward(
48
+ transformer: FluxTransformer2DModel,
49
+ condition_latents: torch.Tensor,
50
+ condition_ids: torch.Tensor,
51
+ condition_type_ids: torch.Tensor,
52
+ model_config: Optional[Dict[str, Any]] = {},
53
+ c_t=0,
54
+ **params: dict,
55
+ ):
56
+ self = transformer
57
+ use_condition = condition_latents is not None
58
+
59
+ (
60
+ hidden_states,
61
+ encoder_hidden_states,
62
+ pooled_projections,
63
+ timestep,
64
+ img_ids,
65
+ txt_ids,
66
+ guidance,
67
+ joint_attention_kwargs,
68
+ controlnet_block_samples,
69
+ controlnet_single_block_samples,
70
+ return_dict,
71
+ ) = prepare_params(**params)
72
+
73
+ if joint_attention_kwargs is not None:
74
+ joint_attention_kwargs = joint_attention_kwargs.copy()
75
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
76
+ else:
77
+ lora_scale = 1.0
78
+
79
+ if USE_PEFT_BACKEND:
80
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
81
+ scale_lora_layers(self, lora_scale)
82
+ else:
83
+ if (
84
+ joint_attention_kwargs is not None
85
+ and joint_attention_kwargs.get("scale", None) is not None
86
+ ):
87
+ logger.warning(
88
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
89
+ )
90
+
91
+ with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
92
+ hidden_states = self.x_embedder(hidden_states)
93
+ condition_latents = self.x_embedder(condition_latents) if use_condition else None
94
+
95
+ timestep = timestep.to(hidden_states.dtype) * 1000
96
+
97
+ if guidance is not None:
98
+ guidance = guidance.to(hidden_states.dtype) * 1000
99
+ else:
100
+ guidance = None
101
+
102
+ temb = (
103
+ self.time_text_embed(timestep, pooled_projections)
104
+ if guidance is None
105
+ else self.time_text_embed(timestep, guidance, pooled_projections)
106
+ )
107
+
108
+ cond_temb = (
109
+ self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
110
+ if guidance is None
111
+ else self.time_text_embed(
112
+ torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
113
+ )
114
+ )
115
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
116
+
117
+ if txt_ids.ndim == 3:
118
+ logger.warning(
119
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
120
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
121
+ )
122
+ txt_ids = txt_ids[0]
123
+ if img_ids.ndim == 3:
124
+ logger.warning(
125
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
126
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
127
+ )
128
+ img_ids = img_ids[0]
129
+
130
+ ids = torch.cat((txt_ids, img_ids), dim=0)
131
+ image_rotary_emb = self.pos_embed(ids)
132
+ if use_condition:
133
+ # condition_ids[:, :1] = condition_type_ids
134
+ cond_rotary_emb = self.pos_embed(condition_ids)
135
+
136
+ # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
137
+
138
+ for index_block, block in enumerate(self.transformer_blocks):
139
+ if self.training and self.gradient_checkpointing:
140
+ ckpt_kwargs: Dict[str, Any] = (
141
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
142
+ )
143
+ encoder_hidden_states, hidden_states, condition_latents = (
144
+ torch.utils.checkpoint.checkpoint(
145
+ block_forward,
146
+ self=block,
147
+ model_config=model_config,
148
+ hidden_states=hidden_states,
149
+ encoder_hidden_states=encoder_hidden_states,
150
+ condition_latents=condition_latents if use_condition else None,
151
+ temb=temb,
152
+ cond_temb=cond_temb if use_condition else None,
153
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
154
+ image_rotary_emb=image_rotary_emb,
155
+ **ckpt_kwargs,
156
+ )
157
+ )
158
+
159
+ else:
160
+ encoder_hidden_states, hidden_states, condition_latents = block_forward(
161
+ block,
162
+ model_config=model_config,
163
+ hidden_states=hidden_states,
164
+ encoder_hidden_states=encoder_hidden_states,
165
+ condition_latents=condition_latents if use_condition else None,
166
+ temb=temb,
167
+ cond_temb=cond_temb if use_condition else None,
168
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
169
+ image_rotary_emb=image_rotary_emb,
170
+ )
171
+
172
+ # controlnet residual
173
+ if controlnet_block_samples is not None:
174
+ interval_control = len(self.transformer_blocks) / len(
175
+ controlnet_block_samples
176
+ )
177
+ interval_control = int(np.ceil(interval_control))
178
+ hidden_states = (
179
+ hidden_states
180
+ + controlnet_block_samples[index_block // interval_control]
181
+ )
182
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
183
+
184
+ for index_block, block in enumerate(self.single_transformer_blocks):
185
+ if self.training and self.gradient_checkpointing:
186
+ ckpt_kwargs: Dict[str, Any] = (
187
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
188
+ )
189
+ result = torch.utils.checkpoint.checkpoint(
190
+ single_block_forward,
191
+ self=block,
192
+ model_config=model_config,
193
+ hidden_states=hidden_states,
194
+ temb=temb,
195
+ image_rotary_emb=image_rotary_emb,
196
+ **(
197
+ {
198
+ "condition_latents": condition_latents,
199
+ "cond_temb": cond_temb,
200
+ "cond_rotary_emb": cond_rotary_emb,
201
+ }
202
+ if use_condition
203
+ else {}
204
+ ),
205
+ **ckpt_kwargs,
206
+ )
207
+
208
+ else:
209
+ result = single_block_forward(
210
+ block,
211
+ model_config=model_config,
212
+ hidden_states=hidden_states,
213
+ temb=temb,
214
+ image_rotary_emb=image_rotary_emb,
215
+ **(
216
+ {
217
+ "condition_latents": condition_latents,
218
+ "cond_temb": cond_temb,
219
+ "cond_rotary_emb": cond_rotary_emb,
220
+ }
221
+ if use_condition
222
+ else {}
223
+ ),
224
+ )
225
+ if use_condition:
226
+ hidden_states, condition_latents = result
227
+ else:
228
+ hidden_states = result
229
+
230
+ # controlnet residual
231
+ if controlnet_single_block_samples is not None:
232
+ interval_control = len(self.single_transformer_blocks) / len(
233
+ controlnet_single_block_samples
234
+ )
235
+ interval_control = int(np.ceil(interval_control))
236
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
237
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
238
+ + controlnet_single_block_samples[index_block // interval_control]
239
+ )
240
+
241
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
242
+
243
+ hidden_states = self.norm_out(hidden_states, temb)
244
+ output = self.proj_out(hidden_states)
245
+
246
+ if USE_PEFT_BACKEND:
247
+ # remove `lora_scale` from each PEFT layer
248
+ unscale_lora_layers(self, lora_scale)
249
+
250
+ if not return_dict:
251
+ return (output,)
252
+ return Transformer2DModelOutput(sample=output)
OminiControl/src/gradio/gradio_app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ from diffusers.pipelines import FluxPipeline
5
+ from diffusers import FluxTransformer2DModel
6
+ import numpy as np
7
+
8
+ from ..flux.condition import Condition
9
+ from ..flux.generate import seed_everything, generate
10
+
11
+ pipe = None
12
+ use_int8 = False
13
+
14
+
15
+ def get_gpu_memory():
16
+ return torch.cuda.get_device_properties(0).total_memory / 1024**3
17
+
18
+
19
+ def init_pipeline():
20
+ global pipe
21
+ if use_int8 or get_gpu_memory() < 33:
22
+ transformer_model = FluxTransformer2DModel.from_pretrained(
23
+ "sayakpaul/flux.1-schell-int8wo-improved",
24
+ torch_dtype=torch.bfloat16,
25
+ use_safetensors=False,
26
+ )
27
+ pipe = FluxPipeline.from_pretrained(
28
+ "black-forest-labs/FLUX.1-schnell",
29
+ transformer=transformer_model,
30
+ torch_dtype=torch.bfloat16,
31
+ )
32
+ else:
33
+ pipe = FluxPipeline.from_pretrained(
34
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
35
+ )
36
+ pipe = pipe.to("cuda")
37
+ pipe.load_lora_weights(
38
+ "Yuanshi/OminiControl",
39
+ weight_name="omini/subject_512.safetensors",
40
+ adapter_name="subject",
41
+ )
42
+
43
+
44
+ def process_image_and_text(image, text):
45
+ # center crop image
46
+ w, h, min_size = image.size[0], image.size[1], min(image.size)
47
+ image = image.crop(
48
+ (
49
+ (w - min_size) // 2,
50
+ (h - min_size) // 2,
51
+ (w + min_size) // 2,
52
+ (h + min_size) // 2,
53
+ )
54
+ )
55
+ image = image.resize((512, 512))
56
+
57
+ condition = Condition("subject", image, position_delta=(0, 32))
58
+
59
+ if pipe is None:
60
+ init_pipeline()
61
+
62
+ result_img = generate(
63
+ pipe,
64
+ prompt=text.strip(),
65
+ conditions=[condition],
66
+ num_inference_steps=8,
67
+ height=512,
68
+ width=512,
69
+ ).images[0]
70
+
71
+ return result_img
72
+
73
+
74
+ def get_samples():
75
+ sample_list = [
76
+ {
77
+ "image": "assets/oranges.jpg",
78
+ "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'",
79
+ },
80
+ {
81
+ "image": "assets/penguin.jpg",
82
+ "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'",
83
+ },
84
+ {
85
+ "image": "assets/rc_car.jpg",
86
+ "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.",
87
+ },
88
+ {
89
+ "image": "assets/clock.jpg",
90
+ "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.",
91
+ },
92
+ {
93
+ "image": "assets/tshirt.jpg",
94
+ "text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.",
95
+ },
96
+ ]
97
+ return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list]
98
+
99
+
100
+ demo = gr.Interface(
101
+ fn=process_image_and_text,
102
+ inputs=[
103
+ gr.Image(type="pil"),
104
+ gr.Textbox(lines=2),
105
+ ],
106
+ outputs=gr.Image(type="pil"),
107
+ title="OminiControl / Subject driven generation",
108
+ examples=get_samples(),
109
+ )
110
+
111
+ if __name__ == "__main__":
112
+ init_pipeline()
113
+ demo.launch(
114
+ debug=True,
115
+ )
OminiControl/src/train/callbacks.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as L
2
+ from PIL import Image, ImageFilter, ImageDraw
3
+ import numpy as np
4
+ from transformers import pipeline
5
+ import cv2
6
+ import torch
7
+ import os
8
+
9
+ try:
10
+ import wandb
11
+ except ImportError:
12
+ wandb = None
13
+
14
+ from ..flux.condition import Condition
15
+ from ..flux.generate import generate
16
+
17
+
18
+ class TrainingCallback(L.Callback):
19
+ def __init__(self, run_name, training_config: dict = {}):
20
+ self.run_name, self.training_config = run_name, training_config
21
+
22
+ self.print_every_n_steps = training_config.get("print_every_n_steps", 10)
23
+ self.save_interval = training_config.get("save_interval", 1000)
24
+ self.sample_interval = training_config.get("sample_interval", 1000)
25
+ self.save_path = training_config.get("save_path", "./output")
26
+
27
+ self.wandb_config = training_config.get("wandb", None)
28
+ self.use_wandb = (
29
+ wandb is not None and os.environ.get("WANDB_API_KEY") is not None
30
+ )
31
+
32
+ self.total_steps = 0
33
+
34
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
35
+ gradient_size = 0
36
+ max_gradient_size = 0
37
+ count = 0
38
+ for _, param in pl_module.named_parameters():
39
+ if param.grad is not None:
40
+ gradient_size += param.grad.norm(2).item()
41
+ max_gradient_size = max(max_gradient_size, param.grad.norm(2).item())
42
+ count += 1
43
+ if count > 0:
44
+ gradient_size /= count
45
+
46
+ self.total_steps += 1
47
+
48
+ # Print training progress every n steps
49
+ if self.use_wandb:
50
+ report_dict = {
51
+ "steps": batch_idx,
52
+ "steps": self.total_steps,
53
+ "epoch": trainer.current_epoch,
54
+ "gradient_size": gradient_size,
55
+ }
56
+ loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches
57
+ report_dict["loss"] = loss_value
58
+ report_dict["t"] = pl_module.last_t
59
+ wandb.log(report_dict)
60
+
61
+ if self.total_steps % self.print_every_n_steps == 0:
62
+ print(
63
+ f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}"
64
+ )
65
+
66
+ # Save LoRA weights at specified intervals
67
+ if self.total_steps % self.save_interval == 0:
68
+ print(
69
+ f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights"
70
+ )
71
+ pl_module.save_lora(
72
+ f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}"
73
+ )
74
+
75
+ # Generate and save a sample image at specified intervals
76
+ if self.total_steps % self.sample_interval == 0:
77
+ print(
78
+ f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample"
79
+ )
80
+ self.generate_a_sample(
81
+ trainer,
82
+ pl_module,
83
+ f"{self.save_path}/{self.run_name}/output",
84
+ f"lora_{self.total_steps}",
85
+ batch["condition_type"][
86
+ 0
87
+ ], # Use the condition type from the current batch
88
+ )
89
+
90
+ @torch.no_grad()
91
+ def generate_a_sample(
92
+ self,
93
+ trainer,
94
+ pl_module,
95
+ save_path,
96
+ file_name,
97
+ condition_type="super_resolution",
98
+ ):
99
+ # TODO: change this two variables to parameters
100
+ condition_size = trainer.training_config["dataset"]["condition_size"]
101
+ target_size = trainer.training_config["dataset"]["target_size"]
102
+ position_scale = trainer.training_config["dataset"].get("position_scale", 1.0)
103
+
104
+ generator = torch.Generator(device=pl_module.device)
105
+ generator.manual_seed(42)
106
+
107
+ test_list = []
108
+
109
+ if condition_type == "subject":
110
+ test_list.extend(
111
+ [
112
+ (
113
+ Image.open("assets/test_in.jpg"),
114
+ [0, -32],
115
+ "Resting on the picnic table at a lakeside campsite, it's caught in the golden glow of early morning, with mist rising from the water and tall pines casting long shadows behind the scene.",
116
+ ),
117
+ (
118
+ Image.open("assets/test_out.jpg"),
119
+ [0, -32],
120
+ "In a bright room. It is placed on a table.",
121
+ ),
122
+ ]
123
+ )
124
+ elif condition_type == "canny":
125
+ condition_img = Image.open("assets/vase_hq.jpg").resize(
126
+ (condition_size, condition_size)
127
+ )
128
+ condition_img = np.array(condition_img)
129
+ condition_img = cv2.Canny(condition_img, 100, 200)
130
+ condition_img = Image.fromarray(condition_img).convert("RGB")
131
+ test_list.append(
132
+ (
133
+ condition_img,
134
+ [0, 0],
135
+ "A beautiful vase on a table.",
136
+ {"position_scale": position_scale} if position_scale != 1.0 else {},
137
+ )
138
+ )
139
+ elif condition_type == "coloring":
140
+ condition_img = (
141
+ Image.open("assets/vase_hq.jpg")
142
+ .resize((condition_size, condition_size))
143
+ .convert("L")
144
+ .convert("RGB")
145
+ )
146
+ test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
147
+ elif condition_type == "depth":
148
+ if not hasattr(self, "deepth_pipe"):
149
+ self.deepth_pipe = pipeline(
150
+ task="depth-estimation",
151
+ model="LiheYoung/depth-anything-small-hf",
152
+ device="cpu",
153
+ )
154
+ condition_img = (
155
+ Image.open("assets/vase_hq.jpg")
156
+ .resize((condition_size, condition_size))
157
+ .convert("RGB")
158
+ )
159
+ condition_img = self.deepth_pipe(condition_img)["depth"].convert("RGB")
160
+ test_list.append(
161
+ (
162
+ condition_img,
163
+ [0, 0],
164
+ "A beautiful vase on a table.",
165
+ {"position_scale": position_scale} if position_scale != 1.0 else {},
166
+ )
167
+ )
168
+ elif condition_type == "depth_pred":
169
+ condition_img = (
170
+ Image.open("assets/vase_hq.jpg")
171
+ .resize((condition_size, condition_size))
172
+ .convert("RGB")
173
+ )
174
+ test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
175
+ elif condition_type == "deblurring":
176
+ blur_radius = 5
177
+ image = Image.open("./assets/vase_hq.jpg")
178
+ condition_img = (
179
+ image.convert("RGB")
180
+ .resize((condition_size, condition_size))
181
+ .filter(ImageFilter.GaussianBlur(blur_radius))
182
+ .convert("RGB")
183
+ )
184
+ test_list.append(
185
+ (
186
+ condition_img,
187
+ [0, 0],
188
+ "A beautiful vase on a table.",
189
+ {"position_scale": position_scale} if position_scale != 1.0 else {},
190
+ )
191
+ )
192
+ elif condition_type == "fill":
193
+ condition_img = (
194
+ Image.open("./assets/vase_hq.jpg")
195
+ .resize((condition_size, condition_size))
196
+ .convert("RGB")
197
+ )
198
+ mask = Image.new("L", condition_img.size, 0)
199
+ draw = ImageDraw.Draw(mask)
200
+ a = condition_img.size[0] // 4
201
+ b = a * 3
202
+ draw.rectangle([a, a, b, b], fill=255)
203
+ condition_img = Image.composite(
204
+ condition_img, Image.new("RGB", condition_img.size, (0, 0, 0)), mask
205
+ )
206
+ test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
207
+ elif condition_type == "sr":
208
+ condition_img = (
209
+ Image.open("assets/vase_hq.jpg")
210
+ .resize((condition_size, condition_size))
211
+ .convert("RGB")
212
+ )
213
+ test_list.append((condition_img, [0, -16], "A beautiful vase on a table."))
214
+ elif condition_type == "cartoon":
215
+ condition_img = (
216
+ Image.open("assets/cartoon_boy.png")
217
+ .resize((condition_size, condition_size))
218
+ .convert("RGB")
219
+ )
220
+ test_list.append(
221
+ (
222
+ condition_img,
223
+ [0, -16],
224
+ "A cartoon character in a white background. He is looking right, and running.",
225
+ )
226
+ )
227
+ else:
228
+ raise NotImplementedError
229
+
230
+ if not os.path.exists(save_path):
231
+ os.makedirs(save_path)
232
+ for i, (condition_img, position_delta, prompt, *others) in enumerate(test_list):
233
+ condition = Condition(
234
+ condition_type=condition_type,
235
+ condition=condition_img.resize(
236
+ (condition_size, condition_size)
237
+ ).convert("RGB"),
238
+ position_delta=position_delta,
239
+ **(others[0] if others else {}),
240
+ )
241
+ res = generate(
242
+ pl_module.flux_pipe,
243
+ prompt=prompt,
244
+ conditions=[condition],
245
+ height=target_size,
246
+ width=target_size,
247
+ generator=generator,
248
+ model_config=pl_module.model_config,
249
+ default_lora=True,
250
+ )
251
+ res.images[0].save(
252
+ os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg")
253
+ )
OminiControl/src/train/data.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageFilter, ImageDraw
2
+ import cv2
3
+ import numpy as np
4
+ from torch.utils.data import Dataset
5
+ import torchvision.transforms as T
6
+ import random
7
+
8
+
9
+ class Subject200KDataset(Dataset):
10
+ def __init__(
11
+ self,
12
+ base_dataset,
13
+ condition_size: int = 512,
14
+ target_size: int = 512,
15
+ image_size: int = 512,
16
+ padding: int = 0,
17
+ condition_type: str = "subject",
18
+ drop_text_prob: float = 0.1,
19
+ drop_image_prob: float = 0.1,
20
+ return_pil_image: bool = False,
21
+ ):
22
+ self.base_dataset = base_dataset
23
+ self.condition_size = condition_size
24
+ self.target_size = target_size
25
+ self.image_size = image_size
26
+ self.padding = padding
27
+ self.condition_type = condition_type
28
+ self.drop_text_prob = drop_text_prob
29
+ self.drop_image_prob = drop_image_prob
30
+ self.return_pil_image = return_pil_image
31
+
32
+ self.to_tensor = T.ToTensor()
33
+
34
+ def __len__(self):
35
+ return len(self.base_dataset) * 2
36
+
37
+ def __getitem__(self, idx):
38
+ # If target is 0, left image is target, right image is condition
39
+ target = idx % 2
40
+ item = self.base_dataset[idx // 2]
41
+
42
+ # Crop the image to target and condition
43
+ image = item["image"]
44
+ left_img = image.crop(
45
+ (
46
+ self.padding,
47
+ self.padding,
48
+ self.image_size + self.padding,
49
+ self.image_size + self.padding,
50
+ )
51
+ )
52
+ right_img = image.crop(
53
+ (
54
+ self.image_size + self.padding * 2,
55
+ self.padding,
56
+ self.image_size * 2 + self.padding * 2,
57
+ self.image_size + self.padding,
58
+ )
59
+ )
60
+
61
+ # Get the target and condition image
62
+ target_image, condition_img = (
63
+ (left_img, right_img) if target == 0 else (right_img, left_img)
64
+ )
65
+
66
+ # Resize the image
67
+ condition_img = condition_img.resize(
68
+ (self.condition_size, self.condition_size)
69
+ ).convert("RGB")
70
+ target_image = target_image.resize(
71
+ (self.target_size, self.target_size)
72
+ ).convert("RGB")
73
+
74
+ # Get the description
75
+ description = item["description"][
76
+ "description_0" if target == 0 else "description_1"
77
+ ]
78
+
79
+ # Randomly drop text or image
80
+ drop_text = random.random() < self.drop_text_prob
81
+ drop_image = random.random() < self.drop_image_prob
82
+ if drop_text:
83
+ description = ""
84
+ if drop_image:
85
+ condition_img = Image.new(
86
+ "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
87
+ )
88
+
89
+ return {
90
+ "image": self.to_tensor(target_image),
91
+ "condition": self.to_tensor(condition_img),
92
+ "condition_type": self.condition_type,
93
+ "description": description,
94
+ # 16 is the downscale factor of the image
95
+ "position_delta": np.array([0, -self.condition_size // 16]),
96
+ **({"pil_image": image} if self.return_pil_image else {}),
97
+ }
98
+
99
+
100
+ class ImageConditionDataset(Dataset):
101
+ def __init__(
102
+ self,
103
+ base_dataset,
104
+ condition_size: int = 512,
105
+ target_size: int = 512,
106
+ condition_type: str = "canny",
107
+ drop_text_prob: float = 0.1,
108
+ drop_image_prob: float = 0.1,
109
+ return_pil_image: bool = False,
110
+ position_scale=1.0,
111
+ ):
112
+ self.base_dataset = base_dataset
113
+ self.condition_size = condition_size
114
+ self.target_size = target_size
115
+ self.condition_type = condition_type
116
+ self.drop_text_prob = drop_text_prob
117
+ self.drop_image_prob = drop_image_prob
118
+ self.return_pil_image = return_pil_image
119
+ self.position_scale = position_scale
120
+
121
+ self.to_tensor = T.ToTensor()
122
+
123
+ def __len__(self):
124
+ return len(self.base_dataset)
125
+
126
+ @property
127
+ def depth_pipe(self):
128
+ if not hasattr(self, "_depth_pipe"):
129
+ from transformers import pipeline
130
+
131
+ self._depth_pipe = pipeline(
132
+ task="depth-estimation",
133
+ model="LiheYoung/depth-anything-small-hf",
134
+ device="cpu",
135
+ )
136
+ return self._depth_pipe
137
+
138
+ def _get_canny_edge(self, img):
139
+ resize_ratio = self.condition_size / max(img.size)
140
+ img = img.resize(
141
+ (int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio))
142
+ )
143
+ img_np = np.array(img)
144
+ img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
145
+ edges = cv2.Canny(img_gray, 100, 200)
146
+ return Image.fromarray(edges).convert("RGB")
147
+
148
+ def __getitem__(self, idx):
149
+ image = self.base_dataset[idx]["jpg"]
150
+ image = image.resize((self.target_size, self.target_size)).convert("RGB")
151
+ description = self.base_dataset[idx]["json"]["prompt"]
152
+
153
+ enable_scale = random.random() < 1
154
+ if not enable_scale:
155
+ condition_size = int(self.condition_size * self.position_scale)
156
+ position_scale = 1.0
157
+ else:
158
+ condition_size = self.condition_size
159
+ position_scale = self.position_scale
160
+
161
+ # Get the condition image
162
+ position_delta = np.array([0, 0])
163
+ if self.condition_type == "canny":
164
+ condition_img = self._get_canny_edge(image)
165
+ elif self.condition_type == "coloring":
166
+ condition_img = (
167
+ image.resize((condition_size, condition_size))
168
+ .convert("L")
169
+ .convert("RGB")
170
+ )
171
+ elif self.condition_type == "deblurring":
172
+ blur_radius = random.randint(1, 10)
173
+ condition_img = (
174
+ image.convert("RGB")
175
+ .filter(ImageFilter.GaussianBlur(blur_radius))
176
+ .resize((condition_size, condition_size))
177
+ .convert("RGB")
178
+ )
179
+ elif self.condition_type == "depth":
180
+ condition_img = self.depth_pipe(image)["depth"].convert("RGB")
181
+ condition_img = condition_img.resize((condition_size, condition_size))
182
+ elif self.condition_type == "depth_pred":
183
+ condition_img = image
184
+ image = self.depth_pipe(condition_img)["depth"].convert("RGB")
185
+ description = f"[depth] {description}"
186
+ elif self.condition_type == "fill":
187
+ condition_img = image.resize((condition_size, condition_size)).convert(
188
+ "RGB"
189
+ )
190
+ w, h = image.size
191
+ x1, x2 = sorted([random.randint(0, w), random.randint(0, w)])
192
+ y1, y2 = sorted([random.randint(0, h), random.randint(0, h)])
193
+ mask = Image.new("L", image.size, 0)
194
+ draw = ImageDraw.Draw(mask)
195
+ draw.rectangle([x1, y1, x2, y2], fill=255)
196
+ if random.random() > 0.5:
197
+ mask = Image.eval(mask, lambda a: 255 - a)
198
+ condition_img = Image.composite(
199
+ image, Image.new("RGB", image.size, (0, 0, 0)), mask
200
+ )
201
+ elif self.condition_type == "sr":
202
+ condition_img = image.resize((condition_size, condition_size)).convert(
203
+ "RGB"
204
+ )
205
+ position_delta = np.array([0, -condition_size // 16])
206
+
207
+ else:
208
+ raise ValueError(f"Condition type {self.condition_type} not implemented")
209
+
210
+ # Randomly drop text or image
211
+ drop_text = random.random() < self.drop_text_prob
212
+ drop_image = random.random() < self.drop_image_prob
213
+ if drop_text:
214
+ description = ""
215
+ if drop_image:
216
+ condition_img = Image.new(
217
+ "RGB", (condition_size, condition_size), (0, 0, 0)
218
+ )
219
+
220
+ return {
221
+ "image": self.to_tensor(image),
222
+ "condition": self.to_tensor(condition_img),
223
+ "condition_type": self.condition_type,
224
+ "description": description,
225
+ "position_delta": position_delta,
226
+ **({"pil_image": [image, condition_img]} if self.return_pil_image else {}),
227
+ **({"position_scale": position_scale} if position_scale != 1.0 else {}),
228
+ }
229
+
230
+
231
+ class CartoonDataset(Dataset):
232
+ def __init__(
233
+ self,
234
+ base_dataset,
235
+ condition_size: int = 1024,
236
+ target_size: int = 1024,
237
+ image_size: int = 1024,
238
+ padding: int = 0,
239
+ condition_type: str = "cartoon",
240
+ drop_text_prob: float = 0.1,
241
+ drop_image_prob: float = 0.1,
242
+ return_pil_image: bool = False,
243
+ ):
244
+ self.base_dataset = base_dataset
245
+ self.condition_size = condition_size
246
+ self.target_size = target_size
247
+ self.image_size = image_size
248
+ self.padding = padding
249
+ self.condition_type = condition_type
250
+ self.drop_text_prob = drop_text_prob
251
+ self.drop_image_prob = drop_image_prob
252
+ self.return_pil_image = return_pil_image
253
+
254
+ self.to_tensor = T.ToTensor()
255
+
256
+ def __len__(self):
257
+ return len(self.base_dataset)
258
+
259
+ def __getitem__(self, idx):
260
+ data = self.base_dataset[idx]
261
+ condition_img = data["condition"]
262
+ target_image = data["target"]
263
+
264
+ # Tag
265
+ tag = data["tags"][0]
266
+
267
+ target_description = data["target_description"]
268
+
269
+ description = {
270
+ "lion": "lion like animal",
271
+ "bear": "bear like animal",
272
+ "gorilla": "gorilla like animal",
273
+ "dog": "dog like animal",
274
+ "elephant": "elephant like animal",
275
+ "eagle": "eagle like bird",
276
+ "tiger": "tiger like animal",
277
+ "owl": "owl like bird",
278
+ "woman": "woman",
279
+ "parrot": "parrot like bird",
280
+ "mouse": "mouse like animal",
281
+ "man": "man",
282
+ "pigeon": "pigeon like bird",
283
+ "girl": "girl",
284
+ "panda": "panda like animal",
285
+ "crocodile": "crocodile like animal",
286
+ "rabbit": "rabbit like animal",
287
+ "boy": "boy",
288
+ "monkey": "monkey like animal",
289
+ "cat": "cat like animal",
290
+ }
291
+
292
+ # Resize the image
293
+ condition_img = condition_img.resize(
294
+ (self.condition_size, self.condition_size)
295
+ ).convert("RGB")
296
+ target_image = target_image.resize(
297
+ (self.target_size, self.target_size)
298
+ ).convert("RGB")
299
+
300
+ # Process datum to create description
301
+ description = data.get(
302
+ "description",
303
+ f"Photo of a {description[tag]} cartoon character in a white background. Character is facing {target_description['facing_direction']}. Character pose is {target_description['pose']}.",
304
+ )
305
+
306
+ # Randomly drop text or image
307
+ drop_text = random.random() < self.drop_text_prob
308
+ drop_image = random.random() < self.drop_image_prob
309
+ if drop_text:
310
+ description = ""
311
+ if drop_image:
312
+ condition_img = Image.new(
313
+ "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
314
+ )
315
+
316
+ return {
317
+ "image": self.to_tensor(target_image),
318
+ "condition": self.to_tensor(condition_img),
319
+ "condition_type": self.condition_type,
320
+ "description": description,
321
+ # 16 is the downscale factor of the image
322
+ "position_delta": np.array([0, -16]),
323
+ }
OminiControl/src/train/model.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as L
2
+ from diffusers.pipelines import FluxPipeline
3
+ import torch
4
+ from peft import LoraConfig, get_peft_model_state_dict
5
+
6
+ import prodigyopt
7
+
8
+ from ..flux.transformer import tranformer_forward
9
+ from ..flux.condition import Condition
10
+ from ..flux.pipeline_tools import encode_images, prepare_text_input
11
+
12
+
13
+ class OminiModel(L.LightningModule):
14
+ def __init__(
15
+ self,
16
+ flux_pipe_id: str,
17
+ lora_path: str = None,
18
+ lora_config: dict = None,
19
+ device: str = "cuda",
20
+ dtype: torch.dtype = torch.bfloat16,
21
+ model_config: dict = {},
22
+ optimizer_config: dict = None,
23
+ gradient_checkpointing: bool = False,
24
+ ):
25
+ # Initialize the LightningModule
26
+ super().__init__()
27
+ self.model_config = model_config
28
+ self.optimizer_config = optimizer_config
29
+
30
+ # Load the Flux pipeline
31
+ self.flux_pipe: FluxPipeline = (
32
+ FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device)
33
+ )
34
+ self.transformer = self.flux_pipe.transformer
35
+ self.transformer.gradient_checkpointing = gradient_checkpointing
36
+ self.transformer.train()
37
+
38
+ # Freeze the Flux pipeline
39
+ self.flux_pipe.text_encoder.requires_grad_(False).eval()
40
+ self.flux_pipe.text_encoder_2.requires_grad_(False).eval()
41
+ self.flux_pipe.vae.requires_grad_(False).eval()
42
+
43
+ # Initialize LoRA layers
44
+ self.lora_layers = self.init_lora(lora_path, lora_config)
45
+
46
+ self.to(device).to(dtype)
47
+
48
+ def init_lora(self, lora_path: str, lora_config: dict):
49
+ assert lora_path or lora_config
50
+ if lora_path:
51
+ # TODO: Implement this
52
+ raise NotImplementedError
53
+ else:
54
+ self.transformer.add_adapter(LoraConfig(**lora_config))
55
+ # TODO: Check if this is correct (p.requires_grad)
56
+ lora_layers = filter(
57
+ lambda p: p.requires_grad, self.transformer.parameters()
58
+ )
59
+ return list(lora_layers)
60
+
61
+ def save_lora(self, path: str):
62
+ FluxPipeline.save_lora_weights(
63
+ save_directory=path,
64
+ transformer_lora_layers=get_peft_model_state_dict(self.transformer),
65
+ safe_serialization=True,
66
+ )
67
+
68
+ def configure_optimizers(self):
69
+ # Freeze the transformer
70
+ self.transformer.requires_grad_(False)
71
+ opt_config = self.optimizer_config
72
+
73
+ # Set the trainable parameters
74
+ self.trainable_params = self.lora_layers
75
+
76
+ # Unfreeze trainable parameters
77
+ for p in self.trainable_params:
78
+ p.requires_grad_(True)
79
+
80
+ # Initialize the optimizer
81
+ if opt_config["type"] == "AdamW":
82
+ optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"])
83
+ elif opt_config["type"] == "Prodigy":
84
+ optimizer = prodigyopt.Prodigy(
85
+ self.trainable_params,
86
+ **opt_config["params"],
87
+ )
88
+ elif opt_config["type"] == "SGD":
89
+ optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"])
90
+ else:
91
+ raise NotImplementedError
92
+
93
+ return optimizer
94
+
95
+ def training_step(self, batch, batch_idx):
96
+ step_loss = self.step(batch)
97
+ self.log_loss = (
98
+ step_loss.item()
99
+ if not hasattr(self, "log_loss")
100
+ else self.log_loss * 0.95 + step_loss.item() * 0.05
101
+ )
102
+ return step_loss
103
+
104
+ def step(self, batch):
105
+ imgs = batch["image"]
106
+ conditions = batch["condition"]
107
+ condition_types = batch["condition_type"]
108
+ prompts = batch["description"]
109
+ position_delta = batch["position_delta"][0]
110
+ position_scale = float(batch.get("position_scale", [1.0])[0])
111
+
112
+ # Prepare inputs
113
+ with torch.no_grad():
114
+ # Prepare image input
115
+ x_0, img_ids = encode_images(self.flux_pipe, imgs)
116
+
117
+ # Prepare text input
118
+ prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input(
119
+ self.flux_pipe, prompts
120
+ )
121
+
122
+ # Prepare t and x_t
123
+ t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
124
+ x_1 = torch.randn_like(x_0).to(self.device)
125
+ t_ = t.unsqueeze(1).unsqueeze(1)
126
+ x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
127
+
128
+ # Prepare conditions
129
+ condition_latents, condition_ids = encode_images(self.flux_pipe, conditions)
130
+
131
+ # Add position delta
132
+ condition_ids[:, 1] += position_delta[0]
133
+ condition_ids[:, 2] += position_delta[1]
134
+
135
+ if position_scale != 1.0:
136
+ scale_bias = (position_scale - 1.0) / 2
137
+ condition_ids[:, 1] *= position_scale
138
+ condition_ids[:, 2] *= position_scale
139
+ condition_ids[:, 1] += scale_bias
140
+ condition_ids[:, 2] += scale_bias
141
+
142
+ # Prepare condition type
143
+ condition_type_ids = torch.tensor(
144
+ [
145
+ Condition.get_type_id(condition_type)
146
+ for condition_type in condition_types
147
+ ]
148
+ ).to(self.device)
149
+ condition_type_ids = (
150
+ torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0]
151
+ ).unsqueeze(1)
152
+
153
+ # Prepare guidance
154
+ guidance = (
155
+ torch.ones_like(t).to(self.device)
156
+ if self.transformer.config.guidance_embeds
157
+ else None
158
+ )
159
+
160
+ # Forward pass
161
+ transformer_out = tranformer_forward(
162
+ self.transformer,
163
+ # Model config
164
+ model_config=self.model_config,
165
+ # Inputs of the condition (new feature)
166
+ condition_latents=condition_latents,
167
+ condition_ids=condition_ids,
168
+ condition_type_ids=condition_type_ids,
169
+ # Inputs to the original transformer
170
+ hidden_states=x_t,
171
+ timestep=t,
172
+ guidance=guidance,
173
+ pooled_projections=pooled_prompt_embeds,
174
+ encoder_hidden_states=prompt_embeds,
175
+ txt_ids=text_ids,
176
+ img_ids=img_ids,
177
+ joint_attention_kwargs=None,
178
+ return_dict=False,
179
+ )
180
+ pred = transformer_out[0]
181
+
182
+ # Compute loss
183
+ loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean")
184
+ self.last_t = t.mean().item()
185
+ return loss
OminiControl/src/train/train.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ import torch
3
+ import lightning as L
4
+ import yaml
5
+ import os
6
+ import time
7
+
8
+ from datasets import load_dataset
9
+
10
+ from .data import ImageConditionDataset, Subject200KDataset, CartoonDataset
11
+ from .model import OminiModel
12
+ from .callbacks import TrainingCallback
13
+
14
+
15
+ def get_rank():
16
+ try:
17
+ rank = int(os.environ.get("LOCAL_RANK"))
18
+ except:
19
+ rank = 0
20
+ return rank
21
+
22
+
23
+ def get_config():
24
+ config_path = os.environ.get("XFL_CONFIG")
25
+ assert config_path is not None, "Please set the XFL_CONFIG environment variable"
26
+ with open(config_path, "r") as f:
27
+ config = yaml.safe_load(f)
28
+ return config
29
+
30
+
31
+ def init_wandb(wandb_config, run_name):
32
+ import wandb
33
+
34
+ try:
35
+ assert os.environ.get("WANDB_API_KEY") is not None
36
+ wandb.init(
37
+ project=wandb_config["project"],
38
+ name=run_name,
39
+ config={},
40
+ )
41
+ except Exception as e:
42
+ print("Failed to initialize WanDB:", e)
43
+
44
+
45
+ def main():
46
+ # Initialize
47
+ is_main_process, rank = get_rank() == 0, get_rank()
48
+ torch.cuda.set_device(rank)
49
+ config = get_config()
50
+ training_config = config["train"]
51
+ run_name = time.strftime("%Y%m%d-%H%M%S")
52
+
53
+ # Initialize WanDB
54
+ wandb_config = training_config.get("wandb", None)
55
+ if wandb_config is not None and is_main_process:
56
+ init_wandb(wandb_config, run_name)
57
+
58
+ print("Rank:", rank)
59
+ if is_main_process:
60
+ print("Config:", config)
61
+
62
+ # Initialize dataset and dataloader
63
+ if training_config["dataset"]["type"] == "subject":
64
+ dataset = load_dataset("Yuanshi/Subjects200K")
65
+
66
+ # Define filter function
67
+ def filter_func(item):
68
+ if not item.get("quality_assessment"):
69
+ return False
70
+ return all(
71
+ item["quality_assessment"].get(key, 0) >= 5
72
+ for key in ["compositeStructure", "objectConsistency", "imageQuality"]
73
+ )
74
+
75
+ # Filter dataset
76
+ if not os.path.exists("./cache/dataset"):
77
+ os.makedirs("./cache/dataset")
78
+ data_valid = dataset["train"].filter(
79
+ filter_func,
80
+ num_proc=16,
81
+ cache_file_name="./cache/dataset/data_valid.arrow",
82
+ )
83
+ dataset = Subject200KDataset(
84
+ data_valid,
85
+ condition_size=training_config["dataset"]["condition_size"],
86
+ target_size=training_config["dataset"]["target_size"],
87
+ image_size=training_config["dataset"]["image_size"],
88
+ padding=training_config["dataset"]["padding"],
89
+ condition_type=training_config["condition_type"],
90
+ drop_text_prob=training_config["dataset"]["drop_text_prob"],
91
+ drop_image_prob=training_config["dataset"]["drop_image_prob"],
92
+ )
93
+ elif training_config["dataset"]["type"] == "img":
94
+ # Load dataset text-to-image-2M
95
+ dataset = load_dataset(
96
+ "webdataset",
97
+ data_files={"train": training_config["dataset"]["urls"]},
98
+ split="train",
99
+ cache_dir="cache/t2i2m",
100
+ num_proc=32,
101
+ )
102
+ dataset = ImageConditionDataset(
103
+ dataset,
104
+ condition_size=training_config["dataset"]["condition_size"],
105
+ target_size=training_config["dataset"]["target_size"],
106
+ condition_type=training_config["condition_type"],
107
+ drop_text_prob=training_config["dataset"]["drop_text_prob"],
108
+ drop_image_prob=training_config["dataset"]["drop_image_prob"],
109
+ position_scale=training_config["dataset"].get("position_scale", 1.0),
110
+ )
111
+ elif training_config["dataset"]["type"] == "cartoon":
112
+ dataset = load_dataset("saquiboye/oye-cartoon", split="train")
113
+ dataset = CartoonDataset(
114
+ dataset,
115
+ condition_size=training_config["dataset"]["condition_size"],
116
+ target_size=training_config["dataset"]["target_size"],
117
+ image_size=training_config["dataset"]["image_size"],
118
+ padding=training_config["dataset"]["padding"],
119
+ condition_type=training_config["condition_type"],
120
+ drop_text_prob=training_config["dataset"]["drop_text_prob"],
121
+ drop_image_prob=training_config["dataset"]["drop_image_prob"],
122
+ )
123
+ else:
124
+ raise NotImplementedError
125
+
126
+ print("Dataset length:", len(dataset))
127
+ train_loader = DataLoader(
128
+ dataset,
129
+ batch_size=training_config["batch_size"],
130
+ shuffle=True,
131
+ num_workers=training_config["dataloader_workers"],
132
+ )
133
+
134
+ # Initialize model
135
+ trainable_model = OminiModel(
136
+ flux_pipe_id=config["flux_path"],
137
+ lora_config=training_config["lora_config"],
138
+ device=f"cuda",
139
+ dtype=getattr(torch, config["dtype"]),
140
+ optimizer_config=training_config["optimizer"],
141
+ model_config=config.get("model", {}),
142
+ gradient_checkpointing=training_config.get("gradient_checkpointing", False),
143
+ )
144
+
145
+ # Callbacks for logging and saving checkpoints
146
+ training_callbacks = (
147
+ [TrainingCallback(run_name, training_config=training_config)]
148
+ if is_main_process
149
+ else []
150
+ )
151
+
152
+ # Initialize trainer
153
+ trainer = L.Trainer(
154
+ accumulate_grad_batches=training_config["accumulate_grad_batches"],
155
+ callbacks=training_callbacks,
156
+ enable_checkpointing=False,
157
+ enable_progress_bar=False,
158
+ logger=False,
159
+ max_steps=training_config.get("max_steps", -1),
160
+ max_epochs=training_config.get("max_epochs", -1),
161
+ gradient_clip_val=training_config.get("gradient_clip_val", 0.5),
162
+ )
163
+
164
+ setattr(trainer, "training_config", training_config)
165
+
166
+ # Save config
167
+ save_path = training_config.get("save_path", "./output")
168
+ if is_main_process:
169
+ os.makedirs(f"{save_path}/{run_name}")
170
+ with open(f"{save_path}/{run_name}/config.yaml", "w") as f:
171
+ yaml.dump(config, f)
172
+
173
+ # Start training
174
+ trainer.fit(trainable_model, train_loader)
175
+
176
+
177
+ if __name__ == "__main__":
178
+ main()