xizaoqu commited on
Commit
b919733
·
1 Parent(s): dbaa006

update README

Browse files
Files changed (1) hide show
  1. README.md +9 -226
README.md CHANGED
@@ -1,226 +1,9 @@
1
- # Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion
2
-
3
- #### [[Project Website]](https://boyuan.space/diffusion-forcing) [[Paper]](https://arxiv.org/abs/2407.01392)
4
-
5
- [Boyuan Chen<sup>1</sup>](https://boyuan.space/), [Diego Martí Monsó<sup>2</sup>](https://www.linkedin.com/in/diego-marti/?originalSubdomain=de), [ Yilun Du<sup>1</sup>](https://yilundu.github.io/), [Max Simchowitz<sup>1</sup>](https://msimchowitz.github.io/), [Russ Tedrake<sup>1</sup>](https://groups.csail.mit.edu/locomotion/russt.html), [Vincent Sitzmann<sup>1</sup>](https://www.vincentsitzmann.com/) <br/>
6
- <sup>1</sup>MIT <sup>2</sup>Technical University of Munich </br>
7
-
8
- This is the v1.5 code base for our paper [Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion](https://boyuan.space/diffusion-forcing). The **main** branch contains our latest reimplementation with temporal attention (recommended) while the **paper** branch contains RNN code used by original paper for reproduction purpose.
9
-
10
- Diffusion Forcing v2 is coming very soon! There is a stronger technique to achieve infinite, consistent video generation uniquely enabled by diffusion forcing. We are actively investigating that so please stay tuned. We will also release latent diffusion code by then that allows you to scale up to higher resolution / longer videos!
11
-
12
- ![plot](teaser.png)
13
-
14
- ```
15
- @misc{chen2024diffusionforcingnexttokenprediction,
16
- title={Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion},
17
- author={Boyuan Chen and Diego Marti Monso and Yilun Du and Max Simchowitz and Russ Tedrake and Vincent Sitzmann},
18
- year={2024},
19
- eprint={2407.01392},
20
- archivePrefix={arXiv},
21
- primaryClass={cs.LG},
22
- url={https://arxiv.org/abs/2407.01392},
23
- }
24
- ```
25
-
26
- # Project Instructions
27
-
28
- ## Setup
29
-
30
- If you want to use our latest improved implementation for video and planning with temporal attention instead of RNN, stay on this branch. If you are instead interested in reproducing claims by orignal paper, switch to the branch used by original paper via `git checkout paper`.
31
-
32
- Run `conda create python=3.10 -n diffusion-forcing` to create environment.
33
- Run `conda activate diffusion-forcing` to activate this environment.
34
-
35
- Install dependencies for time series, video and robotics:
36
-
37
- ```
38
- pip install -r requirements.txt
39
- ```
40
-
41
- [Sign up](https://wandb.ai/site) a wandb account for cloud logging and checkpointing. In command line, run `wandb login` to login.
42
-
43
- Then modify the wandb entity in `configurations/config.yaml` to your wandb account.
44
-
45
- Optionally, if you want to do maze planning, install the following complicated dependencies due to outdated dependencies of d4rl. This involves first installing mujoco 210 and then run
46
-
47
- ```
48
- pip install -r extra_requirements.txt
49
- ```
50
-
51
- ## Quick start with pretrained checkpoints
52
-
53
- Since dataset is huge, we provide a mini subset and pre-trained checkpoints for you to quickly test out our model! To do so, download mini dataset and checkpoints from [here](https://drive.google.com/file/d/1xAOQxWcLzcFyD4zc0_rC9jGXe_uaHb7b/view?usp=sharing) to project root and extract with `tar -xzvf quickstart_atten.tar.gz`. Files shall appear in `data` and `outputs/xxx.ckpt`. Make sure you also git pull upstream to use latest version of code if you forked before ckpt release!
54
-
55
- Then run the following commands and go to the wandb panel to see the results.
56
-
57
- ### Video Prediction:
58
-
59
- Our visualization is side by side, with prediction on the left and ground truth on the right. However, ground truth is expected to not align with prediction since the sequence is highly stochastic. Ground truth is provided to provide an idea about quality only.
60
-
61
- Autoregressively generate minecraft video with 1x the length it's trained on:
62
- `python -m main +name=sample_minecraft_pretrained load=outputs/minecraft.ckpt experiment.tasks=[validation]`
63
-
64
- To let the model roll out **longer than it's trained on**, simply append `dataset.validation_multiplier=8` to the above commands, and it will rollout `8x` longer than maximum sequence length it's trained on.
65
-
66
- The above checkpoint is trained for 100K steps with small number of frames. We've already verified diffusion forcing works in latent diffusion setting and can be extended to many more tokens without sacrificing compositionally (with some addition techniques outside this repo)! Stay tuned for our next project!
67
-
68
- ### Maze Planning:
69
-
70
- The maze planning setting is changed a bit as we gain more insighs, please see corresponding paragraphs in training section for details. We haven't reimplemented MCTG yet, but you can already see nice visualizations on wandb log.
71
-
72
- Medium Maze
73
-
74
- `python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_medium dataset.action_mean=[] dataset.action_std=[] dataset.observation_mean=[3.5092521,3.4765592] dataset.observation_std=[1.3371079,1.52102] load=outputs/maze2d_medium_x.ckpt experiment.tasks=[validation] algorithm.guidance_scale=3 +name=maze2d_medium_x_sampling`
75
-
76
- Large Maze
77
-
78
- `python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_large dataset.observation_mean=[3.7296331,5.3047247] dataset.observation_std=[1.8070312,2.5687592] dataset.action_mean=[] dataset.action_std=[] load=outputs/maze2d_large_x.ckpt experiment.tasks=[validation] algorithm.guidance_scale=2 +name=maze2d_large_x_sampling`
79
-
80
- We also explored a couple more settings but haven't reimplemented everything in original paper yet. If you are interestted in those checkpoints, see the source code of this README file for ckpt loading instructions that's commented out.
81
-
82
- <!--
83
- Here is also a position + velocity setting ckpt, but we don't recommend this because diffusing quantity and its derivative together creates some bad optimization landscape.
84
-
85
- `python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_medium dataset.observation_std=[2.6742158,3.04204,9.3630628,9.4774808] dataset.action_mean=[] dataset.action_std=[] load=outputs/maze2d_medium_xv.ckpt experiment.tasks=[validation] algorithm.guidance_scale=4 +name=maze2d_medium_xv_sampling`
86
-
87
- `python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_large dataset.observation_std=[3.6140624,5.1375184,9.747382,10.5974788] dataset.action_mean=[] dataset.action_std=[] load=outputs/maze2d_large_xv.ckpt experiment.tasks=[validation] algorithm.guidance_scale=4 +name=maze2d_large_xv_sampling`
88
-
89
- Here is also ckpt where we take diffused actions,a challenging setting that's not done in prior papers. We haven't got it working as well as original RNN version of diffusion forcing, but it does have okay numbers. You can tune up the guidance scale a bit.
90
-
91
- `python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_medium dataset.observation_std=[2.67,3.04,8,8] dataset.action_std=[6,6] load=outputs/maze2d_medium_xva.ckpt experiment.tasks=[validation] algorithm.guidance_scale=2 algorithm.open_loop_horizon=10 +name=maze2d_medium_xva_sampling`
92
-
93
- `python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_large dataset.observation_std=[3.62,5.14,9.76,10.6] dataset.action_std=[3,3] load=outputs/maze2d_large_xva.ckpt experiment.tasks=[validation] algorithm.guidance_scale=2 algorithm.open_loop_horizon=10 +name=maze2d_large_xva_sampling` -->
94
-
95
- ## Training
96
-
97
- ### Video
98
-
99
- Video prediction requires downloading giant datasets. First, if you downloaded the mini subset following `Quick start with pretrained checkpoints` section, delete the mini subset folders `data/minecraft` and `data/dmlab` because we have to download the whole dataset this time. We've coded in python that it will download the dataset for you it doesn't already exist. Due to the slowness of the [source](https://github.com/wilson1yan/teco), this may take a couple days. If you prefer to do it yourself via bash script, please refer to the bash scripts in original [TECO dataset](https://github.com/wilson1yan/teco) and use `dmlab.sh` and `minecraft.sh` in their Dataset section of README, any maybe split bash script into parallel scripts.
100
-
101
- Then just run the corresponding commands:
102
-
103
- #### Minecraft
104
-
105
- `python -m main +name=your_experiment_name algorithm=df_video dataset=video_minecraft`
106
-
107
- #### DMLab
108
-
109
- `python -m main +name=your_experiment_name algorithm=df_video dataset=video_dmlab algorithm.weight_decay=1e-3 algorithm.diffusion.architecture.network_size=48 algorithm.diffusion.architecture.attn_dim_head=32 algorithm.diffusion.architecture.attn_resolutions=[8,16,32,64] algorithm.diffusion.beta_schedule=cosine`
110
-
111
- #### No causal masking
112
-
113
- Simply append `algorithm.causal=False` to your command.
114
-
115
- #### Play with sampling
116
-
117
- Please take a look at "Load a checkpoint to eval" paragraph to understand how to use load checkpoint with `load=`. Then, run the exact training command with `experiment.tasks=[validation] load={wandb_run_id}` to load a checkpoint and experiment with sampling.
118
-
119
- To see how you can roll out longer than the sequence is trained on, you can find instructions in `quick start with pretrained checkpoints` section. Keep in mind that rolling out infinitely without sliding window is a property of original RNN implementation on `paper` branch, and this version has to use sliding window since it's temporal attention.
120
-
121
- By default, we run autoregressive sampling with stablization. To sample next 2 tokens jointly, you can append the following to the above command: `algorithm.scheduling_matrix=full_sequence algorithm.chunk_size=2`.
122
-
123
- ## Maze Planning
124
-
125
- For those who only wish to reproduce the original paper instead of transformer architecture, please checkout`paper` branch of the code instead.
126
-
127
- **Medium Maze**
128
-
129
- `python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_medium dataset.action_mean=[] dataset.action_std=[] dataset.observation_mean=[3.5092521,3.4765592] dataset.observation_std=[1.3371079,1.52102] +name=maze2d_medium_x`
130
-
131
- **Large Maze**
132
-
133
- `python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_large dataset.observation_mean=[3.7296331,5.3047247] dataset.observation_std=[1.8070312,2.5687592] dataset.action_mean=[] dataset.action_std=[] +name=maze2d_large_x`
134
-
135
- **Run planning after model is trained**
136
-
137
- Please take a look at "Load a checkpoint to eval" paragraph to understand how to use load checkpoint with `load=`. To sample, simply append `load={wandb_id_of_above_runs} experiment.tasks=[validation] algorithm.guidance_scale=2 +name=maze2d_sampling` to above command after trained. Feel free to tune the `guidance_scale` from 1 - 5.
138
-
139
- This version of maze planning uses a different version of diffusion forcing from original paper - while doing the follow up to diffusion forcing, we realized that training with independent noise actually constructed a smooth interpolation between causal and non-causal models too, since we can just masked out future by complete noise (fully causal) or some noise (interpolation). The best thing is, you can still account for causal uncertainty via pyramoid sampling in this setting, by masking out tokens at different noise levels, and you can still have flexible horizon because you can tell the model that padded entries are pure noise, a unique ability of diffusion forcing.
140
-
141
- We also reflected a bit about the environment and concluded that the original metric isn't necessarily a good metric, because maze planning should reward those who can plan the fastest route to goal, not a slow walking agent that goes there at the end of episode. The dataset never contains data of staying at the goal, so agents are supposed to walk away after reaching the goal. I think [Diffuser](https://arxiv.org/abs/2205.09991) had an unfair advantage of just generating slow plans, that happend to let the agent stay in the neighbour hood of goal for longer and got very high reward, exploiting flaws in the environment design (a good design would involve penalty of longer time taken to reach goal). So, in this version of code, we just optimize for flexible horizon planning that tries to reach goal asap, and the planner will automatically come back to goal if it left the goal since staying is never in dataset. You can see new metrics we designed in wandb logging interface.
142
-
143
- ## Timeseries and Robotics
144
-
145
- Please checkout `paper` branch for the code used by original paper. If I have time later, I will reimplement these two domains with transformer as well to complete this branch.
146
-
147
- # Change Log
148
-
149
- | Data | Notes |
150
- | --------- | :---------------------------------------------------------------------------------------------: |
151
- | Jul/30/24 | Upgrade RNN to temporal attention, move orignal code to 'paper' branch |
152
- | Jul/03/24 | Initial release of the code. Email me if you have questions or find any errors in this version. |
153
-
154
- # Infra instructions
155
-
156
- This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
157
-
158
- All experiments can be launched via `python -m main +name=xxxx {options}` where you can fine more details later in this article.
159
-
160
- The code base will automatically use cuda or your Macbook M1 GPU when available.
161
-
162
- For slurm clusters e.g. mit supercloud, you can run `python -m main cluster=mit_supercloud {options}` on login node.
163
- It will automatically generate slurm scripts and run them for you on a compute node. Even if compute nodes are offline,
164
- the script will still automatically sync wandb logging to cloud with <1min latency. It's also easy to add your own slurm
165
- by following the `Add slurm clusters` section.
166
-
167
- ## Modify for your own project
168
-
169
- First, create a new repository with this template. Make sure the new repository has the name you want to use for wandb
170
- logging.
171
-
172
- Add your method and baselines in `algorithms` following the `algorithms/README.md` as well as the example code in
173
- `algorithms/diffusion_forcing/df_video.py`. For pytorch experiments, write your algorithm as a [pytorch lightning](https://github.com/Lightning-AI/lightning)
174
- `pl.LightningModule` which has extensive
175
- [documentation](https://lightning.ai/docs/pytorch/stable/). For a quick start, read "Define a LightningModule" in this [link](https://lightning.ai/docs/pytorch/stable/starter/introduction.html). Finally, add a yaml config file to `configurations/algorithm` imitating that of `configurations/algorithm/df_video.yaml`, for each algorithm you added.
176
-
177
- Add your dataset in `datasets` following the `datasets/README.md` as well as the example code in
178
- `datasets/video`. Finally, add a yaml config file to `configurations/dataset` imitating that of
179
- `configurations/dataset/video_dmlab.yaml`, for each dataset you added.
180
-
181
- Add your experiment in `experiments` following the `experiments/README.md` or following the example code in
182
- `experiments/exp_video.py`. Then register your experiment in `experiments/__init__.py`.
183
- Finally, add a yaml config file to `configurations/experiment` imitating that of
184
- `configurations/experiment/exp_video.yaml`, for each experiment you added.
185
-
186
- Modify `configurations/config.yaml` to set `algorithm` to the yaml file you want to use in `configurations/algorithm`;
187
- set `experiment` to the yaml file you want to use in `configurations/experiment`; set `dataset` to the yaml file you
188
- want to use in `configurations/dataset`, or to `null` if no dataset is needed; Notice the fields should not contain the
189
- `.yaml` suffix.
190
-
191
- You are all set!
192
-
193
- `cd` into your project root. Now you can launch your new experiment with `python main.py +name=<name_your_experiment>`. You can run baselines or
194
- different datasets by add arguments like `algorithm=xxx` or `dataset=xxx`. You can also override any `yaml` configurations by following the next section.
195
-
196
- One special note, if your want to define a new task for your experiment, (e.g. other than `training` and `test`) you can define it as a method in your experiment class and use `experiment.tasks=[task_name]` to run it. Let's say you have a `generate_dataset` task before the task `training` and you implemented it in experiment class, you can then run `python -m main +name xxxx experiment.tasks=[generate_dataset,training]` to execute it before training.
197
-
198
- ## Pass in arguments
199
-
200
- We use [hydra](https://hydra.cc) instead of `argparse` to configure arguments at every code level. You can both write a static config in `configuration` folder or, at runtime,
201
- [override part of yur static config](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli/) with command line arguments.
202
-
203
- For example, arguments `algorithm=example_classifier experiment.lr=1e-3` will override the `lr` variable in `configurations/experiment/example_classifier.yaml`. The argument `wandb.mode` will override the `mode` under `wandb` namesspace in the file `configurations/config.yaml`.
204
-
205
- All static config and runtime override will be logged to cloud automatically.
206
-
207
- ## Resume a checkpoint & logging
208
-
209
- For machine learning experiments, all checkpoints and logs are logged to cloud automatically so you can resume them on another server. Simply append `resume={wandb_run_id}` to your command line arguments to resume it. The run_id can be founded in a url of a wandb run in wandb dashboard. By default, latest checkpoint in a run is stored indefinitely and earlier checkpoints in the run will be deleted after 5 days to save your storage.
210
-
211
- On the other hand, sometimes you may want to start a new run with different run id but still load a prior ckpt. This can be done by setting the `load={wandb_run_id / ckpt path}` flag.
212
-
213
- ## Load a checkpoint to eval
214
-
215
- The argument `experiment.tasks=[task_name1,task_name2]` (note the `[]` brackets here needed) allows to select a sequence of tasks to execute, such as `training`, `validation` and `test`. Therefore, for testing a machine learning ckpt, you may run `python -m main load={your_wandb_run_id} experiment.tasks=[test]`.
216
-
217
- More generally, the task names are the corresponding method names of your experiment class. For `BaseLightningExperiment`, we already defined three methods `training`, `validation` and `test` for you, but you can also define your own tasks by creating methods to your experiment class under intended task names.
218
-
219
- ## Debug
220
-
221
- We provide a useful debug flag which you can enable by `python main.py debug=True`. This will enable numerical error tracking as well as setting `cfg.debug` to `True` for your experiments, algorithms and datasets class. However, this debug flag will make ML code very slow as it automatically tracks all parameter / gradients!
222
-
223
- ## Add slurm clusters
224
-
225
- It's very easy to add your own slurm clusters via adding a yaml file in `configurations/cluster`. You can take a look
226
- at `configurations/cluster/mit_vision.yaml` for example.
 
1
+ title: WORLDMEM: Long-term Consistent World Generation with Memory
2
+ emoji: 🎮
3
+ colorFrom: yellow
4
+ colorTo: yellow
5
+ sdk: gradio
6
+ sdk_version: 5.22.0
7
+ app_file: app.py
8
+ pinned: true
9
+ license: mit