Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def train(
valid_loss = valid_loss_head # consider only the last head for the checkpoint

# variable used for broadcast by rank == 0 if epoch loop is exited early, e.g. patience
exit_now = torch.zeros(1, device=device) if distributed else None
exit_now = torch.zeros(1, device=device)
while epoch < max_num_epochs:
# LR scheduler and SWA update
if swa is None or epoch < swa.start:
Expand Down Expand Up @@ -308,8 +308,7 @@ def train(
logging.info(
f"Stopping optimization after {patience_counter} epochs without improvement"
)
if exit_now is not None:
exit_now.fill_(1)
exit_now.fill_(1)
if save_all_checkpoints:
param_context = (
ema.average_parameters()
Expand Down Expand Up @@ -337,10 +336,9 @@ def train(
keep_last = False or save_all_checkpoints
if distributed:
torch.distributed.barrier()
if exit_now is not None:
torch.distributed.broadcast(exit_now, src=0)
if exit_now == 1:
break
if exit_now == 1:
break

epoch += 1

Expand Down
71 changes: 71 additions & 0 deletions tests/test_run_train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# pylint: disable=too-many-lines

import importlib
import json
import logging
import os
import subprocess
import sys
Expand Down Expand Up @@ -110,6 +112,75 @@ def fixture_pretraining_configs():
}


@pytest.mark.parametrize(
("save_all_checkpoints", "expected_saved_epochs"),
[
(False, [(0, False)]),
(True, [(0, False), (1, True)]),
],
)
def test_train_non_distributed_early_stopping_exits_loop(
monkeypatch, caplog, save_all_checkpoints, expected_saved_epochs
):
train_module = importlib.import_module("mace.tools.train")
trained_epochs = []
scheduler_metrics = []
saved_epochs = []

def fake_evaluate(**_kwargs):
return 1.0, {}

def fake_train_one_epoch(*_args, epoch, **_kwargs):
trained_epochs.append(epoch)

class DummyScheduler:
def step(self, metrics=None):
scheduler_metrics.append(metrics)

class DummyCheckpointHandler:
def save(self, state, epochs, keep_last): # pylint: disable=unused-argument
saved_epochs.append((epochs, keep_last))

class DummyLogger:
def log(self, _metrics):
pass

monkeypatch.setattr(train_module, "evaluate", fake_evaluate)
monkeypatch.setattr(train_module, "train_one_epoch", fake_train_one_epoch)
monkeypatch.setattr(train_module, "valid_err_log", lambda *_args, **_kwargs: None)

with caplog.at_level(logging.INFO):
train_module.train(
model=torch.nn.Linear(1, 1),
loss_fn=torch.nn.MSELoss(),
train_loader=object(),
valid_loaders={"valid": object()},
optimizer=object(),
lr_scheduler=DummyScheduler(),
start_epoch=0,
max_num_epochs=10,
patience=1,
checkpoint_handler=DummyCheckpointHandler(),
logger=DummyLogger(),
eval_interval=1,
output_args={},
device=torch.device("cpu"),
log_errors="TotalRMSE",
max_grad_norm=None,
distributed=False,
save_all_checkpoints=save_all_checkpoints,
)

assert trained_epochs == [0, 1]
assert scheduler_metrics == [1.0]
assert saved_epochs == expected_saved_epochs
assert any(
record.levelno == logging.INFO
and "Stopping optimization after 1 epochs without improvement" in record.message
for record in caplog.records
)


def test_run_train(tmp_path, fitting_configs):
ase.io.write(tmp_path / "fit.xyz", fitting_configs)

Expand Down