From a4376a06639f879bb9a50446cb5299c8ca1706f3 Mon Sep 17 00:00:00 2001 From: Ariel Xu Date: Fri, 6 Jun 2025 18:10:13 +0100 Subject: [PATCH 1/2] read channel position and calculate site spacing, compatible for npx 1.0 and 2.0 --- ecephys_spike_sorting/common/utils.py | 20 +++++++++++++ .../modules/mean_waveforms/__main__.py | 5 +++- .../mean_waveforms/extract_waveforms.py | 4 ++- .../mean_waveforms/waveform_metrics.py | 3 +- .../mean_waveforms/waveform_visualisation.py | 28 +++++++++++++++++++ 5 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 ecephys_spike_sorting/modules/mean_waveforms/waveform_visualisation.py diff --git a/ecephys_spike_sorting/common/utils.py b/ecephys_spike_sorting/common/utils.py index c1e0fd2b..1012e0a5 100644 --- a/ecephys_spike_sorting/common/utils.py +++ b/ecephys_spike_sorting/common/utils.py @@ -313,6 +313,26 @@ def load_kilosort_data(folder, else: return spike_times, spike_clusters, spike_templates, amplitudes, unwhitened_temps, channel_map, cluster_ids, cluster_quality, pc_features, pc_feature_ind +def load_channel_positions(folder): + """ + Loads Kilosort output files from a directory + + Inputs: + ------- + folder : String + Location of Kilosort output directory + + Outputs: + -------- + channel_positions : numpy.ndarray + x,y positions of each channel + """ + + channel_positions = np.squeeze(load(folder,'channel_positions.npy')) + + return channel_positions + + def get_spike_depths(spike_templates, pc_features, pc_feature_ind): diff --git a/ecephys_spike_sorting/modules/mean_waveforms/__main__.py b/ecephys_spike_sorting/modules/mean_waveforms/__main__.py index e9a13887..99a79320 100644 --- a/ecephys_spike_sorting/modules/mean_waveforms/__main__.py +++ b/ecephys_spike_sorting/modules/mean_waveforms/__main__.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from ...common.utils import load_kilosort_data +from ...common.utils import load_kilosort_data,load_channel_positions from .extract_waveforms import extract_waveforms, writeDataAsNpy from .waveform_metrics import calculate_waveform_metrics @@ -26,6 +26,8 @@ def calculate_mean_waveforms(args): load_kilosort_data(args['directories']['kilosort_output_directory'], \ args['ephys_params']['sample_rate'], \ convert_to_seconds = False) + + channel_positions = load_channel_positions(args['directories']['kilosort_output_directory']) print("Calculating mean waveforms...") @@ -34,6 +36,7 @@ def calculate_mean_waveforms(args): spike_templates, templates, channel_map, + channel_positions, args['ephys_params']['bit_volts'], \ args['ephys_params']['sample_rate'], \ args['ephys_params']['vertical_site_spacing'], \ diff --git a/ecephys_spike_sorting/modules/mean_waveforms/extract_waveforms.py b/ecephys_spike_sorting/modules/mean_waveforms/extract_waveforms.py index 4e5d9812..afbf942f 100644 --- a/ecephys_spike_sorting/modules/mean_waveforms/extract_waveforms.py +++ b/ecephys_spike_sorting/modules/mean_waveforms/extract_waveforms.py @@ -16,7 +16,8 @@ def extract_waveforms(raw_data, spike_clusters, spike_templates, templates, - channel_map, + channel_map, + channel_positions, bit_volts, sample_rate, site_spacing, @@ -130,6 +131,7 @@ def extract_waveforms(raw_data, cluster_id, peak_channels[target_template_id], channel_map, + channel_positions, sample_rate, upsampling_factor, spread_threshold, diff --git a/ecephys_spike_sorting/modules/mean_waveforms/waveform_metrics.py b/ecephys_spike_sorting/modules/mean_waveforms/waveform_metrics.py index c6067723..322d5672 100644 --- a/ecephys_spike_sorting/modules/mean_waveforms/waveform_metrics.py +++ b/ecephys_spike_sorting/modules/mean_waveforms/waveform_metrics.py @@ -9,6 +9,7 @@ def calculate_waveform_metrics(waveforms, cluster_id, peak_channel, channel_map, + channel_positions, sample_rate, upsampling_factor, spread_threshold, @@ -75,7 +76,7 @@ def calculate_waveform_metrics(waveforms, mean_1D_waveform, timestamps) recovery_slope = calculate_waveform_recovery_slope( mean_1D_waveform, timestamps) - + site_spacing = (channel_positions[2,1] - channel_positions[0,1])/2 * 10e-7 # calculate site spacing, compatible for both npx 1.0 and npx 2.0 amplitude, spread, velocity_above, velocity_below = calculate_2D_features( mean_2D_waveform, timestamps, local_peak, spread_threshold, site_range, site_spacing) diff --git a/ecephys_spike_sorting/modules/mean_waveforms/waveform_visualisation.py b/ecephys_spike_sorting/modules/mean_waveforms/waveform_visualisation.py new file mode 100644 index 00000000..e6b0bded --- /dev/null +++ b/ecephys_spike_sorting/modules/mean_waveforms/waveform_visualisation.py @@ -0,0 +1,28 @@ +import numpy as np +import matplotlib.pyplot as plt + +def plot_mean_2d_waveform(mean_2d_waveform, title="Mean 2D Waveform", time_axis=None): + """ + Plot the mean 2D waveform: each channel's waveform over time. + + Parameters: + - mean_2d_waveform: np.ndarray of shape (n_channels, n_timepoints) + - title: Optional title for the plot + - time_axis: Optional 1D array of time values (same length as n_timepoints) + """ + n_channels, n_timepoints = mean_2d_waveform.shape + + if time_axis is None: + time_axis = np.arange(n_timepoints) + + plt.figure(figsize=(12, 6)) + for ch in range(n_channels): + plt.plot(time_axis, mean_2d_waveform[ch, :] + ch * 1000, label=f'Channel {ch}') + # Offset each waveform vertically for clarity (adjust 100 as needed) + + plt.xlabel("Time (samples)") + plt.ylabel("Amplitude + Channel offset") + plt.title(title) + plt.grid(True) + plt.tight_layout() + plt.show() \ No newline at end of file From 85479d4b420ee2c343191cd6cbeb67cf627ce262 Mon Sep 17 00:00:00 2001 From: Ariel Xu Date: Fri, 6 Jun 2025 18:16:56 +0100 Subject: [PATCH 2/2] remove graph plotting --- .../mean_waveforms/waveform_visualisation.py | 28 ------------------- 1 file changed, 28 deletions(-) delete mode 100644 ecephys_spike_sorting/modules/mean_waveforms/waveform_visualisation.py diff --git a/ecephys_spike_sorting/modules/mean_waveforms/waveform_visualisation.py b/ecephys_spike_sorting/modules/mean_waveforms/waveform_visualisation.py deleted file mode 100644 index e6b0bded..00000000 --- a/ecephys_spike_sorting/modules/mean_waveforms/waveform_visualisation.py +++ /dev/null @@ -1,28 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt - -def plot_mean_2d_waveform(mean_2d_waveform, title="Mean 2D Waveform", time_axis=None): - """ - Plot the mean 2D waveform: each channel's waveform over time. - - Parameters: - - mean_2d_waveform: np.ndarray of shape (n_channels, n_timepoints) - - title: Optional title for the plot - - time_axis: Optional 1D array of time values (same length as n_timepoints) - """ - n_channels, n_timepoints = mean_2d_waveform.shape - - if time_axis is None: - time_axis = np.arange(n_timepoints) - - plt.figure(figsize=(12, 6)) - for ch in range(n_channels): - plt.plot(time_axis, mean_2d_waveform[ch, :] + ch * 1000, label=f'Channel {ch}') - # Offset each waveform vertically for clarity (adjust 100 as needed) - - plt.xlabel("Time (samples)") - plt.ylabel("Amplitude + Channel offset") - plt.title(title) - plt.grid(True) - plt.tight_layout() - plt.show() \ No newline at end of file