Parallel MCTS

MCTS is an iterative sampling algorithm, where the reward found in each iteration affects sampling in later iterations. While perfect parallel execution isn’t possible, we can achieve quite good performance using the so-called lock-free method [1], where multiple multiple search threads in the same CPU (the main node) are running MCTS concurrently, each one taking turns editing the search graph. We will implement this in detail later in the tutorial, but in brief, instead of computing the (usually expensive) reward function, each search thread on the main node sends a request to a group of worker CPUs (the worker node) somewhere else that will do the actual computation, and while that thread is waiting for the result, other search threads can use the main CPU. As long as our execution time is significantly longer than the time spent sending and receiving those signals, we should see a performance boost!

Parallel search on a single machine

In order to parallelize the search on a local machine, we can nominate a group of CPUs in our own computer to be the worker node that performs reward function evaluations. We can coordinate the main and worker nodes using a producer-consumer queue. The main node will produce tasks (calls to the reward function) that get added to the queue, and the worker node will consume tasks from the queue and return the result to a shared database where the main node can look up the result. We’ll manage this task queue with the Python utility celery.

Here’s a schematic of how that infrastructure looks.

Local-Infrastructure

Steps to running a parallel search 1) Set up an in-memory database. 2) Package the reward function into a celery app. 3) Define a CircuiTree subclass that calls the reward function in (2). 4) Launch some workers. 5) Run the search script.

1. Database installation

For instance, we can a lightweight database called Redis (https://redis.io/). Follow the instructions here to install the database and command line utility, and test your installation by running

redis-cli ping

If you are using a Redis server hosted somewhere other than the default location (redis://localhost:6379/), you can set the CELERY_BROKER_URL environment variable to point to your server.

2. Making a celery app with the reward function

The app is a Python script that tells celery where the database is and which tasks it will be managing. For instance, here is an app script for the bistability design problem in the Getting Started tutorial.

# bistability_app.py

from celery import Celery
from circuitree.models import SimpleNetworkGrammar
import numpy as np
import redis
import os
from bistability import get_bistability_reward

# Address of the database Celery should use
database_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379/0")
database = redis.Redis.from_url(database_url)
if database.ping():
    print(f"Connected to Redis database at {database_url}")
else:
    raise ConnectionError(f"Could not connect to Redis database at {database_url}")

# Create the app
app = Celery("bistability", broker=database_url, backend=database_url)

grammar = SimpleNetworkGrammar(["A", "B"], ["activates", "inhibits"])


@app.task
def get_reward_celery(state: str, seed: int, expensive: bool = False) -> float:
    """Returns a reward value for the given state based on how many types of positive
    feedback loops (PFLs) it contains. Same as `BistabilityTree.get_reward()`,
    except this function is evaluated by a Celery worker."""

    # Celery cannot pass Numpy random generators as arguments, so we pass a unique
    # integer and use it to seed a high-quality random generator
    hq_seed = np.random.SeedSequence(seed).generate_state(1)[0]
    rg = np.random.default_rng(hq_seed)

    return get_bistability_reward(state, grammar, rg, expensive)

We use the Celery command to create an app that uses the Redis database to pass messages (the broker option) and store results (the backend argument). The URL here points to the default location for a local database (port 6379 on the localhost network). Any function with the @app.task decorator becomes a celery task that can be executed by a worker - we’ll see how this looks in the next section.

3. Calling the reward function as a celery task

Unlike a normal function call, a call to a celery task is asynchronous. This means that when the main node calls the function, it dispatches a task to the workers, and the result can be requested later. This uses different syntax - instead of running reward = get_reward(...) directly, we run result = get_reward_celery.delay(...) to dispatch the task from the main node to the workers. This immediately returns an AsyncResult object that can be inspected to monitor progress. Then, once we need the result, we call future.get() and wait for the reward to arrive. While one thread is waiting for the reply, another thread can take over the main node and run a search iteration.

All we need to do in this step is make a Python file declaring a new subclass of CircuiTree that uses the app. Here’s what that looks like in our bistability example - we’ll call it bistability_parallel.py.

# bistability_parallel.py

from gevent import monkey

monkey.patch_all()

from bistability import BistabilityTree
from bistability_app import get_reward_celery


class ParallelBistabilityTree(BistabilityTree):
    """This class is identical to BistabilityTree except that it uses Celery to compute
    rewards in parallel. This allows other threads to continue performing MCTS steps
    while one thread waits for its reward calculation to finish."""

    def get_reward(self, state, expensive=True):
        # Generate a random seed and run the task in a Celery worker
        seed = int(self.rg.integers(0, 2**32))
        result = get_reward_celery.delay(state, seed, expensive=expensive)
        reward = result.get()
        return reward

Rather than use the built-in threading module, which can only manage up to a few dozen threads, we will use the gevent module, which can support thousands. To achieve this gevent re-defines (“monkey-patches”) many of the built-in Python commands in order to support highly scalable “green threads.”

WARNING: Monkey-patching can have some sharp corners when combined with Celery. The lines from gevent import monkey and monkey.patch_all() have to be the first lines in the file where we define the class, but they cannot be in the same file where we define the Celery app. For this reason, we make a separate file just for this class.

4. Launching a worker node

We can launch a worker node using celery’s command line interface. To do so, open a separate terminal, activate your virtual environment if you have one, cd to the folder with the app, and run the following command, replacing the XX with the number of CPUs to use. It’s good practice to use one or two fewer CPUs than the total number on your machine, since performance can paradoxically degrade if you try to use every single CPU at once.

# Launch a worker called 'worker1' with 'XX' CPUs, specifying the app with the '.app' suffix.
# If you supply the 'logfile' flag, the worker will write its logs to there
celery --app bistability_app.app multi start "worker1" --concurrency=XX --loglevel=INFO #--logfile="./worker1.log"

You can specify logging information like the log level and location of the log file as shown. You can alos use the flag --detach to run the worker as a background process, but beware that Celery will not monitor it. You will need to find and kill the process yourself (on Linux, you can run ps aux | grep 'celery', note the process ID (pid) of any running workers, and kill them with sudo kill -9 {pid}).

Distributed search in the cloud

The same framework we use to run the search in parallel on a local machine can be used to run a search across many machines, in the cloud!

Distributed-Infrastructure

There are a few differences. Notably, the in-memory database now lives in a remote machine. The communication between main and worker nodes can be a bottleneck in scaling up this infrastructure, so it is important that your database has fast, high-bandwidth networking and is on the same network as the main and worker nodes. Most cloud providers already have a solution for this (for example, as of May 2024, Amazon Elasticache + EC2). The main node has the search graph and makes backups, so it will generally need higher memory, while worker nodes should have higher computing resources. Also, because Celery does not make it obvious where each task is run, you should take care that your backups are being saved to the correct location on the correct machine. Celery generally provides very robust logging as well, so be sure to specify the --logfile option in the celery worker command to take advantage of it.