# OminiControl Training 🛠️ ## Preparation ### Setup 1. **Environment** ```bash conda create -n omini python=3.10 conda activate omini ``` 2. **Requirements** ```bash pip install -r train/requirements.txt ``` ### Dataset 1. Download dataset [Subject200K](https://huggingface.co/datasets/Yuanshi/Subjects200K). (**subject-driven generation**) ``` bash train/script/data_download/data_download1.sh ``` 2. Download dataset [text-to-image-2M](https://huggingface.co/datasets/jackyhate/text-to-image-2M). (**spatial control task**) ``` bash train/script/data_download/data_download2.sh ``` **Note:** By default, only a few files are downloaded. You can modify `data_download2.sh` to download additional datasets. Remember to update the config file to specify the training data accordingly. ## Training ### Start training training **Config file path**: `./train/config` **Scripts path**: `./train/script` 1. Subject-driven generation ```bash bash train/script/train_subject.sh ``` 2. Spatial control task ```bash bash train/script/train_canny.sh ``` **Note**: Detailed WanDB settings and GPU settings can be found in the script files and the config files. ### Other spatial control tasks This repository supports 5 spatial control tasks: 1. Canny edge to image (`canny`) 2. Image colorization (`coloring`) 3. Image deblurring (`deblurring`) 4. Depth map to image (`depth`) 5. Image to depth map (`depth_pred`) 6. Image inpainting (`fill`) 7. Super resolution (`sr`) You can modify the `condition_type` parameter in config file `config/canny_512.yaml` to switch between different tasks. ### Customize your own task You can customize your own task by constructing a new dataset and modifying the training code.
Instructions 1. **Dataset** : Construct a new dataset with the following format: (`src/train/data.py`) ```python class MyDataset(Dataset): def __init__(self, ...): ... def __len__(self): ... def __getitem__(self, idx): ... return { "image": image, "condition": condition_img, "condition_type": "your_condition_type", "description": description, "position_delta": position_delta } ``` **Note:** For spatial control tasks, set the `position_delta` to be `[0, 0]`. For non-spatial control tasks, set `position_delta` to be `[0, -condition_width // 16]`. 2. **Condition**: Add a new condition type in the `Condition` class. (`src/flux/condition.py`) ```python condition_dict = { ... "your_condition_type": your_condition_id_number, # Add your condition type here } ... if condition_type in [ ... "your_condition_type", # Add your condition type here ]: ... ``` 3. **Test**: Add a new test function for your task. (`src/train/callbacks.py`) ```python if self.condition_type == "your_condition_type": condition_img = ( Image.open("images/vase.jpg") .resize((condition_size, condition_size)) .convert("RGB") ) ... test_list.append((condition_img, [0, 0], "A beautiful vase on a table.")) ``` 4. **Import relevant dataset in the training script** Update the file in the following section. (`src/train/train.py`) ```python from .data import ( ImageConditionDataset, Subject200KDateset, MyDataset ) ... # Initialize dataset and dataloader if training_config["dataset"]["type"] == "your_condition_type": ... ```
## Hardware requirement **Note**: Memory optimization (like dynamic T5 model loading) is pending implementation. **Recommanded** - Hardware: 2x NVIDIA H100 GPUs - Memory: ~80GB GPU memory **Minimal** - Hardware: 1x NVIDIA L20 GPU - Memory: ~48GB GPU memory