MindStar: Enhancing Math Reasoning in Pre-trained LLMs at Inference Time

Noah's Ark Laboratory

*Equal Contribution
MindStar

M*: A searching framework for inference time step reasoning. A: Each time we gather questions and previous reasoning steps to the LLMs and sample N next reasoning steps. B: We organize the reasoning process as a tree. Each node represents either question (the root node), answers (leaf nodes), or reasoning steps (all other nodes). A searching method traverses the reasoning tree and select a node to expand. We add the reasoning step of the selected node back to the prompt for next query step. We stop the generation processes until either the answer is find or the maximum consumption is reached.

Abstract

Although Large Language Models (LLMs) achieve remarkable performance across various tasks, they often struggle with complex reasoning tasks, such as answering mathematical questions. Recent efforts to address this issue have primarily focused on leveraging mathematical datasets through supervised fine-tuning or self-improvement techniques. However, these methods often depend on high-quality datasets that are difficult to prepare, or they require substantial computational resources for fine-tuning. Inspired by findings that LLMs know how to produce the right answer but struggle to select the correct reasoning path, we propose a purely inference-based searching method---MindStar (M*). This method formulates reasoning tasks as searching problems and proposes two search ideas to identify the optimal reasoning paths. We evaluate the M* framework on both the GSM8K and MATH datasets, comparing its performance with existing open and closed-source LLMs. Our results demonstrate that M* significantly enhances the reasoning abilities of open-source models, such as Llama-2-13B and Mistral-7B, and achieves comparable performance to GPT-3.5 and Grok-1, but with substantially reduced model size and computational costs.

inference_savings
MATH accuracy of different LLMs. M* on LLaMA-2-13B achieves similar performance as GPT-3.5 (4-shot) while saving approximately 200 times the computational resources.

Can we enhance LLM reasoning with output selection?

To explore this, we conduct an experiment utilizing different reward models to assist LLM for output selection. Here, we leverage the Outcome-supervised Reward Model (ORM), which scores the entirety of reasoning solutions, and the Process-supervised Reward Model (PRM), which scores each individual reasoning step, for the selection of reasoning solutions. Initially, we apply both the ORM and the PRM to select the final answer from multiple sampled chain-of-thoughts (CoT) solutions. The figure below shows that PRM selects better reasoning answers than ORM. Additionally, we employ the PRM to assist the LLM in a tree-of-thought context; Rather than generating the complete solution, the LLM produces multiple intermediate steps. The PRM then scores these steps and selects the best, facilitating the LLM in proceeding generation from a promising step. Our results demonstrate that step-level selection outperforms the two CoT selection baselines significantly.

prmvsorm
Different reward models for LLMs' output selections on MATH dataset. The x-axis denotes the total number of generated outputs

M*: Think and Reflect Step by Step

Reasoning Node Expansion

Given that we select a reasoning node \(n_d\) to expand, we design a prompt template Example 3.1 in order to collect next steps from LLMs. As shown in the example, the LLM takes the original question as \{question\} and the current reasoning path as \{answer\} in the prompt. Note that in the first iteration of the algorithm, the selected node is the root containing the question only, and therefore the \{answer\} is empty. For the reasoning path \(n_d\), the LLM generates \(N\) multiple intermediate steps \(e^1_{d}, e^2_{d}, \dots, e^N_{d}\) for the given prompt and we append them as the children node of the current node. In the next step of the algorithm, the new child nodes will be assessed, and a new node will be selected for further expansion. We also acknowledge that one alternative for generating the steps is fine-tuning the LLM using step tokens. However, it could potentially degrade the LLM's reasoning ability and, more importantly, is not aligned with the focus of this paper which, is enhancing the LLM without any weight modification.

example3.1

Reasoning Path Selection

Following the reasoning node expansion, we use the pre-trained PRM \(\mathcal{P}\) to reflect each newly generated step. The PRM takes the path \(n_d\) and the steps \(e_d\) as inputs and returns the corresponding reward value. After the evaluation, we require a tree search algorithm to select the next node for expansion. Note that our framework is agnostic to the search algorithm, and in this work, we instantiate it with two tree search methods, namely beam search and Levin tree search. Additionally, we introduce an ensemble method of M* search as an extension - Forest Search.

algo

Math Reasoning Benchmarks

math
Comparison results of various schemes on the GSM8K and MATH reasoning benchmarks are presented. The number for each entry is the problem solve percentage. The notation SC@32 denotes self-consistency across 32 candidate results, while \(n\)-shot indicates results from few-shot examples. CoT-SC@16 refers to self-consistency on 16 Chain of Thought (CoT) candidate results. BS@16 represents the beam search method, involving 16 candidates at each step-level, and LevinTS@16 details the Levin Tree Search method with the same number of candidates. Notably, the most recent result for the GPT-4 on the MATH dataset is reported as GPT-4-turbo-0409, which we highlight as it represents the best performance within the GPT-4 family.

Scaling Results

scale
We study how M* performance scales with different parameters. In (a), We study how M* performance scales with the number of step-level candidates. We choose Llama-2-13B with BS as the base model and search algorithm, respectively. In (b), we show base LLM model size vs. PRM model size. The red dots represents performance across various base model sizes using PRM-13B, while the purple dots indicates performance with PRM-7B. The grey area shows the performance improvements achieved by increasing the size of the PRM model. In (c), we present forest search results.

Llama Scaling Results

In our investigation of scaling laws within the Llama family of models, notably Llama-2 and Llama-3, we applied the M* method to observe its impact on performance improvement relative to model size. As illustrated in below, the application of M* substantially enhances the performance of the Llama-2 model, aligning its scaling trajectory closer to that of the Llama-3 model. This improvement in scaling efficiency through the M* method is significant because it suggests that the reasoning capabilities of LLMs can be enhanced without necessarily increasing the volume of high-quality training data. Instead, the focus shifts toward selecting right responses, thereby conserving resources while still achieving competitive performance metrics.

llama
Scaling laws for Llama-2 and Llama-3 model families on MATH datasets. The results are all reported from their original resources. We use the Scipy tool and a logarithm function to compute the fitting curve.

BibTeX

@article{kang2024mindstar,
  title={MindStar: Enhancing Math Reasoning in Pre-trained LLMs at Inference Time},
  author={Kang, Jikun and Li, Xin Zhe and Chen, Xi and Kazemi, Amirreza and Sun, Qianyi and Chen, Boxing and Li, Dong and He, Xu and He, Quan and Wen, Feng and Hao, Jianye and Yao, Jun},
  journal={arXiv preprint arXiv:2405.16265},
  year={2024}
}