All sorts of famous Attention Layers
I was interviewing in a startup and they asked me to code Flash Attention in 20 minutes. I got wreaked (and embarrassed). So, writing about it here.
Self-Attension
It's at the core of transformer models. Quadratic compute and memory complexity with respect to the input sequence length. Inference with long sequences (e.g. RAG applications) becomes very expensive.
Algorithm 0 - Standard Attention Implementation
Require: Matrices Q, K, V e RNxd in HBM.
1: Load Q, K by blocks from HBM, compute S = QKT, write S to HBM.
2: Read S from HBM, compute P = softmax(S), write P to HBM.
3: Load P and V by blocks from HBM, compute O = PV, write O to HBM.
4: Return O.
Clearly, as HBM (around 1.5 TB/s) is not the fastest thing off GPU (its not on GPU, its a chip nearby), the K,V being stored in it are problematic. So, its quadratic complexity for HBM accesses with respect to sequence length at inference is clearly bad at scale. Memory is a bottleneck.
You can refer Multi-Head Attention in BERT for reference.
Multi-Query Attention
It's almost like Self-Attension. Just that Vi and Ki (i being used by each head) is not required. We can use same set of K and V across heads. So, just one K and one V tensor shared across all heads. Thus, one head is all you need! So, a great optimization wrt to amount of data that would be required to be loaded via HBM. As the KV is cached as well, we need much less cache. Awesome! Less memory pressure (so you can batch more) and faster decoding on inference. But, there is a small accuracy drop as we have few params. Also, you have to train the model with MQA, can't just a MHA trained model and use MQA on inference. And, no Tensor parallelism as then we will kinda defeat the purpose by having KV replicated across clusters.
You can refer Falcon 7B for reference for MQA
Group Query Attension
Well, it's between MHA and MQA. Just adding another hyparam to the equation, pairing up (K, V) to some heads. This gives best of both world, a nice compromise balance between speed and accuracy. 4 and 8 were quite good. Interesting thing here is that MHA models can be uptrained (not really fine-tuning, just an upgrade) to GQA. And clearly a better fit to tensor parallelism.
This can be referenced from Llama 2
Sliding Window Attention
In vanilla attention, we compute attention score from all token, and at inference time we mask becuase we dont want decoding to look at the future. We have a triangle shaped attention mask which is quadratic. What SWA does is that it limits the self attention computation to a fixed window. So, we can't see more than window size from previous token. So, the max context size would be window size * number of layers, reducing attention complexity to linear. So, we are shortening the attention span.
You can refer Mistral 7B paper and reference sliding window causal mask code.
Faster Attention Layers
Flash Attenstion
As we know, HBM memory is slower to on-GPU memory. Wouldn't it be better to run the Self-Attension computation on GPU itself (with minimal HBM accesses)? Thats exactly what flash attention does.
Load Q and K from HBM once
Multiply Q and K, keep S in SRAM
Compute P incrementally in SRAM (tiling)
Write P back to HBM
And, parallize over batch size and number of heads. Taking N as sequence length, d as embedding length and M as size of SRAM (d<=M<=Nd), Flash Attention requires O(N2d2M-1) HBM accesses which still looks quadratic. But if M=N, then its O(Nd^2) HBM accesses, so linear wrt sequence length. This optimizes for forward and backward passes, so accelerate training.
Later, there was FlashAttension-2 that did some rewriting to reduce number of non-matmul operations to maximize GPU throughput. Also, it optimize operations for Multi-Query Attention and Grouped-Query Attention. Even more sequence parallelism. Its over staggering 9x faster than standard attention.
Refer the FlashAttension Paper-1 and paper 2.
Paged Attention
It's a famous vLLM optimization which enables the KV cache memory grows and shrinks dynamically for each inference request. The management of cache was kinda an old school OS problem in the hindsight. GPU memory fragmentation wastes memory and makes it difficult to increase batch size. So, Paged Attention simply divides the KV cache into fixed-size memory-aligned blocks (pages dont have memory between them), similar to virtual memory pages in operating systems and allocating pages reduces internal and external memory fragmentation.
Refer the paper.