diff --git a/arctic_training/config/trainer.py b/arctic_training/config/trainer.py index f501e93b..dd1e8f54 100644 --- a/arctic_training/config/trainer.py +++ b/arctic_training/config/trainer.py @@ -106,8 +106,11 @@ class TrainerConfig(BaseConfig): loss_log_interval: HumanInt = Field(default=1, ge=0) """ Number of steps between logging loss. """ - train_log_iter_interval: Literal[0, 1] = 1 - """ Iters between training metric log outputs. `0` is off, only intervals of `1` currently supported. """ + train_log_iter_interval: HumanInt = Field(default=1, ge=0) + """ Iters between training metric log outputs. `0` disables metrics logging. """ + + metrics_display_order: List[str] = [] + """ Optional display order for metrics in the log line. Unlisted metrics are appended in their default order. """ # XXX: fixme: the default output dir is broken # train_log_metrics_path: Path = Field( diff --git a/arctic_training/debug.py b/arctic_training/debug.py index 457bbd23..c2839194 100644 --- a/arctic_training/debug.py +++ b/arctic_training/debug.py @@ -134,27 +134,19 @@ def see_memory_usage(message, force=False, ranks=[0]): get_accelerator().reset_peak_memory_stats() -def get_mem_metrics(): - +def get_mem_metrics() -> tuple[float, float, float]: + """Return (memory_allocated_gb, max_memory_allocated_gb, nvml_mem_gb).""" gc.collect() - # torch.cuda.empty_cache() nv_mem = get_nvml_mem() - - summary = " | ".join( - [ - f"MA {round(get_accelerator().memory_allocated() / 2**30, 2):0.2f} GB", - f"Max_MA {round(get_accelerator().max_memory_allocated() / 2**30, 2):0.2f} GB", - f"NV {round(nv_mem / 2**30, 2):0.2f} GB", - ] - ) + ma_gb = round(get_accelerator().memory_allocated() / 2**30, 2) + max_ma_gb = round(get_accelerator().max_memory_allocated() / 2**30, 2) + nv_gb = round(nv_mem / 2**30, 2) # get the peak memory to report correct data, so reset the counter for the next call - # this will lead to wrong peak reports if `see_mem_usage` is also used during the run, - # as it resets the peak counter and there is only one counter get_accelerator().reset_peak_memory_stats() - return summary + return (ma_gb, max_ma_gb, nv_gb) # fcntl.flock can be slow on shared fs, so if things are too slow especially when many ranks are diff --git a/arctic_training/metrics.py b/arctic_training/metrics.py index 4e2885e8..9b7fdda9 100644 --- a/arctic_training/metrics.py +++ b/arctic_training/metrics.py @@ -13,17 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections import defaultdict from typing import TYPE_CHECKING +from typing import Callable from typing import Dict from typing import List +from typing import Optional from typing import Union -from typing import cast import torch from deepspeed.utils.timer import SynchronizedWallClockTimer -from arctic_training.debug import get_mem_metrics from arctic_training.trainer.flops_counter import estimate_decoder_transformer_tflos from arctic_training.utils import human_format_base10_number from arctic_training.utils import human_format_secs @@ -32,170 +34,303 @@ from arctic_training.trainer.trainer import Trainer -def gather_object(number: Union[float, int, list], world_size: int) -> List[Union[float, int]]: - output = [None] * world_size - torch.distributed.all_gather_object(output, number) - output = [v for ll in output for v in (ll if isinstance(ll, list) else [ll])] - return cast(List[Union[float, int]], output) +def _gather_object(value: Union[float, int, list], world_size: int) -> List[float]: + """All-gather a scalar or list across ranks, returning a flat list.""" + output: list = [None] * world_size + torch.distributed.all_gather_object(output, value) + return [v for item in output for v in (item if isinstance(item, list) else [item])] + + +def _compute_tflos(metrics: Metrics, ctx: Dict) -> Optional[float]: + """Compute and cache tflos_total from raw seqlens in the accumulator.""" + if "tflos_total" in ctx: + return ctx["tflos_total"] + seqlens_raw = metrics._accum.get("seqlens", []) + if not seqlens_raw: + return None + batch_of_seqlens = [s for batch in seqlens_raw for s in batch] + ws = metrics.trainer.world_size + tflos_sub, _ = estimate_decoder_transformer_tflos( + hf_model_config=metrics.trainer.model_unwrapped.config, + model_size=metrics._model_size, + batch_of_seqlens=batch_of_seqlens, + enable_gradient_checkpointing=not metrics.trainer.config.model.disable_activation_checkpoint, + ) + ctx["tflos_total"] = sum(_gather_object(tflos_sub, ws)) / metrics.trainer.config.sequence_parallel_size + return ctx["tflos_total"] + + +def _derive_tflops(time_key: str) -> Callable: + """Factory for tflops derive functions: tflos_total / raw gathered time total.""" + + def _derive(metrics: Metrics, ctx: Dict) -> Optional[float]: + tflos = _compute_tflos(metrics, ctx) + time_total = ctx.get(f"{time_key}_total") + if tflos and time_total: + return tflos / time_total + return None + + return _derive + + +def _compute_interval_token_sum(metrics: Metrics) -> Optional[int]: + """Total tokens since last log (summed over all steps and DP).""" + seqlens_raw = metrics._accum.get("seqlens", []) + if not seqlens_raw: + return None + local_total = sum(s for batch in seqlens_raw for sublist in batch for s in sublist) + gathered = _gather_object(local_total, metrics.trainer.world_size) + return int(sum(gathered)) + + +def _derive_seqlen(metrics: Metrics, ctx: Dict) -> Optional[float]: + """Average seqlen per step per DP (since last log).""" + total_tokens = _compute_interval_token_sum(metrics) + if total_tokens is None: + return None + seqlens_raw = metrics._accum.get("seqlens", []) + ws = metrics.trainer.world_size + num_steps = len(seqlens_raw) * ws + return total_tokens / num_steps if num_steps else None + + +def _derive_seqlen_total_since_log(metrics: Metrics, ctx: Dict) -> Optional[int]: + """Total tokens since last log (all steps, all DP ranks).""" + return _compute_interval_token_sum(metrics) + + +class MetricDef: + """Defines how a metric is reduced, formatted, and displayed.""" + + __slots__ = ("reduce", "fmt", "display_name", "wandb", "accumulate", "derive") + + def __init__( + self, + reduce: str = "mean", + fmt: Optional[Union[str, Callable]] = None, + display_name: Optional[str] = None, + wandb: bool = True, + accumulate: bool = False, + derive: Optional[Callable] = None, + ): + self.reduce = reduce + self.fmt = fmt + self.display_name = display_name + self.wandb = wandb + self.accumulate = accumulate + self.derive = derive + + def format_value(self, value: Union[int, float]) -> str: + if callable(self.fmt): + return self.fmt(value) + if isinstance(self.fmt, str): + return f"{value:{self.fmt}}" + return str(value) class Metrics: - """Class for measuring, tracking, and reporting training metrics.""" + """Tracks, accumulates, and reports training metrics with GAS support. + + All values accumulate continuously. When ``report()`` is called, the + ``accumulate`` flag on each metric controls which values are used: + + - ``accumulate=False`` (default): Only the latest value is used. + - ``accumulate=True``: All values since the previous ``report()`` are + aggregated, giving averages over the full log interval. + + After reporting, all accumulators are cleared. - def __init__(self, trainer: "Trainer") -> None: + Metrics can also be *derived* — computed from other metrics via a + ``derive(metrics, ctx)`` callable rather than recorded directly. + + Standard metrics are registered by default. Trainers can register + additional metrics via ``register()``. + """ + + def __init__(self, trainer: Trainer) -> None: self.enabled = trainer.config.train_log_iter_interval > 0 if not self.enabled: return self.trainer = trainer + self._defs: Dict[str, MetricDef] = {} + self._accum: Dict[str, list] = defaultdict(list) + self._timers: Dict[str, SynchronizedWallClockTimer.Timer] = {} + self._display_order: List[str] = trainer.config.metrics_display_order self.summary_dict: Dict[str, Union[int, float]] = {} - self.timers: Dict[str, SynchronizedWallClockTimer.Timer] = {} - self.values: Dict[str, Union[int, float]] = defaultdict(float) - self.seqlens = None - - self.losses: list = [] - # Store model size values for quickly calculating tflos later - def numel_fn(p): + # Register standard metrics -- display order follows registration order. + self.register("epoch", derive=lambda m, _: m.trainer.epoch_idx, display_name="epoch") + self.register("loss", reduce="mean", fmt=".4f", display_name="loss") + self.register("eval_loss", reduce="mean", fmt=".4f", display_name="eval loss") + self.register( + "lr", derive=lambda m, _: m.trainer.model.lr_scheduler.get_last_lr()[0], fmt=".3E", display_name="lr" + ) + self.register( + "seqlens", derive=_derive_seqlen, fmt=human_format_base10_number, display_name="seqlen (avg/step/DP)" + ) + self.register( + "seqlen_total_since_log", + derive=_derive_seqlen_total_since_log, + fmt=human_format_base10_number, + display_name="seqlen total (since log)", + ) + self.register("step_time", reduce="mean", fmt=human_format_secs, display_name="step time", accumulate=True) + self.register("step_tflops", derive=_derive_tflops("step_time"), fmt=".1f", display_name="step tflops") + self.register("iter_time", reduce="mean", fmt=human_format_secs, display_name="iter time", accumulate=True) + self.register("iter_tflops", derive=_derive_tflops("iter_time"), fmt=".1f", display_name="iter tflops") + self.register("mem_ma", reduce="mean", fmt=lambda v: f"{v:.2f} GB", display_name="MA") + self.register("mem_max_ma", reduce="mean", fmt=lambda v: f"{v:.2f} GB", display_name="Max_MA") + self.register("mem_nv", reduce="mean", fmt=lambda v: f"{v:.2f} GB", display_name="NV") + + def numel(p): return p.ds_numel if hasattr(p, "ds_tensor") else p.numel() - model = self.trainer.model_unwrapped - self.model_size = sum(numel_fn(p) for p in model.parameters()) - self.model_num_layers = model.config.num_hidden_layers - self.model_hidden_size = model.config.hidden_size - - # Set max_iter based on when we expect to exit training - if self.trainer.config.exit_iteration > 0: - self.max_iter = min(self.trainer.config.exit_iteration, self.trainer.training_horizon) - else: - self.max_iter = self.trainer.training_horizon - self.max_iter_pad = len(str(self.max_iter)) + self._model_size = sum(numel(p) for p in trainer.model_unwrapped.parameters()) + + horizon = trainer.training_horizon + if trainer.config.exit_iteration > 0: + horizon = min(trainer.config.exit_iteration, horizon) + self._max_iter = horizon + self._max_iter_pad = len(str(horizon)) + + def register( + self, + name: str, + reduce: str = "mean", + fmt: Optional[Union[str, Callable]] = None, + display_name: Optional[str] = None, + wandb: bool = True, + accumulate: bool = False, + derive: Optional[Callable] = None, + ) -> None: + """Register a new metric (or override an existing one). + + Args: + name: Key used with ``record()``. + reduce: ``"mean"`` or ``"sum"`` — how to reduce across GAS micro-steps. + fmt: Format spec string (e.g. ``".4f"``) or callable for display. + display_name: Label shown in the log line. ``None`` hides it from display. + wandb: Whether to include in wandb logs. + accumulate: If ``True``, ``report()`` aggregates all values since + the previous report. If ``False`` (default), only the latest + GAS cycle's values are used. + derive: A callable ``(metrics, ctx) -> value`` for metrics computed + from other metrics rather than recorded directly. *metrics* is the + ``Metrics`` instance (gives access to ``_accum``, ``trainer``, + etc.). *ctx* contains ``{key}_total`` (raw gathered sum) for + every reduced metric. + """ + if not self.enabled: + return + self._defs[name] = MetricDef( + reduce=reduce, fmt=fmt, display_name=display_name, wandb=wandb, accumulate=accumulate, derive=derive + ) - def record(self, key: str, value: Union[int, float]) -> None: - """Records a value in the metrics dictionary.""" + def record(self, key: str, value) -> None: + """Record a metric value. Always appends to the accumulator.""" if not self.enabled: return - if key in self.values: - raise KeyError( - f"Key {key} already exists. You are trying to write a value that has" - " not yet been reported. This can happen if you try to write to a" - " given value more than once in a training iteration loop." - ) - self.values[key] = value + self._accum[key].append(value) def start_timer(self, key: str) -> None: - """Starts a timer identified by `key`. If timer does not exist, one is created.""" + """Start (or create) a wall-clock timer.""" if not self.enabled: return - if key not in self.timers: - self.timers[key] = SynchronizedWallClockTimer().Timer(key) - self.timers[key].start() + if key not in self._timers: + self._timers[key] = SynchronizedWallClockTimer().Timer(key) + self._timers[key].start() def stop_timer(self, key: str) -> None: - """Stops a timer identfied by `key` and records the elapsed time in seconds to the metrics dictionary.""" + """Stop a timer and accumulate its elapsed time (seconds).""" if not self.enabled: return - if key not in self.timers: + if key not in self._timers: raise KeyError(f"Timer {key} not started") - self.timers[key].stop() - self.values[f"{key}_time"] = self.timers[key].elapsed() / 1000 + self._timers[key].stop() + self._accum[f"{key}_time"].append(self._timers[key].elapsed() / 1000) def restart_timer(self, key: str) -> None: + """Stop and immediately restart a timer.""" self.stop_timer(key) self.start_timer(key) - def get_value(self, key: str) -> Union[int, float]: - """Returns the value stored in the metrics dictionary for the given key.""" - return self.values[key] + def should_log(self) -> bool: + """Whether metrics should be logged on the current global step.""" + return ( + self.enabled + and self.trainer.global_step > 0 + and self.trainer.global_step % self.trainer.config.train_log_iter_interval == 0 + ) - def print_summary(self, prefix: str = "train") -> None: - """Prints a summary of the metrics. If a value is not recorded by the Trainer, it is not included in the summary.""" + def clear(self) -> None: + """Clear all accumulated values.""" if not self.enabled: return + self._accum.clear() + + def report(self, prefix: str = "train") -> Dict[str, Union[int, float]]: + """Reduce, gather, print, and return the metrics summary. + + For ``accumulate=False`` metrics, only the latest GAS cycle's values + are used. For ``accumulate=True`` metrics, all values since the last + report are aggregated. + + After reporting, all accumulators are cleared. + """ + if not self.enabled: + return {} self.summary_dict.clear() - self.summary_dict["epoch"] = self.trainer.epoch_idx self.summary_dict["iter"] = self.trainer.global_step - self.summary_dict["lr"] = self.trainer.model.lr_scheduler.get_last_lr()[0] - - tflos_total: float = 0.0 - if self.seqlens is not None: - - tflos_subtotal, seqlen_subtotal = estimate_decoder_transformer_tflos( - hf_model_config=self.trainer.model_unwrapped.config, - model_size=self.model_size, - batch_of_seqlens=self.seqlens, - enable_gradient_checkpointing=not self.trainer.config.model.disable_activation_checkpoint, - ) - - # need total seqlen for tflos calculation because of O(n**2), but then divide by sp_world_size because each rank calculated its fraction of these tflos - tflos_total = ( - sum(gather_object(tflos_subtotal, self.trainer.world_size)) - / self.trainer.config.sequence_parallel_size - ) - self.values["seqlen_total"] = seqlen_subtotal - - if "loss" in self.values: - loss = sum(gather_object(self.values["loss"], self.trainer.world_size)) / self.trainer.world_size - self.summary_dict["loss"] = loss - - # XXX: short term partial GAS support for reporting the correct loss - if self.trainer.config.gradient_accumulation_steps > 1: - if len(self.losses) < self.trainer.config.gradient_accumulation_steps: - self.losses += [loss] - - if len(self.losses) == self.trainer.config.gradient_accumulation_steps: - self.summary_dict["loss"] = sum(self.losses) / self.trainer.config.gradient_accumulation_steps - self.losses = [] - - if "loss/eval" in self.values: - losses = gather_object(self.values["loss/eval"], self.trainer.world_size) - self.summary_dict["loss/eval"] = sum(losses) / len(losses) - - if "iter_time" in self.values: - iter_time_total = sum(gather_object(self.values["iter_time"], self.trainer.world_size)) - self.summary_dict["iter_time"] = iter_time_total / self.trainer.world_size - if tflos_total > 0: - self.summary_dict["iter_tflops"] = tflos_total / iter_time_total - - if "seqlen_total" in self.values: - seq_len_total = sum(gather_object(self.values["seqlen_total"], self.trainer.world_size)) - self.summary_dict["seqlen"] = seq_len_total / self.trainer.world_size - - if "step_time" in self.values: - step_time_total = sum(gather_object(self.values["step_time"], self.trainer.world_size)) - self.summary_dict["step_time"] = step_time_total / self.trainer.world_size - if tflos_total > 0: - self.summary_dict["step_tflops"] = tflos_total / step_time_total - - self.values.clear() - - summary_str = ( - f"{prefix.title():>{len('train')}} iter:" - f" {self.summary_dict['iter']:>{self.max_iter_pad}}/{self.max_iter}" - f" {100*self.summary_dict['iter']//self.max_iter:>3}%" - ) - if "loss" in self.summary_dict: - summary_str += f" | loss: {self.summary_dict['loss']:.4f}" - if "loss/eval" in self.summary_dict: - summary_str += f" | loss: {self.summary_dict['loss/eval']:.4f}" - if "iter_time" in self.summary_dict: - summary_str += f" | iter time: {human_format_secs(self.summary_dict['iter_time'])}" - if "iter_tflops" in self.summary_dict: - summary_str += f" | iter tflops: {self.summary_dict['iter_tflops']:.1f}" - summary_str += f" | lr: {self.summary_dict['lr']:.3E}" - if "seqlen" in self.summary_dict: - summary_str += f" | seqlen: {human_format_base10_number(self.summary_dict['seqlen'])}" - if "step_time" in self.summary_dict: - summary_str += f" | step time: {human_format_secs(self.summary_dict['step_time'])}" - if "step_tflops" in self.summary_dict: - summary_str += f" | step tflops: {self.summary_dict['step_tflops']:.1f}" - summary_str += f" | epoch: {self.summary_dict['epoch']}" + ws = self.trainer.world_size + + # Context dict passed to derive functions -- populated below with + # raw gathered totals ({key}_total) for every reduced scalar metric. + ctx: Dict[str, float] = {} + + # Reduce each scalar accumulator and gather across ranks. + for key, values in list(self._accum.items()): + if not values or not isinstance(values[0], (int, float)): + continue + defn = self._defs.get(key, MetricDef()) + if not defn.accumulate: + values = [values[-1]] + local_sum = sum(values) + gathered = _gather_object(local_sum, ws) + total = sum(gathered) + if defn.reduce == "mean": + self.summary_dict[key] = total / (len(values) * len(gathered)) + else: + self.summary_dict[key] = total / len(gathered) + ctx[f"{key}_total"] = total + + # Evaluate derived metrics + for key, defn in self._defs.items(): + if defn.derive is not None and key not in self.summary_dict: + value = defn.derive(self, ctx) + if value is not None: + self.summary_dict[key] = value if self.trainer.global_rank == 0: - # XXX: make configurable via yaml - mem_metrics = get_mem_metrics() - summary_str += f" | {mem_metrics}" + self._print(prefix) - if self.trainer.global_rank == 0: - print(summary_str) + self.clear() + return dict(self.summary_dict) + + def _print(self, prefix: str) -> None: + s = self.summary_dict + parts = [ + f"{prefix.title():>{len('train')}} iter:" + f" {s['iter']:>{self._max_iter_pad}}/{self._max_iter}" + f" {100 * s['iter'] // self._max_iter:>3}%" + ] + if self._display_order: + keys = list(self._display_order) + [k for k in self._defs if k not in self._display_order] + else: + keys = list(self._defs) + for key in keys: + defn = self._defs.get(key) + if defn and defn.display_name and key in s: + parts.append(f"{defn.display_name}: {defn.format_value(s[key])}") + print(" | ".join(parts)) diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 29686c14..67e3aa1b 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -47,6 +47,7 @@ from arctic_training.config.trainer import TrainerConfig from arctic_training.data.factory import DataFactory from arctic_training.data.utils import OverfitOneBatchDataLoader +from arctic_training.debug import get_mem_metrics from arctic_training.logging import logger from arctic_training.metrics import Metrics from arctic_training.model.factory import ModelFactory @@ -477,45 +478,34 @@ def epoch(self) -> None: [len(batch["input_ids"][idx]) * self.config.sequence_parallel_size] for idx in range(len(batch["input_ids"])) ] - self.metrics.seqlens = sample_seqlens + self.metrics.record("seqlens", sample_seqlens) self.metrics.start_timer("step") self.step(batch) self.metrics.stop_timer("step") - self.metrics.restart_timer("iter") - if self.config.train_log_iter_interval != 0: - self.metrics.print_summary() - if self.gas_boundary: - if ( - self.global_rank == 0 - and self.config.train_log_iter_interval != 0 - and self.global_step % self.config.train_log_iter_interval == 0 - ): - metrics = {k: v for k, v in self.metrics.summary_dict.items()} - if self.ds_wall_clock_available: - ds_timers = self.model.get_wall_clock_timers() - metrics.update(ds_timers) - - append_json_file(self.config.train_log_metrics_path, metrics) - - # do not log the first train iteration to wandb, since it's a massive outlier - # on all performance metrics, which messes up the scale of the report - if self.wandb_experiment is not None and self.global_step > 1: - metrics = {k: v for k, v in metrics.items() if k not in ["iter"]} - self.wandb_experiment.log(metrics, step=self.global_step) + if self.metrics.should_log(): + ma, max_ma, nv = get_mem_metrics() + self.metrics.record("mem_ma", ma) + self.metrics.record("mem_max_ma", max_ma) + self.metrics.record("mem_nv", nv) + summary = self.metrics.report() + if self.global_rank == 0: + if self.ds_wall_clock_available: + summary.update(self.model.get_wall_clock_timers()) + append_json_file(self.config.train_log_metrics_path, summary) + if self.wandb_experiment is not None and self.global_step > 1: + wandb_metrics = {k: v for k, v in summary.items() if k != "iter"} + self.wandb_experiment.log(wandb_metrics, step=self.global_step) if self.config.eval_interval != 0 and self.global_step % self.config.eval_interval == 0: self.evaluate() - if self.is_eval_log_iter(): - self.metrics.print_summary(prefix="eval") - + eval_summary = self.metrics.report(prefix="eval") if self.wandb_experiment is not None: - metrics = {k: self.metrics.summary_dict[k] for k in ["loss/eval"]} - self.wandb_experiment.log(metrics, step=self.global_step) + self.wandb_experiment.log({"eval_loss": eval_summary["eval_loss"]}, step=self.global_step) self.metrics.stop_timer("iter") self.epoch_finished = True @@ -561,8 +551,8 @@ def evaluate(self) -> None: """ self.model.eval() with torch.no_grad(): - losses = [self.loss(eval_batch).item() for eval_batch in self.eval_batches] - self.metrics.record("loss/eval", losses) # type: ignore + for eval_batch in self.eval_batches: + self.metrics.record("eval_loss", self.loss(eval_batch).item()) @callback_wrapper("checkpoint") def checkpoint(self) -> None: