Search

LongLORA: EFFICIENT FINE-TUNING OF LONGCONTEXT LARGE LANGUAGE MODELS

Category
PaperReview
Venue
ARXIV
Backbone
LLAMA2
Text
- LoRA를 활용해 LLM에게 Long Context를 효율적으로 extend 시킬 수 있는 Fine-Tuning 방법론을 제안 - Efficient Attention 계열론
PPT

1. Introduction

→ Pre-trained LLM은 pre-defined length text로 pre-trained되어 왔음
→ LLM의 context window를 extend해주기 위해 long text를 finetuning하는 것은 compuationally 매우 비싸기 때문에 이에 대한 연구들이 많이 이어져 왔음.
#### Positional
최근에 연구된 Positional Interpolation (Chen et al., 2023)은 RoPE을 기반으로 LLM의 context window size를 32k까지 늘리는 방법을 제안했지만, 실제로 8k 이상의 길이를 처리하게끔 학습을 시킬 때 128대의 A100이 필요하다는 resource가 여전히 많이 필요하다는 문제점이 있음.
#### Contrastive Learning
또한, Focused Transformer (Tworkowski et al., 2023)은 contrastive learning에 영감을 받은 학습 방식을 사용해서 256k 길이의 prompt까지 처리할 수 있긴 하지만, 이 역시도 128대의 TPU가 필요하다는 문제점이 있다.
이 외에도 다양한 연구들이 최근에 진행이 되었지만, efficient하게 접근해서 long text를 처리하는 연구는 없음. 본 연구에서 efficient하게 적은 resource로 LLM을 long text에 training시키는 방법을 처음으로 제안함.
연구진들은 직관적으로 pre-trained LLM에 LoRA를 단순히 적용하는 것을 시도해봤는데, 실험적으로 이렇게 하면 long context에서 매우 높은 perplexity를 보인다고 발견함 (=no effectiveness). 또한, LoRA를 사용하는 것과 무관하게 self-attention 연산에서 computational cost가 높은 것을 지적함 (=no efficiency). 아래 그림에서 확인 가능 (LoRA)
본 연구에서 위에 문제점들을 개선한 LongLoRA를 소개함.
Contributions은 다음과 같음
shift short attention (S^2-Attn): standard self-attention 대체
→ S^2-Attn은 기존 attention architecture를 그대로 사용하고 있기 때문에, 최근에 나온 optimization 및 infrastructure 사용 가능함
fine-tuning layers: Embedding layers + normalization layers + LoRA training
→ LoRA weight뿐만 아니라, emb+norm layers를 학습 시키는 것이 실험적으로 매우 중요하다고 언급함. 실제로 embedding+normalization layers는 LLaMA2 7B기준으로 (embedding는 model의 2%) + ( normalization layers는 model의 0.004%) 수준의 params를 가지고 있어서 low cost.
LongQA dataset: SFT에 사용되는 long text dataset을 직접 구축하고 제안

2. Related Work

Long-Context Transformers

Retrieval-based work에서는 관련 있는 문서를 계속해서 context에서 포함시켜야함 → has been developed to increase the context w/o losing the information (inference setting에서는 full attention이 필요함)
Modify multi-head attention: alleviate quadratic complexity
sparse attention 적용: Longformer(2020), BigBird(2020)
과거 input에 대한 memory mechanism 활용: Recurrent Memory Transformer(2022), knn-augmented Transformer(2022)
→ full attention과 많이 다른 연산 방식이기 때문에, pre-trained LLMs을 직접적으로 finetuning하기에는 어려움. (inference & pretrain discrepancy)

Long-Context LLMs (About Training)

Fine-tuning을 통해 context length를 확장
Position Interpolation(2023), FOT(2023)
→ 많은 resource를 필요로 한다는 단점이 존재함
Compresses long context inputs into retrieved tokens
Landmark (2023)
→ efficient, but lossy
Postion embedding에 변화를 주는 연구
Position Interpolation(2023), positional Skipping(2023) etc.
→ 이런 방식들은 inference를 수행할 때 original architecture을 건들어야 하지만, LongLoRA는 그렇지 않기 때문에 original architecture 그대로 사용 가능

3. LongLoRA

Shift Short Attention

Pattern 01
Input token을 group 단위로 나누어 각 group마다 self-attention 연산을 진행
ex) input token이 8192 tokens이면, group를 4로 설정할 때, 각 group마다 2048 크기에 대한 연산 진행 (첫 번째 group: 1st~2048th)
Pattern 02
각 group 간의 communication을 만들어주기 위해 설계한 pattern. Group간의 partition을 group 크기의 절반만큼 shift를 진행시킴.
ex) 위와 동일한 상황에서, group partition은 1024 길이만큼 움직이기에, 첫 번째 group은 (1025th~3072th). 이렇게 하면 앞에 1024개와 뒤에 1024개의 tokens이 남는데, 남는 것들은 동일 group에 귀속시킴.
→ Pattern1과 Pattern2를 multi-head의 절반에서 각각 계산 (head1~4 & head5~8) 하고 합치는 방식을 택함.
→ 이 방식은 추가 연산을 필요로 하지 않고 groups들 간의 정보 흐름을 가능하게 해줌.
(1) shifting tokens in half attention heads
(2) transposing features from token dimension to batch dimension.

Pilot test

LLM에게 Long Context를 효과적으로 extend하기 위한 FT 효용성을 증명하기 위한 Pilot 실험 진행.
→ W/O FT를 보면 아무리 PE를 잘 갈아끼워도 FT를 안하면 Long Context 상황에서는 PPL이 증가함
→ PE 잘 갈아 끼운 상태에서 FT를 하는게 성능이 잘나오니깐 이걸 기본 Baseline으로 삼음.

Consistency to Full Attention

본 논문에서 제안하는 S^2-Attn을 다른 efficient attention과 비교 진행
Efficient Attention은 Training때 Overfitting되면 안된다고 주장함.
Baseline (Train/Test시에 동시에 적용할 수 있는 Efficient Attention)
dialted attention
stride sparse attention
→ S^2-Attn은 test시에 Full-attention을 적용할 수 있을 뿐만 아니라, pre-trained LLMs에 바로 long-context finetuning 적용할 수 있는 이점이 존재함. 아래 Table에서 있는 baselines은 finetuning이 가능한 dialted/stride sparse methods를 선별해서 실험 진행
cro.heads: S2-Attn을 진행할 때, head를 나눠서 attention 연산을 진행하고 합치는 원래 방식
cro.layers: S2-Attn을 진행할 때, layer별로 p1과 p1를 나워서 진행한다는 거 같음
only P1: all no shift (pattern 1)
only P2: all shift (pattern 2)
→ 실험적으로 Attention이 Training때 Overfitting되지 않아야함을 주장함.
LoRA를 바로 LLM에 적용해서 long-context를 다루는 것은 어렵다. 실제로, 아래 Table에서 LoRA와 Full-finetuning간의 perplexity는 많이 차이 나고, LoRA Rank를 키워봐도 별로 효과가 없음.
따라서, 단순히 LoRA만 학습시키는 것이 아니라 Normalization/embedding layers를 LoRA와 함께 학습 시키는 것이 Full Finetuning과 성능 면에서 비슷한 것을 실험적으로 확인함 (RedPajama dataset으로 tr/ PG19로 te)
두 가지 종류의 layers를 학습하는 것이, trainable params 개수에 큰 영향이 없음
Model: LLaMA2 7B (with S^2-Attn)
Target length: 32k
+Norm/Embed: Normalization layers 혹은 embedding layers를 학습 시키는지

4. Main Results

Models and maximum extended context window sizes
LLaMA2-7B up to 100k
LLaMA2-13B up to 65k
LLaMA2-70B up to 32k
→ Position indices들은 모두 Position Interpolation으로 조정 시킴 (따라서, 학습 과정도 Position Interpolation에서 사용한 setting과 동일하게 설정)
Resources
per-device batch: 1
gradient accumulation step: 8
→ 따라서, global batch size를 64로 설정 (with 8 A100 GPUs) [~대략 1000steps]
Datasets
Training: Redpajama (2023)
Test: PG19 (2020), cleaned Arxiv Math proof-pile dataset (2022)
Perplexity로 측정
Finetuning dataset: LongQA 직접 구축해 추가로 Fine-tuning
3k question(theme: technical, science fiction, other books & task type: summarization, relationsips, character, detail)-answer pairs

Long-sequence Language Modeling

Evaluation on Proof-Pile
→ LoRA+가 엄청 도움되지는 않는듯?
Evaluation on Proof-Pile
→ PG 데이터셋은 상대적으로 덜한 성능

Retrieval-based Evaluation

long conversation에서 topic을 retrieve하는 task.
we fine-tuned LLaMA2 13B with a context length of 18k.
→ Fully-FT한 LongChat에 비해서 효율적으로 tuning했음에도 불구하고 좋은 성능을 보여줌.

Efficiency Profile

context length를 65536까지 늘리면 72.2%증가하는데 S^2 Att로 FLOPs를 39.4% 떨굴 수 있음.
Fine-tuning steps
확실히 초반에는 Full-finetuning이 더 빠르게 수렴하지만, LoRA+ (LoRA+Norm/Emb)역시 200steps 이후에 비슷한 수준을 보임 (PPL)

5. Conclusion

→ LLM이 Long-context를 다루는데 효과적인 매우 간단한 방법론인 LongLoRA를 제안함. LLaMA가 2048 tokens, LLaMA2가 4096 tokens까지 처리하는 것을 생각하면, 본 실험에서 진행한 32k~100k는 매우 크다는 것을 알 수 있음.
→ 또한, A100 8장으로 모든 실험을 진행했기에 정말 효율적으로 학습을 진행시킬 수 있고, Full-FT와 비교해도 비슷한 수준을 선 보임.