Paper Link: https://arxiv.org/abs/2205.14135
Idea behind Flash Attention is to reduce I/O read writes to HBM (which is significantly slower than GPU’s SRAM)
Hardware Performance (Background)
GPU Memory Hierarchy:
A100 has 40-80GB of high bandwidth memory (HBM) with bandwidth 1.5-2.0TB/S and 192KB of on-chip SRAM per each of 108 streaming multiprocessors with bandwidth ~19 TB/s
The on-chip SRAM is an order magnitude faster than HBM but many magnitudes smaller in size
Performance Characteristics: Depending on the balance of computation & memory access operations can be classified as compute-bound or memory-bound
- Compute Bound: Time taken by operation determined by how many arithmetic processes there are, while time accessing HBM is smaller
- Memory Bound: Time taken by operation determined by the number of memory accesses, while time spent in computation is smaller
Efficient Attention Algorithm (tiling + recomputation)
Tiling is the process of splitting the attention computation into small blocks that fit in fast on-chip GPU memory (SRAM)