@@ -33,7 +33,7 @@ def run(self, **job_kwargs):
3333 sorting ["unit_index" ] = spikes ["cluster_index" ]
3434 sorting ["segment_index" ] = spikes ["segment_index" ]
3535 sorting = NumpySorting (sorting , self .recording .sampling_frequency , unit_ids )
36- self .result = {"sorting" : sorting }
36+ self .result = {"sorting" : sorting , "spikes" : spikes }
3737 self .result ["templates" ] = self .templates
3838
3939 def compute_result (self , with_collision = False , ** result_params ):
@@ -45,6 +45,7 @@ def compute_result(self, with_collision=False, **result_params):
4545
4646 _run_key_saved = [
4747 ("sorting" , "sorting" ),
48+ ("spikes" , "npy" ),
4849 ("templates" , "zarr_templates" ),
4950 ]
5051 _result_key_saved = [("gt_collision" , "pickle" ), ("gt_comparison" , "pickle" )]
@@ -71,9 +72,15 @@ def plot_performances_vs_snr(self, **kwargs):
7172
7273 return plot_performances_vs_snr (self , ** kwargs )
7374
75+ def plot_performances_comparison (self , ** kwargs ):
76+ from .benchmark_plot_tools import plot_performances_comparison
77+
78+ return plot_performances_comparison (self , ** kwargs )
79+
7480 def plot_collisions (self , case_keys = None , figsize = None ):
7581 if case_keys is None :
7682 case_keys = list (self .cases .keys ())
83+ import matplotlib .pyplot as plt
7784
7885 fig , axs = plt .subplots (ncols = len (case_keys ), nrows = 1 , figsize = figsize , squeeze = False )
7986
@@ -90,70 +97,6 @@ def plot_collisions(self, case_keys=None, figsize=None):
9097
9198 return fig
9299
93- def plot_comparison_matching (
94- self ,
95- case_keys = None ,
96- performance_names = ["accuracy" , "recall" , "precision" ],
97- colors = ["g" , "b" , "r" ],
98- ylim = (- 0.1 , 1.1 ),
99- figsize = None ,
100- ):
101-
102- if case_keys is None :
103- case_keys = list (self .cases .keys ())
104-
105- num_methods = len (case_keys )
106- import pylab as plt
107-
108- fig , axs = plt .subplots (ncols = num_methods , nrows = num_methods , figsize = (10 , 10 ))
109- for i , key1 in enumerate (case_keys ):
110- for j , key2 in enumerate (case_keys ):
111- if len (axs .shape ) > 1 :
112- ax = axs [i , j ]
113- else :
114- ax = axs [j ]
115- comp1 = self .get_result (key1 )["gt_comparison" ]
116- comp2 = self .get_result (key2 )["gt_comparison" ]
117- if i <= j :
118- for performance , color in zip (performance_names , colors ):
119- perf1 = comp1 .get_performance ()[performance ]
120- perf2 = comp2 .get_performance ()[performance ]
121- ax .plot (perf2 , perf1 , "." , label = performance , color = color )
122-
123- ax .plot ([0 , 1 ], [0 , 1 ], "k--" , alpha = 0.5 )
124- ax .set_ylim (ylim )
125- ax .set_xlim (ylim )
126- ax .spines [["right" , "top" ]].set_visible (False )
127- ax .set_aspect ("equal" )
128-
129- label1 = self .cases [key1 ]["label" ]
130- label2 = self .cases [key2 ]["label" ]
131- if j == i :
132- ax .set_ylabel (f"{ label1 } " )
133- else :
134- ax .set_yticks ([])
135- if i == j :
136- ax .set_xlabel (f"{ label2 } " )
137- else :
138- ax .set_xticks ([])
139- if i == num_methods - 1 and j == num_methods - 1 :
140- patches = []
141- import matplotlib .patches as mpatches
142-
143- for color , name in zip (colors , performance_names ):
144- patches .append (mpatches .Patch (color = color , label = name ))
145- ax .legend (handles = patches , bbox_to_anchor = (1.05 , 1 ), loc = "upper left" , borderaxespad = 0.0 )
146- else :
147- ax .spines ["bottom" ].set_visible (False )
148- ax .spines ["left" ].set_visible (False )
149- ax .spines ["top" ].set_visible (False )
150- ax .spines ["right" ].set_visible (False )
151- ax .set_xticks ([])
152- ax .set_yticks ([])
153- plt .tight_layout (h_pad = 0 , w_pad = 0 )
154-
155- return fig
156-
157100 def get_count_units (self , case_keys = None , well_detected_score = None , redundant_score = None , overmerged_score = None ):
158101 import pandas as pd
159102
@@ -196,6 +139,7 @@ def plot_unit_counts(self, case_keys=None, figsize=None):
196139 plot_study_unit_counts (self , case_keys , figsize = figsize )
197140
198141 def plot_unit_losses (self , before , after , metric = ["precision" ], figsize = None ):
142+ import matplotlib .pyplot as plt
199143
200144 fig , axs = plt .subplots (ncols = 1 , nrows = len (metric ), figsize = figsize , squeeze = False )
201145
0 commit comments