Paper Link: https://arxiv.org/abs/2205.14135

Idea behind Flash Attention is to reduce I/O read writes to HBM (which is significantly slower than GPU’s SRAM)

Hardware Performance (Background)

GPU Memory Hierarchy:

A100 has 40-80GB of high bandwidth memory (HBM) with bandwidth 1.5-2.0TB/S and 192KB of on-chip SRAM per each of 108 streaming multiprocessors with bandwidth ~19 TB/s

The on-chip SRAM is an order magnitude faster than HBM but many magnitudes smaller in size

Performance Characteristics: Depending on the balance of computation & memory access operations can be classified as compute-bound or memory-bound

  1. Compute Bound: Time taken by operation determined by how many arithmetic processes there are, while time accessing HBM is smaller
  2. Memory Bound: Time taken by operation determined by the number of memory accesses, while time spent in computation is smaller

Efficient Attention Algorithm (tiling + recomputation)

Tiling is the process of splitting the attention computation into small blocks that fit in fast on-chip GPU memory (SRAM)