Skip to content

Commit c2ba4da

Browse files
committed
refact use of avg_func in objs utils
1 parent b09f9db commit c2ba4da

File tree

1 file changed

+12
-18
lines changed

1 file changed

+12
-18
lines changed

fooof/objs/utils.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,15 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True):
6565
If there are no model fit results available to average across.
6666
"""
6767

68-
if avg_method not in ['mean', 'median']:
69-
raise ValueError("Requested average method not understood.")
7068
if not fg.has_model:
7169
raise NoModelError("No model fit results are available, can not proceed.")
7270

73-
if avg_method == 'mean':
74-
avg_func = np.nanmean
75-
elif avg_method == 'median':
76-
avg_func = np.nanmedian
71+
avg_funcs = {'mean' : np.nanmean, 'median' : np.nanmedian}
72+
if avg_method not in avg_funcs.keys():
73+
raise ValueError("Requested average method not understood.")
7774

7875
# Aperiodic parameters: extract & average
79-
ap_params = avg_func(fg.get_params('aperiodic_params'), 0)
76+
ap_params = avg_funcs[avg_method](fg.get_params('aperiodic_params'), 0)
8077

8178
# Periodic parameters: extract & average
8279
peak_params = []
@@ -90,15 +87,15 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True):
9087
# Check if there are any extracted peaks - if not, don't add
9188
# Note that we only check peaks, but gauss should be the same
9289
if not np.all(np.isnan(peaks)):
93-
peak_params.append(avg_func(peaks, 0))
94-
gauss_params.append(avg_func(gauss, 0))
90+
peak_params.append(avg_funcs[avg_method](peaks, 0))
91+
gauss_params.append(avg_funcs[avg_method](gauss, 0))
9592

9693
peak_params = np.array(peak_params)
9794
gauss_params = np.array(gauss_params)
9895

9996
# Goodness of fit measures: extract & average
100-
r2 = avg_func(fg.get_params('r_squared'))
101-
error = avg_func(fg.get_params('error'))
97+
r2 = avg_funcs[avg_method](fg.get_params('r_squared'))
98+
error = avg_funcs[avg_method](fg.get_params('error'))
10299

103100
# Collect all results together, to be added to FOOOF object
104101
results = FOOOFResults(ap_params, peak_params, r2, error, gauss_params)
@@ -135,21 +132,18 @@ def average_reconstructions(fg, avg_method='mean'):
135132
Note that power values are in log10 space.
136133
"""
137134

138-
if avg_method not in ['mean', 'median']:
139-
raise ValueError("Requested average method not understood.")
140135
if not fg.has_model:
141136
raise NoModelError("No model fit results are available, can not proceed.")
142137

143-
if avg_method == 'mean':
144-
avg_func = np.nanmean
145-
elif avg_method == 'median':
146-
avg_func = np.nanmedian
138+
avg_funcs = {'mean' : np.nanmean, 'median' : np.nanmedian}
139+
if avg_method not in avg_funcs.keys():
140+
raise ValueError("Requested average method not understood.")
147141

148142
models = np.zeros(shape=fg.power_spectra.shape)
149143
for ind in range(len(fg)):
150144
models[ind, :] = fg.get_fooof(ind, regenerate=True).fooofed_spectrum_
151145

152-
avg_model = avg_func(models, 0)
146+
avg_model = avg_funcs[avg_method](models, 0)
153147

154148
return fg.freqs, avg_model
155149

0 commit comments

Comments
 (0)