Parallel MCTS
MCTS is an iterative sampling algorithm, where the reward found in each iteration affects sampling in later iterations. While perfect parallel execution isn’t possible, we can achieve quite good performance using the so-called lock-free method [1], where multiple multiple search threads in the same CPU (the main node) are running MCTS concurrently, each one taking turns editing the search graph. We will implement this in detail later in the tutorial, but in brief, instead of computing the (usually expensive) reward function, each search thread on the main node sends a request to a group of worker CPUs (the worker node) somewhere else that will do the actual computation, and while that thread is waiting for the result, other search threads can use the main CPU. As long as our execution time is significantly longer than the time spent sending and receiving those signals, we should see a performance boost!
Parallel search on a single machine
In order to parallelize the search on a local machine, we can nominate a group of CPUs in our own computer to be the worker node that performs reward function evaluations. We can coordinate the main and worker nodes using a producer-consumer queue. The main node will produce tasks (calls to the reward function) that get added to the queue, and the worker node will consume tasks from the queue and return the result to a shared database where the main node can look up the result. We’ll manage
this task queue with the Python utility celery
.
Here’s a schematic of how that infrastructure looks.
Steps to running a parallel search 1) Set up an in-memory database. 2) Package the reward function into a celery
app. 3) Define a CircuiTree
subclass that calls the reward function in (2). 4) Launch some workers. 5) Run the search script.
1. Database installation
For instance, we can a lightweight database called Redis (https://redis.io/). Follow the instructions here to install the database and command line utility, and test your installation by running
redis-cli ping
If you are using a Redis server hosted somewhere other than the default location (redis://localhost:6379/
), you can set the CELERY_BROKER_URL
environment variable to point to your server.
2. Making a celery
app with the reward function
The app is a Python script that tells celery
where the database is and which tasks it will be managing. For instance, here is an app script for the bistability design problem in the Getting Started tutorial.
# bistability_app.py
from celery import Celery
from circuitree.models import SimpleNetworkGrammar
import numpy as np
import redis
import os
from bistability import get_bistability_reward
# Address of the database Celery should use
database_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379/0")
database = redis.Redis.from_url(database_url)
if database.ping():
print(f"Connected to Redis database at {database_url}")
else:
raise ConnectionError(f"Could not connect to Redis database at {database_url}")
# Create the app
app = Celery("bistability", broker=database_url, backend=database_url)
grammar = SimpleNetworkGrammar(["A", "B"], ["activates", "inhibits"])
@app.task
def get_reward_celery(state: str, seed: int, expensive: bool = False) -> float:
"""Returns a reward value for the given state based on how many types of positive
feedback loops (PFLs) it contains. Same as `BistabilityTree.get_reward()`,
except this function is evaluated by a Celery worker."""
# Celery cannot pass Numpy random generators as arguments, so we pass a unique
# integer and use it to seed a high-quality random generator
hq_seed = np.random.SeedSequence(seed).generate_state(1)[0]
rg = np.random.default_rng(hq_seed)
return get_bistability_reward(state, grammar, rg, expensive)
We use the Celery
command to create an app that uses the Redis
database to pass messages (the broker
option) and store results (the backend
argument). The URL here points to the default location for a local database (port 6379
on the localhost
network). Any function with the @app.task
decorator becomes a celery
task that can be executed by a worker - we’ll see how this looks in the next section.
3. Calling the reward function as a celery
task
Unlike a normal function call, a call to a celery
task is asynchronous. This means that when the main node calls the function, it dispatches a task to the workers, and the result can be requested later. This uses different syntax - instead of running reward = get_reward(...)
directly, we run result = get_reward_celery.delay(...)
to dispatch the task from the main node to the workers. This immediately returns an AsyncResult
object that can be inspected to monitor progress. Then,
once we need the result, we call future.get()
and wait for the reward to arrive. While one thread is waiting for the reply, another thread can take over the main node and run a search iteration.
All we need to do in this step is make a Python file declaring a new subclass of CircuiTree
that uses the app. Here’s what that looks like in our bistability example - we’ll call it bistability_parallel.py
.
# bistability_parallel.py
from gevent import monkey
monkey.patch_all()
from bistability import BistabilityTree
from bistability_app import get_reward_celery
class ParallelBistabilityTree(BistabilityTree):
"""This class is identical to BistabilityTree except that it uses Celery to compute
rewards in parallel. This allows other threads to continue performing MCTS steps
while one thread waits for its reward calculation to finish."""
def get_reward(self, state, expensive=True):
# Generate a random seed and run the task in a Celery worker
seed = int(self.rg.integers(0, 2**32))
result = get_reward_celery.delay(state, seed, expensive=expensive)
reward = result.get()
return reward
Rather than use the built-in threading
module, which can only manage up to a few dozen threads, we will use the gevent
module, which can support thousands. To achieve this gevent
re-defines (“monkey-patches”) many of the built-in Python commands in order to support highly scalable “green threads.”
WARNING: Monkey-patching can have some sharp corners when combined with Celery. The lines from gevent import monkey
and monkey.patch_all()
have to be the first lines in the file where we define the class, but they cannot be in the same file where we define the Celery app. For this reason, we make a separate file just for this class.
4. Launching a worker node
We can launch a worker node using celery
’s command line interface. To do so, open a separate terminal, activate your virtual environment if you have one, cd
to the folder with the app, and run the following command, replacing the XX
with the number of CPUs to use. It’s good practice to use one or two fewer CPUs than the total number on your machine, since performance can paradoxically degrade if you try to use every single CPU at once.
# Launch a worker called 'worker1' with 'XX' CPUs, specifying the app with the '.app' suffix.
# If you supply the 'logfile' flag, the worker will write its logs to there
celery --app bistability_app.app multi start "worker1" --concurrency=XX --loglevel=INFO #--logfile="./worker1.log"
You can specify logging information like the log level and location of the log file as shown. You can alos use the flag --detach
to run the worker as a background process, but beware that Celery will not monitor it. You will need to find and kill the process yourself (on Linux, you can run ps aux | grep 'celery'
, note the process ID (pid
) of any running workers, and kill them with sudo kill -9 {pid}
).
5. Running a parallel search
Now we can run the search in parallel by running a script from the main node. For the bistability example, you could run python ./run_search_parallel.py
# run_search_parallel.py
from datetime import datetime
from pathlib import Path
from bistability_parallel import ParallelBistabilityTree
def main(
n_steps: int = 10_000,
n_threads: int = 8,
expensive: bool = True,
save_dir: str | Path = Path("./tree-backups"),
):
"""Finds bistable circuit designs using parallel MCTS."""
# Make a folder for backups
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True)
print("Running an MCTS search in parallel (see tutorial notebook #2)...")
tree = ParallelBistabilityTree(root="ABC::")
tree.search_mcts_parallel(
n_steps=n_steps, n_threads=n_threads, run_kwargs={"expensive": expensive}
)
print("Search complete!")
# Save the search graph to a GML file and the other attributes to a JSON file
today = datetime.now().strftime("%y%m%d")
save_stem = save_dir.joinpath(f"{today}_parallel_bistability_search_step{n_steps}")
print(f"Saving final tree to {save_stem}.{{gml,json}}")
tree.to_file(save_stem + ".gml", save_stem + ".json")
print("Done")
if __name__ == "__main__":
main()
That’s it! To analyze the results, we can read the object from the saved files using CircuiTree.from_file()
.
from circuitree import SimpleNetworkGrammar
from bistability import BistabilityTree
# Get the file paths for the data
data_dir = Path("./tree-backups")
gml_file = list(data_dir.glob("*parallel_bistability_search*.gml"))[0]
json_file = list(data_dir.glob("*parallel_bistability_search*.json"))[0]
# Read from file. Note that we need to specify the class of the grammar
tree = BistabilityTree.from_file(gml_file, json_file, grammar_cls=SimpleNetworkGrammar)
Distributed search in the cloud
The same framework we use to run the search in parallel on a local machine can be used to run a search across many machines, in the cloud!
There are a few differences. Notably, the in-memory database now lives in a remote machine. The communication between main and worker nodes can be a bottleneck in scaling up this infrastructure, so it is important that your database has fast, high-bandwidth networking and is on the same network as the main and worker nodes. Most cloud providers already have a solution for this (for example, as of May 2024, Amazon Elasticache + EC2). The main node has the search graph and makes backups, so it
will generally need higher memory, while worker nodes should have higher computing resources. Also, because Celery does not make it obvious where each task is run, you should take care that your backups are being saved to the correct location on the correct machine. Celery generally provides very robust logging as well, so be sure to specify the --logfile
option in the celery worker
command to take advantage of it.