Breaking the $O(N^2)$ ceiling: A Deep Dive into Flash Attention with Triton

<aside>

Prerequisites:


<aside>

Summary of the blog:

This blog contains a walkthrough of the code I wrote for FlashAttn v2 and how I gave support for Group Query Attention (referring to Dao AI Labs implementation). I’ll try to keep this a chill read with no long math derivations, covering key insights on difference between Flash Attn & v2. Yeah, so stay with me till the end ;)

</aside>

You can expect topics to be covered in the following order :]

  1. A quick brush-up on standard attn & flash attn
  2. Key differences between the two
  3. A walkthrough of the code
  4. How GQA is added
  5. Visualize “flash” in flash attention through benchmarks against standard attn.

Time to grab your favourite coffee (or any refresher) and enjoy the read 🥂

Standard Attention Bottleneck:

image.png

Flash Attention Idea:

Ummm Then what did they change in Flash Attention v2?

So here it is: