Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion doc/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ important for your research:

Curation Module
---------------
If you use the :code:`get_potential_auto_merge` method from the curation module, please cite [Llobet]_

If you use the default "similarity_correlograms" preset in the :code:`compute_merge_unit_groups` method from the curation module, please cite [Llobet]_

If you use the "slay" preset in the :code:`compute_merge_unit_groups` method, please cite [Koukuntla]_

If you use :code:`auto_label_units` or :code:`train_model`, please cite [Jain]_

Expand Down Expand Up @@ -140,6 +143,8 @@ References

.. [Jia] `High-density extracellular probes reveal dendritic backpropagation and facilitate neuron classification. 2019 <https://journals.physiology.org/doi/full/10.1152/jn.00680.2018>`_

.. [Koukuntla] `SLAy-ing oversplitting errors in high-density electrophysiology spike sorting. 2025. <https://www.biorxiv.org/content/10.1101/2025.06.20.660590v1>`_

.. [Lee] `YASS: Yet another spike sorter. 2017. <https://www.biorxiv.org/content/10.1101/151928v1>`_

.. [Lemon] Methods for neuronal recording in conscious animals. IBRO Handbook Series. 1984.
Expand Down
244 changes: 244 additions & 0 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
"knn",
"quality_score",
],
"slay": [
"template_similarity",
"slay_score",
],
}

_required_extensions = {
Expand All @@ -60,6 +64,7 @@
"snr": ["templates", "noise_levels"],
"template_similarity": ["templates", "template_similarity"],
"knn": ["templates", "spike_locations", "spike_amplitudes"],
"slay_score": ["correlograms", "template_similarity"],
}


Expand All @@ -84,6 +89,7 @@
"censored_period_ms": 0.3,
},
"quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3},
"slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5, "template_diff_thresh": 0.25},
}


Expand Down Expand Up @@ -356,6 +362,12 @@ def compute_merge_unit_groups(
)
outs["pairs_decreased_score"] = pairs_decreased_score

elif step == "slay_score":

M_ij = compute_slay_matrix(sorting_analyzer, params["k1"], params["k2"], params["template_diff_thresh"])

pair_mask = M_ij > params["slay_threshold"]

# FINAL STEP : create the final list from pair_mask boolean matrix
ind1, ind2 = np.nonzero(pair_mask)
merge_unit_groups = list(zip(unit_ids[ind1], unit_ids[ind2]))
Expand Down Expand Up @@ -1506,3 +1518,235 @@ def estimate_cross_contamination(
)

return estimation, p_value


def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float, template_diff_thresh: float):
"""
Computes the "merge decision metric" from the SLAy method, made from combining
a template similarity measure, a cross-correlation significance measure and a
sliding refractory period violation measure. A large M suggests that two
units should be merged.

Paramters
---------
sorting_analyzer : SortingAnalyzer
The sorting analyzer object containing the spike sorting data
k1 : float
Coefficient determining the importance of the cross-correlation significance
k2 : float
Coefficient determining the importance of the sliding rp violation
template_diff_thresh : float
Threshold for how different template similarities can be to be considered for merging


References
----------
Based on computation originally implemented in SLAy [Koukuntla]_.

Implementation is based on one of the original implementations written by Sai Koukuntla,
found at https://github.com/saikoukunt/SLAy.
"""

sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data()
rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, template_diff_thresh)

M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij

return M_ij


def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, template_diff_thresh: float):
"""
Computes a cross-correlation significance measure and a sliding refractory period violation
measure for all units in the `sorting_analyzer`.

Paramters
---------
sorting_analyzer : SortingAnalyzer
The sorting analyzer object containing the spike sorting data
template_diff_thresh : float
Threshold for how different template similarities can be to be considered for merging
"""

correlograms_extension = sorting_analyzer.get_extension("correlograms")
template_similarity = sorting_analyzer.get_extension("template_similarity").get_data()

ccgs, _ = correlograms_extension.get_data()

# convert to seconds for SLAy functions
bin_size_ms = correlograms_extension.params["bin_ms"]

rho_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)])
eta_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)])

for unit_index_1, _ in enumerate(sorting_analyzer.unit_ids):
for unit_index_2, _ in enumerate(sorting_analyzer.unit_ids):

# Don't waste time computing the other metrics if we fail the template similarity check
if template_similarity[unit_index_1, unit_index_2] < 1 - template_diff_thresh:
continue

xgram = ccgs[unit_index_1, unit_index_2, :]

rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair(
xgram, bin_size_s=bin_size_ms / 1000, min_xcorr_rate=0
)
eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size_ms=bin_size_ms)

return rho_ij, eta_ij


def _compute_xcorr_pair(
xgram,
bin_size_s: float,
min_xcorr_rate: float,
) -> float:
"""
Calculates a cross-correlation significance metric for a cluster pair.

Uses the wasserstein distance between an observed cross-correlogram and a null
distribution as an estimate of how significant the dependence between
two neurons is. Low spike count cross-correlograms have large wasserstein
distances from null by chance, so we first try to expand the window size. If
that fails to yield enough spikes, we apply a penalty to the metric.

Ported from https://github.com/saikoukunt/SLAy.

Parameters
----------
xgram : np.array
The raw cross-correlogram for the cluster pair.
bin_size_s : float
The width in seconds of the bin size of the input ccgs.
min_xcorr_rate : float
The minimum ccg firing rate in Hz.

Returns
-------
sig : float
The calculated cross-correlation significance metric.
"""

from scipy.signal import butter, find_peaks_cwt, sosfiltfilt
from scipy.stats import wasserstein_distance

# calculate low-pass filtered second derivative of ccg
fs = 1 / bin_size_s
cutoff_freq = 100
nyqist = fs / 2
cutoff = cutoff_freq / nyqist
peak_width = 0.002 / bin_size_s

xgram_2d = np.diff(xgram, 2)
sos = butter(4, cutoff, output="sos")
xgram_2d = sosfiltfilt(sos, xgram_2d)

if xgram.sum() == 0:
return 0

# find negative peaks of second derivative of ccg, these are the edges of dips in ccg
peaks = find_peaks_cwt(-xgram_2d, peak_width, noise_perc=90) + 1
# if no peaks are found, return a very low significance
if peaks.shape[0] == 0:
return -4
peaks = np.abs(peaks - xgram.shape[0] / 2)
peaks = peaks[peaks > 0.5 * peak_width]
min_peaks = np.sort(peaks)

# start with peaks closest to 0 and move to the next set of peaks if the event count is too low
window_width = min_peaks * 1.5
starts = np.maximum(xgram.shape[0] / 2 - window_width, 0)
ends = np.minimum(xgram.shape[0] / 2 + window_width, xgram.shape[0] - 1)
ind = 0
xgram_window = xgram[int(starts[0]) : int(ends[0] + 1)]
xgram_sum = xgram_window.sum()
window_size = xgram_window.shape[0] * bin_size_s
while (xgram_sum < (min_xcorr_rate * window_size * 10)) and (ind < starts.shape[0]):
xgram_window = xgram[int(starts[ind]) : int(ends[ind] + 1)]
xgram_sum = xgram_window.sum()
window_size = xgram_window.shape[0] * bin_size_s
ind += 1
# use the whole ccg if peak finding fails
if ind == starts.shape[0]:
xgram_window = xgram

# TODO: was getting error messges when xgram_window was all zero. Why was this happening?
if np.abs(xgram_window).sum() == 0:
return 0

sig = (
wasserstein_distance(
np.arange(xgram_window.shape[0]) / xgram_window.shape[0],
np.arange(xgram_window.shape[0]) / xgram_window.shape[0],
xgram_window,
np.ones_like(xgram_window),
)
* 4
)

if xgram_window.sum() < (min_xcorr_rate * window_size):
sig *= (xgram_window.sum() / (min_xcorr_rate * window_size)) ** 2

# if sig < 0.04 and xgram_window.sum() < (min_xcorr_rate * window_size):
if xgram_window.sum() < (min_xcorr_rate / 4 * window_size):
sig = -4 # don't merge if the event count is way too low

return sig


def _sliding_RP_viol_pair(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably...

correlogram,
bin_size_ms: float,
acceptThresh: float = 0.15,
) -> float:
"""
Calculate the sliding refractory period violation confidence for a cluster.

Ported from https://github.com/saikoukunt/SLAy.

Parameters
----------
correlogram : np.array
The auto-correlogram of the cluster.
bin_size_ms : float
The width in ms of the bin size of the input ccgs.
acceptThresh : float, default: 0.15
The minimum ccg firing rate in Hz.

Returns
-------
sig : float
The refractory period violation confidence for the cluster.
"""
from scipy.signal import butter, sosfiltfilt
from scipy.stats import poisson

# create various refractory periods sizes to test (between 0 and 20x bin size)
b = np.arange(0, 21 * bin_size_ms, bin_size_ms) / 1000
bTestIdx = np.array([1, 2, 4, 6, 8, 12, 16, 20], dtype="int8")
bTest = [b[i] for i in bTestIdx]

# calculate and avg halves of acg to ensure symmetry
# keep only second half of acg, refractory period violations are compared from the center of acg
half_len = int(correlogram.shape[0] / 2)
correlogram = (correlogram[half_len:] + correlogram[:half_len][::-1]) / 2

acg_cumsum = np.cumsum(correlogram)
sum_res = acg_cumsum[bTestIdx - 1] # -1 bc 0th bin corresponds to 0-bin_size ms

# low-pass filter acg and use max as baseline event rate
order = 4 # Hz
cutoff_freq = 250 # Hz
fs = 1 / bin_size_ms * 1000
nyqist = fs / 2
cutoff = cutoff_freq / nyqist
sos = butter(order, cutoff, btype="low", output="sos")
smoothed_acg = sosfiltfilt(sos, correlogram)

bin_rate_max = np.max(smoothed_acg)
max_conts_max = np.array(bTest) / bin_size_ms * 1000 * (bin_rate_max * acceptThresh)
# compute confidence of less than acceptThresh contamination at each refractory period
confs = 1 - poisson.cdf(sum_res, max_conts_max)
rp_viol = 1 - confs.max()
Comment on lines 1782 to 1786
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


return rp_viol