From 809351ac2fe82238423dd19e96ee84b1bef4830b Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 21 Jun 2026 20:45:08 +0100 Subject: [PATCH 01/10] make metrics iterable & use for getitem --- specparam/metrics/metrics.py | 16 +++++++++++----- specparam/tests/metrics/test_metrics.py | 4 ++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/specparam/metrics/metrics.py b/specparam/metrics/metrics.py index c7c1d9f1..b82c280d 100644 --- a/specparam/metrics/metrics.py +++ b/specparam/metrics/metrics.py @@ -34,6 +34,12 @@ def __len__(self): return len(self.labels) + def __iter__(self): + """Make object iterable, across individual metrics.""" + + return iter(self.metrics) + + def __getitem__(self, label): """Index into the object based on metric label. @@ -43,11 +49,11 @@ def __getitem__(self, label): Label of the metric to access. """ - for ind, clabel in enumerate(self.labels): - if label == clabel: - return self.metrics[ind] - - raise ValueError('Requested label not found.') + for metric in self: + if metric.label == label: + return metric + else: + raise ValueError('Requested label not found.') def add_metric(self, metric): diff --git a/specparam/tests/metrics/test_metrics.py b/specparam/tests/metrics/test_metrics.py index 90a7ee81..ec6274d9 100644 --- a/specparam/tests/metrics/test_metrics.py +++ b/specparam/tests/metrics/test_metrics.py @@ -27,6 +27,10 @@ def test_metrics_obj(tfm): metrics.compute_metrics(tfm.data, tfm.results) assert isinstance(metrics.results, dict) + # Check iteration + for metric in metrics: + assert metric + # Check indexing met_out = metrics['error_mae'] assert isinstance(met_out, Metric) From b26080738c01a8c29c0de547584d60a278f2c9da Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 22 Jun 2026 11:24:04 +0100 Subject: [PATCH 02/10] add label property attribute to modes --- specparam/modes/modes.py | 7 +++++++ specparam/tests/modes/test_modes.py | 1 + 2 files changed, 8 insertions(+) diff --git a/specparam/modes/modes.py b/specparam/modes/modes.py index b75874cf..39ca80d5 100644 --- a/specparam/modes/modes.py +++ b/specparam/modes/modes.py @@ -34,6 +34,13 @@ def __init__(self, aperiodic, periodic, model=None): self.model = model + @property + def label(self): + """Define label for the current set of modes.""" + + return 'ap-{:s}_pe-{:s}'.format(self.aperiodic.name, self.periodic.name) + + def get_modes(self): """Get the modes definition. diff --git a/specparam/tests/modes/test_modes.py b/specparam/tests/modes/test_modes.py index 18da825f..aa5da546 100644 --- a/specparam/tests/modes/test_modes.py +++ b/specparam/tests/modes/test_modes.py @@ -15,6 +15,7 @@ def test_modes(): assert modes assert isinstance(modes.aperiodic, Mode) assert isinstance(modes.periodic, Mode) + assert modes.label modes.print() def test_modes_gets(): From 492d8e4feb10ce91b396e9ba04a7ba2842d94fe9 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 22 Jun 2026 11:29:06 +0100 Subject: [PATCH 03/10] initialize model compare object --- specparam/compare/__init__.py | 1 + specparam/compare/compare.py | 53 +++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 specparam/compare/__init__.py create mode 100644 specparam/compare/compare.py diff --git a/specparam/compare/__init__.py b/specparam/compare/__init__.py new file mode 100644 index 00000000..4ca54a63 --- /dev/null +++ b/specparam/compare/__init__.py @@ -0,0 +1 @@ +"""Model comparison sub-module.""" diff --git a/specparam/compare/compare.py b/specparam/compare/compare.py new file mode 100644 index 00000000..5272d669 --- /dev/null +++ b/specparam/compare/compare.py @@ -0,0 +1,53 @@ +""" """ + +from copy import deepcopy + +from specparam.plts.compare import plot_comparison +from specparam.reports.strings import gen_model_comparison_str + +################################################################################################### +################################################################################################### + +class ModelComparison(): + """ """ + + def __init__(self, models=None): + """ """ + + self.models = [] + self.add_models(models) + + def fit(self, freqs=None, data=None, freq_range=None, prechecks=True): + """ """ + + self.models[0].fit(freqs, data, freq_range) + for model in self.models[1:]: + model.fit(prechecks=False) + + def report(self, freqs=None, data=None, freq_range=None, prechecks=True): + """ """ + + self.fit(freqs, data, freq_range, prechecks) + self.print() + self.plot() + + def add_models(self, models, clear=False): + """ """ + + for model in models: + self.models.append(deepcopy(model)) + + if self.models: + self.data = deepcopy(models[0].data) + for model in self.models: + model.data = self.data + + def plot(self): + """ """ + + plot_comparison(self) + + def print(self): + """ """ + + print(gen_model_comparison_str(self)) From b34fe75d8632024fef3e0e21a14466dbd3d06e17 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 22 Jun 2026 12:07:09 +0100 Subject: [PATCH 04/10] add test modelcomp object --- specparam/tests/conftest.py | 7 ++++++- specparam/tests/tdata.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/specparam/tests/conftest.py b/specparam/tests/conftest.py index d3b3a9d9..2b0a67ec 100644 --- a/specparam/tests/conftest.py +++ b/specparam/tests/conftest.py @@ -10,7 +10,8 @@ from specparam.tests.tdata import (get_tdata, get_tdata2d, get_tdata2dt, get_tdata3d, get_tfm, get_tfm2, get_tfg, get_tfg2, get_tft, get_tfe, - get_tbands, get_tresults, get_tmodes, get_tdocstring) + get_tmodelcomp, get_tbands, get_tresults, get_tmodes, + get_tdocstring) from specparam.tests.tsettings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH, TEST_PLOTS_PATH) @@ -99,6 +100,10 @@ def tft(): def tfe(): yield get_tfe() +@pytest.fixture(scope='session') +def tmodelcomp(): + yield get_tmodelcomp() + @pytest.fixture(scope='session') def tbands(): yield get_tbands() diff --git a/specparam/tests/tdata.py b/specparam/tests/tdata.py index 85140d7f..c68018a7 100644 --- a/specparam/tests/tdata.py +++ b/specparam/tests/tdata.py @@ -8,6 +8,7 @@ from specparam.data.stores import FitResults from specparam.models import (SpectralModel, SpectralGroupModel, SpectralTimeModel, SpectralTimeEventModel) +from specparam.compare import ModelComparison from specparam.sim.params import param_sampler from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra, sim_spectrogram @@ -137,6 +138,20 @@ def get_tfe(): return tfe +## TEST MODEL COMPARISON OBJECTS + +def get_tmodelcomp(): + """Get a model comparison object, with fit power spectra, for testing.""" + + modelcomp = ModelComparison([ + SpectralModel(aperiodic_mode='fixed', periodic_mode='gaussian'), + SpectralModel(aperiodic_mode='knee', periodic_mode='gaussian'), + ]) + + modelcomp.fit(*sim_power_spectrum(*default_spectrum_params())) + + return modelcomp + ## TEST OTHER OBJECTS def get_tbands(): From d372a621201e3f638cca7021ff6c051b50ea2457 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 22 Jun 2026 12:07:38 +0100 Subject: [PATCH 05/10] add gen_model_comparison_str --- specparam/reports/strings.py | 44 +++++++++++++++++++++++++ specparam/tests/reports/test_strings.py | 6 ++++ 2 files changed, 50 insertions(+) diff --git a/specparam/reports/strings.py b/specparam/reports/strings.py index cbb3bb35..9439135f 100644 --- a/specparam/reports/strings.py +++ b/specparam/reports/strings.py @@ -732,6 +732,50 @@ def gen_event_results_str(event, concise=False): return _format(str_lst, concise) +## MODEL COMPARISONS + +def gen_model_comparison_str(models, concise=False): + """Generate a string representation of model comparisons. + + Parameters + ---------- + models : ModelComparison + Set of model objects to access results from. + concise : bool, optional, default: False + Whether to generate a concise version of the string. + + Returns + ------- + str + Formatted string of model comparisons. + """ + + str_lst = [ + 'SPECTRUM MODEL COMPARISON', + '', + ] + + for ind, model in enumerate(models.models): + str_lst.append('Model {}: '.format(ind + 1) + model.modes.label) + + metric_str = 'Model metrics - ' + for m_ind in range(len(model.results.metrics)): + metric_str = metric_str + '{}: {:1.4f} '.format(\ + model.results.metrics.flabels[m_ind], + list(model.results.metrics.results.values())[m_ind]) + str_lst.append(metric_str) + + str_lst.append(\ + 'Model size: {:d} parameters [{:d} ap + {:d} pe ({:d} peak(s) x {:d})].'.format(\ + model.results.n_params, model.modes.aperiodic.n_params, + model.results.n_peaks * model.modes.periodic.n_params, + model.results.n_peaks, model.modes.periodic.n_params)) + + str_lst.append('') + + return _format(str_lst, concise) + + ## HELPER SUB-FUNCTIONS FOR MODEL REPORT STRINGS def _report_str_algo(model): diff --git a/specparam/tests/reports/test_strings.py b/specparam/tests/reports/test_strings.py index fcceef13..91f23e13 100644 --- a/specparam/tests/reports/test_strings.py +++ b/specparam/tests/reports/test_strings.py @@ -104,6 +104,12 @@ def test_no_model_str(): assert _no_model_str() +## MODEL COMPARISONS + +def test_gen_model_comparison_str(tmodelcomp): + + assert gen_model_comparison_str(tmodelcomp) + ## UTILITIES def test_format(): From 0170d4c6d931d7be671b158e0e1d36bff0fa3cf5 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 22 Jun 2026 12:13:24 +0100 Subject: [PATCH 06/10] add model comparison plot --- specparam/plts/compare.py | 31 ++++++++++++++++++++++++++++ specparam/tests/plts/test_compare.py | 14 +++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 specparam/plts/compare.py create mode 100644 specparam/tests/plts/test_compare.py diff --git a/specparam/plts/compare.py b/specparam/plts/compare.py new file mode 100644 index 00000000..5600c1d0 --- /dev/null +++ b/specparam/plts/compare.py @@ -0,0 +1,31 @@ +"""Plots for model comparison.""" + +from specparam.plts import plot_spectra +from specparam.plts.utils import check_ax +from specparam.plts.settings import PLT_FIGSIZES + +################################################################################################### +################################################################################################### + +def plot_model_comparison(modelcomp, ax=None, **plot_kwargs): + """Plot and compare multiple model fits. + + Parameters + ---------- + modelcomp : ModelComparison + Model comparison object. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional plot related keyword arguments, with styling options managed by ``style_plot``. + """ + + import matplotlib as mpl + cmap = mpl.colormaps['Set1'] + + ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) + plot_spectra(modelcomp.data.freqs, modelcomp.data.power_spectrum, + label='Original Spectrum', lw=3, color='black', ax=ax) + for ind, model in enumerate(modelcomp.models): + plot_spectra(modelcomp.data.freqs, model.results.model.modeled_spectrum, + color=cmap.colors[ind], alpha=0.65, label=model.modes.label, ax=ax) diff --git a/specparam/tests/plts/test_compare.py b/specparam/tests/plts/test_compare.py new file mode 100644 index 00000000..2b70dc5b --- /dev/null +++ b/specparam/tests/plts/test_compare.py @@ -0,0 +1,14 @@ +"""Tests for specparam.plts.compare.""" + +from specparam.tests.tutils import plot_test +from specparam.tests.tsettings import TEST_PLOTS_PATH + +from specparam.plts.compare import * + +################################################################################################### +################################################################################################### + +@plot_test +def test_plot_model_comparison(tmodelcomp, skip_if_no_mpl): + + plot_model_comparison(tmodelcomp) From b4dec5e25ec7b627ecc2418da4da712986ca8c7f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 22 Jun 2026 12:13:42 +0100 Subject: [PATCH 07/10] initialize ModelComparison object --- specparam/compare/__init__.py | 2 ++ specparam/compare/compare.py | 7 ++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/specparam/compare/__init__.py b/specparam/compare/__init__.py index 4ca54a63..309337e8 100644 --- a/specparam/compare/__init__.py +++ b/specparam/compare/__init__.py @@ -1 +1,3 @@ """Model comparison sub-module.""" + +from .compare import ModelComparison diff --git a/specparam/compare/compare.py b/specparam/compare/compare.py index 5272d669..2c88b14c 100644 --- a/specparam/compare/compare.py +++ b/specparam/compare/compare.py @@ -2,7 +2,7 @@ from copy import deepcopy -from specparam.plts.compare import plot_comparison +from specparam.plts.compare import plot_model_comparison from specparam.reports.strings import gen_model_comparison_str ################################################################################################### @@ -20,7 +20,7 @@ def __init__(self, models=None): def fit(self, freqs=None, data=None, freq_range=None, prechecks=True): """ """ - self.models[0].fit(freqs, data, freq_range) + self.models[0].fit(freqs, data, freq_range, prechecks) for model in self.models[1:]: model.fit(prechecks=False) @@ -41,11 +41,12 @@ def add_models(self, models, clear=False): self.data = deepcopy(models[0].data) for model in self.models: model.data = self.data + model.algorithm._reset_subobjects(data=self.data) def plot(self): """ """ - plot_comparison(self) + plot_model_comparison(self) def print(self): """ """ From cdb5260d964d5cc26dc05ce507371fdfacd6a2e5 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 22 Jun 2026 12:28:46 +0100 Subject: [PATCH 08/10] add docs to ModelComparison --- specparam/compare/compare.py | 65 ++++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 13 deletions(-) diff --git a/specparam/compare/compare.py b/specparam/compare/compare.py index 2c88b14c..8560e6ac 100644 --- a/specparam/compare/compare.py +++ b/specparam/compare/compare.py @@ -1,38 +1,73 @@ -""" """ +"""Model comparison object.""" from copy import deepcopy +from specparam.models import SpectralModel from specparam.plts.compare import plot_model_comparison from specparam.reports.strings import gen_model_comparison_str +from specparam.modutils.docs import (copy_func_docstring_drop_first, + docs_get_section, replace_docstring_sections) ################################################################################################### ################################################################################################### class ModelComparison(): - """ """ + """Model a power spectrum with multiple models. + + Parameters + ---------- + models : list of SpectralModel + Model definitions to fit and compare. + """ def __init__(self, models=None): - """ """ + """Initialize model comparison object.""" self.models = [] self.add_models(models) + @replace_docstring_sections([docs_get_section(SpectralModel.fit.__doc__, 'Parameters')]) def fit(self, freqs=None, data=None, freq_range=None, prechecks=True): - """ """ + """Fit models to a power spectrum. + + Parameters + ---------- + % copied in from SpectralModel object + """ self.models[0].fit(freqs, data, freq_range, prechecks) for model in self.models[1:]: model.fit(prechecks=False) - def report(self, freqs=None, data=None, freq_range=None, prechecks=True): - """ """ + @replace_docstring_sections([docs_get_section(SpectralModel.report.__doc__, 'Parameters')]) + def report(self, freqs=None, data=None, freq_range=None, + plt_log=False, plot_full_range=False, **plot_kwargs): + """Run model fit, and display a report, which includes a plot, and printed results. + + Parameters + ---------- + % copied in from SpectralModel object + """ self.fit(freqs, data, freq_range, prechecks) - self.print() - self.plot() + self.print('comparison') + self.plot(plt_log=plt_log, + freqs=freqs if plot_full_range else plot_kwargs.pop('plot_freqs', None), + power_spectrum=power_spectrum if \ + plot_full_range else plot_kwargs.pop('plot_power_spectrum', None), + freq_range=plot_kwargs.pop('plot_freq_range', None), + **plot_kwargs) def add_models(self, models, clear=False): - """ """ + """Add model definitions. + + Parameters + ---------- + models : list of SpectralModel + Model definitions to add to the object. + clear : bool, optional, default: False + Whether to clear the object of previous model definitions before adding. + """ for model in models: self.models.append(deepcopy(model)) @@ -43,12 +78,16 @@ def add_models(self, models, clear=False): model.data = self.data model.algorithm._reset_subobjects(data=self.data) + @copy_func_docstring_drop_first(plot_model_comparison) def plot(self): - """ """ plot_model_comparison(self) - def print(self): - """ """ + def print(self, info='comparison'): + """Print out result information.""" - print(gen_model_comparison_str(self)) + if info == 'comparison': + print(gen_model_comparison_str(self)) + else: + for model in self.models: + model.print(info) From 1ab33885f8d976285a4ab3a230cbf9d68e6613f3 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 22 Jun 2026 16:31:45 +0100 Subject: [PATCH 09/10] add dunders / copy / pass throughs --- specparam/compare/compare.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/specparam/compare/compare.py b/specparam/compare/compare.py index 8560e6ac..01a9bd60 100644 --- a/specparam/compare/compare.py +++ b/specparam/compare/compare.py @@ -24,7 +24,28 @@ def __init__(self, models=None): """Initialize model comparison object.""" self.models = [] - self.add_models(models) + if models: + self.add_models(models) + + + def __len__(self): + """Define length of object as the number of defined models.""" + + return len(self.models) + + + def __iter__(self): + """Define iteration as stepping across models within the object.""" + + for model in self.models: + yield(model) + + + def copy(self): + """Return a copy of the current object.""" + + return deepcopy(self) + @replace_docstring_sections([docs_get_section(SpectralModel.fit.__doc__, 'Parameters')]) def fit(self, freqs=None, data=None, freq_range=None, prechecks=True): @@ -39,6 +60,7 @@ def fit(self, freqs=None, data=None, freq_range=None, prechecks=True): for model in self.models[1:]: model.fit(prechecks=False) + @replace_docstring_sections([docs_get_section(SpectralModel.report.__doc__, 'Parameters')]) def report(self, freqs=None, data=None, freq_range=None, plt_log=False, plot_full_range=False, **plot_kwargs): @@ -49,7 +71,7 @@ def report(self, freqs=None, data=None, freq_range=None, % copied in from SpectralModel object """ - self.fit(freqs, data, freq_range, prechecks) + self.fit(freqs, data, freq_range) self.print('comparison') self.plot(plt_log=plt_log, freqs=freqs if plot_full_range else plot_kwargs.pop('plot_freqs', None), @@ -58,6 +80,7 @@ def report(self, freqs=None, data=None, freq_range=None, freq_range=plot_kwargs.pop('plot_freq_range', None), **plot_kwargs) + def add_models(self, models, clear=False): """Add model definitions. @@ -78,10 +101,12 @@ def add_models(self, models, clear=False): model.data = self.data model.algorithm._reset_subobjects(data=self.data) + @copy_func_docstring_drop_first(plot_model_comparison) - def plot(self): + def plot(self, ax=None, **plot_kwargs): + + plot_model_comparison(self, ax=ax, **plot_kwargs) - plot_model_comparison(self) def print(self, info='comparison'): """Print out result information.""" From 1907937c640d69fcf27d801bd839861f9052215d Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 22 Jun 2026 16:32:00 +0100 Subject: [PATCH 10/10] add model comparison object tests --- specparam/tests/compare/__init__.py | 0 specparam/tests/compare/test_compare.py | 47 +++++++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 specparam/tests/compare/__init__.py create mode 100644 specparam/tests/compare/test_compare.py diff --git a/specparam/tests/compare/__init__.py b/specparam/tests/compare/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/specparam/tests/compare/test_compare.py b/specparam/tests/compare/test_compare.py new file mode 100644 index 00000000..82793324 --- /dev/null +++ b/specparam/tests/compare/test_compare.py @@ -0,0 +1,47 @@ +"""Test functions for specparam.compare.compare.""" + +from specparam import SpectralModel + +from specparam.compare.compare import * + +################################################################################################### +################################################################################################### + +def test_model_comparison(): + + mc1 = ModelComparison() + assert isinstance(mc1, ModelComparison) + + # Test initializing with some models + models = [SpectralModel(aperiodic_mode='fixed', periodic_mode='gaussian'), + SpectralModel(aperiodic_mode='knee', periodic_mode='gaussian')] + mc2 = ModelComparison(models) + assert isinstance(mc2, ModelComparison) + assert len(mc2) == len(models) + for model in mc2: + assert model + + # Test data gets linked / results do not + for ind, model in enumerate(mc2.models[:-1]): + model.data is mc2.models[ind+1].data + model.results is not mc2.models[ind+1].results + +def test_model_comparison_fit(tdata): + + mc = ModelComparison(\ + [SpectralModel(aperiodic_mode='fixed', periodic_mode='gaussian'), + SpectralModel(aperiodic_mode='knee', periodic_mode='gaussian'), + ]) + + mc.fit(tdata.freqs, tdata.get_data('full', 'linear')) + for model in mc: + assert model.results.has_model + +def test_model_comparison_reporting(tmodelcomp, tdata, skip_if_no_mpl): + + mc = tmodelcomp.copy() + + mc.report(tdata.freqs, tdata.get_data('full', 'linear')) + + mc.print() + mc.plot()