Search

MEND: Meta dEmonstratioN Distillation for Efficient and Effective In-Context Learning

Category
PaperReview
Venue
ICLR 2024
Backbone
GPT2-XL
T5
Text
- 긴 Demonstration을 효율적으로 처리하기 위한 KG distillation 기법을 제시함 - KG distillation을 통해 LLM이 demo를 어떻게 처리하는지에 대한 지식을 compression하는 모델에다가 흘리는 meta learning framework 잘 적립. - 역시 compression하는 모델과 LLM은 같은 family여야 효과적임.
PPT

1. Introduction

In-Context Learning에서 demonstration을 매 test instance inference마다 prepend해서 forwarding해주는 것은 computational overhead를 야기한다.
→ Extensive Demonstration (Natural Language) → Vector로 distillation하는 방법이 많이 활용됨
1.
Prompt Tuning
: Demonstration 역할을 해줄 수 있는 Vector Embedding을 Gradient Descent로 Training
(Unseen Demonstration을 prepend해야하는 상황에 직면한 경우 또 update를 해줘야하는 문제점 발생)
2.
Hypernetwork
: LLM에 직접 삽입되는 Vector (e.g., Word Embedding)을 생성하는 또다른 network를 학습
(긴 Unseen Demonstration도 Hypernetwork를 통과하면 pre-defined 길이의 vector embedding으로 mapping됨, CLM objective를 사용해 Distillation하는 Hypertuning이 가장 대표적)
위의 방법론들은 efficient하게 long natural language demonstration을 pre-defined length vector로 줄여서 ICL을 가능하게 만드나, long natural language demonstration을 직접 사용하는 것 대비 effectiveness가 떨어진다는 한계가 있다고 논문에서 제시함
논문에서는 KG distillation 방법론을 차용해 efficieny(긴 demonstration을 짧은 vector로 mapping하는 network를 학습)와 effectivenss(어떤 demonstration이라도 informative한 정보만 extracting할 수 있는 network를 학습)를 모두 달성하는 ‘Meta dEmonstratioN Distillation’을 제시

2. Problem Denfinition

Demonstration: D={(xi,yi)}i=1KD=\{(x_{i},y_{i})\}_{i=1}^{K}D=concat({(xi,yi)}i=1K)D=concat (\{(x_{i},y_{i})\}_{i=1}^{K})
ICL Performance
argmaxcCPLLM(cconcat(ED,Ex))argmax_{c \in C} P_{LLM}(c|concat(E_{D},E_{x}))
Our goal
argmaxcCPLLM(cconcat(SD,Ex)), where SDRl×dargmax_{c \in C} P_{LLM}(c|concat(S_{D},E_{x})), \ where \ S_{D} \in R^{l\times d}
l<<Dl << |D|
Prompt Tuning
SDS_{D} → learnable parameters
MEND (Hypernetwork)
SD=M(ED^) where E^ is the embedding of the hypernetworkS_{D} = M(\hat{E_{D}}) \ where \ \hat{E} \ is \ the \ embedding \ of \ the \ hypernetwork
(limitation of current hyper network: compatibility issue with LLM → resulting suboptimal quality of distilled vector)

3. Methods

MEND를 학습시키기 위해서는 3개의 LM이 필요함
MEND: Demonstration을 Vector로 mapping하는 Model
MEND Model에 SDS_{D} 길이만큼인 l개의 special token을 추가해 demonstration distillation placholder로 작용하도록 함
(For any demonstration DD, these placeholders embedding Eϕ^\hat{E_{\phi}} are appended to the demonstration embedding ED^\hat{E_{D}})
Student Model & Teacher Model: KG Distillation Loss를 위해 instantiate 되어야 하는 모델
#### 3.1. KG Distillation
SD=concat(ED^Eϕ^)[l:]>SD=MEND(ED^)S_{D}=concat(\hat{E_{D}}|\hat{E_{\phi}})_{[-l:]}>S_{D}=MEND(\hat{E_{D}})
Distilled Vector로 Condition된 ICL Performance를 Natural Language로 Condition된 ICL Performance만큼 나오게 만드는 것
Ldistill=KL(PLM(xED)PLM(xMEND(E^D)))\mathcal{L}_{\text{distill}} = \text{KL} \left( P_{\text{LM}}(x | E_D) \parallel P_{\text{LM}}(x | \text{MEND}(\hat{E}_D)) \right)
Student는 MEND Vector Distilled Demonstraiton Condition
Teacher는 Natural Language Demonstration Condition
#### 3.2 OPTIMIZATION
Meta-distillation Pretraining (MEND Network Pre-training)
C4 text
demonstration: 1024* β\beta (e.g., 102)
input: 1024* β\beta (e.g., 922)
input문장의 continuation한 생성문에 대해서 Ldistill\mathcal{L}_{\text{distill}} 계산
: conditional language modeling 기반 demonstration → vector로 compressing하는 모델에 비해서 더 instrinsic한 attribute를 capture할 수 있음.
Meta-distillation Fine-Tuning (Training MEND Network for ICL Prediction)
META-ICL data로 FT
K+1개 (x,y)(x,y) pair의 demonstration examples가 주어졌을때
demonstration: {(xi,yi)}i=1K>SD\{(x_{i},y_{i})\}_{i=1}^{K} > S_D
input: {(xk+1,yk+1)}\{(x_{k+1},y_{k+1})\}
Lpred=logPLM(yconcat(SD,Ex))\mathcal{L}_{\text{pred}} = \log P_{\text{LM}}(y|\text{concat}(S_D, E_x))
: LM이 distilled된 demonstration으로 ICL를 수행할 수 있게 MEND를 Training
Final Loss
Lfinetune=Lpred+λLdistill\mathcal{L}_{\text{finetune}} = \mathcal{L}_{\text{pred}} + \lambda \mathcal{L}_{\text{distill}}

4. Experiments

5. Analysis