circuitree.models.DimerNetworkTree
- class circuitree.models.DimerNetworkTree(components: Iterable[str], regulators: Iterable[str], interactions: Iterable[str], max_interactions: int | None = None, max_interactions_per_promoter: int = 2, root: str | None = None, exploration_constant: float | None = None, seed: int = 2023, graph: DiGraph | None = None, tree_shape: Literal['tree', 'dag'] | None = None, compute_unique: bool = True, **kwargs)
A convenience subclass of CircuiTree that uses DimersGrammar by default.
See CircuiTree and DimersGrammar for more details.
- __init__(components: Iterable[str], regulators: Iterable[str], interactions: Iterable[str], max_interactions: int | None = None, max_interactions_per_promoter: int = 2, root: str | None = None, exploration_constant: float | None = None, seed: int = 2023, graph: DiGraph | None = None, tree_shape: Literal['tree', 'dag'] | None = None, compute_unique: bool = True, **kwargs)
Methods
__init__
(components, regulators, interactions)backpropagate_reward
(selection_path, reward)Update the reward for each node and edge in
selection_path
.backpropagate_visit
(selection_path)Increment the visit count for each node and edge in
selection_path
.bfs_iterator
([root, shuffle])Iterate over all terminal nodes in breadth-first (BFS) order starting from a root node.
Return a shallow copy of the search graph (the
graph
attribute).enumerate_terminal_states
([root, progress, ...])Enumerate all terminal states reachable from the given root state.
expand_edge
(parent, child)Expands the search graph by adding the child node and/or the parent-child edge to the search graph if they do not already exist.
from_file
(graph_gml, attrs_json[, ...])Load a CircuiTree object from a JSON file/string containing the object's attributes and an optional GML file/string of the search graph.
Convert the
graph
attribute to a GML formatted string.get_attributes
(attrs_copy)Returns a dictionary of the object's attributes.
get_random_terminal_descendant
(start[, rg])Sample a random terminal state by following random actions from a given start state.
get_reward
(state, **kwargs)Abstract method that calculates the reward for a given state.
get_ucb_score
(parent, child)Calculates the UCB score for a child node given its parent.
grow_tree
([root, n_visits, print_updates, ...])Exhaustively expand the search tree from a root node (not recommended for large spaces).
grow_tree_from_leaves
(leaves)Returns the tree (or DAG) of all paths that start at the root and ending at a node in
leaves
.is_success
(terminal_state)Determines whether a state represents a successful outcome in the search.
Sample a random successful state by first creating a new graph that contains all possible paths from the root to a successful terminal state.
Sample a random successful state with rejection sampling.
sample_terminal_states
(n_samples[, ...])Sample n_samples random terminal states from the grammar.
search_bfs
([n_steps, n_repeats, n_cycles, ...])Performs a breadth-first search (BFS) traversal of the search graph.
search_mcts
(n_steps[, callback_every, ...])Performs a Monte Carlo Tree Search (MCTS) traversal of circuit topology space.
search_mcts_parallel
(n_steps, n_threads[, ...])Performs a Monte Carlo Tree Search (MCTS) in parallel using multiple threads.
select_and_expand
([rg])Selects a path through the search graph using the UCB score.
test_pattern_significance
(patterns, n_samples)Test whether a pattern is successful by sampling random paths from the design space.
to_complexity_graph
([successes])Generates a directed acyclic graph (DAG) representing the search graph as a "complexity atlas".
to_file
(gml_file[, json_file, save_attrs, ...])Save the
CircuiTree
object to a gml file and optionally a JSON file containing the object's attributes.Return a a GML-formatted string of the
graph
attribute and a JSON-formatted string of the other serializable attributes.traverse
(**kwargs)Performs a single iteration of the MCTS algorithm.
Attributes
Default attributes for nodes and edges in the search graph.
Generator that yields terminal states in the search graph.
- backpropagate_reward(selection_path: list, reward: float | int)
Update the reward for each node and edge in
selection_path
.Notes:
The reward function (
get_reward
) is called afterbackpropagate_visit
and beforebackpropagate_reward
.In parallel mode, each node and edge in the selection path incurs “virtual loss” until the reward is computed and backpropagated. This is because the visit count is incremented before the actual reward is known.
- Parameters:
selection_path (list) – The list of nodes in the selection path.
reward (float | int) – The reward value to add to each node and edge.
- backpropagate_visit(selection_path: list) None
Increment the visit count for each node and edge in
selection_path
.Notes:
The reward function (
get_reward
) is called afterbackpropagate_visit
and beforebackpropagate_reward
.In parallel mode, each node and edge in the selection path incurs “virtual loss” until the reward is computed and backpropagated. This is because the visit count is incremented before the actual reward is known.
- Parameters:
selection_path (list) – The list of nodes in the selection path.
- bfs_iterator(root=None, shuffle=False) Iterator[Hashable]
Iterate over all terminal nodes in breadth-first (BFS) order starting from a root node.
- copy_graph() DiGraph
Return a shallow copy of the search graph (the
graph
attribute). Usecopy.deepcopy()
for a deep copy.
- enumerate_terminal_states(root: Hashable | None = None, progress: bool = False, max_iter: int | None = None) Iterable[Hashable]
Enumerate all terminal states reachable from the given root state.
- expand_edge(parent: Hashable, child: Hashable)
Expands the search graph by adding the child node and/or the parent-child edge to the search graph if they do not already exist.
- Parameters:
parent (Hashable) – The parent node of the edge.
child (Hashable) – The child node.
- classmethod from_file(graph_gml: str | Path | None, attrs_json: str | Path, grammar_cls: CircuitGrammar | None = None, grammar_kwargs: dict | None = None, **kwargs)
Load a CircuiTree object from a JSON file/string containing the object’s attributes and an optional GML file/string of the search graph.
These files are typically saved with the
to_file()
method.The grammar attribute is loaded by looking for a key “grammar” in the JSON file, whose value should be a dict
grammar_kwargs
used to create a grammar object. Thegrammar_cls
keyword argument can be used to directly specify the grammar class constructor. Alternatively, the JSON file can include a key “__grammar_cls__
” that specifies the class name string (which will be looked up inglobals()
).- Parameters:
graph_gml (str | Path | None, optional) – The path to the GML file or the serialized GML string (optional).
attrs_json (str | Path) – The path to the JSON file containing the object’s attributes or the serialized JSON string.
grammar_cls (Optional[CircuitGrammar], optional) – The grammar class constructor to use. If None, the class is loaded from the JSON file. Defaults to None.
grammar_kwargs (Optional[dict], optional) – Keyword arguments used to create the grammar object.
**kwargs – Keyword arguments to pass to the CircuiTree constructor.
- Returns:
A CircuiTree object loaded from the provided files or strings.
- Return type:
- generate_gml() str
Convert the
graph
attribute to a GML formatted string.- Returns:
The GML representation of
self.graph
.- Return type:
str
- get_attributes(attrs_copy: Iterable[str] | None) dict
Returns a dictionary of the object’s attributes.
This method allows controlled access to the object’s attributes, potentially excluding ones that should not be serialized when writing to a JSON file. For any attribute with a
to_dict()
method, the output of that method is used. Otherwise, the attribute is copied directly.- Parameters:
attrs_copy (Optional[Iterable[str]]) – An optional list of attribute names to include. If None, all attributes except those in
_non_serializable_attrs
are returned.- Raises:
ValueError – If
attrs_copy
contains attributes from_non_serializable_attrs
.- Returns:
A dictionary containing the requested attributes. Attributes with a
to_dict()
method are converted using that method, otherwise they are copied directly.- Return type:
dict
- get_random_terminal_descendant(start: Hashable, rg: Generator | None = None) Hashable
Sample a random terminal state by following random actions from a given start state.
- Parameters:
start (Hashable) – The starting state from which to begin the random walk.
rg (Optional[np.random.Generator], optional) – Defaults to None, in which
used. (case the rg attribute of the CircuiTree instance is)
- Returns:
The randomly sampled terminal state.
- Return type:
Hashable
- abstract get_reward(state: Hashable, **kwargs) float | int
Abstract method that calculates the reward for a given state.
This method must be implemented by a subclass of
CircuiTree
. It defines the interface for reward calculation within the framework.- Parameters:
state (Hashable) – The state for which to calculate the reward.
- Returns:
- The reward for the given state. Reward values can be
deterministic or stochastic. The range of possible outputs should be finite and ideally normalized to the range [0, 1].
- Return type:
float | int
- get_ucb_score(parent: Hashable, child: Hashable)
Calculates the UCB score for a child node given its parent.
- Parameters:
parent (Hashable) – The parent node.
child (Hashable) – The child node.
- Returns:
The UCB score of the child node.
- Return type:
float
- grow_tree(root: Hashable | None = None, n_visits: int = 0, print_updates: bool = False, print_every: int = 1000)
Exhaustively expand the search tree from a root node (not recommended for large spaces).
This method performs a depth-first search expansion of the search graph, adding all possible nodes and edges based on the grammar until no new states can be reached.
Warning: This method can be computationally expensive and memory-intensive for large search spaces.
- Parameters:
root (Hashable, optional) – The starting node for the expansion. Defaults to None.
n_visits (int, optional) – The initial visit count for all added nodes. Defaults to 0.
print_updates (bool, optional) – Whether to print information about the number of added nodes during growth. Defaults to False.
print_every (int, optional) – The frequency at which to print updates (if print_updates is True). Defaults to 1000.
- grow_tree_from_leaves(leaves: Iterable[Hashable]) DiGraph
Returns the tree (or DAG) of all paths that start at the root and ending at a node in
leaves
.
- is_success(terminal_state: Hashable) bool
Determines whether a state represents a successful outcome in the search.
Designed to be implemented in a subclass, since the definition of success depends on the specific search problem. It takes a terminal state as input, which represents a potential solution for the design problem.
- Parameters:
terminal_state (Hashable) – The state to evaluate for success.
- Raises:
NotImplementedError – This base class implementation is intended to be
overridden in subclasses. –
- Returns:
Whether the provided state represents a successful outcome (True) or not (False).
- Return type:
bool
- sample_successful_circuits_by_enumeration(n_samples: int, progress: bool = False, nprocs: int = 1, chunksize: int = 100) list[Hashable]
Sample a random successful state by first creating a new graph that contains all possible paths from the root to a successful terminal state. Then, sample paths by random traversal from the root.
- sample_successful_circuits_by_rejection(n_samples: int, max_iter: int = 10000000, progress: bool = False, nprocs: int = 1, chunksize: int = 100) list[Hashable]
Sample a random successful state with rejection sampling. Starts from the root state, selects random actions until termination, and accepts the sample if it is successful.
- sample_terminal_states(n_samples: int, progress: bool = False, nprocs: int = 1, chunksize: int = 100) list[Hashable]
Sample n_samples random terminal states from the grammar.
- search_bfs(n_steps: int | None = None, n_repeats: int | None = None, n_cycles: int | None = None, callback: Callable | None = None, callback_every: int = 1, shuffle: bool = False, progress: bool = False, run_kwargs: dict | None = None) None
Performs a breadth-first search (BFS) traversal of the search graph.
This method iterates over the search graph in a breadth-first manner, visiting nodes layer by layer. The number of iterations can be controlled using various parameters:
n_steps
: Stop the search after a fixed number of total iterations over all explored nodes (considering repeats and cycles).n_repeats
: For each node encountered during BFS traversal, repeat the visitn_repeats
times before moving on to the next node in the layer.n_cycles
: Repeat the entire BFS traversal n_cycles times.
- Parameters:
n_steps (Optional[int], optional) – Number of total iterations (repeats and cycles considered). Defaults to None.
n_repeats (Optional[int], optional) – Number of repeats per node. Defaults to None.
n_cycles (Optional[int], optional) – Number of BFS traversal cycles. Defaults to None.
callback (Optional[Callable], optional) –
A callback function to be executed at specific points during the search. Defaults to None. The callback is called with three arguments:
tree
: TheCircuiTree
instance calling the callback.node
: The current node being visited (value is None during initialization).reward
: The reward obtained from the node (value is None during initialization).
callback_every (int, optional) – How often to call the callback (in terms of iterations). Defaults to 1.
shuffle (bool, optional) – If True, shuffles the order of nodes within each BFS layer. Defaults to False.
progress (bool, optional) – If True, displays a progress bar during the search. Defaults to False.
run_kwargs (Optional[dict], optional) – A dictionary of additional keyword arguments passed to the
get_reward
method during node evaluation. Defaults to None.
- Raises:
ValueError – If both
n_steps
andn_cycles
are specified (exactly one should be provided).
- search_mcts(n_steps: int, callback_every: int = 1, callback: Callable | None = None, progress_bar: bool = False, run_kwargs: dict | None = None, callback_before_start: bool = True) None
Performs a Monte Carlo Tree Search (MCTS) traversal of circuit topology space.
This method implements the core MCTS algorithm for exploring and exploiting the search tree. It performs a sequence of
n_steps
iterations, each consisting of selection, expansion, simulation, and backpropagation steps. If provided, a callback function is called at specific points during the search. This can be used for various purposes:Logging search progress
Backing up intermediate results
Recording search statistics
Checking for convergence or early stopping conditions
- Parameters:
n_steps (int) – The total number of MCTS iterations to perform.
callback_every (int, optional) – How often to call the callback function (in terms of iterations). Defaults to 1 (every iteration).
callback (Optional[Callable], optional) –
A callback function to be executed at specific points during the search. Defaults to None. The callback is called with five arguments:
tree
: TheCircuiTree
instance calling the callback.step
: The current MCTS iteration (0-based index).path
: A list of nodes representing the selected path in the tree.sim_node
: The state used for the simulation step. Chosen by following random actions from the last node in the path until a terminal state is reached.reward
: The reward obtained from the simulation step.
progress_bar (bool, optional) – If True, displays a progress bar during the search. Defaults to False. Requires the
tqdm
package.run_kwargs (Optional[dict], optional) – A dictionary of additional keyword arguments passed to the
get_reward
method during node evaluation. Defaults to None.callback_before_start (bool, optional) – Whether to call the callback before starting the search. If so, the callback is called with
step = -1
. Defaults to True.
- Returns:
None
- search_mcts_parallel(n_steps: int, n_threads: int, callback: Callable | None = None, callback_every: int = 1, callback_before_start: bool = True, run_kwargs: dict | None = None, logger: Any | None = None) None
Performs a Monte Carlo Tree Search (MCTS) in parallel using multiple threads. This method leverages the
gevent
library (included withcircuitree[distributed]
) to execute the MCTS search algorithm across multiple execution threads on the same search graph.Key differences from
search_mcts
:This function utilizes multiple threads for parallel execution, whereas
search_mcts
runs sequentially on a single threadFor intended performance, reward computations should be performed by a separate pool of worker processes (see User Guide > Parallelization)
Requires the
gevent
library
- Parameters:
n_steps (int) – The total number of MCTS iterations to perform (divided among threads).
n_threads (int) – The number of threads to use for parallel MCTS. Must be at least 1.
callback (Optional[Callable], optional) – A callback function to be executed at specific points during the search. Defaults to None. (See
search_mcts
docstring for details).callback_every (int, optional) – How often to call the callback function (in terms of iterations). Defaults to 1 (every iteration).
callback_before_start (bool, optional) – Whether to call the callback before starting the search (step=-1). Defaults to True.
run_kwargs (Optional[dict], optional) – A dictionary of additional keyword arguments passed to the
get_reward
method during node evaluation. Defaults to None.logger (Optional[Any], optional) – A logger object to be used for logging messages during the search. Can be useful for monitoring progress. Defaults to None.
- Raises:
ImportError – If
gevent
is not installed.ValueError – If the number of threads is less than 1.
- select_and_expand(rg: Generator | None = None) list[Hashable]
Selects a path through the search graph using the UCB score. Adds the last node to the graph if it is not already present.
This method implements the core selection and expansion step of the UCB algorithm. It iteratively selects the child node with the highest UCB score until a terminal state or an unexpanded edge is encountered.
- Parameters:
rg (Optional[np.random.Generator], optional) – Random number generator. Defaults to None, in which case the
rg
attribute of theCircuiTree
instance is used.- Returns:
A list of states representing the selected path through the search graph.
- Return type:
list[Hashable]
- test_pattern_significance(patterns: Iterable[Any], n_samples: int, confidence: float | None = 0.95, correction: bool = True, progress: bool = False, null_samples: list[Hashable] | None = None, succ_samples: list[Hashable] | None = None, sampling_method: Literal['rejection', 'enumeration'] = 'rejection', nprocs_sampling: int = 1, nprocs_testing: int = 1, max_iter: int = 10000000, null_kwargs: dict | None = None, succ_kwargs: dict | None = None, barnard_ok: bool = True, exclude_self: bool = True) DataFrame
Test whether a pattern is successful by sampling random paths from the design space. Returns the contingency table (a Pandas DataFrame) containing test statistics and p-values.
Samples
n_samples
paths from the overall design space and uses rejection sampling to samplen_samples
paths that terminate in a successful circuit as determined by theis_success
method.if
exclude_self
is True, the pattern being tested is excluded from the null and successful samples. This is to properly evaluate the significance of rare patterns.
- to_complexity_graph(successes: bool | Iterable[Hashable] = True) DiGraph
Generates a directed acyclic graph (DAG) representing the search graph as a “complexity atlas”.
A complexity atlas [1] is a subgraph of the search graph that includes only certain terminal states (determined by the
successes
argument) and their parent nodes. The returned graph can be used to visualize the search space and identify clusters of topologically similar solutions that occur, for example, due to motifs.- Parameters:
successes (bool | Iterable[Hashable], optional) –
A flag or an iterable of states representing successful terminal states.
If
True
(default), theis_success
method is used to identify the successful states, and only these are included.If
False
, all terminal states are included.If an iterable, the states in the iterable are included.
- Raises:
ValueError – If an invalid value is provided for
successes
.NotImplementedError – If
successes
isTrue
andis_success
is not implemented. The intended usage is to subclassCircuiTree
and implement theis_success
method.
- Returns:
A directed acyclic graph representing the complexity atlas.
- Return type:
nx.DiGraph
References
- to_file(gml_file: str | Path, json_file: str | Path | None = None, save_attrs: Iterable[str] | None = None, compress: bool = False, **kwargs)
Save the
CircuiTree
object to a gml file and optionally a JSON file containing the object’s attributes.This method allows saving the CircuiTree object to disk in a GML format for the graph and optionally a JSON format for the serializable attributes. The saved object can be loaded later using the
from_file
class method (seeCircuiTree.from_file
).- Parameters:
gml_file (str | Path) – The path to the GML file as a
str
orPath
object.json_file (Optional[str | Path], optional) – The path to the optional JSON file for saving other attributes. Defaults to None.
save_attrs (Optional[Iterable[str]], optional) – An optional list of attribute names to include in the JSON file. If None, all attributes except those in
_non_serializable_attrs
are saved. Defaults to None.compress (bool, optional) – If True, the GML file will be compressed with gzip. Defaults to False.
**kwargs – Additional keyword arguments passed to
networkx.write_gml
.
- Returns:
Returns the path to the saved GML file and, optionally, the JSON file.
- Return type:
Path | Tuple[Path, Path]
- to_string() tuple[str, str]
Return a a GML-formatted string of the
graph
attribute and a JSON-formatted string of the other serializable attributes.
- traverse(**kwargs) tuple[list[Hashable], float | int, Hashable]
Performs a single iteration of the MCTS algorithm.
This method implements the core traversal step of the UCB algorithm. It selects a path through the search tree using UCB scores, expands the tree if necessary, obtains a reward estimate by simulating a random downstream terminal state, and updates the search graph based on the reward value.
Notes:
The reward function (
get_reward
) is called afterbackpropagate_visit
and beforebackpropagate_reward
.In parallel mode, each node and edge in the selection path incurs “virtual loss” until the reward is computed and backpropagated. This is because the visit count is incremented before the actual reward is known.
- Parameters:
**kwargs – Additional keyword arguments passed to the :func:
get_reward
method.- Returns:
selection_path (list): The selected path through the search graph.
reward (float | int): The reward obtained from the simulation.
sim_node (Hashable): The simulated terminal state.
- Return type:
tuple[list[Hashable], float | int, Hashable]
- property default_attrs
Default attributes for nodes and edges in the search graph.
- property terminal_states: Iterator[Hashable]
Generator that yields terminal states in the search graph.
This property provides an iterator that traverses the terminal states present in the search graph associated with the
CircuiTree
instance.- Yields:
Hashable – Each element returned by the iterator represents a terminal state within the search graph (see
CircuitGrammar.is_terminal
).
Example:
# Find terminal states with mean reward > 0.5 for state in tree.terminal_states: if tree.graph.nodes[state]["reward"] / tree.graph.nodes[state]["visits"] > 0.5: print(state)