@@ -193,7 +193,7 @@ def test_02_dag(self):
193193 )
194194
195195 log .debug ("get_leaves" )
196- self .assertEqual ([p .name for p in root .leaves ()], ["b2" , "c1" , "c2" , "b1" ])
196+ self .assertEqual (set ( [p .name for p in root .leaves ()]), set ( ["b2" , "c1" , "c2" , "b1" ]) )
197197 log .debug ("get_roots" )
198198 self .assertEqual ([p .name for p in c2 .roots ()], ["root" ])
199199
@@ -493,7 +493,51 @@ def test_02_dag(self):
493493 log .debug (f"Node count: { NetworkNode .objects .count ()} " )
494494 log .debug (f"Edge count: { NetworkEdge .objects .count ()} " )
495495
496- def test_03_deep_dag (self ):
496+ def test_03_multilinked_nodes (self ):
497+ log = logging .getLogger ("test_03" )
498+ log .debug ("Test deletion of nodes two nodes with multiple shared edges" )
499+
500+ shared_edge_count = 5
501+
502+ def create_multilinked_nodes (shared_edge_count ):
503+ log .debug ("Creating multiple links between a parent and child node" )
504+ child_node = NetworkNode .objects .create ()
505+ parent_node = NetworkNode .objects .create ()
506+
507+ # Call this multiple times to create multiple edges between same parent/child
508+ for _ in range (shared_edge_count ):
509+ child_node .add_parent (parent_node )
510+
511+ return child_node , parent_node
512+
513+ def delete_parents ():
514+ child_node , parent_node = create_multilinked_nodes (shared_edge_count )
515+
516+ # Refresh the related manager
517+ child_node .refresh_from_db ()
518+
519+ self .assertEqual (child_node .parents .count (), shared_edge_count )
520+ log .debug (f"Initial parents count: { child_node .parents .count ()} " )
521+ child_node .remove_parent (parent_node )
522+ self .assertEqual (child_node .parents .count (), 0 )
523+ log .debug (f"Final parents count: { child_node .parents .count ()} " )
524+
525+ def delete_children ():
526+ child_node , parent_node = create_multilinked_nodes (shared_edge_count )
527+
528+ # Refresh the related manager
529+ parent_node .refresh_from_db ()
530+
531+ self .assertEqual (parent_node .children .count (), shared_edge_count )
532+ log .debug (f"Initial children count: { parent_node .children .count ()} " )
533+ parent_node .remove_child (child_node )
534+ self .assertEqual (parent_node .children .count (), 0 )
535+ log .debug (f"Final children count: { parent_node .children .count ()} " )
536+
537+ delete_parents ()
538+ delete_children ()
539+
540+ def test_04_deep_dag (self ):
497541 """
498542 Create a deep graph and check that graph operations run in a
499543 reasonable amount of time (linear in size of graph, not
@@ -503,11 +547,10 @@ def test_03_deep_dag(self):
503547 def run_test ():
504548 # Using the graph generation algorithm below, the number of potential
505549 # paths from node 0 doubles for each increase in n.
506- # number_of_paths = 2^(n-1) WRONG!!!
507- # When n=22, there are on the order of 1 million paths through the graph
508- # from node 0, so results for intermediate nodes need to be cached
550+ # When n=22, there are many paths through the graph from node 0,
551+ # so results for intermediate nodes need to be cached
509552
510- log = logging .getLogger ("test_03 " )
553+ log = logging .getLogger ("test_04 " )
511554
512555 n = 22 # Keep it an even number
513556
@@ -547,8 +590,9 @@ def run_test():
547590 first = NetworkNode .objects .get (name = "0" )
548591 last = NetworkNode .objects .get (name = str (2 * n - 1 ))
549592
550- log .debug (f"Path exists: { first .path_exists (last , max_depth = n )} " )
551- self .assertTrue (first .path_exists (last , max_depth = n ), True )
593+ path_exists = first .path_exists (last , max_depth = n )
594+ log .debug (f"Path exists: { path_exists } " )
595+ self .assertTrue (path_exists , True )
552596 self .assertEqual (first .distance (last , max_depth = n ), n - 1 )
553597
554598 log .debug (f"Node count: { NetworkNode .objects .count ()} " )
@@ -560,8 +604,9 @@ def run_test():
560604 )
561605
562606 middle = NetworkNode .objects .get (pk = n - 1 )
563- log .debug ("Distance" )
564- self .assertEqual (first .distance (middle , max_depth = n ), n / 2 - 1 )
607+ distance = first .distance (middle , max_depth = n )
608+ log .debug (f"Distance: { distance } " )
609+ self .assertEqual (distance , n / 2 - 1 )
565610
566611 # Run the test, raising an error if the code times out
567612 p = multiprocessing .Process (target = run_test )
0 commit comments