Search

Reasoning in Flux: Enhancing Large Language Models Reasoning through Uncertainty-aware Adaptive Guidance

Category
PaperReview
Venue
ACL 2024
Backbone
Mistral
Text
PPT

1. Introduction

Reasoning Task에서 문제가 복잡해질 수록 어려운 문제를 여러 쉬운 문제로 decompose해서 푸는 technique가 취해지는데, 문제 난이도가 증가함에 따라 reasoning chain이 복잡해지고 lengthy해짐
multiple intermediate step 도중 error가 한번이라도 발생 → error accumulation → challenge
이전 연구들은 multiple intermediate step에서 기인된 error를 uncertainty 관점에서 해결하고자 함
Self-consistency decoding (Wang et al., 2023)
reasoning step을 여러개 생성 → majority vote answering (하나의 reasoning step에서 기인된 randomness를 제거)
Tree-of-thought (Yao et al., 2023)
intermediate step의 exploration and evaluation을 통한 uncertainty 제거
Self-evaluation guided beam search for reasoning. (Xie et al. (2023))
step-wise evaluation을 decoding process에 통합
⇒ 하지만, 위의 방법론은 ‘generation of individual intermediate steps’ 이후에나 reasoning process에 대한 post-hoc manipulation이 들어가 각 step에 대한 fine-grained adjustment를 하는데에 한계가 있다고 주장
논문에서는 직관적인 uncertainty 방법론은 채택, (1) LM이 잠재적 오류에 직면했을 때 자연스럽게 uncertainty의 징후를 보일 수 있음을 알 수 있음 보이고 (2) 어떤 reasoning step이 uncertainty가 높을경우 Bayesian Rule을 기반으로 demonstration을 추가하는 방법론을 제시
→ Wrong Reasoning에서 13 eggs가 “baked” 대신에 “left”라고 생성했을때 Uncertainty (NLL)이 높은 것을 알 수 있고, 이 이 후에도 지속적으로 NLL Distribution이 쭉 높은 것을 알 수 있음
→ 논문에서는 저 NLL이 높아지는 문장까지는 유지하고, 그 다음 문장까지는 demonstration을 추가해서 향후 reasoning path 문장을 재생성하는 방법론을 제시함

2. Related Work

2.1 Demonstration Guidance

Intro에서도 언급했듯이, 제안하는 방법론이 uncertatinty가 향상될때 demonstration을 추가시켜서 uncertatiny를 감소시키는 방법론이기 때문에 demonstration이 reasoning 성능 향상에 유용하다는 related work section을 기술함
AutoCoT - 자동으로 CoT Demonstration 구성
Boosting language models reasoning with chain-of-knowledge prompting. - LLM intrinsic knowledge를 활용해 정확성을 향상시킬 수 있는 CoT Demonstration 구성
Diao et al. (2023) - LM의 uncertainty를 활용해 informative demonstration 선정
….

2.2 Decomposition and Validation

LLM은 어려운 문제를 쉬운 sub-question으로 분해해서 풀면서 여러 문제에 직면 (Least-to-most prompting enables complex reasoning in large language models. In The Eleventh International Conference on Learning Representations.)
이런 문제를 해결하기 위해 분해과정 중에 verification 과정을 통합해준게 ‘Tree-of-thought’, ‘Graph of thoughts’ (근데 실전에서 과연 쓸까? 물론 내가 고민한 연구들도)
그외에도 local-scoring & global evaluation을 통해서 LLM의 evaluation을 향상시키거나, 여러 LLM에게 동일한 문제를 풀게 시켜서 verification을 향상시킨 연구들이 있었으나 computational demand라는 한계에 직면

2.3 Decoding Enhancement

Self-Consistency : generating multiple reasoning paths → selecting the most consistent answer
Fu et al. (2023b) : increased reasoning complexity를 demonsteation으로 줘서 더 complex reasoning이 가능하도록 하자
Contrastive decoding: inter-model contrasting & Chuang et al. (2023): intra-model contrasting (outputs from later versus earlier layers)

3 Preliminary

LLM이 Q\mathcal{Q}에 대해서 A\mathcal{A}를 생성하는것을 다음과 같이 formulation할 수 있음
P(AQ)=i=1APM(aiQ,a<i)P(\mathcal{A}|\mathcal{Q}) = \prod_{i=1}^{|\mathcal{A}|} P_{\mathcal{M}}(a_i | \mathcal{Q}, a_{<i})
CoT는 demonstration을 통해 LM에서 위의 식에서 reasoning path를 생성하도록 만드는 기법. 간단한 확률 조작으로 아래와 같이 전개가 가능함
P(R,AD,Q)=P(AD,Q,R)P(RD,Q)P(R, \mathcal{A} | \mathcal{D}, \mathcal{Q}) = P(\mathcal{A} | \mathcal{D}, \mathcal{Q}, R) P(R | \mathcal{D}, \mathcal{Q})
베이지안 전개로 Demonstration/Question 기반으로 Rationale/Answer을 생성하는 식은 아래와 같이 전개가 가능함
P(R,AD,Q)=P(DQ,R,A)P(R,AQ)P(Q)P(D,Q)=P(DQ,R,A)P(R,AQ)P(DQ)P(R, \mathcal{A} | \mathcal{D}, \mathcal{Q}) = \frac{P(\mathcal{D} | \mathcal{Q}, R, \mathcal{A}) P(R, \mathcal{A} | \mathcal{Q}) P(\mathcal{Q})}{P(\mathcal{D}, \mathcal{Q})} = \frac{P(\mathcal{D} | \mathcal{Q}, R, \mathcal{A}) P(R, \mathcal{A} | \mathcal{Q})}{P(\mathcal{D} | \mathcal{Q})}
Problem: low P(R,AQ)P(R, \mathcal{A} | \mathcal{Q}) (additional context (demonstration) 없이 desired rationale, answer를 생성하는데 어려움이 있다)
Objective: improve P(R,AD,Q)P(R, \mathcal{A} | \mathcal{D, Q})
Method: higher P(DQ,R,A)P(\mathcal{D} | \mathcal{Q}, R, \mathcal{A}) & lower P(DQ)P(\mathcal{D} | \mathcal{Q})
Relevance: D\mathcal{D} 가 expected reasoning process과 얼만큰 연관이 있는가? (당연히 커야하는 개념)
Originality: D\mathcal{D} 가 LM이 모르는 novel concept이나 모르는 지식을 얼만큼 많이 가져오는가? (당연히 적어야 하는 개념)

4 Uncertainty-aware Adaptive Guidance

위의 Preliminary에서 언급했듯, 논문의 목표는 (1) few-shot에서의 reasoning성능을 향상이지만 (2) 이를 위해 역으로 zero-shot에서 reasoning path uncertainty를 선제적으로 탐색한다.
결론적으로 논문에서 제안한 framework는 Uncertainty Identification (zero-shot), Adaptive Reasoning Adjustment (transform to few-shot), and Demonstration Clustering (engineering)으로 구성된다.

4.1 Uncertainty Identification (Zero-shot)

Reasoning에서 특정 추론 단계의 누적된 오류로 인해 발생하기 때문에 이를 사전에 탐지하게 위해 아래와 같은 metric을 정의함
P(RQ)=tPM(rtQ,r<t)P(\mathcal{R} | \mathcal{Q}) = \prod_{t} P_{\mathcal{M}}(r_t | \mathcal{Q}, r_{<t})
여기서 r번째 reasoning path의 token을 단순히 cross entropy loss로 정의
H(rt)=logP(rtr<t)\mathcal{H}(r_t) = -\log P(r_t | r_{<t})
decoding step에서 confidence 변화를 측정하기 위해 매 token마다의 uncertainty 변화를 uncertainty gap으로 정의
ΔH(rt)=H(rt)H(rt1)\Delta \mathcal{H}(r_t) = \mathcal{H}(r_t) - \mathcal{H}(r_{t-1})
uncertainty gap이 특정 threshold이상일시, intervention (i.e., additional context 도입)을 통해 LM이 신뢰할 수 있는 reasoning path를 generating하도록 하
if ΔH(rt)>θ\text{if } \Delta \mathcal{H}(r_t) > \theta (그림에서 4번 문장)

4.2 Adaptive Reasoning Adjustment (transform to few-shot)

ΔH(rt)>θ\Delta \mathcal{H}(r_t) > \theta이 발생하기 전인 rmr_{m}까지는 유지하되, demonstration을 prepend해 uncertainty를 감소
현재 status
LM이 rmr_{\leq m} 까지 생성
Demonstration 정의
D={Qd,Rd,Ad}\mathcal{D}= \{\mathcal{Q}_d, \mathcal{R}_d, \mathcal{A}_d\}
Method
높은 Relevance
SR=logP(DQ,rm)=logP(Qd,Rd,AdQ,rm)S_R = \log P(\mathcal{D} | \mathcal{Q}, r_{\leq m}) = \log P(\mathcal{Q}_d, \mathcal{R}_d, \mathcal{A}_d | \mathcal{Q}, r_{\leq m})
낮은 Originality
SO=logP(DQ)=logP(Qd,Rd,AdQ)S_O = -\log P(\mathcal{D} | \mathcal{Q}) = -\log P(\mathcal{Q}_d, \mathcal{R}_d, \mathcal{A}_d | \mathcal{Q})
Score 기반으로 최종 Demonstration 선정
S=λ1SR+λ2SOS = \lambda_1 S_R + \lambda_2 S_O
Objective
Di:k,ΔH(rm+k)θ\exists \mathcal{D}_i : \forall k, \Delta \mathcal{H}(r_{m+k}) \leq \theta

4.3 Optimizing Demonstration Selection Through Clustering

매번 ΔH(rt)>θ\Delta \mathcal{H}(r_t) > \theta이 발생할때마다 demonstration을 retrieve해야하기 때문에 computational overhead가 발생 → K-means로 search space를 사전에 정의해놓음
text-embedding-3-large\text{text-embedding-3-large}DiD_{i} compute → K-means
Cj=[D1j,D2j,],CjK-Means({D})C_j = [D_1^j, D_2^j, \dots], \quad C_j \in K\text{-}Means(\{\mathcal{D}\})
각 cluster 내 centriod내 거리를 기반으로 demonstration 정렬
Cluster내 Centroid랑 가장 가까운 DijD_i^j랑 Relevance, Originality 계산

5. Experiments

Baselines

Zero-shot CoT: Let’s think step by step
CoT: Demonstration에 사람이 직접 작성한 Rationale을 포함한 Example 적어주기
ComplexCoT: Demonstration에 CoT보다 보다 구체적인 Example 적어주기
위의 prompting에 self-consistency incorporate

Benchmarks

Arithmetic Reasoning
Commonsense Reasoning
Symbolic Reasoning

Main Results

(Mistral-7B)
limitation: computational cost를 generation cost로 측정함
generation 도중 prompt를 교체하면 이전에 caching해 두었던 query, key representation을 전부 다시 생성해야 하는데 이를 싹 다 무시함.
Arithmetic Reasoning - 도중에 Uncertainty를 detect한 이후에 필요한 demonstration을 추가했기에 처음부터 demonstration을 prepend한 CoT에 비해서 GSM8K, AddSub, AQuA에서 7.81%, 2.78%, 3.14% 성능 향상을 가져옴. (하지만 생성 token 수가 더 많다는 한계가 존재0
ArithmCommonsense and Symbolic Reasoning -BoolQ dataset같은 경우, CoT기 58.07% → 우리는 62.26% 달성. Zero-Shot CoT 보면 precise한 rationale 생성에 어려움을 겪는 것을 알 수 있음. Computational Cost 증가는 한계로 보임
(Model Models in Aqua dataset)
특정 LM의 경우 ZS이 FS보다 성능이 좋으며, on-demand때 적절한 demonstration을 주어야하는 본인들의 방법론의 필요성을 강조
(Importance of Relevance and Originality)
Originality 제거시 5.73%감소, Relevane 제거시 성능 더 감소
2개 dataset으로 Ablation 확인해보았는데, weight는 0.5로 주었는게 가장 좋았다.
(Threshold θ)
threshold를 너무 크게 설정하면 추가되는 samples(demonstration)이 적어지게 되고, 많은 test case가 사실상 Zero-shot가 동일하게 처리될 수 있음
그렇다고 threshold가 너무 낮으면 불필요하고 너무 앞에 (uncertainty 변동이 낮은데) demonstration을 추가해 error을 유발할 수 있다고 함
⇒ 16으로 설정
(Comparison to Existing Demonstration Selection Method)
demonstration을 pre-selection하는 Auto-CoT보다 성능이 좋았다. (computation은 훨씬 더 들었을거 같은데…?)

5. Conclusion

RQ를 어마어마하게 많이 던져준 논문
(Zero-shot setting을 상정)
문제가 복잡해질수록 LM은 더 긴 CoT Rationale을 생성해 문제를 해결한다 → 이때, 한번 error을 잘못 생성하면 LM은 error를 중첩적인 error를 생성한다. →
이 논문 error가 높은 문장을 uncertainty가 높은 문장이라고 가정 → uncertainty가 높은 문장을 삭제 → uncertainty를 낮추는 demonstration를 prepend하자
query, key caching 무시
‘ error을 잘못 생성하는 문장’이 (1) uncertainty하다면, 혹은 (2) 불필요한 문장 이라면 그 문장은 paramterize하고 꼭 필요한 문장만 생성하도록 하면 중첩적인 error를 생성할 확률을 줄일 수 있지 않을까?
Do not Generate All: Enhancing language model’s reasoning capabilities by parametrizing uncertain rationale.