Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,9 +786,9 @@ def compute_auto_merge(self, **params):

merge_unit_groups, extra = compute_merge_unit_groups(
self.analyzer,
preset=params['preset'],
extra_outputs=True,
resolve_graph=False
resolve_graph=False,
**params
)

return merge_unit_groups, extra
Expand Down
240 changes: 129 additions & 111 deletions spikeinterface_gui/mergeview.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,76 +3,87 @@

from .view_base import ViewBase

from spikeinterface.curation.auto_merge import _compute_merge_presets, _default_step_params


default_preset_list = ["similarity"] + list(_compute_merge_presets.keys())

all_presets = _compute_merge_presets.copy()
all_presets["similarity"] = ["template_similarity"]

class MergeView(ViewBase):
_supported_backend = ['qt', 'panel']

_settings = None

_methods = [{"name": "method", "type": "list", "limits": ["similarity", "automerge"]}]

_method_params = {
"similarity": [
{"name": "similarity_threshold", "type": "float", "value": .9, "step": 0.01},
{"name": "similarity_method", "type": "list", "limits": ["l1", "l2", "cosine"]},
],
"automerge": [
{"name": "automerge_preset", "type": "list", "limits": [
'similarity_correlograms',
'temporal_splits',
'x_contaminations',
'feature_neighbors'
]}
]
}

_presets = [
{
"name": "preset",
"type": "list",
# set similarity to default
"limits": default_preset_list
}
]

_preset_params = {}
# add similarity preset parameters
for preset_name, preset_params in all_presets.items():
_preset_params[preset_name] = []
for step_name in preset_params:
for step_parameter_name, step_parameter_ in _default_step_params[step_name].items():
parameter_dict = {
"name": step_name + "/" + step_parameter_name,
"value": step_parameter_,
}
if step_parameter_name == "similarity_method":
parameter_dict["type"] = "list"
parameter_dict["limits"] = ["l1", "l2", "cosine"]
else:
parameter_dict["type"] = type(step_parameter_).__name__
_preset_params[preset_name].append(parameter_dict)
_need_compute = False

def __init__(self, controller=None, parent=None, backend="qt"):
if controller.has_extension("template_similarity"):
similarity_ext = controller.analyzer.get_extension("template_similarity")
similarity_method = similarity_ext.params["method"]
self._method_params["similarity"][1]["value"] = similarity_method
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)

def get_potential_merges(self):
method = self.method
preset = self.preset
if self.controller.verbose:
print(f"Computing potential merges using {method} method")
if method == 'similarity':
similarity_params = self.method_params['similarity']
similarity = self.controller.get_similarity(similarity_params['similarity_method'])
if similarity is None:
similarity = self.controller.compute_similarity(similarity_params['similarity_method'])
th_sim = similarity > similarity_params['similarity_threshold']
unit_ids = self.controller.unit_ids
self.proposed_merge_unit_groups = [[unit_ids[i], unit_ids[j]] for i, j in zip(*np.nonzero(th_sim)) if i < j]
self.merge_info = {'similarity': similarity}
elif method == 'automerge':
automerge_params = self.method_params['automerge']
params = {
'preset': automerge_params['automerge_preset']
}
self.proposed_merge_unit_groups, self.merge_info = self.controller.compute_auto_merge(**params)
else:
raise ValueError(f"Unknown method: {method}")
print(f"Computing potential merges using {preset} method")
params_dict = {}
params_dict["preset"] = preset

preset_params = self.preset_params[preset]

steps_params = {}
for name in preset_params.keys():
step_name, step_param = name.split("/")
if steps_params.get(step_name) is None:
steps_params[step_name] = {}
steps_params[step_name][step_param] = preset_params[name]
params_dict["steps_params"] = steps_params

# define steps for similarity preset
if preset == "similarity":
params_dict["preset"] = None
params_dict["steps"] = all_presets["similarity"]
self.proposed_merge_unit_groups, self.merge_info = self.controller.compute_auto_merge(**params_dict)

if self.controller.verbose:
print(f"Found {len(self.proposed_merge_unit_groups)} merge groups using {method} method")
print(f"Found {len(self.proposed_merge_unit_groups)} merge groups using {preset} preset")

def get_table_data(self, include_deleted=False):
"""Get data for displaying in table"""
if not self.proposed_merge_unit_groups:
return [], []

max_group_size = max(len(g) for g in self.proposed_merge_unit_groups)
potential_labels = {"similarity", "correlogram_diff", "templates_diff"}
more_labels = []
for lbl in self.merge_info.keys():
if lbl in potential_labels:
if max_group_size == 2:
more_labels.append(lbl)
else:
more_labels.append([lbl + "_min", lbl + "_max"])
if max_group_size == 2:
more_labels.append(lbl)
else:
more_labels.append([lbl + "_min", lbl + "_max"])

labels = [f"unit_id{i}" for i in range(max_group_size)] + more_labels + ["group_ids"]

Expand All @@ -91,20 +102,29 @@ def get_table_data(self, include_deleted=False):
# row[f"unit_id{i}_color"] = self.controller.get_unit_color(unit_id)
row["group_ids"] = group_ids

# Add metrics information
# Add pairwise metric information
for info_name in more_labels:
values = []
for unit_id1, unit_id2 in itertools.combinations(group_ids, 2):
unit_ind1 = unit_ids.index(unit_id1)
unit_ind2 = unit_ids.index(unit_id2)
values.append(self.merge_info[info_name][unit_ind1][unit_ind2])

if max_group_size == 2:
row[info_name] = f"{values[0]:.2f}"
merge_info = self.merge_info[info_name]
if isinstance(merge_info, np.ndarray) and \
merge_info.shape == (len(unit_ids), len(unit_ids)):
for unit_id1, unit_id2 in itertools.combinations(group_ids, 2):
unit_ind1 = unit_ids.index(unit_id1)
unit_ind2 = unit_ids.index(unit_id2)
values.append(merge_info[unit_ind1][unit_ind2])

if max_group_size == 2:
row[info_name] = f"{values[0]:.2f}"
else:
min_, max_ = min(values), max(values)
row[f"{info_name}_min"] = f"{min_:.2f}"
row[f"{info_name}_max"] = f"{max_:.2f}"
else:
min_, max_ = min(values), max(values)
row[f"{info_name}_min"] = f"{min_:.2f}"
row[f"{info_name}_max"] = f"{max_:.2f}"
if info_name in labels:
labels.remove(info_name)
elif f"{info_name}_min" in labels:
labels.remove(f"{info_name}_min")
labels.remove(f"{info_name}_max")
Comment on lines +108 to +127
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For @samuelgarcia

Changed the logic here to be able to show any extra outs with shape (num_units, num_units), instead of pre-selecting potential labels

rows.append(row)
return labels, rows

Expand Down Expand Up @@ -155,10 +175,10 @@ def _qt_on_item_selection_changed(self):
def _qt_on_double_click(self, item):
self.accept_group_merge(item.group_ids)

def _qt_on_method_change(self):
self.method = self.method_selector['method']
for method in self.method_params_selectors:
self.method_params_selectors[method].setVisible(method == self.method)
def _qt_on_preset_change(self):
self.preset = self.preset_selector['preset']
for preset in self.preset_params_selectors:
self.preset_params_selectors[preset].setVisible(preset == self.preset)


def _qt_make_layout(self):
Expand All @@ -167,33 +187,33 @@ def _qt_make_layout(self):

self.proposed_merge_unit_groups = []

# create method and arguments layout
self.method_selector = pg.parametertree.Parameter.create(name="method", type='group', children=self._methods)
method_select = pg.parametertree.ParameterTree(parent=None)
method_select.header().hide()
method_select.setParameters(self.method_selector, showTop=True)
method_select.setWindowTitle(u'View options')
method_select.setFixedHeight(50)
self.method_selector.sigTreeStateChanged.connect(self._qt_on_method_change)
# create presets and arguments layout
self.preset_selector = pg.parametertree.Parameter.create(name="preset", type='group', children=self._presets)
preset_select = pg.parametertree.ParameterTree(parent=None)
preset_select.header().hide()
preset_select.setParameters(self.preset_selector, showTop=True)
preset_select.setWindowTitle(u'View options')
preset_select.setFixedHeight(50)
self.preset_selector.sigTreeStateChanged.connect(self._qt_on_preset_change)

self.merge_info = {}
self.layout = QT.QVBoxLayout()
self.layout.addWidget(method_select)

self.method_params_selectors = {}
self.method_params = {}
for method, params in self._method_params.items():
method_params = pg.parametertree.Parameter.create(name="params", type='group', children=params)
method_tree_settings = pg.parametertree.ParameterTree(parent=None)
method_tree_settings.header().hide()
method_tree_settings.setParameters(method_params, showTop=True)
method_tree_settings.setWindowTitle(u'View options')
method_tree_settings.setFixedHeight(100)
self.method_params_selectors[method] = method_tree_settings
self.method_params[method] = method_params
self.layout.addWidget(method_tree_settings)
self.method = self.method_selector['method']
self._qt_on_method_change()
self.layout.addWidget(preset_select)

self.preset_params_selectors = {}
self.preset_params = {}
for preset, params in self._preset_params.items():
preset_params = pg.parametertree.Parameter.create(name="params", type='group', children=params)
preset_tree_settings = pg.parametertree.ParameterTree(parent=None)
preset_tree_settings.header().hide()
preset_tree_settings.setParameters(preset_params, showTop=True)
preset_tree_settings.setWindowTitle(u'View options')
preset_tree_settings.setFixedHeight(100)
self.preset_params_selectors[preset] = preset_tree_settings
self.preset_params[preset] = preset_params
self.layout.addWidget(preset_tree_settings)
self.preset = self.preset_selector['preset']
self._qt_on_preset_change()

row_layout = QT.QHBoxLayout()

Expand Down Expand Up @@ -260,7 +280,7 @@ def _qt_refresh(self):
self.table.setItem(r, c, item)
item.setIcon(icon)
item.group_ids = row.get("group_ids", [])
elif "_color" not in label:
elif "_color" not in label and label in row:
value = row[label]
item = CustomItem(value)
self.table.setItem(r, c, item)
Expand All @@ -273,7 +293,7 @@ def _compute_merges(self):
with self.busy_cursor():
self.get_potential_merges()
if len(self.proposed_merge_unit_groups) == 0:
self.warning(f"No potential merges found with method {self.method}")
self.warning(f"No potential merges found with preset {self.preset}")
self.refresh()

def _qt_on_spike_selection_changed(self):
Expand All @@ -292,20 +312,20 @@ def _panel_make_layout(self):

self.proposed_merge_unit_groups = []

# Create method and arguments layout
method_settings = SettingsProxy(create_dynamic_parameterized(self._methods))
self.method_selector = pn.Param(method_settings._parameterized, sizing_mode="stretch_width", name="Method")
for setting_data in self._methods:
method_settings._parameterized.param.watch(self._panel_on_method_change, setting_data["name"])

self.method_params = {}
self.method_params_selectors = {}
for method, params in self._method_params.items():
method_params = SettingsProxy(create_dynamic_parameterized(params))
self.method_params[method] = method_params
self.method_params_selectors[method] = pn.Param(method_params._parameterized, sizing_mode="stretch_width",
name=f"{method.capitalize()} parameters")
self.method = list(self.method_params.keys())[0]
# Create presets and arguments layout
preset_settings = SettingsProxy(create_dynamic_parameterized(self._presets))
self.preset_selector = pn.Param(preset_settings._parameterized, sizing_mode="stretch_width", name="Preset")
for setting_data in self._presets:
preset_settings._parameterized.param.watch(self._panel_on_preset_change, setting_data["name"])

self.preset_params = {}
self.preset_params_selectors = {}
for preset, params in self._preset_params.items():
preset_params = SettingsProxy(create_dynamic_parameterized(params))
self.preset_params[preset] = preset_params
self.preset_params_selectors[preset] = pn.Param(preset_params._parameterized, sizing_mode="stretch_width",
name=f"{preset.capitalize()} parameters")
self.preset = list(self.preset_params.keys())[0]

# shortcuts
shortcuts = [
Expand All @@ -332,8 +352,8 @@ def _panel_make_layout(self):

self.layout = pn.Column(
# add params
self.method_selector,
self.method_params_selectors[self.method],
self.preset_selector,
self.preset_params_selectors[self.preset],
calculate_row,
self.table_area,
shortcuts_component,
Expand Down Expand Up @@ -384,9 +404,9 @@ def _panel_refresh(self):
def _panel_compute_merges(self, event):
self._compute_merges()

def _panel_on_method_change(self, event):
self.method = event.new
self.layout[1] = self.method_params_selectors[self.method]
def _panel_on_preset_change(self, event):
self.preset = event.new
self.layout[1] = self.preset_params_selectors[self.preset]

def _panel_on_click(self, event):
# set unit visibility
Expand Down Expand Up @@ -432,10 +452,8 @@ def _panel_on_unit_visibility_changed(self):
## Merge View

This view allows you to compute potential merges between units based on their similarity or using the auto merge function.
Select the method to use for merging units.
The available methods are:
- similarity: Computes the similarity between units based on their features.
- automerge: uses the auto merge function in SpikeInterface to find potential merges.
Select the preset to use for merging units.
The available presets are inherited from spikeinterface.

Click "Calculate merges" to compute the potential merges. When finished, the table will be populated
with the potential merges.
Expand Down
15 changes: 4 additions & 11 deletions spikeinterface_gui/tests/test_mainwindow_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
import sys


test_folder = Path(__file__).parent / 'my_dataset_small'
# test_folder = Path(__file__).parent / 'my_dataset_big'
# test_folder = Path(__file__).parent / 'my_dataset_multiprobe'

# yep is for testing
yep_layout = dict(
Expand All @@ -30,6 +27,7 @@


def setup_module():
global test_folder
case = test_folder.stem.split('_')[-1]
make_analyzer_folder(test_folder, case=case)

Expand Down Expand Up @@ -123,14 +121,9 @@ def test_launcher(verbose=True):
if __name__ == '__main__':
args = parser.parse_args()
dataset = args.dataset
if dataset == "small":
test_folder = Path(__file__).parent / 'my_dataset_small'
elif dataset == "big":
test_folder = Path(__file__).parent / 'my_dataset_big'
elif dataset == "multiprobe":
test_folder = Path(__file__).parent / 'my_dataset_multiprobe'
else:
test_folder = Path(dataset)
global test_folder
if dataset is not None:
test_folder = Path(dataset).parent / f"my_dataset_{dataset}"
if not test_folder.is_dir():
setup_module()

Expand Down
Loading