카테고리 없음

Training Language Models to Self-Correction viaReinforcement Learning (SCoRe), 논문리뷰

jinuklee 2024. 9. 21. 14:20

https://arxiv.org/pdf/2409.12917

point : 모델의 distribution에서 가능한 가장 좋은 final answer을 생성해내기 위함 + 모델 collapse를 막기

<- multi turn 온라인 RL, 완전히 스스로 생성한 데이터를 통해 self-correct 능력을 향상

방식

스스로 생성한 데이터 - distribution mistmatch 회피

두단계로 훈련 stage - minimal edit strategy 의 실패 경우의 모델 collapse 회피를 위한것 (STaR) 

 

LLM의 self correction 능력은 비효율적이다 (e.g llm cannot self correct yet 논문)

Existing approach는 self-correct을 위해 여러개의 모델, more capable LLM, 혹은 다른 형태의 어떤 supervision을 필요로한다

 

기존의 SFT를 통한 훈련은 model's own response와 훈련 데이터간의 distribution mismatch가 발생, 특정 mode 의 correction만을 취하는 것을 선호한다 ( 이는 often not effective at test time )

 

1. introduction

Prior attempts toward self-correcting LLMs에는 

- 프롬프트 엔지니어링 ( Language models can solve computer tasks ,self-refine) 

- 모델 미세조정 specifically for self-correction ( Glore,  Recursive introspection , Generating sequences by learning to self-correct Advancing llm reasoning generalists with preference trees).

 

기존의 미세조정 접근은 require running multiple models upon inference, e.g., a separate verifier or refinement model(glore, generating sequences)

or require oracle “teacher” supervision to guide the process of self-correction (recursive introspection)

 

우리의 접근법, SCoRe는 trains only a single model that can both produce a response to a reasoning problem and also correct errors despite not receiving any oracle feedback

 

We observe that running supervised fine-tuning on multi-turn self-correction traces coupled with rejection sampling (i.e., a “multi-turn” variant of STaR (Zelikman et al., 2022)) often amplifies the model’s bias to not make any error corrections.

위의 문제와 일맥상통

 

Concretely, SCoRe runs multi-turn RL on self-generated data to avoid challenges with distribution mismatch between training and inference.

 

To avoid the failure mode of learning a minimal edit strategy when training on on-policy data(STaR의 이야기), we train SCoRe in two stages, with each stage regularizing the learning process to not collapse its behavior.

The first stage replaces SFT in conventional LLM fine-tuning workflows by training a model initialization that optimizes correction performance while constraining the first attempt to be close to the base model.

The second stage runs multi-turn RL to optimize reward at both attempts, while using a reward bonus term that encourages improving responses from the first attempt to the second. Both the initialization and the reward bonus ensure that the model cannot simply learn to produce the best first-attempt response and only minorly edit it.

 

첫 번째 단계에서는 SFT를 대체하여, 수정 성능을 최적화하는 모델 initialization 를 훈련

이 과정에서 첫 번째 시도가 base 모델과 유사하도록 제한합니다.

두 번째 단계에서는 다중 턴 RL을 실행하여 두 번의 시도에서 보상을 최적화하며, 첫 번째 시도의 응답을 두 번째 시도로 개선하도록 장려하는 보너스 보상 항목을 사용합니다.

초기화와 보상 보너스 모두 모델이 첫 번째 시도에서 최고의 응답을 만들어내고 이를 소폭 수정하는 방식으로 학습하지 않도록 보장합니다. 전반적으로 SCoRe는 기본 모델의 지식을 활용하여 긍정적인 자기 수정을 가능하게 합니

구조도
개선점 , SFT의 distribution mistmatch와 STAR의 collapse를 일으키는 minor change 문제

 

 

기존의 문제점

 

 

Issues with Existing Strategies: When fine-tuning models using multi-turn self-correction data, the model tends to become more biased against making any corrections.

기존 멀티턴으로는 self-corrction하지 않으려함

Minimal Edit Strategy: A strategy that only makes small edits can be somewhat effective, but it doesn’t help the model learn to correct itself.

적게 수정하는것은 효과적이나 , 이걸로는  self-correct 능력을 학습하지 못함

Data Weight Adjustment: Adjusting the training dataset to down-weight certain minor edits helps avoid issues, but it still faces the challenge of distributional shift, meaning the correction strategies learned from off-policy data might not enable the model to correct its own mistakes effectively.

적게 수정하는 것의 가중치(SFT데이터셋)를 줄이면 distribution shift의 문제가 생김 -> self-correct이 효과적이지 못함

 

 

multi-turn objective

x는 프롬프트 , 즉 질문

y*는 데이터셋의 oracle response 즉 정확한 답

y_hat은 모델 LLM(파이)이 생성해낸 optimal response

p는 self-correct instruction

y_hat(l+1) LLM에 x, y_hat, p 를 넣은후 next token prediction으로 생성해낸 답변

이를 reward function/verifier (string-matching based answer checking function와 같은) 을 통해 생성한 reward maximize

 

stage 1

+ KL divergence 를 사용해  첫번째 attempt response가 가능한 바뀌지 않게 하면서 두번째 attempt에 high-reward revisions 생성하도록 함

stage 2

self - correct strategy에 leverage , incentize하기 위해 shaped reward 를 사용

이는 첫번째에 best를 얻고 두번째에 minor edit을 거치는 전략 대신에 사용되는것