circuitree.models.SimpleNetworkTree

class circuitree.models.SimpleNetworkTree(grammar: SimpleNetworkGrammar | None = None, components: Iterable[Iterable[str]] | None = None, interactions: Iterable[str] | None = None, max_interactions: int | None = None, root: str | None = None, exploration_constant: float | None = None, seed: int = 2023, graph: DiGraph | None = None, tree_shape: Literal['tree', 'dag'] | None = None, compute_unique: bool = True, fixed_components: list[str] | None = None, **kwargs)

A convenience subclass of CircuiTree that uses SimpleNetworkGrammar by default.

See CircuiTree and SimpleNetworkGrammar for more details.

__init__(grammar: SimpleNetworkGrammar | None = None, components: Iterable[Iterable[str]] | None = None, interactions: Iterable[str] | None = None, max_interactions: int | None = None, root: str | None = None, exploration_constant: float | None = None, seed: int = 2023, graph: DiGraph | None = None, tree_shape: Literal['tree', 'dag'] | None = None, compute_unique: bool = True, fixed_components: list[str] | None = None, **kwargs)

Methods

__init__([grammar, components, ...])

backpropagate_reward(selection_path, reward)

Update the reward for each node and edge in selection_path.

backpropagate_visit(selection_path)

Increment the visit count for each node and edge in selection_path.

bfs_iterator([root, shuffle])

Iterate over all terminal nodes in breadth-first (BFS) order starting from a root node.

copy_graph()

Return a shallow copy of the search graph (the graph attribute).

enumerate_terminal_states([root, progress, ...])

Enumerate all terminal states reachable from the given root state.

expand_edge(parent, child)

Expands the search graph by adding the child node and/or the parent-child edge to the search graph if they do not already exist.

from_file(graph_gml, attrs_json[, ...])

Load a CircuiTree object from a JSON file/string containing the object's attributes and an optional GML file/string of the search graph.

generate_gml()

Convert the graph attribute to a GML formatted string.

get_attributes(attrs_copy)

Returns a dictionary of the object's attributes.

get_random_terminal_descendant(start[, rg])

Sample a random terminal state by following random actions from a given start state.

get_reward(state, **kwargs)

Abstract method that calculates the reward for a given state.

get_ucb_score(parent, child)

Calculates the UCB score for a child node given its parent.

grow_tree([root, n_visits, print_updates, ...])

Exhaustively expand the search tree from a root node (not recommended for large spaces).

grow_tree_from_leaves(leaves)

Returns the tree (or DAG) of all paths that start at the root and ending at a node in leaves.

is_success(terminal_state)

Determines whether a state represents a successful outcome in the search.

sample_successful_circuits_by_enumeration(...)

Sample a random successful state by first creating a new graph that contains all possible paths from the root to a successful terminal state.

sample_successful_circuits_by_rejection(...)

Sample a random successful state with rejection sampling.

sample_terminal_states(n_samples[, ...])

Sample n_samples random terminal states from the grammar.

search_bfs([n_steps, n_repeats, n_cycles, ...])

Performs a breadth-first search (BFS) traversal of the search graph.

search_mcts(n_steps[, callback_every, ...])

Performs a Monte Carlo Tree Search (MCTS) traversal of circuit topology space.

search_mcts_parallel(n_steps, n_threads[, ...])

Performs a Monte Carlo Tree Search (MCTS) in parallel using multiple threads.

select_and_expand([rg])

Selects a path through the search graph using the UCB score.

test_pattern_significance(patterns, n_samples)

Test whether a pattern is successful by sampling random paths from the design space.

to_complexity_graph([successes])

Generates a directed acyclic graph (DAG) representing the search graph as a "complexity atlas".

to_file(gml_file[, json_file, save_attrs, ...])

Save the CircuiTree object to a gml file and optionally a JSON file containing the object's attributes.

to_string()

Return a a GML-formatted string of the graph attribute and a JSON-formatted string of the other serializable attributes.

traverse(**kwargs)

Performs a single iteration of the MCTS algorithm.

Attributes

default_attrs

Default attributes for nodes and edges in the search graph.

terminal_states

Generator that yields terminal states in the search graph.

backpropagate_reward(selection_path: list, reward: float | int)

Update the reward for each node and edge in selection_path.

Notes:

  • The reward function (get_reward) is called after backpropagate_visit and before backpropagate_reward.

  • In parallel mode, each node and edge in the selection path incurs “virtual loss” until the reward is computed and backpropagated. This is because the visit count is incremented before the actual reward is known.

Parameters:
  • selection_path (list) – The list of nodes in the selection path.

  • reward (float | int) – The reward value to add to each node and edge.

backpropagate_visit(selection_path: list) None

Increment the visit count for each node and edge in selection_path.

Notes:

  • The reward function (get_reward) is called after backpropagate_visit and before backpropagate_reward.

  • In parallel mode, each node and edge in the selection path incurs “virtual loss” until the reward is computed and backpropagated. This is because the visit count is incremented before the actual reward is known.

Parameters:

selection_path (list) – The list of nodes in the selection path.

bfs_iterator(root=None, shuffle=False) Iterator[Hashable]

Iterate over all terminal nodes in breadth-first (BFS) order starting from a root node.

copy_graph() DiGraph

Return a shallow copy of the search graph (the graph attribute). Use copy.deepcopy() for a deep copy.

enumerate_terminal_states(root: Hashable | None = None, progress: bool = False, max_iter: int | None = None) Iterable[Hashable]

Enumerate all terminal states reachable from the given root state.

expand_edge(parent: Hashable, child: Hashable)

Expands the search graph by adding the child node and/or the parent-child edge to the search graph if they do not already exist.

Parameters:
  • parent (Hashable) – The parent node of the edge.

  • child (Hashable) – The child node.

classmethod from_file(graph_gml: str | Path | None, attrs_json: str | Path, grammar_cls: CircuitGrammar | None = None, grammar_kwargs: dict | None = None, **kwargs)

Load a CircuiTree object from a JSON file/string containing the object’s attributes and an optional GML file/string of the search graph.

These files are typically saved with the to_file() method.

The grammar attribute is loaded by looking for a key “grammar” in the JSON file, whose value should be a dict grammar_kwargs used to create a grammar object. The grammar_cls keyword argument can be used to directly specify the grammar class constructor. Alternatively, the JSON file can include a key “__grammar_cls__” that specifies the class name string (which will be looked up in globals()).

Parameters:
  • graph_gml (str | Path | None, optional) – The path to the GML file or the serialized GML string (optional).

  • attrs_json (str | Path) – The path to the JSON file containing the object’s attributes or the serialized JSON string.

  • grammar_cls (Optional[CircuitGrammar], optional) – The grammar class constructor to use. If None, the class is loaded from the JSON file. Defaults to None.

  • grammar_kwargs (Optional[dict], optional) – Keyword arguments used to create the grammar object.

  • **kwargs – Keyword arguments to pass to the CircuiTree constructor.

Returns:

A CircuiTree object loaded from the provided files or strings.

Return type:

CircuiTree

generate_gml() str

Convert the graph attribute to a GML formatted string.

Returns:

The GML representation of self.graph.

Return type:

str

get_attributes(attrs_copy: Iterable[str] | None) dict

Returns a dictionary of the object’s attributes.

This method allows controlled access to the object’s attributes, potentially excluding ones that should not be serialized when writing to a JSON file. For any attribute with a to_dict() method, the output of that method is used. Otherwise, the attribute is copied directly.

Parameters:

attrs_copy (Optional[Iterable[str]]) – An optional list of attribute names to include. If None, all attributes except those in _non_serializable_attrs are returned.

Raises:

ValueError – If attrs_copy contains attributes from _non_serializable_attrs.

Returns:

A dictionary containing the requested attributes. Attributes with a to_dict() method are converted using that method, otherwise they are copied directly.

Return type:

dict

get_random_terminal_descendant(start: Hashable, rg: Generator | None = None) Hashable

Sample a random terminal state by following random actions from a given start state.

Parameters:
  • start (Hashable) – The starting state from which to begin the random walk.

  • rg (Optional[np.random.Generator], optional) – Defaults to None, in which

  • used. (case the rg attribute of the CircuiTree instance is)

Returns:

The randomly sampled terminal state.

Return type:

Hashable

abstract get_reward(state: Hashable, **kwargs) float | int

Abstract method that calculates the reward for a given state.

This method must be implemented by a subclass of CircuiTree. It defines the interface for reward calculation within the framework.

Parameters:

state (Hashable) – The state for which to calculate the reward.

Returns:

The reward for the given state. Reward values can be

deterministic or stochastic. The range of possible outputs should be finite and ideally normalized to the range [0, 1].

Return type:

float | int

get_ucb_score(parent: Hashable, child: Hashable)

Calculates the UCB score for a child node given its parent.

Parameters:
  • parent (Hashable) – The parent node.

  • child (Hashable) – The child node.

Returns:

The UCB score of the child node.

Return type:

float

grow_tree(root: Hashable | None = None, n_visits: int = 0, print_updates: bool = False, print_every: int = 1000)

Exhaustively expand the search tree from a root node (not recommended for large spaces).

This method performs a depth-first search expansion of the search graph, adding all possible nodes and edges based on the grammar until no new states can be reached.

Warning: This method can be computationally expensive and memory-intensive for large search spaces.

Parameters:
  • root (Hashable, optional) – The starting node for the expansion. Defaults to None.

  • n_visits (int, optional) – The initial visit count for all added nodes. Defaults to 0.

  • print_updates (bool, optional) – Whether to print information about the number of added nodes during growth. Defaults to False.

  • print_every (int, optional) – The frequency at which to print updates (if print_updates is True). Defaults to 1000.

grow_tree_from_leaves(leaves: Iterable[Hashable]) DiGraph

Returns the tree (or DAG) of all paths that start at the root and ending at a node in leaves.

is_success(terminal_state: Hashable) bool

Determines whether a state represents a successful outcome in the search.

Designed to be implemented in a subclass, since the definition of success depends on the specific search problem. It takes a terminal state as input, which represents a potential solution for the design problem.

Parameters:

terminal_state (Hashable) – The state to evaluate for success.

Raises:
  • NotImplementedError – This base class implementation is intended to be

  • overridden in subclasses.

Returns:

Whether the provided state represents a successful outcome (True) or not (False).

Return type:

bool

sample_successful_circuits_by_enumeration(n_samples: int, progress: bool = False, nprocs: int = 1, chunksize: int = 100) list[Hashable]

Sample a random successful state by first creating a new graph that contains all possible paths from the root to a successful terminal state. Then, sample paths by random traversal from the root.

sample_successful_circuits_by_rejection(n_samples: int, max_iter: int = 10000000, progress: bool = False, nprocs: int = 1, chunksize: int = 100) list[Hashable]

Sample a random successful state with rejection sampling. Starts from the root state, selects random actions until termination, and accepts the sample if it is successful.

sample_terminal_states(n_samples: int, progress: bool = False, nprocs: int = 1, chunksize: int = 100) list[Hashable]

Sample n_samples random terminal states from the grammar.

search_bfs(n_steps: int | None = None, n_repeats: int | None = None, n_cycles: int | None = None, callback: Callable | None = None, callback_every: int = 1, shuffle: bool = False, progress: bool = False, run_kwargs: dict | None = None) None

Performs a breadth-first search (BFS) traversal of the search graph.

This method iterates over the search graph in a breadth-first manner, visiting nodes layer by layer. The number of iterations can be controlled using various parameters:

  • n_steps: Stop the search after a fixed number of total iterations over all explored nodes (considering repeats and cycles).

  • n_repeats: For each node encountered during BFS traversal, repeat the visit n_repeats times before moving on to the next node in the layer.

  • n_cycles: Repeat the entire BFS traversal n_cycles times.

Parameters:
  • n_steps (Optional[int], optional) – Number of total iterations (repeats and cycles considered). Defaults to None.

  • n_repeats (Optional[int], optional) – Number of repeats per node. Defaults to None.

  • n_cycles (Optional[int], optional) – Number of BFS traversal cycles. Defaults to None.

  • callback (Optional[Callable], optional) –

    A callback function to be executed at specific points during the search. Defaults to None. The callback is called with three arguments:

    • tree: The CircuiTree instance calling the callback.

    • node: The current node being visited (value is None during initialization).

    • reward: The reward obtained from the node (value is None during initialization).

  • callback_every (int, optional) – How often to call the callback (in terms of iterations). Defaults to 1.

  • shuffle (bool, optional) – If True, shuffles the order of nodes within each BFS layer. Defaults to False.

  • progress (bool, optional) – If True, displays a progress bar during the search. Defaults to False.

  • run_kwargs (Optional[dict], optional) – A dictionary of additional keyword arguments passed to the get_reward method during node evaluation. Defaults to None.

Raises:

ValueError – If both n_steps and n_cycles are specified (exactly one should be provided).

search_mcts(n_steps: int, callback_every: int = 1, callback: Callable | None = None, progress_bar: bool = False, run_kwargs: dict | None = None, callback_before_start: bool = True) None

Performs a Monte Carlo Tree Search (MCTS) traversal of circuit topology space.

This method implements the core MCTS algorithm for exploring and exploiting the search tree. It performs a sequence of n_steps iterations, each consisting of selection, expansion, simulation, and backpropagation steps. If provided, a callback function is called at specific points during the search. This can be used for various purposes:

  • Logging search progress

  • Backing up intermediate results

  • Recording search statistics

  • Checking for convergence or early stopping conditions

Parameters:
  • n_steps (int) – The total number of MCTS iterations to perform.

  • callback_every (int, optional) – How often to call the callback function (in terms of iterations). Defaults to 1 (every iteration).

  • callback (Optional[Callable], optional) –

    A callback function to be executed at specific points during the search. Defaults to None. The callback is called with five arguments:

    • tree: The CircuiTree instance calling the callback.

    • step: The current MCTS iteration (0-based index).

    • path: A list of nodes representing the selected path in the tree.

    • sim_node: The state used for the simulation step. Chosen by following random actions from the last node in the path until a terminal state is reached.

    • reward: The reward obtained from the simulation step.

  • progress_bar (bool, optional) – If True, displays a progress bar during the search. Defaults to False. Requires the tqdm package.

  • run_kwargs (Optional[dict], optional) – A dictionary of additional keyword arguments passed to the get_reward method during node evaluation. Defaults to None.

  • callback_before_start (bool, optional) – Whether to call the callback before starting the search. If so, the callback is called with step = -1. Defaults to True.

Returns:

None

search_mcts_parallel(n_steps: int, n_threads: int, callback: Callable | None = None, callback_every: int = 1, callback_before_start: bool = True, run_kwargs: dict | None = None, logger: Any | None = None) None

Performs a Monte Carlo Tree Search (MCTS) in parallel using multiple threads. This method leverages the gevent library (included with circuitree[distributed]) to execute the MCTS search algorithm across multiple execution threads on the same search graph.

Key differences from search_mcts:

  • This function utilizes multiple threads for parallel execution, whereas search_mcts runs sequentially on a single thread

  • For intended performance, reward computations should be performed by a separate pool of worker processes (see User Guide > Parallelization)

  • Requires the gevent library

Parameters:
  • n_steps (int) – The total number of MCTS iterations to perform (divided among threads).

  • n_threads (int) – The number of threads to use for parallel MCTS. Must be at least 1.

  • callback (Optional[Callable], optional) – A callback function to be executed at specific points during the search. Defaults to None. (See search_mcts docstring for details).

  • callback_every (int, optional) – How often to call the callback function (in terms of iterations). Defaults to 1 (every iteration).

  • callback_before_start (bool, optional) – Whether to call the callback before starting the search (step=-1). Defaults to True.

  • run_kwargs (Optional[dict], optional) – A dictionary of additional keyword arguments passed to the get_reward method during node evaluation. Defaults to None.

  • logger (Optional[Any], optional) – A logger object to be used for logging messages during the search. Can be useful for monitoring progress. Defaults to None.

Raises:
  • ImportError – If gevent is not installed.

  • ValueError – If the number of threads is less than 1.

select_and_expand(rg: Generator | None = None) list[Hashable]

Selects a path through the search graph using the UCB score. Adds the last node to the graph if it is not already present.

This method implements the core selection and expansion step of the UCB algorithm. It iteratively selects the child node with the highest UCB score until a terminal state or an unexpanded edge is encountered.

Parameters:

rg (Optional[np.random.Generator], optional) – Random number generator. Defaults to None, in which case the rg attribute of the CircuiTree instance is used.

Returns:

A list of states representing the selected path through the search graph.

Return type:

list[Hashable]

test_pattern_significance(patterns: Iterable[Any], n_samples: int, confidence: float | None = 0.95, correction: bool = True, progress: bool = False, null_samples: list[Hashable] | None = None, succ_samples: list[Hashable] | None = None, sampling_method: Literal['rejection', 'enumeration'] = 'rejection', nprocs_sampling: int = 1, nprocs_testing: int = 1, max_iter: int = 10000000, null_kwargs: dict | None = None, succ_kwargs: dict | None = None, barnard_ok: bool = True, exclude_self: bool = True) DataFrame

Test whether a pattern is successful by sampling random paths from the design space. Returns the contingency table (a Pandas DataFrame) containing test statistics and p-values.

Samples n_samples paths from the overall design space and uses rejection sampling to sample n_samples paths that terminate in a successful circuit as determined by the is_success method.

if exclude_self is True, the pattern being tested is excluded from the null and successful samples. This is to properly evaluate the significance of rare patterns.

to_complexity_graph(successes: bool | Iterable[Hashable] = True) DiGraph

Generates a directed acyclic graph (DAG) representing the search graph as a “complexity atlas”.

A complexity atlas [1] is a subgraph of the search graph that includes only certain terminal states (determined by the successes argument) and their parent nodes. The returned graph can be used to visualize the search space and identify clusters of topologically similar solutions that occur, for example, due to motifs.

Parameters:

successes (bool | Iterable[Hashable], optional) –

A flag or an iterable of states representing successful terminal states.

  • If True (default), the is_success method is used to identify the successful states, and only these are included.

  • If False, all terminal states are included.

  • If an iterable, the states in the iterable are included.

Raises:
  • ValueError – If an invalid value is provided for successes.

  • NotImplementedError – If successes is True and is_success is not implemented. The intended usage is to subclass CircuiTree and implement the is_success method.

Returns:

A directed acyclic graph representing the complexity atlas.

Return type:

nx.DiGraph

References

to_file(gml_file: str | Path, json_file: str | Path | None = None, save_attrs: Iterable[str] | None = None, compress: bool = False, **kwargs)

Save the CircuiTree object to a gml file and optionally a JSON file containing the object’s attributes.

This method allows saving the CircuiTree object to disk in a GML format for the graph and optionally a JSON format for the serializable attributes. The saved object can be loaded later using the from_file class method (see CircuiTree.from_file).

Parameters:
  • gml_file (str | Path) – The path to the GML file as a str or Path object.

  • json_file (Optional[str | Path], optional) – The path to the optional JSON file for saving other attributes. Defaults to None.

  • save_attrs (Optional[Iterable[str]], optional) – An optional list of attribute names to include in the JSON file. If None, all attributes except those in _non_serializable_attrs are saved. Defaults to None.

  • compress (bool, optional) – If True, the GML file will be compressed with gzip. Defaults to False.

  • **kwargs – Additional keyword arguments passed to networkx.write_gml.

Returns:

Returns the path to the saved GML file and, optionally, the JSON file.

Return type:

Path | Tuple[Path, Path]

to_string() tuple[str, str]

Return a a GML-formatted string of the graph attribute and a JSON-formatted string of the other serializable attributes.

traverse(**kwargs) tuple[list[Hashable], float | int, Hashable]

Performs a single iteration of the MCTS algorithm.

This method implements the core traversal step of the UCB algorithm. It selects a path through the search tree using UCB scores, expands the tree if necessary, obtains a reward estimate by simulating a random downstream terminal state, and updates the search graph based on the reward value.

Notes:

  • The reward function (get_reward) is called after backpropagate_visit and before backpropagate_reward.

  • In parallel mode, each node and edge in the selection path incurs “virtual loss” until the reward is computed and backpropagated. This is because the visit count is incremented before the actual reward is known.

Parameters:

**kwargs – Additional keyword arguments passed to the :func:get_reward method.

Returns:

  • selection_path (list): The selected path through the search graph.

  • reward (float | int): The reward obtained from the simulation.

  • sim_node (Hashable): The simulated terminal state.

Return type:

tuple[list[Hashable], float | int, Hashable]

property default_attrs

Default attributes for nodes and edges in the search graph.

property terminal_states: Iterator[Hashable]

Generator that yields terminal states in the search graph.

This property provides an iterator that traverses the terminal states present in the search graph associated with the CircuiTree instance.

Yields:

Hashable – Each element returned by the iterator represents a terminal state within the search graph (see CircuitGrammar.is_terminal).

Example:

# Find terminal states with mean reward > 0.5
for state in tree.terminal_states:
    if tree.graph.nodes[state]["reward"] / tree.graph.nodes[state]["visits"] > 0.5:
        print(state)