diff --git a/specparam/compare/__init__.py b/specparam/compare/__init__.py new file mode 100644 index 00000000..309337e8 --- /dev/null +++ b/specparam/compare/__init__.py @@ -0,0 +1,3 @@ +"""Model comparison sub-module.""" + +from .compare import ModelComparison diff --git a/specparam/compare/compare.py b/specparam/compare/compare.py new file mode 100644 index 00000000..01a9bd60 --- /dev/null +++ b/specparam/compare/compare.py @@ -0,0 +1,118 @@ +"""Model comparison object.""" + +from copy import deepcopy + +from specparam.models import SpectralModel +from specparam.plts.compare import plot_model_comparison +from specparam.reports.strings import gen_model_comparison_str +from specparam.modutils.docs import (copy_func_docstring_drop_first, + docs_get_section, replace_docstring_sections) + +################################################################################################### +################################################################################################### + +class ModelComparison(): + """Model a power spectrum with multiple models. + + Parameters + ---------- + models : list of SpectralModel + Model definitions to fit and compare. + """ + + def __init__(self, models=None): + """Initialize model comparison object.""" + + self.models = [] + if models: + self.add_models(models) + + + def __len__(self): + """Define length of object as the number of defined models.""" + + return len(self.models) + + + def __iter__(self): + """Define iteration as stepping across models within the object.""" + + for model in self.models: + yield(model) + + + def copy(self): + """Return a copy of the current object.""" + + return deepcopy(self) + + + @replace_docstring_sections([docs_get_section(SpectralModel.fit.__doc__, 'Parameters')]) + def fit(self, freqs=None, data=None, freq_range=None, prechecks=True): + """Fit models to a power spectrum. + + Parameters + ---------- + % copied in from SpectralModel object + """ + + self.models[0].fit(freqs, data, freq_range, prechecks) + for model in self.models[1:]: + model.fit(prechecks=False) + + + @replace_docstring_sections([docs_get_section(SpectralModel.report.__doc__, 'Parameters')]) + def report(self, freqs=None, data=None, freq_range=None, + plt_log=False, plot_full_range=False, **plot_kwargs): + """Run model fit, and display a report, which includes a plot, and printed results. + + Parameters + ---------- + % copied in from SpectralModel object + """ + + self.fit(freqs, data, freq_range) + self.print('comparison') + self.plot(plt_log=plt_log, + freqs=freqs if plot_full_range else plot_kwargs.pop('plot_freqs', None), + power_spectrum=power_spectrum if \ + plot_full_range else plot_kwargs.pop('plot_power_spectrum', None), + freq_range=plot_kwargs.pop('plot_freq_range', None), + **plot_kwargs) + + + def add_models(self, models, clear=False): + """Add model definitions. + + Parameters + ---------- + models : list of SpectralModel + Model definitions to add to the object. + clear : bool, optional, default: False + Whether to clear the object of previous model definitions before adding. + """ + + for model in models: + self.models.append(deepcopy(model)) + + if self.models: + self.data = deepcopy(models[0].data) + for model in self.models: + model.data = self.data + model.algorithm._reset_subobjects(data=self.data) + + + @copy_func_docstring_drop_first(plot_model_comparison) + def plot(self, ax=None, **plot_kwargs): + + plot_model_comparison(self, ax=ax, **plot_kwargs) + + + def print(self, info='comparison'): + """Print out result information.""" + + if info == 'comparison': + print(gen_model_comparison_str(self)) + else: + for model in self.models: + model.print(info) diff --git a/specparam/metrics/metrics.py b/specparam/metrics/metrics.py index c7c1d9f1..b82c280d 100644 --- a/specparam/metrics/metrics.py +++ b/specparam/metrics/metrics.py @@ -34,6 +34,12 @@ def __len__(self): return len(self.labels) + def __iter__(self): + """Make object iterable, across individual metrics.""" + + return iter(self.metrics) + + def __getitem__(self, label): """Index into the object based on metric label. @@ -43,11 +49,11 @@ def __getitem__(self, label): Label of the metric to access. """ - for ind, clabel in enumerate(self.labels): - if label == clabel: - return self.metrics[ind] - - raise ValueError('Requested label not found.') + for metric in self: + if metric.label == label: + return metric + else: + raise ValueError('Requested label not found.') def add_metric(self, metric): diff --git a/specparam/modes/modes.py b/specparam/modes/modes.py index b75874cf..39ca80d5 100644 --- a/specparam/modes/modes.py +++ b/specparam/modes/modes.py @@ -34,6 +34,13 @@ def __init__(self, aperiodic, periodic, model=None): self.model = model + @property + def label(self): + """Define label for the current set of modes.""" + + return 'ap-{:s}_pe-{:s}'.format(self.aperiodic.name, self.periodic.name) + + def get_modes(self): """Get the modes definition. diff --git a/specparam/plts/compare.py b/specparam/plts/compare.py new file mode 100644 index 00000000..5600c1d0 --- /dev/null +++ b/specparam/plts/compare.py @@ -0,0 +1,31 @@ +"""Plots for model comparison.""" + +from specparam.plts import plot_spectra +from specparam.plts.utils import check_ax +from specparam.plts.settings import PLT_FIGSIZES + +################################################################################################### +################################################################################################### + +def plot_model_comparison(modelcomp, ax=None, **plot_kwargs): + """Plot and compare multiple model fits. + + Parameters + ---------- + modelcomp : ModelComparison + Model comparison object. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional plot related keyword arguments, with styling options managed by ``style_plot``. + """ + + import matplotlib as mpl + cmap = mpl.colormaps['Set1'] + + ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) + plot_spectra(modelcomp.data.freqs, modelcomp.data.power_spectrum, + label='Original Spectrum', lw=3, color='black', ax=ax) + for ind, model in enumerate(modelcomp.models): + plot_spectra(modelcomp.data.freqs, model.results.model.modeled_spectrum, + color=cmap.colors[ind], alpha=0.65, label=model.modes.label, ax=ax) diff --git a/specparam/reports/strings.py b/specparam/reports/strings.py index cbb3bb35..9439135f 100644 --- a/specparam/reports/strings.py +++ b/specparam/reports/strings.py @@ -732,6 +732,50 @@ def gen_event_results_str(event, concise=False): return _format(str_lst, concise) +## MODEL COMPARISONS + +def gen_model_comparison_str(models, concise=False): + """Generate a string representation of model comparisons. + + Parameters + ---------- + models : ModelComparison + Set of model objects to access results from. + concise : bool, optional, default: False + Whether to generate a concise version of the string. + + Returns + ------- + str + Formatted string of model comparisons. + """ + + str_lst = [ + 'SPECTRUM MODEL COMPARISON', + '', + ] + + for ind, model in enumerate(models.models): + str_lst.append('Model {}: '.format(ind + 1) + model.modes.label) + + metric_str = 'Model metrics - ' + for m_ind in range(len(model.results.metrics)): + metric_str = metric_str + '{}: {:1.4f} '.format(\ + model.results.metrics.flabels[m_ind], + list(model.results.metrics.results.values())[m_ind]) + str_lst.append(metric_str) + + str_lst.append(\ + 'Model size: {:d} parameters [{:d} ap + {:d} pe ({:d} peak(s) x {:d})].'.format(\ + model.results.n_params, model.modes.aperiodic.n_params, + model.results.n_peaks * model.modes.periodic.n_params, + model.results.n_peaks, model.modes.periodic.n_params)) + + str_lst.append('') + + return _format(str_lst, concise) + + ## HELPER SUB-FUNCTIONS FOR MODEL REPORT STRINGS def _report_str_algo(model): diff --git a/specparam/tests/compare/__init__.py b/specparam/tests/compare/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/specparam/tests/compare/test_compare.py b/specparam/tests/compare/test_compare.py new file mode 100644 index 00000000..82793324 --- /dev/null +++ b/specparam/tests/compare/test_compare.py @@ -0,0 +1,47 @@ +"""Test functions for specparam.compare.compare.""" + +from specparam import SpectralModel + +from specparam.compare.compare import * + +################################################################################################### +################################################################################################### + +def test_model_comparison(): + + mc1 = ModelComparison() + assert isinstance(mc1, ModelComparison) + + # Test initializing with some models + models = [SpectralModel(aperiodic_mode='fixed', periodic_mode='gaussian'), + SpectralModel(aperiodic_mode='knee', periodic_mode='gaussian')] + mc2 = ModelComparison(models) + assert isinstance(mc2, ModelComparison) + assert len(mc2) == len(models) + for model in mc2: + assert model + + # Test data gets linked / results do not + for ind, model in enumerate(mc2.models[:-1]): + model.data is mc2.models[ind+1].data + model.results is not mc2.models[ind+1].results + +def test_model_comparison_fit(tdata): + + mc = ModelComparison(\ + [SpectralModel(aperiodic_mode='fixed', periodic_mode='gaussian'), + SpectralModel(aperiodic_mode='knee', periodic_mode='gaussian'), + ]) + + mc.fit(tdata.freqs, tdata.get_data('full', 'linear')) + for model in mc: + assert model.results.has_model + +def test_model_comparison_reporting(tmodelcomp, tdata, skip_if_no_mpl): + + mc = tmodelcomp.copy() + + mc.report(tdata.freqs, tdata.get_data('full', 'linear')) + + mc.print() + mc.plot() diff --git a/specparam/tests/conftest.py b/specparam/tests/conftest.py index d3b3a9d9..2b0a67ec 100644 --- a/specparam/tests/conftest.py +++ b/specparam/tests/conftest.py @@ -10,7 +10,8 @@ from specparam.tests.tdata import (get_tdata, get_tdata2d, get_tdata2dt, get_tdata3d, get_tfm, get_tfm2, get_tfg, get_tfg2, get_tft, get_tfe, - get_tbands, get_tresults, get_tmodes, get_tdocstring) + get_tmodelcomp, get_tbands, get_tresults, get_tmodes, + get_tdocstring) from specparam.tests.tsettings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH, TEST_PLOTS_PATH) @@ -99,6 +100,10 @@ def tft(): def tfe(): yield get_tfe() +@pytest.fixture(scope='session') +def tmodelcomp(): + yield get_tmodelcomp() + @pytest.fixture(scope='session') def tbands(): yield get_tbands() diff --git a/specparam/tests/metrics/test_metrics.py b/specparam/tests/metrics/test_metrics.py index 90a7ee81..ec6274d9 100644 --- a/specparam/tests/metrics/test_metrics.py +++ b/specparam/tests/metrics/test_metrics.py @@ -27,6 +27,10 @@ def test_metrics_obj(tfm): metrics.compute_metrics(tfm.data, tfm.results) assert isinstance(metrics.results, dict) + # Check iteration + for metric in metrics: + assert metric + # Check indexing met_out = metrics['error_mae'] assert isinstance(met_out, Metric) diff --git a/specparam/tests/modes/test_modes.py b/specparam/tests/modes/test_modes.py index 18da825f..aa5da546 100644 --- a/specparam/tests/modes/test_modes.py +++ b/specparam/tests/modes/test_modes.py @@ -15,6 +15,7 @@ def test_modes(): assert modes assert isinstance(modes.aperiodic, Mode) assert isinstance(modes.periodic, Mode) + assert modes.label modes.print() def test_modes_gets(): diff --git a/specparam/tests/plts/test_compare.py b/specparam/tests/plts/test_compare.py new file mode 100644 index 00000000..2b70dc5b --- /dev/null +++ b/specparam/tests/plts/test_compare.py @@ -0,0 +1,14 @@ +"""Tests for specparam.plts.compare.""" + +from specparam.tests.tutils import plot_test +from specparam.tests.tsettings import TEST_PLOTS_PATH + +from specparam.plts.compare import * + +################################################################################################### +################################################################################################### + +@plot_test +def test_plot_model_comparison(tmodelcomp, skip_if_no_mpl): + + plot_model_comparison(tmodelcomp) diff --git a/specparam/tests/reports/test_strings.py b/specparam/tests/reports/test_strings.py index fcceef13..91f23e13 100644 --- a/specparam/tests/reports/test_strings.py +++ b/specparam/tests/reports/test_strings.py @@ -104,6 +104,12 @@ def test_no_model_str(): assert _no_model_str() +## MODEL COMPARISONS + +def test_gen_model_comparison_str(tmodelcomp): + + assert gen_model_comparison_str(tmodelcomp) + ## UTILITIES def test_format(): diff --git a/specparam/tests/tdata.py b/specparam/tests/tdata.py index 85140d7f..c68018a7 100644 --- a/specparam/tests/tdata.py +++ b/specparam/tests/tdata.py @@ -8,6 +8,7 @@ from specparam.data.stores import FitResults from specparam.models import (SpectralModel, SpectralGroupModel, SpectralTimeModel, SpectralTimeEventModel) +from specparam.compare import ModelComparison from specparam.sim.params import param_sampler from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra, sim_spectrogram @@ -137,6 +138,20 @@ def get_tfe(): return tfe +## TEST MODEL COMPARISON OBJECTS + +def get_tmodelcomp(): + """Get a model comparison object, with fit power spectra, for testing.""" + + modelcomp = ModelComparison([ + SpectralModel(aperiodic_mode='fixed', periodic_mode='gaussian'), + SpectralModel(aperiodic_mode='knee', periodic_mode='gaussian'), + ]) + + modelcomp.fit(*sim_power_spectrum(*default_spectrum_params())) + + return modelcomp + ## TEST OTHER OBJECTS def get_tbands():