11from __future__ import annotations
2+ from collections import defaultdict
23
34import numpy as np
45
@@ -17,7 +18,7 @@ class UnitSummaryWidget(BaseWidget):
1718 """
1819 Plot a unit summary.
1920
20- If amplitudes are alreday computed they are displayed.
21+ If amplitudes are alreday computed, they are displayed.
2122
2223 Parameters
2324 ----------
@@ -30,6 +31,14 @@ class UnitSummaryWidget(BaseWidget):
3031 sparsity : ChannelSparsity or None, default: None
3132 Optional ChannelSparsity to apply.
3233 If SortingAnalyzer is already sparse, the argument is ignored
34+ subwidget_kwargs : dict or None, default: None
35+ Parameters for the subwidgets in a nested dictionary
36+ unit_locations : UnitLocationsWidget (see UnitLocationsWidget for details)
37+ unit_waveforms : UnitWaveformsWidget (see UnitWaveformsWidget for details)
38+ unit_waveform_density_map : UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details)
39+ autocorrelograms : AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details)
40+ amplitudes : AmplitudesWidget (see AmplitudesWidget for details)
41+ Please note that the unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary.
3342 """
3443
3544 # possible_backends = {}
@@ -40,21 +49,29 @@ def __init__(
4049 unit_id ,
4150 unit_colors = None ,
4251 sparsity = None ,
43- radius_um = 100 ,
52+ subwidget_kwargs = None ,
4453 backend = None ,
4554 ** backend_kwargs ,
4655 ):
47-
4856 sorting_analyzer = self .ensure_sorting_analyzer (sorting_analyzer )
4957
5058 if unit_colors is None :
5159 unit_colors = get_unit_colors (sorting_analyzer )
5260
61+ if subwidget_kwargs is None :
62+ subwidget_kwargs = dict ()
63+ for kwargs in subwidget_kwargs .values ():
64+ if "unit_colors" in kwargs :
65+ raise ValueError (
66+ "unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary"
67+ )
68+
5369 plot_data = dict (
5470 sorting_analyzer = sorting_analyzer ,
5571 unit_id = unit_id ,
5672 unit_colors = unit_colors ,
5773 sparsity = sparsity ,
74+ subwidget_kwargs = subwidget_kwargs ,
5875 )
5976
6077 BaseWidget .__init__ (self , plot_data , backend = backend , ** backend_kwargs )
@@ -70,6 +87,14 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
7087 unit_colors = dp .unit_colors
7188 sparsity = dp .sparsity
7289
90+ # defaultdict returns empty dict if key not found in subwidget_kwargs
91+ subwidget_kwargs = defaultdict (lambda : dict (), dp .subwidget_kwargs )
92+ unitlocationswidget_kwargs = subwidget_kwargs ["unit_locations" ]
93+ unitwaveformswidget_kwargs = subwidget_kwargs ["unit_waveforms" ]
94+ unitwaveformdensitymapwidget_kwargs = subwidget_kwargs ["unit_waveform_density_map" ]
95+ autocorrelogramswidget_kwargs = subwidget_kwargs ["autocorrelograms" ]
96+ amplitudeswidget_kwargs = subwidget_kwargs ["amplitudes" ]
97+
7398 # force the figure without axes
7499 if "figsize" not in backend_kwargs :
75100 backend_kwargs ["figsize" ] = (18 , 7 )
@@ -99,6 +124,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
99124 plot_legend = False ,
100125 backend = "matplotlib" ,
101126 ax = ax1 ,
127+ ** unitlocationswidget_kwargs ,
102128 )
103129
104130 unit_locations = sorting_analyzer .get_extension ("unit_locations" ).get_data (outputs = "by_unit" )
@@ -121,6 +147,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
121147 sparsity = sparsity ,
122148 backend = "matplotlib" ,
123149 ax = ax2 ,
150+ ** unitwaveformswidget_kwargs ,
124151 )
125152
126153 ax2 .set_title (None )
@@ -134,6 +161,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
134161 same_axis = False ,
135162 backend = "matplotlib" ,
136163 ax = ax3 ,
164+ ** unitwaveformdensitymapwidget_kwargs ,
137165 )
138166 ax3 .set_ylabel (None )
139167
@@ -145,6 +173,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
145173 unit_colors = unit_colors ,
146174 backend = "matplotlib" ,
147175 ax = ax4 ,
176+ ** autocorrelogramswidget_kwargs ,
148177 )
149178
150179 ax4 .set_title (None )
@@ -162,6 +191,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
162191 plot_histograms = True ,
163192 backend = "matplotlib" ,
164193 axes = axes ,
194+ ** amplitudeswidget_kwargs ,
165195 )
166196
167197 fig .suptitle (f"unit_id: { dp .unit_id } " )
0 commit comments