1. Introduction
→ Pre-trained LLM은 pre-defined length text로 pre-trained되어 왔음
→ LLM의 context window를 extend해주기 위해 long text를 finetuning하는 것은 compuationally 매우 비싸기 때문에 이에 대한 연구들이 많이 이어져 왔음.
#### Positional
최근에 연구된 Positional Interpolation (Chen et al., 2023)은 RoPE을 기반으로 LLM의 context window size를 32k까지 늘리는 방법을 제안했지만, 실제로 8k 이상의 길이를 처리하게끔 학습을 시킬 때 128대의 A100이 필요하다는 resource가 여전히 많이 필요하다는 문제점이 있음.
#### Contrastive Learning
또한, Focused Transformer (Tworkowski et al., 2023)은 contrastive learning에 영감을 받은 학습 방식을 사용해서 256k 길이의 prompt까지 처리할 수 있긴 하지만, 이 역시도 128대의 TPU가 필요하다는 문제점이 있다.
이 외에도 다양한 연구들이 최근에 진행이 되었지만, efficient하게 접근해서 long text를 처리하는 연구는 없음. 본 연구에서 efficient하게 적은 resource로 LLM을 long text에 training시키는 방법을 처음으로 제안함.
연구진들은 직관적으로 pre-trained LLM에 LoRA를 단순히 적용하는 것을 시도해봤는데, 실험적으로 이렇게 하면 long context에서 매우 높은 perplexity를 보인다고 발견함 (=no effectiveness). 또한, LoRA를 사용하는 것과 무관하게 self-attention 연산에서 computational cost가 높은 것을 지적함 (=no efficiency). 아래 그림에서 확인 가능 (LoRA)
본 연구에서 위에 문제점들을 개선한 LongLoRA를 소개함.
Contributions은 다음과 같음
•
shift short attention (S^2-Attn): standard self-attention 대체
→ S^2-Attn은 기존 attention architecture를 그대로 사용하고 있기 때문에, 최근에 나온 optimization 및 infrastructure 사용 가능함
•
fine-tuning layers: Embedding layers + normalization layers + LoRA training
→ LoRA weight뿐만 아니라, emb+norm layers를 학습 시키는 것이 실험적으로 매우 중요하다고 언급함. 실제로 embedding+normalization layers는 LLaMA2 7B기준으로 (embedding는 model의 2%) + ( normalization layers는 model의 0.004%) 수준의 params를 가지고 있어서 low cost.
•
LongQA dataset: SFT에 사용되는 long text dataset을 직접 구축하고 제안
2. Related Work
Long-Context Transformers
•
Retrieval-based work에서는 관련 있는 문서를 계속해서 context에서 포함시켜야함 → has been developed to increase the context w/o losing the information (inference setting에서는 full attention이 필요함)
•
Modify multi-head attention: alleviate quadratic complexity
◦
sparse attention 적용: Longformer(2020), BigBird(2020)
◦
과거 input에 대한 memory mechanism 활용: Recurrent Memory Transformer(2022), knn-augmented Transformer(2022)
→ full attention과 많이 다른 연산 방식이기 때문에, pre-trained LLMs을 직접적으로 finetuning하기에는 어려움. (inference & pretrain discrepancy)
Long-Context LLMs (About Training)
•
Fine-tuning을 통해 context length를 확장
◦
Position Interpolation(2023), FOT(2023)
→ 많은 resource를 필요로 한다는 단점이 존재함
•
Compresses long context inputs into retrieved tokens
◦
Landmark (2023)
→ efficient, but lossy
•
Postion embedding에 변화를 주는 연구
◦
Position Interpolation(2023), positional Skipping(2023) etc.
→ 이런 방식들은 inference를 수행할 때 original architecture을 건들어야 하지만, LongLoRA는 그렇지 않기 때문에 original architecture 그대로 사용 가능
3. LongLoRA
Shift Short Attention
•
Pattern 01
◦
Input token을 group 단위로 나누어 각 group마다 self-attention 연산을 진행
ex) input token이 8192 tokens이면, group를 4로 설정할 때, 각 group마다 2048 크기에 대한 연산 진행 (첫 번째 group: 1st~2048th)
•
Pattern 02
◦
각 group 간의 communication을 만들어주기 위해 설계한 pattern. Group간의 partition을 group 크기의 절반만큼 shift를 진행시킴.
ex) 위와 동일한 상황에서, group partition은 1024 길이만큼 움직이기에, 첫 번째 group은 (1025th~3072th). 이렇게 하면 앞에 1024개와 뒤에 1024개의 tokens이 남는데, 남는 것들은 동일 group에 귀속시킴.
→ Pattern1과 Pattern2를 multi-head의 절반에서 각각 계산 (head1~4 & head5~8) 하고 합치는 방식을 택함.
→ 이 방식은 추가 연산을 필요로 하지 않고 groups들 간의 정보 흐름을 가능하게 해줌.
(1) shifting tokens in half attention heads
(2) transposing features from token dimension to batch dimension.
Pilot test
•
LLM에게 Long Context를 효과적으로 extend하기 위한 FT 효용성을 증명하기 위한 Pilot 실험 진행.
→ W/O FT를 보면 아무리 PE를 잘 갈아끼워도 FT를 안하면 Long Context 상황에서는 PPL이 증가함
→ PE 잘 갈아 끼운 상태에서 FT를 하는게 성능이 잘나오니깐 이걸 기본 Baseline으로 삼음.
Consistency to Full Attention
•
본 논문에서 제안하는 S^2-Attn을 다른 efficient attention과 비교 진행
•
Efficient Attention은 Training때 Overfitting되면 안된다고 주장함.
•
Baseline (Train/Test시에 동시에 적용할 수 있는 Efficient Attention)
◦
dialted attention
◦
stride sparse attention
→ S^2-Attn은 test시에 Full-attention을 적용할 수 있을 뿐만 아니라, pre-trained LLMs에 바로 long-context finetuning 적용할 수 있는 이점이 존재함. 아래 Table에서 있는 baselines은 finetuning이 가능한 dialted/stride sparse methods를 선별해서 실험 진행
•
cro.heads: S2-Attn을 진행할 때, head를 나눠서 attention 연산을 진행하고 합치는 원래 방식
•
cro.layers: S2-Attn을 진행할 때, layer별로 p1과 p1를 나워서 진행한다는 거 같음
•
only P1: all no shift (pattern 1)
•
only P2: all shift (pattern 2)
→ 실험적으로 Attention이 Training때 Overfitting되지 않아야함을 주장함.
LoRA를 바로 LLM에 적용해서 long-context를 다루는 것은 어렵다. 실제로, 아래 Table에서 LoRA와 Full-finetuning간의 perplexity는 많이 차이 나고, LoRA Rank를 키워봐도 별로 효과가 없음.
따라서, 단순히 LoRA만 학습시키는 것이 아니라 Normalization/embedding layers를 LoRA와 함께 학습 시키는 것이 Full Finetuning과 성능 면에서 비슷한 것을 실험적으로 확인함 (RedPajama dataset으로 tr/ PG19로 te)
•
두 가지 종류의 layers를 학습하는 것이, trainable params 개수에 큰 영향이 없음
•
Model: LLaMA2 7B (with S^2-Attn)
•
Target length: 32k
•
+Norm/Embed: Normalization layers 혹은 embedding layers를 학습 시키는지
4. Main Results
•
Models and maximum extended context window sizes
◦
LLaMA2-7B up to 100k
◦
LLaMA2-13B up to 65k
◦
LLaMA2-70B up to 32k
→ Position indices들은 모두 Position Interpolation으로 조정 시킴 (따라서, 학습 과정도 Position Interpolation에서 사용한 setting과 동일하게 설정)
•
Resources
◦
per-device batch: 1
◦
gradient accumulation step: 8
→ 따라서, global batch size를 64로 설정 (with 8 A100 GPUs) [~대략 1000steps]
•
Datasets
◦
Training: Redpajama (2023)
◦
Test: PG19 (2020), cleaned Arxiv Math proof-pile dataset (2022)
▪
Perplexity로 측정
◦
Finetuning dataset: LongQA 직접 구축해 추가로 Fine-tuning
▪
3k question(theme: technical, science fiction, other books & task type: summarization, relationsips, character, detail)-answer pairs
Long-sequence Language Modeling
•
Evaluation on Proof-Pile
→ LoRA+가 엄청 도움되지는 않는듯?
•
Evaluation on Proof-Pile
→ PG 데이터셋은 상대적으로 덜한 성능
Retrieval-based Evaluation
•
long conversation에서 topic을 retrieve하는 task.
•
we fine-tuned LLaMA2 13B with a context length of 18k.
→ Fully-FT한 LongChat에 비해서 효율적으로 tuning했음에도 불구하고 좋은 성능을 보여줌.
Efficiency Profile
•
context length를 65536까지 늘리면 72.2%증가하는데 S^2 Att로 FLOPs를 39.4% 떨굴 수 있음.
•
Fine-tuning steps
◦
확실히 초반에는 Full-finetuning이 더 빠르게 수렴하지만, LoRA+ (LoRA+Norm/Emb)역시 200steps 이후에 비슷한 수준을 보임 (PPL)
5. Conclusion
→ LLM이 Long-context를 다루는데 효과적인 매우 간단한 방법론인 LongLoRA를 제안함. LLaMA가 2048 tokens, LLaMA2가 4096 tokens까지 처리하는 것을 생각하면, 본 실험에서 진행한 32k~100k는 매우 크다는 것을 알 수 있음.
→ 또한, A100 8장으로 모든 실험을 진행했기에 정말 효율적으로 학습을 진행시킬 수 있고, Full-FT와 비교해도 비슷한 수준을 선 보임.