From 1890bc01cf16ac35cdbef1853f77f5f76297e68d Mon Sep 17 00:00:00 2001 From: Richard Tomsett Date: Tue, 16 Jun 2026 22:14:09 +0100 Subject: [PATCH 1/2] fix: fixes the bug preventing early stopping from exiting the training loop when not running in distributed mode (issue 1384), and adds a regression test --- mace/tools/train.py | 10 +++--- tests/test_run_train.py | 71 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 6 deletions(-) diff --git a/mace/tools/train.py b/mace/tools/train.py index 443801d7c..0253b9a64 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -205,7 +205,7 @@ def train( valid_loss = valid_loss_head # consider only the last head for the checkpoint # variable used for broadcast by rank == 0 if epoch loop is exited early, e.g. patience - exit_now = torch.zeros(1, device=device) if distributed else None + exit_now = torch.zeros(1, device=device) while epoch < max_num_epochs: # LR scheduler and SWA update if swa is None or epoch < swa.start: @@ -308,8 +308,7 @@ def train( logging.info( f"Stopping optimization after {patience_counter} epochs without improvement" ) - if exit_now is not None: - exit_now.fill_(1) + exit_now.fill_(1) if save_all_checkpoints: param_context = ( ema.average_parameters() @@ -337,10 +336,9 @@ def train( keep_last = False or save_all_checkpoints if distributed: torch.distributed.barrier() - if exit_now is not None: torch.distributed.broadcast(exit_now, src=0) - if exit_now == 1: - break + if exit_now == 1: + break epoch += 1 diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 35baf332c..3e53de067 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -1,6 +1,8 @@ # pylint: disable=too-many-lines +import importlib import json +import logging import os import subprocess import sys @@ -110,6 +112,75 @@ def fixture_pretraining_configs(): } +@pytest.mark.parametrize( + ("save_all_checkpoints", "expected_saved_epochs"), + [ + (False, [(0, False)]), + (True, [(0, False), (1, True)]), + ], +) +def test_train_non_distributed_early_stopping_exits_loop( + monkeypatch, caplog, save_all_checkpoints, expected_saved_epochs +): + train_module = importlib.import_module("mace.tools.train") + trained_epochs = [] + scheduler_metrics = [] + saved_epochs = [] + + def fake_evaluate(**_kwargs): + return 1.0, {} + + def fake_train_one_epoch(*_args, epoch, **_kwargs): + trained_epochs.append(epoch) + + class DummyScheduler: + def step(self, metrics=None): + scheduler_metrics.append(metrics) + + class DummyCheckpointHandler: + def save(self, state, epochs, keep_last): # pylint: disable=unused-argument + saved_epochs.append((epochs, keep_last)) + + class DummyLogger: + def log(self, _metrics): + pass + + monkeypatch.setattr(train_module, "evaluate", fake_evaluate) + monkeypatch.setattr(train_module, "train_one_epoch", fake_train_one_epoch) + monkeypatch.setattr(train_module, "valid_err_log", lambda *_args, **_kwargs: None) + + with caplog.at_level(logging.INFO): + train_module.train( + model=torch.nn.Linear(1, 1), + loss_fn=torch.nn.MSELoss(), + train_loader=object(), + valid_loaders={"valid": object()}, + optimizer=object(), + lr_scheduler=DummyScheduler(), + start_epoch=0, + max_num_epochs=100, + patience=1, + checkpoint_handler=DummyCheckpointHandler(), + logger=DummyLogger(), + eval_interval=1, + output_args={}, + device=torch.device("cpu"), + log_errors="TotalRMSE", + max_grad_norm=None, + distributed=False, + save_all_checkpoints=save_all_checkpoints, + ) + + assert trained_epochs == [0, 1] + assert scheduler_metrics == [1.0] + assert saved_epochs == expected_saved_epochs + assert any( + record.levelno == logging.INFO + and "Stopping optimization after 1 epochs without improvement" in record.message + for record in caplog.records + ) + + def test_run_train(tmp_path, fitting_configs): ase.io.write(tmp_path / "fit.xyz", fitting_configs) From 238c9e2290e16f99bade17cd8fe5c24740a88ace Mon Sep 17 00:00:00 2001 From: Richard Tomsett Date: Wed, 17 Jun 2026 09:02:03 +0100 Subject: [PATCH 2/2] Reduce max_num_epochs in test from 100 to 10 --- tests/test_run_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 3e53de067..79ac89e2f 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -158,7 +158,7 @@ def log(self, _metrics): optimizer=object(), lr_scheduler=DummyScheduler(), start_epoch=0, - max_num_epochs=100, + max_num_epochs=10, patience=1, checkpoint_handler=DummyCheckpointHandler(), logger=DummyLogger(),