theonlyengine commited on
Commit
036610e
·
verified ·
1 Parent(s): 4a98549

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +480 -0
README.md ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FlashAttention
2
+ This repository provides the official implementation of FlashAttention and
3
+ FlashAttention-2 from the
4
+ following papers.
5
+
6
+ **FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**
7
+ Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
8
+ Paper: https://arxiv.org/abs/2205.14135
9
+ IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention.
10
+ ![FlashAttention](assets/flashattn_banner.jpg)
11
+
12
+ **FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning**
13
+ Tri Dao
14
+
15
+ Paper: https://tridao.me/publications/flash2/flash2.pdf
16
+
17
+ ![FlashAttention-2](assets/flashattention_logo.png)
18
+
19
+
20
+ ## Usage
21
+
22
+ We've been very happy to see FlashAttention being widely adopted in such a short
23
+ time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md)
24
+ contains a partial list of places where FlashAttention is being used.
25
+
26
+ FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
27
+ Please cite and credit FlashAttention if you use it.
28
+
29
+
30
+ ## FlashAttention-3 beta release
31
+ FlashAttention-3 is optimized for Hopper GPUs (e.g. H100).
32
+
33
+ Blogpost: https://tridao.me/blog/2024/flash3/
34
+
35
+ Paper: https://tridao.me/publications/flash3/flash3.pdf
36
+
37
+ ![FlashAttention-3 speedup on H100 80GB SXM5 with FP16](assets/flash3_fp16_fwd.png)
38
+
39
+ This is a beta release for testing / benchmarking before we integrate that with
40
+ the rest of the repo.
41
+
42
+ Currently released:
43
+ - FP16 forward and backward
44
+
45
+ Coming soon in the next couple of days / next week:
46
+ - BF16
47
+ - Variable length (FP16, BF16)
48
+ - FP8 forward.
49
+
50
+ Requirements: H100 / H800 GPU, CUDA >= 12.3.
51
+
52
+ To install:
53
+ ```sh
54
+ cd hopper
55
+ python setup.py install
56
+ ```
57
+ To run the test:
58
+ ```sh
59
+ export PYTHONPATH=$PWD
60
+ pytest -q -s test_flash_attn.py
61
+ ```
62
+
63
+
64
+
65
+ ## Installation and features
66
+
67
+ Requirements:
68
+ - CUDA 11.6 and above.
69
+ - PyTorch 1.12 and above.
70
+ - Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
71
+
72
+ We recommend the
73
+ [Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
74
+ container from Nvidia, which has all the required tools to install FlashAttention.
75
+
76
+ To install:
77
+ 1. Make sure that PyTorch is installed.
78
+ 2. Make sure that `packaging` is installed (`pip install packaging`)
79
+ 3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
80
+ --version` then `echo $?` should return exit code 0). If not (sometimes `ninja
81
+ --version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
82
+ `ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
83
+ compiling can take a very long time (2h) since it does not use multiple CPU
84
+ cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
85
+ 4. Then:
86
+ ```sh
87
+ pip install flash-attn --no-build-isolation
88
+ ```
89
+ Alternatively you can compile from source:
90
+ ```sh
91
+ python setup.py install
92
+ ```
93
+
94
+ If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might
95
+ run too many parallel compilation jobs that could exhaust the amount of RAM. To
96
+ limit the number of parallel compilation jobs, you can set the environment
97
+ variable `MAX_JOBS`:
98
+ ```sh
99
+ MAX_JOBS=4 pip install flash-attn --no-build-isolation
100
+ ```
101
+
102
+ Interface: `src/flash_attention_interface.py`
103
+
104
+ FlashAttention-2 currently supports:
105
+ 1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
106
+ GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
107
+ GPUs for now.
108
+ 2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
109
+ 3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
110
+
111
+
112
+ ## How to use FlashAttention
113
+
114
+ The main functions implement scaled dot product attention (softmax(Q @ K^T *
115
+ softmax_scale) @ V):
116
+ ```python
117
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
118
+ ```
119
+
120
+ ```python
121
+ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
122
+ window_size=(-1, -1), alibi_slopes=None, deterministic=False):
123
+ """dropout_p should be set to 0.0 during evaluation
124
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
125
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
126
+ of the gradients of Q, K, V.
127
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
128
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
129
+ Arguments:
130
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
131
+ dropout_p: float. Dropout probability.
132
+ softmax_scale: float. The scaling of QK^T before applying softmax.
133
+ Default to 1 / sqrt(headdim).
134
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
135
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
136
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
137
+ the attention score of query i and key j.
138
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
139
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
140
+ Return:
141
+ out: (batch_size, seqlen, nheads, headdim).
142
+ """
143
+ ```
144
+
145
+ ```python
146
+ flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
147
+ window_size=(-1, -1), alibi_slopes=None, deterministic=False):
148
+ """dropout_p should be set to 0.0 during evaluation
149
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
150
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
151
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
152
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
153
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
154
+ will only attend to keys between
155
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
156
+
157
+ Arguments:
158
+ q: (batch_size, seqlen, nheads, headdim)
159
+ k: (batch_size, seqlen, nheads_k, headdim)
160
+ v: (batch_size, seqlen, nheads_k, headdim)
161
+ dropout_p: float. Dropout probability.
162
+ softmax_scale: float. The scaling of QK^T before applying softmax.
163
+ Default to 1 / sqrt(headdim).
164
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
165
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
166
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
167
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
168
+ is added to the attention score of query i and key j.
169
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
170
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
171
+ Return:
172
+ out: (batch_size, seqlen, nheads, headdim).
173
+ """
174
+ ```
175
+
176
+ ```python
177
+ def flash_attn_with_kvcache(
178
+ q,
179
+ k_cache,
180
+ v_cache,
181
+ k=None,
182
+ v=None,
183
+ rotary_cos=None,
184
+ rotary_sin=None,
185
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
186
+ cache_batch_idx: Optional[torch.Tensor] = None,
187
+ block_table: Optional[torch.Tensor] = None,
188
+ softmax_scale=None,
189
+ causal=False,
190
+ window_size=(-1, -1), # -1 means infinite context window
191
+ rotary_interleaved=True,
192
+ alibi_slopes=None,
193
+ ):
194
+ """
195
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
196
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
197
+ the previous step, and update them with the new keys/values from the current step, and do
198
+ attention with the updated cache, all in 1 kernel.
199
+
200
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
201
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
202
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
203
+
204
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
205
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
206
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
207
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
208
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
209
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
210
+
211
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
212
+
213
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
214
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
215
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
216
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
217
+
218
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
219
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
220
+ 1 1 1 1 0
221
+ 1 1 1 1 1
222
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
223
+ 0 0
224
+ 0 0
225
+ 0 0
226
+ 1 0
227
+ 1 1
228
+ If the row of the mask is all zero, the output will be zero.
229
+
230
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
231
+ will only attend to keys between
232
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
233
+
234
+ Note: Does not support backward pass.
235
+
236
+ Arguments:
237
+ q: (batch_size, seqlen, nheads, headdim)
238
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
239
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
240
+ page_block_size must be a multiple of 256.
241
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
242
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
243
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
244
+ k with k_cache, starting at the indices specified by cache_seqlens.
245
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
246
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
247
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
248
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
249
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
250
+ KV cache.
251
+ block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
252
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
253
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
254
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
255
+ might come from any of the duplicate indices.
256
+ softmax_scale: float. The scaling of QK^T before applying softmax.
257
+ Default to 1 / sqrt(headdim).
258
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
259
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
260
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
261
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
262
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
263
+ (i.e. GPT-NeoX style).
264
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
265
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
266
+ is added to the attention score of query i and key j.
267
+
268
+ Return:
269
+ out: (batch_size, seqlen, nheads, headdim).
270
+ """
271
+ ```
272
+
273
+ To see how these functions are used in a multi-head attention layer (which
274
+ includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).
275
+
276
+ ## Changelog
277
+
278
+ ### 2.0: Complete rewrite, 2x faster
279
+ Upgrading from FlashAttention (1.x) to FlashAttention-2
280
+
281
+ These functions have been renamed:
282
+ - `flash_attn_unpadded_func` -> `flash_attn_varlen_func`
283
+ - `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func`
284
+ - `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func`
285
+
286
+ If the inputs have the same sequence lengths in the same batch, it is simpler
287
+ and faster to use these functions:
288
+ ```python
289
+ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
290
+ ```
291
+ ```python
292
+ flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
293
+ ```
294
+ ### 2.1: Change behavior of causal flag
295
+
296
+ If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
297
+ bottom right corner of the attention matrix, instead of the top-left corner.
298
+
299
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 =
300
+ masked out) is:
301
+ v2.0:
302
+ 1 0 0 0 0
303
+ 1 1 0 0 0
304
+ v2.1:
305
+ 1 1 1 1 0
306
+ 1 1 1 1 1
307
+
308
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
309
+ v2.0:
310
+ 1 0
311
+ 1 1
312
+ 1 1
313
+ 1 1
314
+ 1 1
315
+ v2.1:
316
+ 0 0
317
+ 0 0
318
+ 0 0
319
+ 1 0
320
+ 1 1
321
+ If the row of the mask is all zero, the output will be zero.
322
+
323
+ ### 2.2: Optimize for inference
324
+
325
+ Optimize for inference (iterative decoding) when query has very small sequence
326
+ length (e.g., query sequence length = 1). The bottleneck here is to load KV
327
+ cache as fast as possible, and we split the loading across different thread
328
+ blocks, with a separate kernel to combine results.
329
+
330
+ See the function `flash_attn_with_kvcache` with more features for inference
331
+ (perform rotary embedding, updating KV cache inplace).
332
+
333
+ Thanks to the xformers team, and in particular Daniel Haziza, for this
334
+ collaboration.
335
+
336
+ ### 2.3: Local (i.e., sliding window) attention
337
+
338
+ Implement sliding window attention (i.e., local attention). Thanks to [Mistral
339
+ AI](https://mistral.ai/) and in particular Timothée Lacroix for this
340
+ contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.
341
+
342
+ ### 2.4: ALiBi (attention with linear bias), deterministic backward pass.
343
+
344
+ Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.
345
+
346
+ Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.
347
+
348
+ ### 2.5: Paged KV cache.
349
+
350
+ Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)).
351
+ Thanks to @beginlner for this contribution.
352
+
353
+ ### 2.6: Softcapping.
354
+
355
+ Support attention with softcapping, as used in Gemma-2 and Grok models.
356
+ Thanks to @Narsil and @lucidrains for this contribution.
357
+
358
+ ## Performance
359
+
360
+ We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
361
+
362
+ We currently have benchmarks for these GPUs:
363
+ * [A100](#a100)
364
+ * [H100](#h100)
365
+ <!-- * [RTX 3090](#rtx-3090) -->
366
+ <!-- * [T4](#t4) -->
367
+
368
+ ### A100
369
+
370
+ We display FlashAttention speedup using these parameters:
371
+ * Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads).
372
+ * Sequence length 512, 1k, 2k, 4k, 8k, 16k.
373
+ * Batch size set to 16k / seqlen.
374
+
375
+ #### Speedup
376
+
377
+ ![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png)
378
+
379
+ #### Memory
380
+
381
+ ![FlashAttention memory](assets/flashattn_memory.jpg)
382
+
383
+ We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking).
384
+ Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length.
385
+ We see 10X memory savings at sequence length 2K, and 20X at 4K.
386
+ As a result, FlashAttention can scale to much longer sequence lengths.
387
+
388
+ ### H100
389
+
390
+ ![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png)
391
+
392
+ ## Full model code and training script
393
+
394
+ We have released the full GPT model
395
+ [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py).
396
+ We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
397
+ cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
398
+ compared to the baseline implementation from Huggingface, reaching up to 225
399
+ TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need
400
+ any activation checkpointing).
401
+
402
+ We also include a training
403
+ [script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to
404
+ train GPT2 on Openwebtext and GPT3 on The Pile.
405
+
406
+ ## Triton implementation of FlashAttention
407
+
408
+ Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
409
+ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
410
+
411
+ As Triton is a higher-level language than CUDA, it might be easier to understand
412
+ and experiment with. The notations in the Triton implementation are also closer
413
+ to what's used in our paper.
414
+
415
+ We also have an experimental implementation in Triton that support attention
416
+ bias (e.g. ALiBi):
417
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
418
+
419
+
420
+ ## Tests
421
+ We test that FlashAttention produces the same output and gradient as a reference
422
+ implementation, up to some numerical tolerance. In particular, we check that the
423
+ maximum numerical error of FlashAttention is at most twice the numerical error
424
+ of a baseline implementation in Pytorch (for different head dimensions, input
425
+ dtype, sequence length, causal / non-causal).
426
+
427
+ To run the tests:
428
+ ```sh
429
+ pytest -q -s tests/test_flash_attn.py
430
+ ```
431
+ ## When you encounter issues
432
+
433
+ This new release of FlashAttention-2 has been tested on several GPT-style
434
+ models, mostly on A100 GPUs.
435
+
436
+ If you encounter bugs, please open a GitHub Issue!
437
+ ## AMD GPU/ROCm Support
438
+ ROCm version use [composable_kernel](https://github.com/ROCm/composable_kernel) as backend. It provides the implementation of FlashAttention-2.
439
+
440
+ ## Installation and features
441
+ Requirements:
442
+ - ROCm 6.0+
443
+ - PyTorch 1.12.1+
444
+
445
+ We recommend the
446
+ [Pytorch](https://hub.docker.com/r/rocm/pytorch)
447
+ container from ROCm, which has all the required tools to install FlashAttention.
448
+
449
+ To compile from source:
450
+ ```sh
451
+ python setup.py install
452
+ ```
453
+
454
+ FlashAttention-2 on ROCm currently supports:
455
+ 1. MI200 or MI300 GPUs.
456
+ 2. Datatype fp16 and bf16
457
+ 3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
458
+
459
+ ## Tests
460
+ To run the tests:
461
+ ```sh
462
+ pytest tests/test_flash_attn_ck.py
463
+ ```
464
+
465
+ ## Citation
466
+ If you use this codebase, or otherwise found our work valuable, please cite:
467
+ ```
468
+ @inproceedings{dao2022flashattention,
469
+ title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
470
+ author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
471
+ booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
472
+ year={2022}
473
+ }
474
+ @inproceedings{dao2023flashattention2,
475
+ title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
476
+ author={Dao, Tri},
477
+ booktitle={International Conference on Learning Representations (ICLR)},
478
+ year={2024}
479
+ }
480
+ ```