Efficiency is the Bottleneck for the Modeling Long Sequences with Attention.
→ Context Length가 늘어나면 학습속도가 줄어들거나 OOM 문제로 학습이 중단될 수도 있다.
Our observation: Attention is Bottlenecked by Memory Reads
→ Score matrix를 계산하고 Attention을 계산하는것보다, 이 과정에사 Memory에 접근해서 R/W하는게 더 많은 cost를 야기한다.
→ 이걸 이해하기 위해서는 GPU hardware 구조를 이해할 필요가 있다. (CPU, DRAM이랑 거의 유사함)
•
HBM: GPU memory라고 보면 좋을 것 같다.
•
Compute Unit: 연상장치
•
SRAM: small cache (L1,L2 cache랑 유사 → 메모리 공간 작고 속도 빠름)
How to reduce HBM R/W: Compute by Blocks
#### Challenges
(1) Compute softmax challenges normalization without access to full input. (softmax를 계산하려면 전체 row term을 전부 다 계산해야함으로 block처리를 하기 어렵다)
(2) Back-prop 계산하려면 forward pass때 계산했던 large attention matrixr가 필요하다.
#### Approaches
(1) Tiliing: Restructure the algorithm to load block by block from HBM to SRAM to compute attention.
→ Tiliing이라는 기술을 통해 block단위로 attention_score*value, value를 sram에 올린 후 계산
(2) Recomputation: Don’t store attn. matrix from forward, recompute it in the backward.
→ Attention Matrix를 VRAM에 가지고 있지 않음. normlized term만 가지고 있고, back-prop때 다시 계산함. 따라서 당연히 flops는 증가한다고 함.
Version 02
Flash_Attention-2에서 이 파티션을 개선하여 서로 다른 워프 간의 동기화 및 통신 양을 줄임으로써 공유 메모리 읽기 및 쓰기를 줄였습니다 .
→ Ver01에는 K,V를 4개의 워프에 걸쳐서 분할하고 모든 워프에서 Q에 엑세스할 수 있도록 했다. 하지만 이럴 경우, 모든 워프가 중간 결과를 공유 메모리에 쓰고 동기화한 다음 주간 결과를 추가해야하기 때문에 비효율적이다
→ Ver02에는 이를 개선하기 위해 Q를 4개 워프에 걸쳐서 분할하면, 워프간에 통신이 필요 없다. 각각 하나의 워프가 바로 K랑 곱하고 → 마찬가지로 하나의 워프가 바로 V랑 곱하면 되기 때문이다. 결론적으로 공유 메모리 및 쓰기의 감소는 속도를 향상시킨다.
→ Flash_AttentionVER01은 헤드 치수를 최대 128까지만 지원하며 대부분의 모델에서 작동하지만 일부 모델은 제외됩니다. 하지만 Flash_AttentionVER02는 최대 256개의 헤드 치수를 지원.