File size: 4,022 Bytes
fb6a167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# OminiControl Training 🛠️

## Preparation

### Setup
1. **Environment**
    ```bash
    conda create -n omini python=3.10
    conda activate omini
    ```
2. **Requirements**
    ```bash
    pip install -r train/requirements.txt
    ```

### Dataset
1. Download dataset [Subject200K](https://huggingface.co/datasets/Yuanshi/Subjects200K). (**subject-driven generation**)
    ```
    bash train/script/data_download/data_download1.sh
    ```
2. Download dataset [text-to-image-2M](https://huggingface.co/datasets/jackyhate/text-to-image-2M). (**spatial control task**)
    ```
    bash train/script/data_download/data_download2.sh
    ```
    **Note:** By default, only a few files are downloaded. You can modify `data_download2.sh` to download additional datasets. Remember to update the config file to specify the training data accordingly.

## Training

### Start training training
**Config file path**: `./train/config`

**Scripts path**: `./train/script`

1. Subject-driven generation
    ```bash
    bash train/script/train_subject.sh
    ```
2. Spatial control task
    ```bash
    bash train/script/train_canny.sh
    ```

**Note**: Detailed WanDB settings and GPU settings can be found in the script files and the config files.

### Other spatial control tasks
This repository supports 5 spatial control tasks: 
1. Canny edge to image (`canny`)
2. Image colorization (`coloring`)
3. Image deblurring (`deblurring`)
4. Depth map to image (`depth`)
5. Image to depth map  (`depth_pred`)
6. Image inpainting (`fill`)
7. Super resolution (`sr`)

You can modify the `condition_type` parameter in config file `config/canny_512.yaml` to switch between different tasks.

### Customize your own task
You can customize your own task by constructing a new dataset and modifying the training code.

<details>
<summary>Instructions</summary>

1. **Dataset** : 
   
   Construct a new dataset with the following format: (`src/train/data.py`)
    ```python
    class MyDataset(Dataset):
        def __init__(self, ...):
            ...
        def __len__(self):
            ...
        def __getitem__(self, idx):
            ...
            return {
                "image": image,
                "condition": condition_img,
                "condition_type": "your_condition_type",
                "description": description,
                "position_delta": position_delta
            }
    ```
    **Note:** For spatial control tasks, set the `position_delta` to be `[0, 0]`. For non-spatial control tasks, set `position_delta` to be `[0, -condition_width // 16]`.
2. **Condition**:
   
   Add a new condition type in the `Condition` class. (`src/flux/condition.py`)
    ```python
    condition_dict = {
        ...
        "your_condition_type": your_condition_id_number, # Add your condition type here
    }
    ...
    if condition_type in [
        ...
        "your_condition_type", # Add your condition type here
    ]:
        ...
    ```
3. **Test**: 
   
   Add a new test function for your task. (`src/train/callbacks.py`)
    ```python
    if self.condition_type == "your_condition_type":
        condition_img = (
            Image.open("images/vase.jpg")
            .resize((condition_size, condition_size))
            .convert("RGB")
        )
        ...
        test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
    ```

4. **Import relevant dataset in the training script**
   Update the file in the following section. (`src/train/train.py`)
   ```python
    from .data import (
        ImageConditionDataset,
        Subject200KDateset,
        MyDataset
    )
    ...
   
    # Initialize dataset and dataloader
    if training_config["dataset"]["type"] == "your_condition_type":
       ...
   ```
   
</details>

## Hardware requirement
**Note**: Memory optimization (like dynamic T5 model loading) is pending implementation.

**Recommanded**
- Hardware: 2x NVIDIA H100 GPUs
- Memory: ~80GB GPU memory

**Minimal**
- Hardware: 1x NVIDIA L20 GPU
- Memory: ~48GB GPU memory