@@ -107,82 +107,90 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
107107 # and use custum grid spec
108108 fig = self .figure
109109 nrows = 2
110- ncols = 3
111- if sorting_analyzer .has_extension ("correlograms" ) or sorting_analyzer .has_extension ("spike_amplitudes" ):
110+ ncols = 2
111+ if sorting_analyzer .has_extension ("correlograms" ):
112+ ncols += 1
113+ if sorting_analyzer .has_extension ("waveforms" ):
112114 ncols += 1
113115 if sorting_analyzer .has_extension ("spike_amplitudes" ):
114116 nrows += 1
115117 gs = fig .add_gridspec (nrows , ncols )
118+ col_counter = 0
116119
117- if sorting_analyzer . has_extension ( "unit_locations" ):
118- ax1 = fig .add_subplot (gs [:2 , 0 ])
119- # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1)
120- w = UnitLocationsWidget (
121- sorting_analyzer ,
122- unit_ids = [ unit_id ] ,
123- unit_colors = unit_colors ,
124- plot_legend = False ,
125- backend = "matplotlib" ,
126- ax = ax1 ,
127- ** unitlocationswidget_kwargs ,
128- )
129-
130- unit_locations = sorting_analyzer .get_extension ("unit_locations" ).get_data (outputs = "by_unit" )
131- unit_location = unit_locations [unit_id ]
132- x , y = unit_location [0 ], unit_location [1 ]
133- ax1 .set_xlim (x - 80 , x + 80 )
134- ax1 .set_ylim (y - 250 , y + 250 )
135- ax1 .set_xticks ([])
136- ax1 .set_xlabel (None )
137- ax1 .set_ylabel (None )
138-
139- ax2 = fig .add_subplot (gs [:2 , 1 ])
140- w = UnitWaveformsWidget (
120+ # Unit locations and unit waveform plots are always generated
121+ ax_unit_locations = fig .add_subplot (gs [:2 , col_counter ])
122+ _ = UnitLocationsWidget (
123+ sorting_analyzer ,
124+ unit_ids = [ unit_id ] ,
125+ unit_colors = unit_colors ,
126+ plot_legend = False ,
127+ backend = "matplotlib" ,
128+ ax = ax_unit_locations ,
129+ ** unitlocationswidget_kwargs ,
130+ )
131+ col_counter += 1
132+
133+ unit_locations = sorting_analyzer .get_extension ("unit_locations" ).get_data (outputs = "by_unit" )
134+ unit_location = unit_locations [unit_id ]
135+ x , y = unit_location [0 ], unit_location [1 ]
136+ ax_unit_locations .set_xlim (x - 80 , x + 80 )
137+ ax_unit_locations .set_ylim (y - 250 , y + 250 )
138+ ax_unit_locations .set_xticks ([])
139+ ax_unit_locations .set_xlabel (None )
140+ ax_unit_locations .set_ylabel (None )
141+
142+ ax_unit_waveforms = fig .add_subplot (gs [:2 , col_counter ])
143+ _ = UnitWaveformsWidget (
141144 sorting_analyzer ,
142145 unit_ids = [unit_id ],
143146 unit_colors = unit_colors ,
144147 plot_templates = True ,
148+ plot_waveforms = sorting_analyzer .has_extension ("waveforms" ),
145149 same_axis = True ,
146150 plot_legend = False ,
147151 sparsity = sparsity ,
148152 backend = "matplotlib" ,
149- ax = ax2 ,
153+ ax = ax_unit_waveforms ,
150154 ** unitwaveformswidget_kwargs ,
151155 )
156+ col_counter += 1
152157
153- ax2 .set_title (None )
158+ ax_unit_waveforms .set_title (None )
154159
155- ax3 = fig .add_subplot (gs [:2 , 2 ])
156- UnitWaveformDensityMapWidget (
157- sorting_analyzer ,
158- unit_ids = [unit_id ],
159- unit_colors = unit_colors ,
160- use_max_channel = True ,
161- same_axis = False ,
162- backend = "matplotlib" ,
163- ax = ax3 ,
164- ** unitwaveformdensitymapwidget_kwargs ,
165- )
166- ax3 .set_ylabel (None )
160+ if sorting_analyzer .has_extension ("waveforms" ):
161+ ax_waveform_density = fig .add_subplot (gs [:2 , col_counter ])
162+ UnitWaveformDensityMapWidget (
163+ sorting_analyzer ,
164+ unit_ids = [unit_id ],
165+ unit_colors = unit_colors ,
166+ use_max_channel = True ,
167+ same_axis = False ,
168+ backend = "matplotlib" ,
169+ ax = ax_waveform_density ,
170+ ** unitwaveformdensitymapwidget_kwargs ,
171+ )
172+ col_counter += 1
173+ ax_waveform_density .set_ylabel (None )
167174
168175 if sorting_analyzer .has_extension ("correlograms" ):
169- ax4 = fig .add_subplot (gs [:2 , 3 ])
176+ ax_correlograms = fig .add_subplot (gs [:2 , col_counter ])
170177 AutoCorrelogramsWidget (
171178 sorting_analyzer ,
172179 unit_ids = [unit_id ],
173180 unit_colors = unit_colors ,
174181 backend = "matplotlib" ,
175- ax = ax4 ,
182+ ax = ax_correlograms ,
176183 ** autocorrelogramswidget_kwargs ,
177184 )
185+ col_counter += 1
178186
179- ax4 .set_title (None )
180- ax4 .set_yticks ([])
187+ ax_correlograms .set_title (None )
188+ ax_correlograms .set_yticks ([])
181189
182190 if sorting_analyzer .has_extension ("spike_amplitudes" ):
183- ax5 = fig .add_subplot (gs [2 , :3 ])
184- ax6 = fig .add_subplot (gs [2 , 3 ])
185- axes = np .array ([ax5 , ax6 ])
191+ ax_spike_amps = fig .add_subplot (gs [2 , : col_counter - 1 ])
192+ ax_amps_distribution = fig .add_subplot (gs [2 , col_counter - 1 ])
193+ axes = np .array ([ax_spike_amps , ax_amps_distribution ])
186194 AmplitudesWidget (
187195 sorting_analyzer ,
188196 unit_ids = [unit_id ],
0 commit comments