Search
✴️

Reflections on Optimizer and LM parameter values

Category
BlogPost
Venue
Backbone
Text
PPT
Writer: Joonwon Jang
**** LLM Fine-Tuning시 Optimizer가 차지하는 VRAM의 비율이 높다고 하고 특히나 Learning Rate의 설정이 중요한 것을 깨닫는 요즘입니다, 오늘 그 2가지가 왜 LLM 학습에 중요한지 제가 이해한 바를 적어보겠습니다! ****

Optimizer’s Configuration

⇒ Mistral 7B를 AdamW를 활용해 1 epoch 학습하고 나면 아래와 같은 optimizer state가 출력됩니다.
Optimizer state at beginning of epoch 1: {'state': {0: {'step': tensor([584.], device='cuda:2'), 'exp_avg': tensor([ 2.2361e-07, -1.4407e-07, 1.0716e-06, ..., -3.3894e-07, 0.0000e+00, 0.0000e+00], device='cuda:2'), 'exp_avg_sq': tensor([2.2285e-12, 3.9868e-12, 5.2110e-12, ..., 5.1709e-12, 0.0000e+00, .... 0.0000e+00], device='cuda:2')}}, 'param_groups': [{'lr': 2e-06, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.01, 'amsgrad': False, 'foreach': None, 'maximize': False, 'capturable': False, 'differentiable': False, 'fused': True, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]}]}
Python
복사
Optimizer의 전체 구조는 아래와 같은데요.
{ 'state': { ... }, 'param_groups': [ ... ] }
Python
복사
state: 모델의 각 parameter에 대한 optimizer의 상태를 저장하는 딕셔너리입니다. 키는 parameter의 index이고, value은 해당 parameter의 state 딕셔너리입니다.
param_groups: optimizer의 하이퍼파라미터와 parameter 그룹에 대한 정보를 담고 있는 리스트입니다. 일반적으로 하나의 딕셔너리를 포함하며, 그 안에 lr, momentum term 등과 함께 파라미터의 인덱스 리스트(params)가 포함됩니다.

State’s Configuration

0: { 'step': tensor([584.], device='cuda:2'), 'exp_avg': tensor([...], device='cuda:2'), 'exp_avg_sq': tensor([...], device='cuda:2') }, 1: { ... },
Python
복사
state 는 아래와 같은 필드로 구성되어 있습니다.
step: 해당 parameter가 업데이트된 횟수를 나타냅니다. 여기서는 tensor([584.], device='cuda:2')로 되어 있는데, 이는 해당 parameter가 총 584번 업데이트 되었다는 의미입니다. (i.e., total training step = 584)
exp_avg: 기울기의 지수 이동 평균(i.e., First Moment Estimate)을 나타내는 텐서입니다. Adam 옵티마이저에서 사용되는 First Order Moment로, 아래와 같이 계산됩니다.
gt=θLt(θt1)mt=β1mt1+(1β1)gtg_t = \nabla_\theta L_t(\theta_{t-1}) \\ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t
exp_avg_sq: 기울기 제곱의 지수 이동 평균(Second Moment Estimate)을 나타내는 텐서입니다. 이는 Second Order Moment로, 아래와 같이 계산됩니다.
gt=θLt(θt1)vt=β2vt1+(1β2)gt2g_t = \nabla_\theta L_t(\theta_{t-1}) \\ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2
모델의 각 Parameter와 동일한 크기를 가지는 First Order MomentumSecond Order Momentum을 사용하는 순간 모델 전체 Parameter 크기의 2배에 해당하는 추가 메모리를 VRAM에 로드해야 됩니다.
Code로 간단히 확인해 보는 VRAM Memory 사용량
# Instantiate and move the model to device with correct dtype ptdtype=bfloat16 model = Mistral(config).to(device, dtype=ptdtype) print_memory_usage("After Model Instantiation") '''bfloat16 기준: 14.48 GB''' for batch in tqdm.tqdm(train_dataloader): input_ids = batch['input_ids'].to(device) labels = batch['labels'].to(device) # Forward pass with autocast if bf16 is enabled with autocast(dtype=ptdtype) if args.bf16 else torch.cuda.amp.autocast(enabled=False): outputs = model.compute_loss(input_ids=input_ids, labels=labels) loss = outputs.loss print("Loss : ", loss) # Backward pass if use_scaler: scaler.scale(loss).div(args.accumulate).backward() else: loss.div(args.accumulate).backward() print_memory_usage("After Backward Pass") '''bfloat16 기준: 29.02GB (모델 크기 x2)''' optimizer.step() optimizer.zero_grad(set_to_none=True) print_memory_usage("After Optimizer Step") '''bfloat16 기준: 45GB (약 14.48GB MODEL + 28.96GB First & Sececond Moment) '''
Python
복사

Param Group’s Configuration

[{ 'lr': 2e-06, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.01, 'amsgrad': False, 'foreach': None, 'maximize': False, 'capturable': False, 'differentiable': False, 'fused': True, 'params': [0, 1, 2, ..., 34] }]
Python
복사
param_groups는 처음에 optimizer를 초기화할때 선언했던 하이퍼파라미터들과 그 외의 정보들을 확인할 수 있습니다.
lr : 다들 아시는 learning rate 입니다.
betas: Adam에서 사용하는 momentum 계수로, (beta1, beta2) 형태로 지정됩니다. beta1은 First Order Moment의 지수 감소율, beta2는 Second Order Moment의 지수 감소율을 나타냅니다.
eps : 수치적 안정성을 위한 작은 값으로, 옵티마이저의 분모에 추가되어 분모가 0이 되는 것을 방지하는 값입니다.
weight_decay: 가중치 감쇠 계수로, L2 정규화를 위한 항입니다.
amsgrad: AMSGrad 알고리즘 사용 여부를 나타내는 불리언 값입니다.
foreach: 연산을 벡터화하여 성능을 향상시킬지 여부를 결정합니다.
maximize: 손실 함수를 최소화할지(False), 최대화할지(True)를 결정합니다. 일반적으로 손실 함수를 최소화하므로 False로 설정됩니다.
capturable: 옵티마이저 상태가 CUDA 그래프에서 캡처될 수 있는지 여부를 나타냅니다.
differentiable: 옵티마이저가 고계 미분을 지원하는지 여부를 나타냅니다. 여기서는 False로 설정되어 있습니다.
fused: 옵티마이저가 fused kernel을 사용하여 연산을 수행하는지 여부를 나타냅니다. LLM에서 성능을 향상시키기 위해 True로 설정되어 있습니다.
[Kernel Fusion]
GPU에서는 MatMul, SoftMax같은 연산을 수행하는데 이런 연산들을 하나의 커널로 통합해 메모리 엑세스 시간과 커널 실행시간을 상각해 전체 계산그래프를 최적화하는 기법을 kernel fusion이라고 합니다.
기본적인 데이터의 흐름은 ‘메모리에서 불러옴 → GPU연산 → 메모리에 쓰기’인데, 직렬적이고 독립적인 연산을 처리하기 위해서는 메모리 엑세스 시간이 많이 소모됩니다. 이를 위해 여러 독립적인 커널 연산을 CUDA C++ 코드 작성, Pytorch JIT를 통해 하나의 연산으로 통합하여 메모리 이동을 최소화시킬 수 있습니다.
params: 이 파라미터 그룹에 포함된 parameter index.

Parameter Update

First Momentum과 Second Momentum은 Bias 보정을 마치고 Weight decay까지 마지면 아래와 같이 업데이트 됩니다.
m^t=mt1β1tv^t=vt1β2tθt=θt1η(m^tv^t+ϵ+λθt1)\hat{m}_t = \frac{m_t}{1 - \beta_1^t} \\ \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \\ \theta_t = \theta_{t-1} - \eta \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_{t-1} \right)
exp_avg(first order momentum)exp_avg_sq(second order momemtum) 의 역할을 상기해보면,
exp_avg → 현재와 과거의 기울기를 누적하여 방향성을 파악
exp_avg_sq → 기울기 제곱의 지수 이동 평균으로, 기울기의 변화량을 추적하여 학습률을 조정
Parameter 업데이트의 관점에서 해석을 해보면,
파라미터 업데이트는 주로 m^tv^t+ϵ\frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} 항에 의해 결정이 되는데,
m^t\hat{m}_t: 기울기의 평균값으로, 파라미터가 이동할 방향을 결정
v^t+ϵ\sqrt{\hat{v}_t} + \epsilon: 기울기의 표준편차에 해당하며, 학습률을 조정하여 안정성을 높임
⇒ Gradient clipping을 하지 않았다고 가정했을 시 (물론 설정해도 마찬가지 입니다.), Gradient가 너무 커지면 학습률 Learning Rate가 적정하다고 판단된 값이라도 모델 ParameterActivation 값이 NaN(Not a Number)으로 수렴할 수 있습니다.
⇒ LLM 같은 경우에는 (제가 목격하기론 특정 Task에서 Training시) Gradient Explosion 현상에 더 취약한것 같아, 아래와 같은 Norm Clipping으로 Gradient의 L2 Norm이 특정 임계값을 초과하면 스케일링하여 Norm이 임계값에 맞도록 조정해주는게 중요한 것 같습니다.
g=g×clip_valuegifg>clip_value g = g \times \frac{\text{clip\_value}}{\|g\|} \quad \text{if} \quad \|g\| > \text{clip\_value} 
⇒ 낮은 정밀도나 숫자 범위표현력을 가진 bf16, fp16에서는 작은 값의 언더플로우나 큰 값의 오버플로우로 인해 NaN이나 Inf가 발생할 확률이 높기 때문에 어쩔 수 없더라도 낮은 LR과 함께 fp32를 써주는 방법도 수치적 안정성을 확보하기 위한 하나의 해결책으로 사료됩니다.
⇒ (굉장히 강한 휴리스틱이긴 하지만) 안정적인 학습을 위해서는 logits 값의 범위 제한하는 Clipping 기법을 사용해볼 수도 있습니다.
torch.clamp()를 활용해 logits 값을 특정 범위로 제한하여 극단적인 값이 나오지 않도록 합니다.
logits = torch.clamp(logits, min=-max_value, max=max_value)
Python
복사

References