Search
✳️

DeepSpeed - Sharding Optimizer, Gradients, Parameters, and Reducing Activations for Efficient Training

Category
BlogPost
Venue
Backbone
Text
PPT

0. Preliminary: Where Did All the Memory Go?

우리는 pytorch 학습시 memory를 어디에서 소모하게 될까요? 모델, 특히 LLM을 학습하게 되면 VRAM이 부족한 상황을 늘 겪게 되는데요. Deepspeed 논문에서 설명하는 VRAM이 소모되는 순간들은 아래과 같습니다.

Model States: Optimizer States, Gradients and Parameters

모델을 학습하기 위해서는 GPU에 크게 3가지가 올라가야 합니다.
Model Parameters
Gradients
⇒ Model Parameters와 Gradients는 동일한 memory footprints를 차지합니다.
Optimizers
⇒ 어떤 optimizer를 사용하느냐에 따라 크게 달라지는데요, Adam과 같이 gradient의 (1) first order moments(기울기의 지수평균)와 (2) second order moments(기울기 제곱의 지수 평균: 크기가 큰 기울기에는 작은 학습률을, 작은 기울기에는 큰 학습률을 적용:how? 이 term을 분모로 씀)를 동시에 활용하는 optimizer는 Model Parameters의 2배를 memory footprints로 차지시킵니다.
Model Parameters과 Gradients의 memory footprints를 감소시키기 위해 parameters(config에 따라 다름)와 activations을 half precision으로 쓰는 Mixed Precision이 많이 활용되기도 하는데요.
optimizer에서는 (1) first order moments와 (2) second order moments를 계산해야하기 때문에 fp32를 사용합니다.
Half Precision으로 Model Copy를 뜨는 순간 실질적인 memory footprints 감소가 크지 않게 되며, Model Copy를 뜨지 않아도 LLM 같이 큰 모델을 학습시 Mixed precision의 효과를 극대화하려면 Half Precision인 activation을 최대한 많이 활용하는 실험환경이어야 된다고 생각합니다. (개인적으로 이거보단 model parameters의 precision를 처음부터 낮추던가 그냥 full precision으로 학습하는게 더 좋은 선택지라고 생각합니다.)

Residual Memory Consumption

Optimizer States, Gradients and Parameters을 제외하고도 학습중에도 memory footprints를 발생시키는 요인들이 더 있는데요.
Activations
가장 대표적인게 hidden state를 통과하고 나온 representations인 activations입니다.
이 activations는 bsz랑 seq_len이 길어짐에 따라 기하급수적으로 memory footprints를 증가시키는데요, 1.5B parameter GPT-2기준 1K, bsz32에서 60GB의 메모리를 발생시킨다고 합니다.
Activation memory footprint를 해결하는 방법은 특정 layer의 gradient를 계산할때마다 forwarding을 다시 해주는 Activation checkpointing이나 Gradient Accumulation등을 통해서 Activation으로 인한 과부하를 줄여줄 수 있겠죠.
Temporary buffers
Sharding을 하다보면 gradient partition에 all-reduce를 호출을 하거나 grad_norm 연산을 하기전에 gradient를 single flattened buffer에 먼저 넣어놓아야하는데요. 이를 위해 allocated 해 놓은 Temporary buffers도 non-trival한 memory를 차지하게 됩니다.
Memory Fragmentation
아래에서 설명한 다양한 테크닉들이 적용되거나 큰 모델 학습시, 극단적으로 30%의 memory가 남았음에도 할당되지 못한 경우가 존재한다고 합니다.

1. What is Deepspeed?

Deepspeed는 위에서 설명한 문제점들을 해결하기 위해 code 몇줄만 추가하면 pytorch 위에 wrapping을 진행해 줌으로써 ‘distributed training, mixed precision, gradient accumulation와 같이 모델 개발에 있어서 필요한 기능을 효율적으로 지원해주도록 개발된 프레임워크라고 생각하시면 됩니다.
논문에서 베이스라인으로 두는 방법론은 DP(Data Parallelism)으로 (1) Model, Optimizer를 모두 여러 device에 복사한 후 (2) 각 device에서 다른 Mini-batch로 forward와 backward를 진행한 후 (3) gradient 평균으로 모델을 동기화하는 방법을 언급하고 있습니다.
논문에서 이야기 하는 DP 방법론은 pytorch에서 이야기하는 DP보다는 DDP에 더 가까운 방법론이라고 보는게 맞는거 같습니다.
그렇다면 DeepSpeed에서는 대규모 모델의 memory footprint문제를 해결하기 위해서 어떤 테크닉을 제시했을까요?
(스포를 하자면) FSDP랑 크게 다르다고 느껴지지는 않으니, FSDP에 익숙하신 분들은 더 빠르게 이해하실 수 있을꺼라고 생각합니다.

2. ZeRO

DeepSpeed에서는 2개의 optimization을 제안했습니다.
1.
ZeRO-DP: Model States: Optimizer States, Gradients and Parameters 에서 발생한 memory footprint를 감소시키기 위한 방법론을 제안했습니다.
2.
ZeRO-R: Residual Memory Consumption 에서 발생한 memory footprint를 감소시키기 위한 방법론을 제안했습니다.
NdN_{d} : device 수
Ψ\Psi : parameter 수 (7.5B)
KK: Optimizer의 memory multiplier (12)

ZeRO-DP

PosP_{os} : Optimizer State Partitioning (Stage-1)

optimizer state만 NdN_{d} 에 분할되는 형식입니다. 여기서 optmizer state가 분할 된다는 것은, Adam기준으로 특정 layer에 대응되는 first momentum, second momentum은 0번째 device에, 다른 layer에 대응되는 first momentum, second momentum은 1번째 device에 분할되었다고 이해하시면 됩니다.
(1) model parameter와 gradient는 모든 device에 있는 상황임으로, (2) parameter update시에는 device내에 있는 optimizer는 device가 동일한 model parameter만 update해준 후 AllGather 연산을 해주어 모든 layer의 updated parameter를 동기화해면 됩니다.
#### Memory Saving
Before: 2Ψ+2Ψ+KΨ2\Psi + 2\Psi + K\Psi
After: 2Ψ+2Ψ+KΨNd2\Psi + 2\Psi + \frac{K\Psi}{N_{d}}
NdN_{d} 가 64일 경우 위의 config에서 4배이상 메모리 감소가 가능해집니다.

Pos+gP_{os+g} : Gradient Partitioning (Stage-2)

optimizer state partitioning을 보면 의문점이 생깁니다. partitioned된 optimizer는 partitioned된 parameter의 gradient만을 업데이트하고 나머지 gradient는 활용하지 않습니다. (불필요한 gradient를 매 optimizer step 전후로 들고 있다고 봐도 무방하죠.)
예를 들어 device가 4개 있고 각 device에 있는 model parameter의 마지막 layer에 대해서 backprop()이 진행된다고 할때,
3번 GPU에 마지막 layer를 update하는 optimizer stater가 있다면,
4개 device가 backprop을 마치면 3번 GPU로 마지막 layer의 gradient를 reduce-scatter한 후 0,1,2번 device는 gradient를 지우는 겁니다 (다른 layer도 마찬가지).
3번 GPU에 마지막 layer에 있는 gradient는 Optimizer State Update하고 이후에 Model Parameter도 Update합니다. (각 GPU마다 대응되는 layer gradient가 있고, 그 layer gradient로 Optimizer State Update하고 모델도 업데이트 합니다)
그리고 AllGather 연산을 해주어 모든 layer의 updated parameter를 동기화해줍니다.
논문에서는 gradient reduction할 partition들을 bucketize해놓고 한번에 reduction 해놓았다고 합니다. (코드를 확인 안해봤지만 아마 당연히 partition-wise로 진행하지 않았을까 싶습니다. 아마 4개 device가 backprop을 마치면 3번 GPU로 마지막 layer의 gradient를 reduce-scatter한 후 0,1,2번 device는 gradient를 지우는 겁니다 (다른 layer도 마찬가지). ← 이 이야기인거 같습니다.)
논문 문단 서두에 partition boundaries마담 reduce 연산 수행했다고 하니 맞는거 같습니다
gradient와 optimizer의 redundancy를 동시에 줄이는 방법으로 baseline DP와 동일한 communication volume이 발생한다고 합니다.
#### Memory Saving
Before: 2Ψ+2Ψ+KΨ2\Psi + 2\Psi + K\Psi
After: 2Ψ+2Ψ+KΨNd2\Psi + \frac{2\Psi+K\Psi}{N_{d}}
NdN_{d} 가 64일 경우 위의 config에서 8배이상 메모리 감소가 가능해집니다.

Pos+g+pP_{os+g+p} : Parameter Partitioning (Stage-3)

모든 device에 걸쳐서 model parameter를 sharding해 올리는 겁니다. FSDP랑 원리가 사실상 동일하며, 영상의 중간중간 설명을 통해 기록하는게 더 효율적일거라 판단이 됩니다.
실질적으로는 (Pos+g+pP_{os+g+p})가 다같이 쓰이며 Zero-3라고 불립니다.
[Setting]
4 device
Model을 Copy하는 Mixed Precision
각 device에 서로 다른 mini-batch가 forward되는 setting입니다.
위의 각 device 위의 청록색 줄은 activation들을 저장하기 위해 미리 할당해 놓은 buffer입니다.
(Pos+g+pP_{os+g+p})인 parameter, optimizer, gradient를 미리 partition 해놓기 위해 device들을 할당해놓았습니다.
model parameter는 sharding 후 각 device에 loading해 놓습니다.
[Forwarding]
0번 device에 있는 model을 broadcast해 모든 device에 복사해줍니다.
각 device에 할당된 mini-batch를 partition후 broadcast된 M0을 활용해 forwarding해줍니다.
activation buffer를 보면 일부 activation들은 vram에 안남겨둔 걸 확인할 수 있습니다.
M0에 대한 fowarding이 끝나면 device 1,2,3에서는 M0을 지워줍니다.
M1,M2,M3에 대해서 마찬가지의 작업을 진행합니다.
[Backprop]
forwarding 이후 (broadcast) M3는 모든 device에서 남겨줍니다 (loss 계산이 되었고, 그 loss로 backprop을 하기 위해).
각 mini-batch에서 학습한 loss를 가지고 backprop을 진행해줍니다.
이 때, device 0,1,2는 gradient 저장을 위한 temporary buffer (희미한 노란색)을 메모리에서 할당해줍니다.
또한, 중간중간 저장하지 않은 M3에 해당하는 activation은 cpu에서 가져오던가 activation checkpointing을 진행해줍니다. (그림에서 activation 부분 채워짐)
각 device에 M3에 해당하는 activation들이 다 채워졌으면 각 device의 M3에 해당하는 부분에 대해서 backprop을 진행해줍니다.
gradient tensor를 reduce해 GPU3의 M3에 모아줍니다. (reduction)
M3에 gradient reduction이 끝나면 device 0,1,2 temporary buffer를 지우고, 모든 device의 (M3 영역에 대한) activation를 지웁니다.
이후 M2에 대해서 backprop을 위해서 device 2에 있는 M2를 모든 device에 broadcasting하고 위의 과정을 계속 반복해줍니다.
각 device에 gradient 계산이 완료되었으면, optimizer momentum 계산도 완료해줍니다.
현재 Setting은 mixed precison임으로 fp32상태에서 optimizer step을 진행(GPU마다 진행: fp16 gradient copy→ fp32 gradient → momentum term calculate → fp32 weight update) 한 후 fp16 weight copy를 만든 후 다시 forward를 위해 재배치 시키는 것을 확인할 수 있습니다. (짙은 파란색이 device의 첫번째 행으로 이동)
gpu간 통신이 늘어나기 때문에 baseline DP 대비 communication overhead가 1.5x 증가하지만, memory의 경우 NdN_{d} 에 비례하게 감소시킬 수 있다고 합니다.
#### Memory Saving
Before: 2Ψ+2Ψ+KΨ2\Psi + 2\Psi + K\Psi
After: 2Ψ+2Ψ+KΨNd\frac{2\Psi+2\Psi+K\Psi}{N_{d}}
NdN_{d} 가 64일 경우 위의 config에서 8배이상 메모리 감소가 가능해집니다.

ZeRO-R

PaP_{a} : Partitioned Activation Checkpointing

이전 그림에서 보았듯이 forward가 끝나면 GPU는 (backward 계산에 필요하기 때문에) activation을 계속 memory에 hold하고 있어야 하는 단점이 존재합니다.
이를 위해 위의 그림처럼 (1) activation checkpointing을 활용하거나 (2) offloading CPU를 활용해 memory footprint를 낮췄다고 합니다.

CBC_{B} : Constant Buffer Size

NVIDIA Apex는 Megatron는 모든 parameters를 하나의 버퍼로 fuse해 연산을 처리해 효율성을 올린다고 합니다. 하지만, 3B모델에게 단일 32-bit fused buffer을 할당하면 12GB를 차지하기 때문에 constant-size fused buffer를 차용했다고 합니다.

MDM_{D} : Memory Defragmentation

activation checkpointing를 사용하면 forward시에 특정 메모리 구역에만 activation을 저장하고, 다시 recomputing하고 gradient를 계산하기 때문에 short lived memory (discarded activations: 저장안한 activation)와 long lived memory (checkpointed activation) 사이의 interleaving 때문에 memory fragmentation가 발생할 수 있다고 합니다.
LLM 학습시에 사용 가능한 메모리가 충분한 경우에도 연속 메모리 부족으로 인한 OOM이 발생하거나 메모리 할당자가 메모리 요청을 충족하기 위해 연속 메모리 조각을 검색하는 데 상당한 시간을 소비하므로 효율성 저하를 야기함으로, activation checkpoints and gradients를 위해 인접한 메모리 chunk를 미리 allocation하고, 생성되는 대로 미리 할당된 메모리로 복사시켜 효율을 최대화했다고 합니다.

3. Code

=== WORKING ====

4. References

DeepSpeed
microsoft