Breaking the $O(N^2)$ ceiling: A Deep Dive into Flash Attention with Triton
<aside>
Prerequisites:
<aside>
Summary of the blog:
This blog contains a walkthrough of the code I wrote for FlashAttn v2 and how I gave support for Group Query Attention (referring to Dao AI Labs implementation). I’ll try to keep this a chill read with no long math derivations, covering key insights on difference between Flash Attn & v2. Yeah, so stay with me till the end ;)
</aside>
You can expect topics to be covered in the following order :]
- A quick brush-up on standard attn & flash attn
- Key differences between the two
- A walkthrough of the code
- How GQA is added
- Visualize “flash” in flash attention through benchmarks against standard attn.
Time to grab your favourite coffee (or any refresher) and enjoy the read 🥂
Standard Attention Bottleneck:
- It stores huge N x N attention matrix on HBM and indeed this attention matrix is S = $Q * K^T$ which grows quadratically with sequence length. Then again loading S into SRAM to compute P = softmax(S), write again to HBM, again load P to compute O = P * V then again write it back to HBM
- This leads to excessive memory R/W (Read/Write) operations, which bottlenecks the GPU.

Flash Attention Idea:
- Recomputation: Do not store these huge matrix, since S & P matrix are required for backward pass so we recompute it on the fly.
- Tiling: Split Q,K,V blocks, load it to fast SRAM, calculate attention in blocks (note that softmax done over “local blocks” would be inaccurate as softmax operation needs a “global view” so for that we use Online softmax to keep track of “running max” & normalization factor.
Ummm Then what did they change in Flash Attention v2?
So here it is: