카테고리 없음

Smaller, Weaker, Yet Better: Training LLM Reasoners via Compute-Optimal Sampling 논문리뷰

jinuklee 2024. 10. 3. 18:08

https://arxiv.org/pdf/2408.16737

Training on high-quality synthetic data from strong language models (LMs) is a common strategy to improve the reasoning performance of LMs.

 

In this work, we revisit whether this strategy is computeoptimal under a fixed inference budget (e.g., FLOPs).

 

To do so, we investigate the trade-offs between generating synthetic data using a stronger but more expensive (SE) model versus a weaker but cheaper (WC) model.

 

We evaluate the generated data across three key metrics: coverage, diversity, and false positive rate, and show that the data from WC models may have higher coverage and diversity, but also exhibit higher false positive rates.

 

We then finetune LMs on data from SE and WC models in different settings: knowledge distillation, self-improvement, and a novel weak-to-strong improvement setup where a weaker LM teaches reasoning to a stronger LM.

 

Our findings reveal that models finetuned on WC-generated data consistently outperform those trained on SE-generated data across multiple benchmarks and multiple choices of WC and SE models.

 

These results challenge the prevailing practice of relying on SE models for synthetic data generation, suggesting that WC may be the compute-optimal approach for training advanced LM reasoners.

 

3. Compute-Matched Sampling and Training

 

To generate a dataset D𝐺 with synthetic solutions from D, one can leverage different models for generating solutions.

 

Specifically, at a fixed sampling budget (FLOPs), one can generate more samples from a weaker but cheaper (WC) model or fewer samples from a stronger but more expensive (SE) model.

 

Given a WC model with 𝑃(𝑊𝐶) parameters and SE with 𝑃𝑆𝐸 parameters, we compute the sampling ratio at a fix budget for the two models, focusing on decoder-only transformer models (Vaswani, 2017). Following (Kaplan et al., 2020), we note that the FLOPs per inference token is 2𝑃, for a model with 𝑃 parameters. As a result, the FLOPs for 𝑇 inference tokens is 2𝑃𝑇. Further, we assume that generating each solution requires an average of 𝑊 inference tokens for both models2 . Let 𝑆𝑊𝐶 and 𝑆𝑆𝐸 represent the number of samples we generate per question for the two models. The total cost of generating samples for the dataset D will then be 𝐶𝑜𝑠𝑡𝑊𝐶 = 𝑛×𝑆𝑊𝐶 ×𝑊 × (2𝑃𝑊𝐶) and 𝐶𝑜𝑠𝑡𝑆𝐸 = 𝑛×𝑆(𝑆𝐸) ×𝑊 × (2𝑃(𝑆𝐸)) for the cheap and expensive models, respectively. At a fixed sampling budget, we have:

Equation 1 indicates that at a fixed sampling budget, for each question we can generate 𝑃(𝑆𝐸) / 𝑃(wc)

 

more samples from WC; the ratio scales linearly with the model parameters ratio3 . Sampling more solutions from WC may increase the likelihood of correctly solving a larger subset of the problems (high coverage) and obtaining more correct solutions per question (high diversity). Given a fixed budget, we can either generate fewer samples from a SE model or more samples from a WC model, and then finetune models for a fixed number of steps on the data from each of these models to measure and compare the utility of the data from each model. Specifically, we generate 𝑃(𝑆𝐸) / 𝑃 (𝑊𝐶) more samples from the WC model compared to the SE model. We consider three finetuning setups that consists of diverse finetuning paradigms. The paradigms include the widely used knowledge distillation, the emerging framework of self-improvement, and a novel weak-to-strong improvement paradigm we introduce in this work. We define weak-to-strong improvement (W2S-I) as enhancing the reasoning capabilities of a strong model using samples generated from a weaker model. The three setups are as follows (a summary of the three setups and the finetuning paradigms that each case corresponds to can be found in Table 1). Student-LM finetuning: Conventionally, the supervised finetuning data for training student LM is acquired from SE models to ensure high-quality (Teknium, 2023). However, we aim to understand whether WC models can replace SE models for distillation at the fixed sampling budget. To do so, we finetune a student LM separate from the WC and SE models on the WC and SE data, which corresponds to distillation in both the cases. WC-LM finetuning: Prior work (Singh et al., 2023) has shown that finetuning a WC model through self-generated data lags behind distillation from SE data. However, their setup spends a higher sampling budget on collecting data from SE than WC. In this work, we revisit this finetuning setup under the fixed sampling budget and finetune the WC model on the WC and SE data at a fixed budget for both. Note that training the WC model on its own data corresponds to self-improvement whereas training WC on the data from SE corresponds to distillation. Hence, this setup compares self-improvement on WC data with distillation from SE data. SE-LM finetuning: It is commonly believed that to improve a SE model, we either need synthetic data from the SE model itself or from an even stronger (and perhaps more expensive) model. Here, we test an alternative approach to understand whether the synthetic data from the WC model can improve the SE model. To this end, we finetune the SE model on the WC and SE data. Training SE on data from WC corresponds to W2S-I and training SE on data from SE corresponds to self-improvement. Overall, this setup compares W2S-I by WC data with self-improvement by SE data.