트랜스포머의 가장 큰 병목은 어텐션이다. 시퀀스 길이 n에 대해 n×n 어텐션 행렬을 만들어야 한다. n=1024면 1M 원소, n=16384(16K 컨텍스트)면 256M 원소다. 이 행렬을 GPU 메모리(HBM)에 읽고 쓰는 것이 병목이다.

Stanford의 Dao 등이 2022년 발표한 FlashAttention은 이 문제를 알고리즘 수준에서 해결했다.

GPU 메모리 계층

GPU에는 두 종류의 메모리가 있다.

HBM (High Bandwidth Memory): GPU 메모리라고 부르는 것. A100 기준 80GB, 대역폭 2TB/s. 크지만 느리다.

SRAM (Static RAM): GPU 코어 내부의 공유 메모리(shared memory). A100 기준 192KB/SM, 대역폭 19TB/s. 작지만 10배 빠르다.

기존 어텐션은 HBM에서 Q, K, V를 읽어 n×n 어텐션 행렬을 계산하고 다시 HBM에 쓴다. HBM 접근이 병목이다.

IO-Aware 타일링

FlashAttention의 핵심은 어텐션 행렬을 HBM에 쓰지 않는 것이다.

Q, K, V를 작은 블록(tile)으로 나눠 SRAM에 올려놓고 한 번에 계산한다. 어텐션 행렬을 전체 생성하지 않고 블록 단위로 처리하면서 최종 출력만 HBM에 쓴다.

기존 어텐션:
HBM → Q,K,V 로드 → S = QK^T 계산 → HBM 저장
HBM → S 로드 → P = softmax(S) 계산 → HBM 저장
HBM → P,V 로드 → O = PV 계산 → HBM 저장

FlashAttention:
HBM → Q_block, K_block, V_block 로드 (작은 블록) → SRAM
SRAM → 블록 단위 어텐션 계산 (HBM 저장 없음)
최종 출력만 HBM에 저장

수학적으로 정확히 같은 결과를 내면서 HBM 접근 횟수를 줄인다. 온라인 소프트맥스(online softmax)로 소프트맥스를 블록 단위로 점진적으로 계산한다.

성능

항목기존 어텐션FlashAttention
메모리 복잡도O(n²)O(n)
속도기준2~4배 빠름
수치 정확도기준동일
역전파 지원OO

메모리 O(n)이 핵심이다. 16K 컨텍스트에서 기존 어텐션은 256M×4바이트 = 1GB가 어텐션 행렬에만 필요하다. FlashAttention은 이 행렬을 저장하지 않아 수십 배 적은 메모리로 같은 컨텍스트 길이를 처리한다.

긴 컨텍스트를 가능하게 하다

FlashAttention이 없었다면 현재의 100K+ 컨텍스트 LLM은 불가능했다. 컨텍스트 길이가 늘어날수록 FlashAttention의 메모리 절약 효과가 커진다.

GPT-4, Claude, LLaMA 2 이후 대부분의 대형 모델이 FlashAttention을 기본으로 사용한다.

FlashAttention-2 (2023)

작업 분할 방식을 개선해 FlashAttention 대비 2배 추가 속도 향상을 달성했다. 시퀀스 길이 방향으로 병렬화해 GPU 점유율을 높였다.

FlashAttention-3 (2024)

H100 GPU의 비동기 실행(async warpgroups)과 FP8 저정밀도를 활용한다. A100 대비 H100에서 1.5~2배 추가 속도 향상.

PagedAttention과의 차이

086이 학습/추론 효율화라면, PagedAttention(vLLM, 092에서 다룸)은 서빙 효율화다. FlashAttention은 단일 요청의 어텐션 계산을 빠르게 하고, PagedAttention은 다수 요청의 KV 캐시를 효율적으로 관리한다. 두 기술은 상호 보완적이다.

트레이드오프

CUDA 커널을 직접 작성해야 해 구현이 복잡하다. PyTorch 기본 연산으로는 구현할 수 없고, 하드웨어(NVIDIA GPU)에 종속적이다. AMD GPU나 Apple Silicon에서는 별도 구현이 필요하다. 그러나 PyTorch 2.0부터 scaled_dot_product_attention 함수가 FlashAttention을 내부적으로 사용해 사용자가 직접 신경 쓸 필요가 없어졌다.