{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Parallel MCTS" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "MCTS is an iterative sampling algorithm, where the reward found in each iteration affects sampling in later iterations. While perfect parallel execution isn't possible, we can achieve quite good performance using the so-called lock-free method [[1]](https://doi.org/10.1007/978-3-642-12993-3_2), where multiple multiple search threads in the same CPU (the *main node*) are running MCTS concurrently, each one taking turns editing the search graph. We will implement this in detail later in the tutorial, but in brief, instead of computing the (usually expensive) reward function, each search thread on the main node sends a request to a group of worker CPUs (the *worker node*) somewhere else that will do the actual computation, and while that thread is waiting for the result, other search threads can use the main CPU. As long as our execution time is significantly longer than the time spent sending and receiving those signals, we should see a performance boost!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Parallel search on a single machine\n", "\n", "In order to parallelize the search on a local machine, we can nominate a group of CPUs in our own computer to be the worker node that performs reward function evaluations. We can coordinate the main and worker nodes using a *producer-consumer* queue. The main node will produce tasks (calls to the reward function) that get added to the queue, and the worker node will consume tasks from the queue and return the result to a shared database where the main node can look up the result. We'll manage this task queue with the Python utility `celery`. \n", "\n", "Here's a schematic of how that infrastructure looks.\n", "\n", "![Local-Infrastructure](./local_parallel_infrastructure.png)\n", "\n", "__Steps to running a parallel search__\n", "1) Set up an in-memory database.\n", "2) Package the reward function into a `celery` app.\n", "3) Define a `CircuiTree` subclass that calls the reward function in (2).\n", "4) Launch some workers.\n", "5) Run the search script." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Database installation\n", "\n", "For instance, we can a lightweight database called Redis (https://redis.io/). Follow the instructions [here](https://redis.io/docs/latest/operate/oss_and_stack/install/install-redis/) to install the database and command line utility, and test your installation by running \n", "\n", "```\n", "redis-cli ping\n", "```\n", "\n", "If you are using a Redis server hosted somewhere other than the default location (`redis://localhost:6379/`), you can set the `CELERY_BROKER_URL` environment variable to point to your server." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Making a `celery` app with the reward function \n", "The app is a Python script that tells `celery` where the database is and which tasks it will be managing. For instance, here is an app script for the bistability design problem in the Getting Started tutorial." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "# bistability_app.py\n", "\n", "from celery import Celery\n", "from circuitree.models import SimpleNetworkGrammar\n", "import numpy as np\n", "import redis\n", "import os\n", "from bistability import get_bistability_reward\n", "\n", "# Address of the database Celery should use\n", "database_url = os.environ.get(\"CELERY_BROKER_URL\", \"redis://localhost:6379/0\")\n", "database = redis.Redis.from_url(database_url)\n", "if database.ping():\n", " print(f\"Connected to Redis database at {database_url}\")\n", "else:\n", " raise ConnectionError(f\"Could not connect to Redis database at {database_url}\")\n", "\n", "# Create the app\n", "app = Celery(\"bistability\", broker=database_url, backend=database_url)\n", "\n", "grammar = SimpleNetworkGrammar([\"A\", \"B\"], [\"activates\", \"inhibits\"])\n", "\n", "\n", "@app.task\n", "def get_reward_celery(state: str, seed: int, expensive: bool = False) -> float:\n", " \"\"\"Returns a reward value for the given state based on how many types of positive\n", " feedback loops (PFLs) it contains. Same as `BistabilityTree.get_reward()`,\n", " except this function is evaluated by a Celery worker.\"\"\"\n", "\n", " # Celery cannot pass Numpy random generators as arguments, so we pass a unique \n", " # integer and use it to seed a high-quality random generator\n", " hq_seed = np.random.SeedSequence(seed).generate_state(1)[0]\n", " rg = np.random.default_rng(hq_seed)\n", "\n", " return get_bistability_reward(state, grammar, rg, expensive)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use the `Celery` command to create an app that uses the `Redis` database to pass messages (the `broker` option) and store results (the `backend` argument). The URL here points to the default location for a local database (port `6379` on the `localhost` network). Any function with the `@app.task` decorator becomes a `celery` *task* that can be executed by a worker - we'll see how this looks in the next section." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. Calling the reward function as a `celery` task\n", "\n", "Unlike a normal function call, a call to a `celery` task is *asynchronous*. This means that when the main node calls the function, it dispatches a task to the workers, and the result can be requested later. This uses different syntax - instead of running `reward = get_reward(...)` directly, we run `result = get_reward_celery.delay(...)` to dispatch the task from the main node to the workers. This immediately returns an `AsyncResult` object that can be inspected to monitor progress. Then, once we need the result, we call `future.get()` and wait for the reward to arrive. While one thread is waiting for the reply, another thread can take over the main node and run a search iteration. \n", "\n", "All we need to do in this step is make a Python file declaring a new subclass of `CircuiTree` that uses the app. Here's what that looks like in our bistability example - we'll call it `bistability_parallel.py`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "# bistability_parallel.py\n", "\n", "from gevent import monkey\n", "\n", "monkey.patch_all()\n", "\n", "from bistability import BistabilityTree\n", "from bistability_app import get_reward_celery\n", "\n", "\n", "class ParallelBistabilityTree(BistabilityTree):\n", " \"\"\"This class is identical to BistabilityTree except that it uses Celery to compute\n", " rewards in parallel. This allows other threads to continue performing MCTS steps\n", " while one thread waits for its reward calculation to finish.\"\"\"\n", "\n", " def get_reward(self, state, expensive=True):\n", " # Generate a random seed and run the task in a Celery worker\n", " seed = int(self.rg.integers(0, 2**32))\n", " result = get_reward_celery.delay(state, seed, expensive=expensive)\n", " reward = result.get()\n", " return reward\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Rather than use the built-in `threading` module, which can only manage up to a few dozen threads, we will use the `gevent` module, which can support thousands. To achieve this `gevent` re-defines (\"monkey-patches\") many of the built-in Python commands in order to support highly scalable \"green threads.\" \n", "\n", "***WARNING:*** \n", "Monkey-patching can have some sharp corners when combined with Celery. The lines `from gevent import monkey` and `monkey.patch_all()` have to be the first lines in the file where we define the class, but they *cannot* be in the same file where we define the Celery app. For this reason, we make a separate file just for this class. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Launching a worker node" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can launch a worker node using `celery`'s command line interface. To do so, open a separate terminal, activate your virtual environment if you have one, `cd` to the folder with the app, and run the following command, replacing the `XX` with the number of CPUs to use. It's good practice to use one or two fewer CPUs than the total number on your machine, since performance can paradoxically degrade if you try to use every single CPU at once." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "# Launch a worker called 'worker1' with 'XX' CPUs, specifying the app with the '.app' suffix.\n", "# If you supply the 'logfile' flag, the worker will write its logs to there\n", "celery --app bistability_app.app multi start \"worker1\" --concurrency=XX --loglevel=INFO #--logfile=\"./worker1.log\"\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can specify logging information like the log level and location of the log file as shown. You can alos use the flag `--detach` to run the worker as a background process, but beware that **Celery will not monitor it**. You will need to find and kill the process yourself (on Linux, you can run `ps aux | grep 'celery'`, note the process ID (`pid`) of any running workers, and kill them with `sudo kill -9 {pid}`). " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5. Running a parallel search\n", "\n", "Now we can run the search in parallel by running a script from the main node. For the bistability example, you could run `python ./run_search_parallel.py`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "# run_search_parallel.py\n", "\n", "from datetime import datetime\n", "from pathlib import Path\n", "from bistability_parallel import ParallelBistabilityTree\n", "\n", "\n", "def main(\n", " n_steps: int = 10_000,\n", " n_threads: int = 8,\n", " expensive: bool = True,\n", " save_dir: str | Path = Path(\"./tree-backups\"),\n", "):\n", " \"\"\"Finds bistable circuit designs using parallel MCTS.\"\"\"\n", "\n", " # Make a folder for backups\n", " save_dir = Path(save_dir)\n", " save_dir.mkdir(exist_ok=True)\n", "\n", " print(\"Running an MCTS search in parallel (see tutorial notebook #2)...\")\n", " tree = ParallelBistabilityTree(root=\"ABC::\")\n", " tree.search_mcts_parallel(\n", " n_steps=n_steps, n_threads=n_threads, run_kwargs={\"expensive\": expensive}\n", " )\n", " print(\"Search complete!\")\n", "\n", " # Save the search graph to a GML file and the other attributes to a JSON file\n", " today = datetime.now().strftime(\"%y%m%d\")\n", " save_stem = save_dir.joinpath(f\"{today}_parallel_bistability_search_step{n_steps}\")\n", "\n", " print(f\"Saving final tree to {save_stem}.{{gml,json}}\")\n", " tree.to_file(save_stem + \".gml\", save_stem + \".json\")\n", "\n", " print(\"Done\")\n", "\n", "\n", "if __name__ == \"__main__\":\n", " main()\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's it! To analyze the results, we can read the object from the saved files using `CircuiTree.from_file()`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "from circuitree import SimpleNetworkGrammar\n", "from bistability import BistabilityTree\n", "\n", "# Get the file paths for the data\n", "data_dir = Path(\"./tree-backups\")\n", "gml_file = list(data_dir.glob(\"*parallel_bistability_search*.gml\"))[0]\n", "json_file = list(data_dir.glob(\"*parallel_bistability_search*.json\"))[0]\n", "\n", "# Read from file. Note that we need to specify the class of the grammar\n", "tree = BistabilityTree.from_file(gml_file, json_file, grammar_cls=SimpleNetworkGrammar)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Distributed search in the cloud\n", "\n", "The same framework we use to run the search in parallel on a local machine can be used to run a search across many machines, in the cloud! \n", "\n", "![Distributed-Infrastructure](./distributed_infrastructure.png)\n", "\n", "\n", "There are a few differences. Notably, the in-memory database now lives in a remote machine. The communication between main and worker nodes can be a bottleneck in scaling up this infrastructure, so it is important that your database has fast, high-bandwidth networking and is on the same network as the main and worker nodes. Most cloud providers already have a solution for this (for example, as of May 2024, Amazon Elasticache + EC2). The main node has the search graph and makes backups, so it will generally need higher memory, while worker nodes should have higher computing resources. Also, because Celery does not make it obvious where each task is run, you should take care that your backups are being saved to the correct location on the correct machine. Celery generally provides very robust logging as well, so be sure to specify the `--logfile` option in the `celery worker` command to take advantage of it. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": ".tutorial-venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 2 }