Skip to content

Commit 3635f34

Browse files
authored
Merge pull request #289 from fooof-tools/avgg
[ENH] - Add `average_reconstructions` function
2 parents 118d200 + bb96d8e commit 3635f34

File tree

5 files changed

+56
-13
lines changed

5 files changed

+56
-13
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ Functions to manipulate, examine and analyze FOOOF objects, and related utilitie
5252

5353
compare_info
5454
average_fg
55+
average_reconstructions
5556
combine_fooofs
5657

5758
.. currentmodule:: fooof

fooof/objs/utils.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,15 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True):
6565
If there are no model fit results available to average across.
6666
"""
6767

68-
if avg_method not in ['mean', 'median']:
69-
raise ValueError("Requested average method not understood.")
7068
if not fg.has_model:
7169
raise NoModelError("No model fit results are available, can not proceed.")
7270

73-
if avg_method == 'mean':
74-
avg_func = np.nanmean
75-
elif avg_method == 'median':
76-
avg_func = np.nanmedian
71+
avg_funcs = {'mean' : np.nanmean, 'median' : np.nanmedian}
72+
if avg_method not in avg_funcs.keys():
73+
raise ValueError("Requested average method not understood.")
7774

7875
# Aperiodic parameters: extract & average
79-
ap_params = avg_func(fg.get_params('aperiodic_params'), 0)
76+
ap_params = avg_funcs[avg_method](fg.get_params('aperiodic_params'), 0)
8077

8178
# Periodic parameters: extract & average
8279
peak_params = []
@@ -90,15 +87,15 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True):
9087
# Check if there are any extracted peaks - if not, don't add
9188
# Note that we only check peaks, but gauss should be the same
9289
if not np.all(np.isnan(peaks)):
93-
peak_params.append(avg_func(peaks, 0))
94-
gauss_params.append(avg_func(gauss, 0))
90+
peak_params.append(avg_funcs[avg_method](peaks, 0))
91+
gauss_params.append(avg_funcs[avg_method](gauss, 0))
9592

9693
peak_params = np.array(peak_params)
9794
gauss_params = np.array(gauss_params)
9895

9996
# Goodness of fit measures: extract & average
100-
r2 = avg_func(fg.get_params('r_squared'))
101-
error = avg_func(fg.get_params('error'))
97+
r2 = avg_funcs[avg_method](fg.get_params('r_squared'))
98+
error = avg_funcs[avg_method](fg.get_params('error'))
10299

103100
# Collect all results together, to be added to FOOOF object
104101
results = FOOOFResults(ap_params, peak_params, r2, error, gauss_params)
@@ -116,6 +113,41 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True):
116113
return fm
117114

118115

116+
def average_reconstructions(fg, avg_method='mean'):
117+
"""Average across model reconstructions for a group of power spectra.
118+
119+
Parameters
120+
----------
121+
fg : FOOOFGroup
122+
Object with model fit results to average across.
123+
avg : {'mean', 'median'}
124+
Averaging function to use.
125+
126+
Returns
127+
-------
128+
freqs : 1d array
129+
Frequency values for the average model reconstruction.
130+
avg_model : 1d array
131+
Power values for the average model reconstruction.
132+
Note that power values are in log10 space.
133+
"""
134+
135+
if not fg.has_model:
136+
raise NoModelError("No model fit results are available, can not proceed.")
137+
138+
avg_funcs = {'mean' : np.nanmean, 'median' : np.nanmedian}
139+
if avg_method not in avg_funcs.keys():
140+
raise ValueError("Requested average method not understood.")
141+
142+
models = np.zeros(shape=fg.power_spectra.shape)
143+
for ind in range(len(fg)):
144+
models[ind, :] = fg.get_fooof(ind, regenerate=True).fooofed_spectrum_
145+
146+
avg_model = avg_funcs[avg_method](models, 0)
147+
148+
return fg.freqs, avg_model
149+
150+
119151
def combine_fooofs(fooofs):
120152
"""Combine a group of FOOOF and/or FOOOFGroup objects into a single FOOOFGroup object.
121153

fooof/plts/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
"""Plots sub-module for FOOOF."""
22

3-
from .spectra import plot_spectra
4-
from .spectra import plot_spectra as plot_spectrum
3+
from .spectra import plot_spectrum, plot_spectra

fooof/plts/spectra.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False,
7777
style_spectrum_plot(ax, log_freqs, log_powers)
7878

7979

80+
# Alias `plot_spectrum` to `plot_spectra` for backwards compatibility
81+
plot_spectrum = plot_spectra
82+
83+
8084
@savefig
8185
@check_dependency(plt, 'matplotlib')
8286
def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r',

fooof/tests/objs/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ def test_average_fg(tfg, tbands):
4545
with raises(NoModelError):
4646
average_fg(ntfg, tbands)
4747

48+
def test_average_reconstructions(tfg):
49+
50+
freqs, avg_model = average_reconstructions(tfg)
51+
assert isinstance(freqs, np.ndarray)
52+
assert isinstance(avg_model, np.ndarray)
53+
assert freqs.shape == avg_model.shape
54+
4855
def test_combine_fooofs(tfm, tfg):
4956

5057
tfm2 = tfm.copy()

0 commit comments

Comments
 (0)