본문 바로가기
Paper Review/Model Compression

[논문 리뷰 - FlashAttention] FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness (NIPS 2022)

by je0nsye0n 2025. 3. 21.
본 포스팅은 "FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness (NIPS 2022)"을 Review한 것으로, 해당 강의를 통해 이해하는데 많은 도움을 받았습니다. (+ 자료 첨부)

 

Introduction

Attnetion 연산의 최적화를 위한 paper의 Intro에서는 Transformer의 한계를 언급한다. 본 논문에서도 역시 Transformer의 한계인 긴 문맥을 처리하는데 발생하는 시간과 메모리 복잡도의 제곱 증가를 지적한다. 따라서 빠르고 메모리 효율적인 Attn을 통해 Transformer의 긴 시퀀스 처리 시 직면하는 실행시간과 메모리 문제를 해결해야 한다.

 

 

당시의 기존 방법들은 approximation한 방법들로 연구가 되어왔다.

Sparse  [1] Reformer: The efficient transformer (2020)
 [2] Efficient content-based sparse attention with routing transformers (2021)
Low-Rank  [3] Rethinking attention with performers
 [4] Linformer: Self-attention with linear complexity (2020)
 [5] Transformers are RNNs: Fast autoregressive transformers with linear attention (2020)
결합  [6] Big bird: Transformers for longer sequences (2020)
 [7] Scatterbrain: Unifying sparse and low-rank attention (2021)

 

→ 이 방법은 FLOPs 감소에 초점을 맞추어, 계산 복잡도를 시퀀스 길이에 대해 선형 또는 근선형으로 줄였으나, 말 그대로 "approximation"하다. 즉 trade-off가 발생한다는 의미이다.

 

 

반면, 본 논문의 제목을 살펴보면 다음과 같은 특징을 발견할 수 있다.

 

- 빠르고 적은 공간으로 연산을 가능하게 하나, "Exact Attention". 즉, 연산 결과에 trade off가 발생하지 않는다는 것이다.

해당 방식을 IO-Aware 알고리즘을 도입하여 정확한 어텐션 연산을 수행할 수 있다고 주장한다.

 

✅ 핵심 아이디어

- 목표 : HBM과의 불필요한 Read/Write을 줄이면서 정확한 Attention 연산을 수행하는 새로운 알고리즘 제안
- 아이디어: IO-Aware 알고리즘의 원칙을 도입하여, SRAM과 HBM 사이의 Read/Write 오버헤드를 고려해야한다 주장. 현대 GPU에서는 계산 속도가 메모리 속도를 앞서고, Transformer의 대부분 연산이 메모리 접근에 의해 병목 현상이 일어난다고 판단

Background

모델의 핵심 연산을 이해하기 위해서는 논문에서 제시하는 Background를 보는 것이 중요하다.

 

[1] GPU Hierarchy

 

[2] Arithmetic Intensity

Compute bound SRAM에서의 FLOPs 연산이 병목인 현상 (ex. Matmul)
Memory bound HBM ↔ SRAM의 통신(IO)이 병목인 연산 (ex. Elem-wise, Reduction)

 

→ 현 시점의 SRAM 속도는 매우 빠르기 때문에 FLOPs를 줄이는 것보다 IO를 줄이는 것이 더 유리하다고 판단

 

[3] Motivation

 

Attn의 동작 과정을 다루어보면 다음과 같다.

 

QK를 행렬곱한 값을 S, softmax에 S를 넣은 값을 P, P와 V를 행렬곱한 값을 O라고 한다면

다음과 같은 과정의 Read(초록 화살표), Write(주황 화살표)가 발생한다.

 

본 논문에서는 S와 P에 대한 Read/Write과정을 없애고 아래와 같은 동작 방식을 구성한다는 의미이다!

 


Model

[모델의 전략]

  • O=Attn(Q,K,V)가 하나의 GPU 연산 호출이 되도록 융합 (즉. kernel fusion을 의미)
  • matmul에서 효율적인 연산을 위해 Tiling 진행

 

[Tiling Softmax]

  • Softmax의 한 행씩 이루어진다는 특징에 의해 matmul의 타일링을 위해서는 Softmax에도 타일링을 해주어야 한다.

 

 

Tiling Softmax

Softmax 계산을 블록 단위로 분할해서 계산하는 방식으로, HBM에서 한번에 모든 데이터를 접근하는 것이 아니라 작은 블록 단위로 계산하여 SRAM에 데이터를 불러오고 Softmax결과를 조합하는 방식

 

입력이 다음과 같을 때, 블록 단위(x_1과 x_2)로 나누었다고 하자.

 

그렇다면, 동작 흐름은 다음과 같다.

 

FlashAttention 알고리즘


Experiments

1. Training Efficiency

  • FLOPs는 그대로였음에도 IO를 줄이는 것이 저자들의 주장과 동일하게 속도에 영향을 미치는 것을 확인
    Encoder와 Decoder에 동일하게 적용

 

 

2. Long-Range Arena (LRA) Benchmark 결과 비교

  • LRA는 길이가 긴 시퀀스(1024~4096) 에서 다양한 Task (ListOps, Text, Retrieval, Image, Pathfinder)에 대해 정확도, 처리 속도(throughput), 학습 시간 등을 평가하는 벤치마크

 


3. Runtime and Memory usage

  • FlashAttn는 N이 커질 수록 큰 효과를 보이며, Memory Usage는 항상 Superior함