inference-time, RLHF/search at inference (GoT,MCTS, A*)

MCTS(monte carlo tree search) + LLM

jinuklee 2024. 6. 22. 22:04

https://arxiv.org/abs/2406.07394

 

Monte Carlo Tree Search (MCTS)는 게임 및 복잡한 결정 과정에서 널리 사용되는 의사 결정 알고리즘으로, 탐색 트리를 구축하고 결과를 시뮬레이션하여 행동의 가치를 추정하는 방식 일반적으로 네 가지 주요 단계로 구성됨 (Browne 등, 2012):

 

선택 (Selection): 루트에서 시작하여 UCT(상한 신뢰 구간) 전략을 기반으로 promising 자식 노드를 탐색 리프 노드에 도달할 때까지 진행

 

확장 (Expansion): 리프 노드에서는 게임의 종료 상태가 아닌 경우 새로운 자식 노드를 추가하여 잠재적인 미래의 움직임을 illustrate

 

시뮬레이션 또는 평가 (Simulation or Evaluation): 새로 추가된 노드에서 알고리즘은 무작위 시뮬레이션(rollout이라고도 함)을 통해 게임의 종료까지 임의의 움직임을 선택하며 노드의 잠재력을 평가합니다.

 

역전파 (Backpropagation): 시뮬레이션 후 결과(승리, 패배 또는 무승부)가 루트로 전파되어 각 통과한 노드의 통계 데이터(예: 승리, 패배 횟수)를 업데이트하고, 이를 통해 미래의 결정에 정보를 제공

 

MCTS는 이러한 단계들을 반복적으로 거쳐 결정 트리를 점진적으로 구축하여, 상태 공간이 방대하여 직접적인 최적 전략 계산이 불가능한 상황에서 최적의 의사 결정을 위한 전략을 개선


1. Initialization
모델의 적절한 답변, dummy 답변들 포함해서 root 노드 구성

2. Selection
value 함수 Q를 통해 답변들 중 greedy 하게 선택

이런식의 프롬프트를 통해 reward 값을 부여
f'Question: {question}\nAnswer:{ans}\n "이 답변을 엄격하게 분석하고 비평하세요. 모든 가능한 불완전한 점에 대해 모든 결함을 지적하여 가능한 모든 점수를 감점하세요! 점수를 계산할 때 매우 가혹하고 냉정해야 하며, 점수의 권위를 보장하기 위해 절대로 만점을 주지 마세요. \nOutput a score between [-100,+100], ig. from -100 to +100. \nResponse format:\n[Analyst]...[Score]...'


3. Self-Refine
선택된 답변 self-refine (피드백 + prompt)



4. Self-Evaluation
향상된 답변의 value Q 값 구하기


5. Backpropagation
새로운 Q value로 모든 노드 업데이트, 다음 selection을

Upper Confidence Bound

j = 행동

X = 그에 따른 평균 보상

N = father 노드의 총 방문 횟수

n = 노드 j 가 시뮬레이션 동안 방문된 횟수

C = 탐사-활용의 밸런스를 맞추기 위한 상수값

P: 다루고 있는 문제 인스턴스

A: 각각이 P에 대한 잠재적인 답변을 나타내는 노드의 집합

M: 각 노드에서 가능한 행동의 집합, 가능한 self-refine modifications 를 나타냄

R: modification의 효율성과 퀄리티를 기반으로 노드의 self-보상을 샘플링하는 함수

Ra: 노드 a의 자기 보상 함수 R을 사용하여 샘플링한 모든 자기 보상 결과를 저장하는 집합

T: 탐색 프로세스의 종료를 결정하는 함수(기준 : 최대 반복횟수 달성 or 목표한 답변 품질 달성)

 

 

  • Q(a): 누적된 보상 Ra와 자식 노드로부터의 역전파를 통해 도출된, 답변 노드 a의 가치를 추정하는 가치 함수
  • U(a): 노드 a의 Q 값에 대한 상한 신뢰 구간(Upper Confidence Bound)으로, 탐색(exploitation)과 탐험(exploration) 사이의 균형을 맞추는 데 사용
  • Father(a): 주어진 노드 a의 부모 노드를 반환하는 함수. a가 루트 노드인 경우, 이 함수는 null 또는 특정 식별자를 반환
  • Children(a): 주어진 노드 a에 대한 모든 자식 노드의 집합을 반환하는 함수로, m ∈ M인 행동들을 실행하여 a로부터 파생될 수 있는 모든 가능한 상태를 나타낸다.
  • N(a): 노드 a를 방문한 총 횟수로, 이 노드의 UCB 값을 계산하고 탐험과 탐색 상태를 평가하는 데 사용. 각 방문마다 보상을 샘플링할 것이므로, 이 값은 |Ra|와 같다.

 

root node 를 i don't know

evalution 평가

level 1~5까지의 (숫자가 커질수록 어려워지는 수학문제) 에서의 MTCSr 적용시의 변화인데 

level 1에서 가장 극적으로 8번의 rollout을 진행하니 437개의 문제중 394개의 문제의 답을 맞추는 상당한 성능향상을 보였다.