Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 0 additions & 68 deletions unicore/ema.py

This file was deleted.

16 changes: 6 additions & 10 deletions unicore/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from itertools import chain
from typing import Any, Dict, List
import torch
from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn
from unicore import checkpoint_utils, models, optim, utils
from unicore.distributed import utils as distributed_utils
from unicore.logging import meters, metrics
from unicore.nan_detector import NanDetector
from unicore.optim import lr_scheduler
from unicore.ema import ExponentialMovingAverageModel


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -119,11 +119,10 @@ def __init__(self, args, task, model, loss):
if args.ema_decay > 0 and (
self.data_parallel_rank == 0 or args.validate_with_ema
):
self.ema = ExponentialMovingAverageModel(
self.ema = AveragedModel(
model,
args.ema_decay,
multi_avg_fn=get_ema_multi_avg_fn(args.ema_decay),
)

else:
self.ema = None
metrics.log_start_time("wall", priority=790, round=2)
Expand Down Expand Up @@ -393,8 +392,8 @@ def load_checkpoint(
logger.info(
f"Cannot find EMA state in checkpoint, load model weight to ema directly"
)
self.ema = ExponentialMovingAverageModel(
self._model, decay=self.ema.decay
self.ema = AveragedModel(
self._model, multi_avg_fn=get_ema_multi_avg_fn(self.args.ema_decay)
)

loaded_train_itr = False
Expand Down Expand Up @@ -712,10 +711,7 @@ def maybe_no_sync():
)
if self.ema is not None:
with torch.autograd.profiler.record_function("ema"):
if self.args.fp16 or self.args.bf16:
self.ema.update(self.optimizer.fp32_params, is_flattened=True)
else:
self.ema.update(self.model.named_parameters(), is_flattened=False)
self.ema.update_parameters(self.model)

except FloatingPointError:
# re-run the forward and backward pass with hooks attached to print
Expand Down
2 changes: 1 addition & 1 deletion unicore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def validate_with_ema(trainer, ema=False):
yield
return
_wrapped_model = trainer._wrapped_model
trainer._wrapped_model = deepcopy(trainer.ema.model_ema)
trainer._wrapped_model = deepcopy(trainer.ema.module)
if trainer.args.fp16:
trainer._wrapped_model.half()
elif trainer.args.bf16:
Expand Down