Skip to main content

Language Agent Tree Search

Open In Colab Open on GitHub

Language Agent Tree Search (LATS), by Zhou, et. al, is a general LLM agent search algorithm that combines reflection/evaluation and search (specifically Monte-Carlo tree search) to achieve stronger overall task performance by leveraging inference-time compute.

It has four main phases consisting of six steps:

  1. Select: pick the best next state to progress from, based on its aggregate value.
  2. Expand and simulate: sample n potential actions to take and execute them in parallel.
  3. Reflect + Evaluate: observe the outcomes of these actions and score the decisions based on reflection (and possibly external feedback if available)
  4. Backpropagate: update the scores of the root trajectories based on the outcomes.
lats
import json
import logging
import os
import uuid
from typing import Any, Dict, List

from autogen import AssistantAgent, ConversableAgent, GroupChat, UserProxyAgent, config_list_from_json

Configure logging

logging.basicConfig(level=logging.INFO)

Set environment variables

os.environ["AUTOGEN_USE_DOCKER"] = "0"  # Disable Docker usage globally for Autogen
os.environ["OPENAI_API_KEY"] = "YOUR_API_KEY"

Prerequisites

Install autogen (for the LLM framework and agents)

Required packages: autogen

Please ensure these packages are installed before running this script

Directly create the config_list with the API key

config_list = [{"model": "gpt-4o-mini", "api_key": "YOUR_API_KEY"}]
if not config_list:
raise ValueError("Failed to create configuration. Please check the API key.")

Reflection Class

The reflection chain will score agent outputs based on the decision and the tool responses.

from pydantic import BaseModel, Field


class Reflection(BaseModel):
reflections: str = Field(
description="The critique and reflections on the sufficiency, superfluency,"
" and general quality of the response"
)
score: int = Field(
description="Score from 0-10 on the quality of the candidate response.",
gte=0,
lte=10,
)
found_solution: bool = Field(description="Whether the response has fully solved the question or task.")

def as_message(self):
return {"role": "human", "content": f"Reasoning: {self.reflections}\nScore: {self.score}"}

@property
def normalized_score(self) -> float:
return self.score / 10.0

Tree State

LATS is based on a (greedy) Monte-Carlo tree search. For each search steps, it picks the node with the highest “upper confidence bound”, which is a metric that balances exploitation (highest average reward) and exploration (lowest visits). Starting from that node, it generates N (5 in this case) new candidate actions to take, and adds them to the tree. It stops searching either when it has generated a valid solution OR when it has reached the maximum number of rollouts (search tree depth).

import math
import os
from collections import deque
from typing import Optional
class Node:
def __init__(
self,
messages: List[Dict[str, str]],
reflection: Optional[Reflection] = None,
parent: Optional["Node"] = None,
):
self.messages = messages
self.parent = parent
self.children: List["Node"] = []
self.value = 0.0
self.visits = 0
self.reflection = reflection
self.depth = parent.depth + 1 if parent is not None else 1
self._is_solved = reflection.found_solution if reflection else False
if self._is_solved:
self._mark_tree_as_solved()
if reflection:
self.backpropagate(reflection.normalized_score)

def __repr__(self) -> str:
return (
f"<Node value={self.value:.2f}, visits={self.visits}," f" depth={self.depth}, is_solved={self._is_solved}>"
)

@property
def is_solved(self) -> bool:
"""If any solutions exist, we can end the search."""
return self._is_solved

@property
def is_terminal(self):
return not self.children

@property
def best_child(self):
"""Select the child with the highest UCT to search next."""
if not self.children:
return None
all_nodes = self._get_all_children()
return max(all_nodes, key=lambda child: child.upper_confidence_bound())

@property
def best_child_score(self):
"""Return the child with the highest value."""
if not self.children:
return None
return max(self.children, key=lambda child: int(child.is_solved) * child.value)

@property
def height(self) -> int:
"""Check for how far we've rolled out the tree."""
if self.children:
return 1 + max([child.height for child in self.children])
return 1

def upper_confidence_bound(self, exploration_weight=1.0):
"""Return the UCT score. This helps balance exploration vs. exploitation of a branch."""
if self.parent is None:
raise ValueError("Cannot obtain UCT from root node")
if self.visits == 0:
return self.value
# Encourages exploitation of high-value trajectories
average_reward = self.value / self.visits
exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
return average_reward + exploration_weight * exploration_term

def backpropagate(self, reward: float):
"""Update the score of this node and its parents."""
node = self
while node:
node.visits += 1
node.value = (node.value * (node.visits - 1) + reward) / node.visits
node = node.parent

def get_messages(self, include_reflections: bool = True):
if include_reflections and self.reflection:
return self.messages + [self.reflection.as_message()]
return self.messages

def get_trajectory(self, include_reflections: bool = True) -> List[Dict[str, str]]:
"""Get messages representing this search branch."""
messages = []
node = self
while node:
messages.extend(node.get_messages(include_reflections=include_reflections)[::-1])
node = node.parent
# Reverse the final back-tracked trajectory to return in the correct order
return messages[::-1] # root solution, reflection, child 1, ...

def _get_all_children(self):
all_nodes = []
nodes = deque()
nodes.append(self)
while nodes:
node = nodes.popleft()
all_nodes.extend(node.children)
for n in node.children:
nodes.append(n)
return all_nodes

def get_best_solution(self):
"""Return the best solution from within the current sub-tree."""
all_nodes = [self] + self._get_all_children()
best_node = max(
all_nodes,
# We filter out all non-terminal, non-solution trajectories
key=lambda node: int(node.is_terminal and node.is_solved) * node.value,
)
return best_node

def _mark_tree_as_solved(self):
parent = self.parent
while parent:
parent._is_solved = True
parent = parent.parent

The main component is the tree, represented by the root node.

from typing_extensions import TypedDict


class TreeState(TypedDict):
# The full tree
root: Node
# The original input
input: str

Define Language Agent

Our agent will have three primary LLM-powered processes:

  1. Reflect: score the action based on the tool response.
  2. Initial response: to create the root node and start the search.
  3. Expand: generate 5 candidate “next steps” from the best spot in the current tree

For more “Grounded” tool applications (such as code synthesis), you could integrate code execution into the reflection/reward step. This type of external feedback is very useful.

Tools

For our example, we will give the language agent a search engine.

Define the UserProxyAgent with web search / tool-use capability

user_proxy = UserProxyAgent(
name="user",
human_input_mode="NEVER",
max_consecutive_auto_reply=10,
code_execution_config={
"work_dir": "web",
"use_docker": False,
},
)

Create a ConversableAgent without tools

assistant_agent = ConversableAgent(
name="assistant_agent",
system_message="You are an AI assistant capable of helping with various tasks.",
human_input_mode="NEVER",
code_execution_config=False,
)

Reflection

Self-reflection allows the agent to boostrap, improving its future responses based on the outcome of previous ones. In agents this is more powerful since it can use external feedback to improve.

reflection_prompt = """
Reflect and grade the assistant response to the user question below.
User question: {input}
Assistant response: {candidate}

Provide your reflection in the following format:
Reflections: [Your detailed critique and reflections]
Score: [A score from 0-10]
Found Solution: [true/false]
"""
reflection_agent = AssistantAgent(
name="reflection_agent",
system_message="You are an AI assistant that reflects on and grades responses.",
llm_config={
"config_list": config_list,
"temperature": 0.2,
},
)
def reflection_chain(inputs: Dict[str, Any]) -> Reflection:
try:
candidate_content = ""
if "candidate" in inputs:
candidate = inputs["candidate"]
if isinstance(candidate, list):
candidate_content = (
candidate[-1]["content"]
if isinstance(candidate[-1], dict) and "content" in candidate[-1]
else str(candidate[-1])
)
elif isinstance(candidate, dict):
candidate_content = candidate.get("content", str(candidate))
elif isinstance(candidate, str):
candidate_content = candidate
else:
candidate_content = str(candidate)

formatted_prompt = [
{"role": "system", "content": "You are an AI assistant that reflects on and grades responses."},
{
"role": "user",
"content": reflection_prompt.format(input=inputs.get("input", ""), candidate=candidate_content),
},
]
response = reflection_agent.generate_reply(formatted_prompt)

# Parse the response
response_str = str(response)
lines = response_str.split("\n")
reflections = next((line.split(": ", 1)[1] for line in lines if line.startswith("Reflections:")), "")
score_str = next((line.split(": ", 1)[1] for line in lines if line.startswith("Score:")), "0")
try:
if "/" in score_str:
numerator, denominator = map(int, score_str.split("/"))
score = int((numerator / denominator) * 10)
else:
score = int(score_str)
except ValueError:
logging.warning(f"Invalid score value: {score_str}. Defaulting to 0.")
score = 0

found_solution = next(
(line.split(": ", 1)[1].lower() == "true" for line in lines if line.startswith("Found Solution:")), False
)

if not reflections:
logging.warning("No reflections found in the response. Using default values.")
reflections = "No reflections provided."

return Reflection(reflections=reflections, score=score, found_solution=found_solution)
except Exception as e:
logging.error(f"Error in reflection_chain: {str(e)}", exc_info=True)
return Reflection(reflections=f"Error in reflection: {str(e)}", score=0, found_solution=False)

Initial Response

We start with a single root node, generated by this first step. It responds to the user input either with a tool invocation or a response.

Create Autogen agents

assistant = AssistantAgent(name="assistant", llm_config={"config_list": config_list}, code_execution_config=False)
user = UserProxyAgent(
name="user",
human_input_mode="NEVER",
max_consecutive_auto_reply=10,
code_execution_config={"work_dir": "web", "use_docker": False},
)

Define a function to create the initial prompt

def create_initial_prompt(input_text):
return [
{"role": "system", "content": "You are an AI assistant."},
{"role": "user", "content": input_text},
]

Function to generate initial response

def generate_initial_response(state: TreeState) -> TreeState:
chat_messages = create_initial_prompt(state["input"])
try:
# Ensure chat_messages is a list of dictionaries
if not isinstance(chat_messages, list):
chat_messages = [{"role": "user", "content": chat_messages}]

logging.info(f"Generating initial response for input: {state['input']}")
logging.debug(f"Chat messages: {chat_messages}")

response = assistant.generate_reply(chat_messages)
logging.debug(f"Raw response from assistant: {response}")

# Ensure response is properly formatted as a string
if isinstance(response, str):
content = response
elif isinstance(response, dict) and "content" in response:
content = response["content"]
elif isinstance(response, list) and len(response) > 0:
content = response[-1].get("content", str(response[-1]))
else:
content = str(response)

content = content.strip()
if not content:
raise ValueError("Generated content is empty after processing")

logging.debug(f"Processed content: {content[:100]}...") # Log first 100 chars

# Generate reflection
reflection_input = {"input": state["input"], "candidate": content}
logging.info("Generating reflection on the initial response")
reflection = reflection_chain(reflection_input)
logging.debug(f"Reflection generated: {reflection}")

# Create Node with messages as a list containing a single dict
messages = [{"role": "assistant", "content": content}]
root = Node(messages=messages, reflection=reflection)

logging.info("Initial response and reflection generated successfully")
return TreeState(root=root, input=state["input"])

except Exception as e:
logging.error(f"Error in generate_initial_response: {str(e)}", exc_info=True)
return TreeState(root=None, input=state["input"])

Example usage of the generate_initial_response function

initial_prompt = "Why is the sky blue?"
initial_state = TreeState(input=initial_prompt, root=None)
result_state = generate_initial_response(initial_state)
if result_state["root"] is not None:
print(result_state["root"].messages[0]["content"])
else:
print("Failed to generate initial response.")

Starting Node

We will package up the candidate generation and reflection in a single node of our graph. This is represented by the following function:

Define the function to generate the initial response

# Define the function to generate the initial response


def generate_initial_response(state: TreeState) -> TreeState:
"""Generate the initial candidate response using Autogen components."""
assistant = AssistantAgent(name="assistant", llm_config={"config_list": config_list}, code_execution_config=False)

# Generate initial response
initial_message = [
{"role": "system", "content": "You are an AI assistant."},
{"role": "user", "content": state["input"]},
]

try:
logging.info(f"Generating initial response for input: {state['input']}")
response = assistant.generate_reply(initial_message)
logging.debug(f"Raw response from assistant: {response}")

# Ensure response is properly formatted as a string
if isinstance(response, str):
content = response
elif isinstance(response, dict):
content = response.get("content", "")
if not content:
content = json.dumps(response)
elif isinstance(response, list):
content = " ".join(str(item) for item in response)
else:
content = str(response)

# Ensure content is always a string and not empty
content = content.strip()
if not content:
raise ValueError("Generated content is empty after processing")

logging.debug(f"Final processed content (first 100 chars): {content[:100]}...")

# Generate reflection
logging.info("Generating reflection on the initial response")
reflection_input = {"input": state["input"], "candidate": content}
reflection = reflection_chain(reflection_input)
logging.debug(f"Reflection generated: {reflection}")

if not isinstance(reflection, Reflection):
raise TypeError(f"Invalid reflection type: {type(reflection)}. Expected Reflection, got {type(reflection)}")

# Create Node with messages as a list containing a single dict
messages = [{"role": "assistant", "content": content}]
logging.debug(f"Creating Node with messages: {messages}")
root = Node(messages=messages, reflection=reflection)
logging.info("Initial response and reflection generated successfully")
logging.debug(f"Created root node: {root}")
return TreeState(root=root, input=state["input"])

except Exception as e:
logging.error(f"Error in generate_initial_response: {str(e)}", exc_info=True)
return TreeState(root=None, input=state["input"])

Candidate Generation

The following code prompts the same LLM to generate N additional candidates to check.

This generates N candidate values for a single input to sample actions from the environment

def generate_candidates(messages: list, config: dict):
n = config.get("N", 5)
assistant = AssistantAgent(name="assistant", llm_config={"config_list": config_list}, code_execution_config=False)

candidates = []
for _ in range(n):
try:
# Use the assistant to generate a response
last_message = messages[-1]["content"] if messages and isinstance(messages[-1], dict) else str(messages[-1])
response = assistant.generate_reply([{"role": "user", "content": last_message}])
if isinstance(response, str):
candidates.append(response)
elif isinstance(response, dict) and "content" in response:
candidates.append(response["content"])
elif (
isinstance(response, list) and response and isinstance(response[-1], dict) and "content" in response[-1]
):
candidates.append(response[-1]["content"])
else:
candidates.append(str(response))
except Exception as e:
logging.error(f"Error generating candidate: {str(e)}")
candidates.append("Failed to generate candidate.")

if not candidates:
logging.warning("No candidates were generated.")

return candidates


expansion_chain = generate_candidates

Candidate generation node

We will package the candidate generation and reflection steps in the following “expand” node. We do all the operations as a batch process to speed up execution.

def expand(state: TreeState, config: Dict[str, Any]) -> dict:
root = state["root"]
best_candidate: Node = root.best_child if root.children else root
messages = best_candidate.get_trajectory()

# Generate N candidates using Autogen's generate_candidates function
new_candidates = generate_candidates(messages, config)

# Reflect on each candidate using Autogen's AssistantAgent
reflections = []
for candidate in new_candidates:
reflection = reflection_chain({"input": state["input"], "candidate": candidate})
reflections.append(reflection)

# Grow tree
child_nodes = [
Node([{"role": "assistant", "content": candidate}], parent=best_candidate, reflection=reflection)
for candidate, reflection in zip(new_candidates, reflections)
]
best_candidate.children.extend(child_nodes)

# We have already extended the tree directly, so we just return the state
return state

Create Tree

With those two nodes defined, we are ready to define the tree. After each agent step, we have the option of finishing.

from typing import Any, Dict, Literal


def should_loop(state: Dict[str, Any]) -> Literal["expand", "end"]:
"""Determine whether to continue the tree search."""
root = state["root"]
if root.is_solved:
return "end"
if root.height > 5:
return "end"
return "expand"


def run_lats(input_query: str, max_iterations: int = 10):
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

try:

state = {"input": input_query, "root": None}
try:
state = generate_initial_response(state)
if not isinstance(state, dict) or "root" not in state or state["root"] is None:
logger.error("Initial response generation failed or returned invalid state")
return "Failed to generate initial response."
logger.info("Initial response generated successfully")
except Exception as e:
logger.error(f"Error generating initial response: {str(e)}", exc_info=True)
return "Failed to generate initial response due to an unexpected error."

for iteration in range(max_iterations):
action = should_loop(state)
if action == "end":
logger.info(f"Search ended after {iteration + 1} iterations")
break
try:
state = expand(
state,
{
"N": 5,
"input_query": input_query,
},
)
logger.info(f"Completed iteration {iteration + 1}")
except Exception as e:
logger.error(f"Error during iteration {iteration + 1}: {str(e)}", exc_info=True)
continue

if not isinstance(state, dict) or "root" not in state or state["root"] is None:
return "No valid solution found due to an error in the search process."

solution_node = state["root"].get_best_solution()
best_trajectory = solution_node.get_trajectory(include_reflections=False)
if not best_trajectory:
return "No solution found in the search process."

result = (
best_trajectory[-1].get("content") if isinstance(best_trajectory[-1], dict) else str(best_trajectory[-1])
)
logger.info("LATS search completed successfully")
return result
except Exception as e:
logger.error(f"An unexpected error occurred during LATS execution: {str(e)}", exc_info=True)
return f"An unexpected error occurred: {str(e)}"

Example usage:

result = run_lats(“Write a research report on deep learning.”)

print(result)

Example usage of the LATS algorithm with Autogen

import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def run_lats_example(question):
try:
logger.info(f"Processing question: {question}")
result = run_lats(question)
logger.info(f"LATS algorithm completed. Result: {result[:100]}...") # Log first 100 chars of result
print(f"Question: {question}")
print(f"Answer: {result}")
except Exception as e:
logger.error(f"An error occurred while processing the question: {str(e)}", exc_info=True)
print(f"An error occurred: {str(e)}")
finally:
print("---")

List of example questions

questions = [
"Explain how epigenetic modifications can influence gene expression across generations and the implications for evolution.",
"Discuss the challenges of grounding ethical theories in moral realism, especially in light of the is-ought problem introduced by Hume.",
"How does the Riemann Hypothesis relate to the distribution of prime numbers, and why is it significant in number theory?",
"Describe the challenges and theoretical underpinnings of unifying general relativity with quantum mechanics, particularly focusing on string theory and loop quantum gravity.",
]

Run LATS algorithm for each question

for i, question in enumerate(questions, 1):
print(f"\nExample {i}:")
run_lats_example(question)

logger.info("All examples processed.")

Conclusion

Congrats on implementing LATS! This is a technique that can be reasonably fast and effective at solving complex agent tasks. A few notes that you probably observed above:

  1. While LATS is effective, the tree rollout process can require additional inference compute time. If you plan to integrate this into a production application, consider streaming intermediate steps to allow users to see the thought process and access intermediate results. Alternatively, you could use it to generate fine-tuning data to enhance single-shot accuracy and avoid lengthy rollouts. The cost of using LATS has significantly decreased since its initial proposal and is expected to continue decreasing.

  2. The effectiveness of the candidate selection process depends on the quality of the rewards generated. In this example, we exclusively use self-reflection as feedback, but if you have access to external feedback sources (such as code test execution), those should be incorporated as suggested above.