-
Notifications
You must be signed in to change notification settings - Fork 228
Add SLAy auto-merge preset #4190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
6441f50
9ff3ef0
8ed4d81
45656ce
3994547
7571822
9c9dac5
4596915
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,6 +52,10 @@ | |
| "knn", | ||
| "quality_score", | ||
| ], | ||
| "slay": [ | ||
| "template_similarity", | ||
| "slay_score", | ||
| ], | ||
| } | ||
|
|
||
| _required_extensions = { | ||
|
|
@@ -60,6 +64,7 @@ | |
| "snr": ["templates", "noise_levels"], | ||
| "template_similarity": ["templates", "template_similarity"], | ||
| "knn": ["templates", "spike_locations", "spike_amplitudes"], | ||
| "slay_score": ["correlograms", "template_similarity"], | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -84,6 +89,7 @@ | |
| "censored_period_ms": 0.3, | ||
| }, | ||
| "quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, | ||
| "slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5, "template_diff_thresh": 0.25}, | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -356,6 +362,12 @@ def compute_merge_unit_groups( | |
| ) | ||
| outs["pairs_decreased_score"] = pairs_decreased_score | ||
|
|
||
| elif step == "slay_score": | ||
|
|
||
| M_ij = compute_slay_matrix(sorting_analyzer, params["k1"], params["k2"], params["template_diff_thresh"]) | ||
|
|
||
| pair_mask = M_ij > params["slay_threshold"] | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # FINAL STEP : create the final list from pair_mask boolean matrix | ||
| ind1, ind2 = np.nonzero(pair_mask) | ||
| merge_unit_groups = list(zip(unit_ids[ind1], unit_ids[ind2])) | ||
|
|
@@ -1506,3 +1518,235 @@ def estimate_cross_contamination( | |
| ) | ||
|
|
||
| return estimation, p_value | ||
|
|
||
|
|
||
| def compute_slay_matrix(sorting_analyzer: SortingAnalyzer, k1: float, k2: float, template_diff_thresh: float): | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| Computes the "merge decision metric" from the SLAy method, made from combining | ||
| a template similarity measure, a cross-correlation significance measure and a | ||
| sliding refractory period violation measure. A large M suggests that two | ||
| units should be merged. | ||
|
|
||
| Paramters | ||
| --------- | ||
| sorting_analyzer : SortingAnalyzer | ||
| The sorting analyzer object containing the spike sorting data | ||
| k1 : float | ||
| Coefficient determining the importance of the cross-correlation significance | ||
| k2 : float | ||
| Coefficient determining the importance of the sliding rp violation | ||
| template_diff_thresh : float | ||
| Threshold for how different template similarities can be to be considered for merging | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| References | ||
| ---------- | ||
| Based on computation originally implemented in SLAy [Koukuntla]_. | ||
|
|
||
| Implementation is based on one of the original implementations written by Sai Koukuntla, | ||
| found at https://github.com/saikoukunt/SLAy. | ||
| """ | ||
|
|
||
| sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data() | ||
| rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, template_diff_thresh) | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij | ||
|
|
||
| return M_ij | ||
|
|
||
|
|
||
| def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, template_diff_thresh: float): | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| Computes a cross-correlation significance measure and a sliding refractory period violation | ||
| measure for all units in the `sorting_analyzer`. | ||
|
|
||
| Paramters | ||
| --------- | ||
| sorting_analyzer : SortingAnalyzer | ||
| The sorting analyzer object containing the spike sorting data | ||
| template_diff_thresh : float | ||
| Threshold for how different template similarities can be to be considered for merging | ||
| """ | ||
|
|
||
| correlograms_extension = sorting_analyzer.get_extension("correlograms") | ||
| template_similarity = sorting_analyzer.get_extension("template_similarity").get_data() | ||
|
|
||
| ccgs, _ = correlograms_extension.get_data() | ||
|
|
||
| # convert to seconds for SLAy functions | ||
| bin_size_ms = correlograms_extension.params["bin_ms"] | ||
|
|
||
| rho_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) | ||
| eta_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)]) | ||
|
|
||
| for unit_index_1, _ in enumerate(sorting_analyzer.unit_ids): | ||
| for unit_index_2, _ in enumerate(sorting_analyzer.unit_ids): | ||
|
|
||
| # Don't waste time computing the other metrics if we fail the template similarity check | ||
| if template_similarity[unit_index_1, unit_index_2] < 1 - template_diff_thresh: | ||
| continue | ||
|
|
||
| xgram = ccgs[unit_index_1, unit_index_2, :] | ||
|
|
||
| rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair( | ||
| xgram, bin_size_s=bin_size_ms / 1000, min_xcorr_rate=0 | ||
| ) | ||
| eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size_ms=bin_size_ms) | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return rho_ij, eta_ij | ||
|
|
||
|
|
||
| def _compute_xcorr_pair( | ||
| xgram, | ||
| bin_size_s: float, | ||
| min_xcorr_rate: float, | ||
| ) -> float: | ||
| """ | ||
| Calculates a cross-correlation significance metric for a cluster pair. | ||
|
|
||
| Uses the wasserstein distance between an observed cross-correlogram and a null | ||
| distribution as an estimate of how significant the dependence between | ||
| two neurons is. Low spike count cross-correlograms have large wasserstein | ||
| distances from null by chance, so we first try to expand the window size. If | ||
| that fails to yield enough spikes, we apply a penalty to the metric. | ||
|
|
||
| Ported from https://github.com/saikoukunt/SLAy. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| xgram : np.array | ||
| The raw cross-correlogram for the cluster pair. | ||
| bin_size_s : float | ||
| The width in seconds of the bin size of the input ccgs. | ||
| min_xcorr_rate : float | ||
| The minimum ccg firing rate in Hz. | ||
|
|
||
| Returns | ||
| ------- | ||
| sig : float | ||
| The calculated cross-correlation significance metric. | ||
| """ | ||
|
|
||
| from scipy.signal import butter, find_peaks_cwt, sosfiltfilt | ||
| from scipy.stats import wasserstein_distance | ||
|
|
||
| # calculate low-pass filtered second derivative of ccg | ||
| fs = 1 / bin_size_s | ||
| cutoff_freq = 100 | ||
| nyqist = fs / 2 | ||
| cutoff = cutoff_freq / nyqist | ||
| peak_width = 0.002 / bin_size_s | ||
|
|
||
| xgram_2d = np.diff(xgram, 2) | ||
| sos = butter(4, cutoff, output="sos") | ||
| xgram_2d = sosfiltfilt(sos, xgram_2d) | ||
|
|
||
| if xgram.sum() == 0: | ||
| return 0 | ||
|
|
||
| # find negative peaks of second derivative of ccg, these are the edges of dips in ccg | ||
| peaks = find_peaks_cwt(-xgram_2d, peak_width, noise_perc=90) + 1 | ||
| # if no peaks are found, return a very low significance | ||
| if peaks.shape[0] == 0: | ||
| return -4 | ||
| peaks = np.abs(peaks - xgram.shape[0] / 2) | ||
| peaks = peaks[peaks > 0.5 * peak_width] | ||
| min_peaks = np.sort(peaks) | ||
|
|
||
| # start with peaks closest to 0 and move to the next set of peaks if the event count is too low | ||
| window_width = min_peaks * 1.5 | ||
| starts = np.maximum(xgram.shape[0] / 2 - window_width, 0) | ||
| ends = np.minimum(xgram.shape[0] / 2 + window_width, xgram.shape[0] - 1) | ||
| ind = 0 | ||
| xgram_window = xgram[int(starts[0]) : int(ends[0] + 1)] | ||
| xgram_sum = xgram_window.sum() | ||
| window_size = xgram_window.shape[0] * bin_size_s | ||
| while (xgram_sum < (min_xcorr_rate * window_size * 10)) and (ind < starts.shape[0]): | ||
| xgram_window = xgram[int(starts[ind]) : int(ends[ind] + 1)] | ||
| xgram_sum = xgram_window.sum() | ||
| window_size = xgram_window.shape[0] * bin_size_s | ||
| ind += 1 | ||
| # use the whole ccg if peak finding fails | ||
| if ind == starts.shape[0]: | ||
| xgram_window = xgram | ||
|
|
||
| # TODO: was getting error messges when xgram_window was all zero. Why was this happening? | ||
| if np.abs(xgram_window).sum() == 0: | ||
| return 0 | ||
|
|
||
| sig = ( | ||
| wasserstein_distance( | ||
| np.arange(xgram_window.shape[0]) / xgram_window.shape[0], | ||
| np.arange(xgram_window.shape[0]) / xgram_window.shape[0], | ||
| xgram_window, | ||
| np.ones_like(xgram_window), | ||
| ) | ||
| * 4 | ||
| ) | ||
|
|
||
| if xgram_window.sum() < (min_xcorr_rate * window_size): | ||
| sig *= (xgram_window.sum() / (min_xcorr_rate * window_size)) ** 2 | ||
|
|
||
| # if sig < 0.04 and xgram_window.sum() < (min_xcorr_rate * window_size): | ||
| if xgram_window.sum() < (min_xcorr_rate / 4 * window_size): | ||
| sig = -4 # don't merge if the event count is way too low | ||
|
|
||
| return sig | ||
|
|
||
|
|
||
| def _sliding_RP_viol_pair( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we try to unify with this code @chrishalcrow ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably... |
||
| correlogram, | ||
| bin_size_ms: float, | ||
| acceptThresh: float = 0.15, | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) -> float: | ||
| """ | ||
| Calculate the sliding refractory period violation confidence for a cluster. | ||
|
|
||
| Ported from https://github.com/saikoukunt/SLAy. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| correlogram : np.array | ||
| The auto-correlogram of the cluster. | ||
| bin_size_ms : float | ||
| The width in ms of the bin size of the input ccgs. | ||
| acceptThresh : float, default: 0.15 | ||
| The minimum ccg firing rate in Hz. | ||
|
|
||
| Returns | ||
| ------- | ||
| sig : float | ||
| The refractory period violation confidence for the cluster. | ||
| """ | ||
| from scipy.signal import butter, sosfiltfilt | ||
| from scipy.stats import poisson | ||
|
|
||
| # create various refractory periods sizes to test (between 0 and 20x bin size) | ||
| b = np.arange(0, 21 * bin_size_ms, bin_size_ms) / 1000 | ||
| bTestIdx = np.array([1, 2, 4, 6, 8, 12, 16, 20], dtype="int8") | ||
| bTest = [b[i] for i in bTestIdx] | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # calculate and avg halves of acg to ensure symmetry | ||
| # keep only second half of acg, refractory period violations are compared from the center of acg | ||
| half_len = int(correlogram.shape[0] / 2) | ||
| correlogram = (correlogram[half_len:] + correlogram[:half_len][::-1]) / 2 | ||
|
|
||
| acg_cumsum = np.cumsum(correlogram) | ||
| sum_res = acg_cumsum[bTestIdx - 1] # -1 bc 0th bin corresponds to 0-bin_size ms | ||
|
|
||
| # low-pass filter acg and use max as baseline event rate | ||
| order = 4 # Hz | ||
| cutoff_freq = 250 # Hz | ||
| fs = 1 / bin_size_ms * 1000 | ||
| nyqist = fs / 2 | ||
| cutoff = cutoff_freq / nyqist | ||
| sos = butter(order, cutoff, btype="low", output="sos") | ||
| smoothed_acg = sosfiltfilt(sos, correlogram) | ||
|
|
||
| bin_rate_max = np.max(smoothed_acg) | ||
| max_conts_max = np.array(bTest) / bin_size_ms * 1000 * (bin_rate_max * acceptThresh) | ||
| # compute confidence of less than acceptThresh contamination at each refractory period | ||
| confs = 1 - poisson.cdf(sum_res, max_conts_max) | ||
| rp_viol = 1 - confs.max() | ||
|
Comment on lines
1782
to
1786
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| return rp_viol | ||
Uh oh!
There was an error while loading. Please reload this page.