Language Agent Tree Search (LATS),
by Zhou, et. al, is a general LLM agent search algorithm that combines
reflection/evaluation and search (specifically Monte-Carlo tree search)
to achieve stronger overall task performance by leveraging
inference-time compute.It has four main phases consisting of six steps:
Select: pick the best next state to progress from, based on its
aggregate value.
Expand and simulate: sample n potential actions to take and execute
them in parallel.
Reflect + Evaluate: observe the outcomes of these actions and score
the decisions based on reflection (and possibly external feedback if
available)
Backpropagate: update the scores of the root trajectories based on
the outcomes.
The reflection chain will score agent outputs based on the decision and
the tool responses.
Copy
from pydantic import BaseModel, Fieldclass Reflection(BaseModel): reflections: str = Field( description="The critique and reflections on the sufficiency, superfluency, and general quality of the response" ) score: int = Field( description="Score from 0-10 on the quality of the candidate response.", gte=0, lte=10, ) found_solution: bool = Field(description="Whether the response has fully solved the question or task.") def as_message(self): return {"role": "human", "content": f"Reasoning: {self.reflections}\nScore: {self.score}"} @property def normalized_score(self) -> float: return self.score / 10.0
LATS is based on a (greedy) Monte-Carlo tree search. For each search
steps, it picks the node with the highest “upper confidence bound”,
which is a metric that balances exploitation (highest average reward)
and exploration (lowest visits). Starting from that node, it generates N
(5 in this case) new candidate actions to take, and adds them to the
tree. It stops searching either when it has generated a valid solution
OR when it has reached the maximum number of rollouts (search tree
depth).
Our agent will have three primary LLM-powered processes:
Reflect: score the action based on the tool response.
Initial response: to create the root node and start the search.
Expand: generate 5 candidate “next steps” from the best spot in the
current tree
For more “Grounded” tool applications (such as code synthesis), you
could integrate code execution into the reflection/reward step. This
type of external feedback is very useful.
assistant_agent = ConversableAgent( name="assistant_agent", system_message="You are an AI assistant capable of helping with various tasks.", human_input_mode="NEVER", code_execution_config=False,)
Self-reflection allows the agent to bootstrap, improving its future
responses based on the outcome of previous ones. In agents this is more
powerful since it can use external feedback to improve.
Copy
reflection_prompt = """Reflect and grade the assistant response to the user question below.User question: {input}Assistant response: {candidate}Provide your reflection in the following format:Reflections: [Your detailed critique and reflections]Score: [A score from 0-10]Found Solution: [true/false]"""
Copy
reflection_agent = AssistantAgent( name="reflection_agent", system_message="You are an AI assistant that reflects on and grades responses.", llm_config={ "config_list": config_list, "temperature": 0.2, },)
Copy
def reflection_chain(inputs: Dict[str, Any]) -> Reflection: try: candidate_content = "" if "candidate" in inputs: candidate = inputs["candidate"] if isinstance(candidate, list): candidate_content = ( candidate[-1]["content"] if isinstance(candidate[-1], dict) and "content" in candidate[-1] else str(candidate[-1]) ) elif isinstance(candidate, dict): candidate_content = candidate.get("content", str(candidate)) elif isinstance(candidate, str): candidate_content = candidate else: candidate_content = str(candidate) formatted_prompt = [ {"role": "system", "content": "You are an AI assistant that reflects on and grades responses."}, { "role": "user", "content": reflection_prompt.format(input=inputs.get("input", ""), candidate=candidate_content), }, ] response = reflection_agent.generate_reply(formatted_prompt) # Parse the response response_str = str(response) lines = response_str.split("\n") reflections = next((line.split(": ", 1)[1] for line in lines if line.startswith("Reflections:")), "") score_str = next((line.split(": ", 1)[1] for line in lines if line.startswith("Score:")), "0") try: if "/" in score_str: numerator, denominator = map(int, score_str.split("/")) score = int((numerator / denominator) * 10) else: score = int(score_str) except ValueError: logging.warning(f"Invalid score value: {score_str}. Defaulting to 0.") score = 0 found_solution = next( (line.split(": ", 1)[1].lower() == "true" for line in lines if line.startswith("Found Solution:")), False ) if not reflections: logging.warning("No reflections found in the response. Using default values.") reflections = "No reflections provided." return Reflection(reflections=reflections, score=score, found_solution=found_solution) except Exception as e: logging.error(f"Error in reflection_chain: {e!s}", exc_info=True) return Reflection(reflections=f"Error in reflection: {e!s}", score=0, found_solution=False)
Example usage of the generate_initial_response function
Copy
initial_prompt = "Why is the sky blue?"initial_state = TreeState(input=initial_prompt, root=None)result_state = generate_initial_response(initial_state)if result_state["root"] is not None: print(result_state["root"].messages[0]["content"])else: print("Failed to generate initial response.")
The following code prompts the same LLM to generate N additional
candidates to check.This generates N candidate values for a single input to sample actions
from the environment
Copy
def generate_candidates(messages: list, config: dict): n = config.get("N", 5) assistant = AssistantAgent(name="assistant", llm_config={"config_list": config_list}, code_execution_config=False) candidates = [] for _ in range(n): try: # Use the assistant to generate a response last_message = messages[-1]["content"] if messages and isinstance(messages[-1], dict) else str(messages[-1]) response = assistant.generate_reply([{"role": "user", "content": last_message}]) if isinstance(response, str): candidates.append(response) elif isinstance(response, dict) and "content" in response: candidates.append(response["content"]) elif ( isinstance(response, list) and response and isinstance(response[-1], dict) and "content" in response[-1] ): candidates.append(response[-1]["content"]) else: candidates.append(str(response)) except Exception as e: logging.error(f"Error generating candidate: {e!s}") candidates.append("Failed to generate candidate.") if not candidates: logging.warning("No candidates were generated.") return candidatesexpansion_chain = generate_candidates
We will package the candidate generation and reflection steps in the
following “expand” node. We do all the operations as a batch process to
speed up execution.
Copy
def expand(state: TreeState, config: Dict[str, Any]) -> dict: root = state["root"] best_candidate: Node = root.best_child if root.children else root messages = best_candidate.get_trajectory() # Generate N candidates using Autogen's generate_candidates function new_candidates = generate_candidates(messages, config) # Reflect on each candidate using Autogen's AssistantAgent reflections = [] for candidate in new_candidates: reflection = reflection_chain({"input": state["input"], "candidate": candidate}) reflections.append(reflection) # Grow tree child_nodes = [ Node([{"role": "assistant", "content": candidate}], parent=best_candidate, reflection=reflection) for candidate, reflection in zip(new_candidates, reflections) ] best_candidate.children.extend(child_nodes) # We have already extended the tree directly, so we just return the state return state
With those two nodes defined, we are ready to define the tree. After
each agent step, we have the option of finishing.
Copy
from typing import Any, Dict, Literaldef should_loop(state: Dict[str, Any]) -> Literal["expand", "end"]: """Determine whether to continue the tree search.""" root = state["root"] if root.is_solved: return "end" if root.height > 5: return "end" return "expand"def run_lats(input_query: str, max_iterations: int = 10): import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) try: state = {"input": input_query, "root": None} try: state = generate_initial_response(state) if not isinstance(state, dict) or "root" not in state or state["root"] is None: logger.error("Initial response generation failed or returned invalid state") return "Failed to generate initial response." logger.info("Initial response generated successfully") except Exception as e: logger.error(f"Error generating initial response: {e!s}", exc_info=True) return "Failed to generate initial response due to an unexpected error." for iteration in range(max_iterations): action = should_loop(state) if action == "end": logger.info(f"Search ended after {iteration + 1} iterations") break try: state = expand( state, { "N": 5, "input_query": input_query, }, ) logger.info(f"Completed iteration {iteration + 1}") except Exception as e: logger.error(f"Error during iteration {iteration + 1}: {e!s}", exc_info=True) continue if not isinstance(state, dict) or "root" not in state or state["root"] is None: return "No valid solution found due to an error in the search process." solution_node = state["root"].get_best_solution() best_trajectory = solution_node.get_trajectory(include_reflections=False) if not best_trajectory: return "No solution found in the search process." result = ( best_trajectory[-1].get("content") if isinstance(best_trajectory[-1], dict) else str(best_trajectory[-1]) ) logger.info("LATS search completed successfully") return result except Exception as e: logger.error(f"An unexpected error occurred during LATS execution: {e!s}", exc_info=True) return f"An unexpected error occurred: {e!s}"
Example usage:result = run_lats(“Write a research report on deep learning.”)print(result)
questions = [ "Explain how epigenetic modifications can influence gene expression across generations and the implications for evolution.", "Discuss the challenges of grounding ethical theories in moral realism, especially in light of the is-ought problem introduced by Hume.", "How does the Riemann Hypothesis relate to the distribution of prime numbers, and why is it significant in number theory?", "Describe the challenges and theoretical underpinnings of unifying general relativity with quantum mechanics, particularly focusing on string theory and loop quantum gravity.",]
Congrats on implementing LATS! This is a technique that can be
reasonably fast and effective at solving complex agent tasks. A few
notes that you probably observed above:
While LATS is effective, the tree rollout process can require
additional inference compute time. If you plan to integrate this
into a production application, consider streaming intermediate steps
to allow users to see the thought process and access intermediate
results. Alternatively, you could use it to generate fine-tuning
data to enhance single-shot accuracy and avoid lengthy rollouts. The
cost of using LATS has significantly decreased since its initial
proposal and is expected to continue decreasing.
The effectiveness of the candidate selection process depends on the
quality of the rewards generated. In this example, we exclusively
use self-reflection as feedback, but if you have access to external
feedback sources (such as code test execution), those should be
incorporated as suggested above.