1414import warnings
1515
1616from collections import defaultdict
17- from collections .abc import Iterable , Sequence
17+ from collections .abc import Callable , Iterable , Sequence
18+ from enum import Enum
1819from os import path
20+ from typing import Any
1921
2022from pytensor import function
2123from pytensor .graph import Apply
@@ -41,6 +43,119 @@ def fast_eval(var):
4143 return function ([], var , mode = "FAST_COMPILE" )()
4244
4345
46+ class NodeType (str , Enum ):
47+ """Enum for the types of nodes in the graph."""
48+
49+ POTENTIAL = "Potential"
50+ FREE_RV = "Free Random Variable"
51+ OBSERVED_RV = "Observed Random Variable"
52+ DETERMINISTIC = "Deterministic"
53+ DATA = "Data"
54+
55+
56+ GraphvizNodeKwargs = dict [str , Any ]
57+ NodeFormatter = Callable [[TensorVariable ], GraphvizNodeKwargs ]
58+
59+
60+ def default_potential (var : TensorVariable ) -> GraphvizNodeKwargs :
61+ """Default data for potential in the graph."""
62+ return {
63+ "shape" : "octagon" ,
64+ "style" : "filled" ,
65+ "label" : f"{ var .name } \n ~\n Potential" ,
66+ }
67+
68+
69+ def random_variable_symbol (var : TensorVariable ) -> str :
70+ """Get the symbol of the random variable."""
71+ symbol = var .owner .op .__class__ .__name__
72+
73+ if symbol .endswith ("RV" ):
74+ symbol = symbol [:- 2 ]
75+
76+ return symbol
77+
78+
79+ def default_free_rv (var : TensorVariable ) -> GraphvizNodeKwargs :
80+ """Default data for free RV in the graph."""
81+ symbol = random_variable_symbol (var )
82+
83+ return {
84+ "shape" : "ellipse" ,
85+ "style" : None ,
86+ "label" : f"{ var .name } \n ~\n { symbol } " ,
87+ }
88+
89+
90+ def default_observed_rv (var : TensorVariable ) -> GraphvizNodeKwargs :
91+ """Default data for observed RV in the graph."""
92+ symbol = random_variable_symbol (var )
93+
94+ return {
95+ "shape" : "ellipse" ,
96+ "style" : "filled" ,
97+ "label" : f"{ var .name } \n ~\n { symbol } " ,
98+ }
99+
100+
101+ def default_deterministic (var : TensorVariable ) -> GraphvizNodeKwargs :
102+ """Default data for the deterministic in the graph."""
103+ return {
104+ "shape" : "box" ,
105+ "style" : None ,
106+ "label" : f"{ var .name } \n ~\n Deterministic" ,
107+ }
108+
109+
110+ def default_data (var : TensorVariable ) -> GraphvizNodeKwargs :
111+ """Default data for the data in the graph."""
112+ return {
113+ "shape" : "box" ,
114+ "style" : "rounded, filled" ,
115+ "label" : f"{ var .name } \n ~\n Data" ,
116+ }
117+
118+
119+ def get_node_type (var_name : VarName , model ) -> NodeType :
120+ """Return the node type of the variable in the model."""
121+ v = model [var_name ]
122+
123+ if v in model .deterministics :
124+ return NodeType .DETERMINISTIC
125+ elif v in model .free_RVs :
126+ return NodeType .FREE_RV
127+ elif v in model .observed_RVs :
128+ return NodeType .OBSERVED_RV
129+ elif v in model .data_vars :
130+ return NodeType .DATA
131+ else :
132+ return NodeType .POTENTIAL
133+
134+
135+ NodeTypeFormatterMapping = dict [NodeType , NodeFormatter ]
136+
137+ DEFAULT_NODE_FORMATTERS : NodeTypeFormatterMapping = {
138+ NodeType .POTENTIAL : default_potential ,
139+ NodeType .FREE_RV : default_free_rv ,
140+ NodeType .OBSERVED_RV : default_observed_rv ,
141+ NodeType .DETERMINISTIC : default_deterministic ,
142+ NodeType .DATA : default_data ,
143+ }
144+
145+
146+ def update_node_formatters (node_formatters : NodeTypeFormatterMapping ) -> NodeTypeFormatterMapping :
147+ node_formatters = {** DEFAULT_NODE_FORMATTERS , ** node_formatters }
148+
149+ unknown_keys = set (node_formatters .keys ()) - set (NodeType )
150+ if unknown_keys :
151+ raise ValueError (
152+ f"Node formatters must be of type NodeType. Found: { list (unknown_keys )} ."
153+ f" Please use one of { [node_type .value for node_type in NodeType ]} ."
154+ )
155+
156+ return node_formatters
157+
158+
44159class ModelGraph :
45160 def __init__ (self , model ):
46161 self .model = model
@@ -148,42 +263,23 @@ def make_compute_graph(
148263
149264 return input_map
150265
151- def _make_node (self , var_name , graph , * , nx = False , cluster = False , formatting : str = "plain" ):
266+ def _make_node (
267+ self ,
268+ var_name ,
269+ graph ,
270+ * ,
271+ node_formatters : NodeTypeFormatterMapping ,
272+ nx = False ,
273+ cluster = False ,
274+ formatting : str = "plain" ,
275+ ):
152276 """Attaches the given variable to a graphviz or networkx Digraph"""
153277 v = self .model [var_name ]
154278
155- shape = None
156- style = None
157- label = str (v )
158-
159- if v in self .model .potentials :
160- shape = "octagon"
161- style = "filled"
162- label = f"{ var_name } \n ~\n Potential"
163- elif v in self .model .basic_RVs :
164- shape = "ellipse"
165- if v in self .model .observed_RVs :
166- style = "filled"
167- else :
168- style = None
169- symbol = v .owner .op .__class__ .__name__
170- if symbol .endswith ("RV" ):
171- symbol = symbol [:- 2 ]
172- label = f"{ var_name } \n ~\n { symbol } "
173- elif v in self .model .deterministics :
174- shape = "box"
175- style = None
176- label = f"{ var_name } \n ~\n Deterministic"
177- else :
178- shape = "box"
179- style = "rounded, filled"
180- label = f"{ var_name } \n ~\n Data"
181-
182- kwargs = {
183- "shape" : shape ,
184- "style" : style ,
185- "label" : label ,
186- }
279+ node_type = get_node_type (var_name , self .model )
280+ node_formatter = node_formatters [node_type ]
281+
282+ kwargs = node_formatter (v )
187283
188284 if cluster :
189285 kwargs ["cluster" ] = cluster
@@ -240,6 +336,7 @@ def make_graph(
240336 save = None ,
241337 figsize = None ,
242338 dpi = 300 ,
339+ node_formatters : NodeTypeFormatterMapping | None = None ,
243340 ):
244341 """Make graphviz Digraph of PyMC model
245342
@@ -255,18 +352,26 @@ def make_graph(
255352 "The easiest way to install all of this is by running\n \n "
256353 "\t conda install -c conda-forge python-graphviz"
257354 )
355+
356+ node_formatters = node_formatters or {}
357+ node_formatters = update_node_formatters (node_formatters )
358+
258359 graph = graphviz .Digraph (self .model .name )
259360 for plate_label , all_var_names in self .get_plates (var_names ).items ():
260361 if plate_label :
261362 # must be preceded by 'cluster' to get a box around it
262363 with graph .subgraph (name = "cluster" + plate_label ) as sub :
263364 for var_name in all_var_names :
264- self ._make_node (var_name , sub , formatting = formatting )
365+ self ._make_node (
366+ var_name , sub , formatting = formatting , node_formatters = node_formatters
367+ )
265368 # plate label goes bottom right
266369 sub .attr (label = plate_label , labeljust = "r" , labelloc = "b" , style = "rounded" )
267370 else :
268371 for var_name in all_var_names :
269- self ._make_node (var_name , graph , formatting = formatting )
372+ self ._make_node (
373+ var_name , graph , formatting = formatting , node_formatters = node_formatters
374+ )
270375
271376 for child , parents in self .make_compute_graph (var_names = var_names ).items ():
272377 # parents is a set of rv names that precede child rv nodes
@@ -287,7 +392,12 @@ def make_graph(
287392
288393 return graph
289394
290- def make_networkx (self , var_names : Iterable [VarName ] | None = None , formatting : str = "plain" ):
395+ def make_networkx (
396+ self ,
397+ var_names : Iterable [VarName ] | None = None ,
398+ formatting : str = "plain" ,
399+ node_formatters : NodeTypeFormatterMapping | None = None ,
400+ ):
291401 """Make networkx Digraph of PyMC model
292402
293403 Returns
@@ -302,6 +412,10 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting:
302412 "The easiest way to install all of this is by running\n \n "
303413 "\t conda install networkx"
304414 )
415+
416+ node_formatters = node_formatters or {}
417+ node_formatters = update_node_formatters (node_formatters )
418+
305419 graphnetwork = networkx .DiGraph (name = self .model .name )
306420 for plate_label , all_var_names in self .get_plates (var_names ).items ():
307421 if plate_label :
@@ -314,6 +428,7 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting:
314428 var_name ,
315429 subgraphnetwork ,
316430 nx = True ,
431+ node_formatters = node_formatters ,
317432 cluster = "cluster" + plate_label ,
318433 formatting = formatting ,
319434 )
@@ -332,7 +447,13 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting:
332447 graphnetwork .graph ["name" ] = self .model .name
333448 else :
334449 for var_name in all_var_names :
335- self ._make_node (var_name , graphnetwork , nx = True , formatting = formatting )
450+ self ._make_node (
451+ var_name ,
452+ graphnetwork ,
453+ nx = True ,
454+ formatting = formatting ,
455+ node_formatters = node_formatters ,
456+ )
336457
337458 for child , parents in self .make_compute_graph (var_names = var_names ).items ():
338459 # parents is a set of rv names that precede child rv nodes
@@ -346,6 +467,7 @@ def model_to_networkx(
346467 * ,
347468 var_names : Iterable [VarName ] | None = None ,
348469 formatting : str = "plain" ,
470+ node_formatters : NodeTypeFormatterMapping | None = None ,
349471):
350472 """Produce a networkx Digraph from a PyMC model.
351473
@@ -367,6 +489,10 @@ def model_to_networkx(
367489 Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
368490 formatting : str, optional
369491 one of { "plain" }
492+ node_formatters : dict, optional
493+ A dictionary mapping node types to functions that return a dictionary of node attributes.
494+ Check out the networkx documentation for more information
495+ how attributes are added to nodes: https://networkx.org/documentation/stable/reference/classes/generated/networkx.Graph.add_node.html
370496
371497 Examples
372498 --------
@@ -392,6 +518,17 @@ def model_to_networkx(
392518 obs = Normal("obs", theta, sigma=sigma, observed=y)
393519
394520 model_to_networkx(schools)
521+
522+ Add custom attributes to Free Random Variables and Observed Random Variables nodes.
523+
524+ .. code-block:: python
525+
526+ node_formatters = {
527+ "Free Random Variable": lambda var: {"shape": "circle", "label": var.name},
528+ "Observed Random Variable": lambda var: {"shape": "square", "label": var.name},
529+ }
530+ model_to_networkx(schools, node_formatters=node_formatters)
531+
395532 """
396533 if "plain" not in formatting :
397534 raise ValueError (f"Unsupported formatting for graph nodes: '{ formatting } '. See docstring." )
@@ -403,7 +540,9 @@ def model_to_networkx(
403540 stacklevel = 2 ,
404541 )
405542 model = pm .modelcontext (model )
406- return ModelGraph (model ).make_networkx (var_names = var_names , formatting = formatting )
543+ return ModelGraph (model ).make_networkx (
544+ var_names = var_names , formatting = formatting , node_formatters = node_formatters
545+ )
407546
408547
409548def model_to_graphviz (
@@ -414,6 +553,7 @@ def model_to_graphviz(
414553 save : str | None = None ,
415554 figsize : tuple [int , int ] | None = None ,
416555 dpi : int = 300 ,
556+ node_formatters : NodeTypeFormatterMapping | None = None ,
417557):
418558 """Produce a graphviz Digraph from a PyMC model.
419559
@@ -441,6 +581,10 @@ def model_to_graphviz(
441581 the size of the saved figure.
442582 dpi : int, optional
443583 Dots per inch. It only affects the resolution of the saved figure. The default is 300.
584+ node_formatters : dict, optional
585+ A dictionary mapping node types to functions that return a dictionary of node attributes.
586+ Check out graphviz documentation for more information on available
587+ attributes. https://graphviz.org/docs/nodes/
444588
445589 Examples
446590 --------
@@ -475,6 +619,16 @@ def model_to_graphviz(
475619
476620 # creates the file `schools.pdf`
477621 model_to_graphviz(schools).render("schools")
622+
623+ Display Free Random Variables and Observed Random Variables nodes with custom formatting.
624+
625+ .. code-block:: python
626+
627+ node_formatters = {
628+ "Free Random Variable": lambda var: {"shape": "circle", "label": var.name},
629+ "Observed Random Variable": lambda var: {"shape": "square", "label": var.name},
630+ }
631+ model_to_graphviz(schools, node_formatters=node_formatters)
478632 """
479633 if "plain" not in formatting :
480634 raise ValueError (f"Unsupported formatting for graph nodes: '{ formatting } '. See docstring." )
@@ -491,4 +645,5 @@ def model_to_graphviz(
491645 save = save ,
492646 figsize = figsize ,
493647 dpi = dpi ,
648+ node_formatters = node_formatters ,
494649 )
0 commit comments