From 16b626c28ab4b429a53f6dc131205b576d43d378 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Tue, 14 Apr 2026 21:12:02 -0700 Subject: [PATCH] Collect peak memory directly in integration tests Add direct peak-memory collection for integration runs by having MetricsProcessor write a JSON summary at the end of training. The test runner now passes TORCHTITAN_PEAK_MEMORY_JSON into each launched training job, forces metrics logging every step, and reads the emitted summary file back for reporting. MetricsProcessor tracks the maximum reserved and active CUDA memory it observes across log and validation calls and writes a single summary on close from the metrics rank. This keeps the measurement path local to the training run, avoids depending on TensorBoard event parsing for memory collection, and preserves the integration-test UX via --collect_peak_memory. The graph-trainer integration entrypoint and 8-GPU workflow are wired to use the flag. [ghstack-poisoned] --- .../integration_test_8gpu_graph_trainer.yaml | 2 +- tests/integration_tests/run_tests.py | 29 +++++++++++ torchtitan/components/metrics.py | 51 +++++++++++++++++++ .../graph_trainer/tests/integration_tests.py | 5 ++ 4 files changed, 86 insertions(+), 1 deletion(-) diff --git a/.github/workflows/integration_test_8gpu_graph_trainer.yaml b/.github/workflows/integration_test_8gpu_graph_trainer.yaml index 676db95abe..aa9fe32094 100644 --- a/.github/workflows/integration_test_8gpu_graph_trainer.yaml +++ b/.github/workflows/integration_test_8gpu_graph_trainer.yaml @@ -61,7 +61,7 @@ jobs: sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded" sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded" - python -m torchtitan.experiments.graph_trainer.tests.integration_tests --test_suite graph_trainer_default --gpu_arch_type cuda $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 + python -m torchtitan.experiments.graph_trainer.tests.integration_tests --test_suite graph_trainer_default --gpu_arch_type cuda $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 --collect_peak_memory # Run the numerics unit tests (dense models only; MoE tests run in the H100 workflow) diff --git a/tests/integration_tests/run_tests.py b/tests/integration_tests/run_tests.py index a23f95160c..920cd70c60 100644 --- a/tests/integration_tests/run_tests.py +++ b/tests/integration_tests/run_tests.py @@ -5,9 +5,11 @@ # LICENSE file in the root directory of this source tree. import argparse +import json import os import subprocess import time +from pathlib import Path from torchtitan.tools.logging import logger @@ -40,6 +42,8 @@ def run_single_test( output_dir: str, module: str | None = None, config: str | None = None, + *, + collect_peak_memory: bool = False, ): # run_test supports sequence of tests. test_name = test_flavor.test_name @@ -48,6 +52,7 @@ def run_single_test( all_ranks = ",".join(map(str, range(test_flavor.ngpu))) for idx, override_arg in enumerate(test_flavor.override_args): + peak_memory_path = Path(output_dir) / test_name / f"peak_memory_{idx}.json" cmd = "" if module is not None: cmd += f"MODULE={module} " @@ -59,6 +64,9 @@ def run_single_test( cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd cmd += " " + dump_folder_arg + if collect_peak_memory: + cmd = f'TORCHTITAN_PEAK_MEMORY_JSON="{peak_memory_path}" ' + cmd + cmd += " --metrics.log_freq=1" if override_arg: cmd += " " + " ".join(override_arg) logger.info( @@ -89,6 +97,19 @@ def run_single_test( f"Command: {cmd}\n" f"stderr: {result.stderr}\n" ) + if collect_peak_memory: + if not peak_memory_path.exists(): + raise FileNotFoundError( + f"Peak memory summary not found: {peak_memory_path}" + ) + peak_memory = json.loads(peak_memory_path.read_text()) + logger.info( + f"Peak memory for {test_name}[{idx}]: " + f"reserved={peak_memory['max_reserved_gib']:.3f} GiB " + f"at step {peak_memory['max_reserved_step']}, " + f"active={peak_memory['max_active_gib']:.3f} GiB " + f"at step {peak_memory['max_active_step']}" + ) def run_tests( @@ -98,6 +119,7 @@ def run_tests( config=None, ): """Run all integration tests to test the core features of TorchTitan""" + collect_peak_memory = getattr(args, "collect_peak_memory", False) exclude_set = set() if hasattr(args, "exclude") and args.exclude: exclude_set = {name.strip() for name in args.exclude.split(",")} @@ -132,6 +154,7 @@ def run_tests( args.output_dir, module, config, + collect_peak_memory=collect_peak_memory, ) except Exception as e: logger.error(str(e)) @@ -202,6 +225,12 @@ def main(): default=None, help="Comma-separated list of test names to skip", ) + parser.add_argument( + "--collect_peak_memory", + default=False, + action="store_true", + help="Collect peak reserved/active CUDA memory directly from the run.", + ) args = parser.parse_args() if not os.path.exists(args.output_dir): diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 2199b083b3..31d6c01d12 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import json import os import time from collections import namedtuple @@ -353,6 +354,9 @@ def __init__( self.data_loading_times = [] self.time_last_log = time.perf_counter() self.device_memory_monitor.reset_peak_stats() + self.peak_memory_summary_path = os.getenv("TORCHTITAN_PEAK_MEMORY_JSON") + self.pp_schedule = pp_schedule + self.peak_memory_summary: dict[str, float | int] | None = None self.has_quantization = has_quantization @@ -362,6 +366,48 @@ def __init__( self.lr_schedulers = None self.model_parts = None + def _should_write_peak_memory(self) -> bool: + return self.peak_memory_summary_path is not None and ( + not torch.distributed.is_initialized() + or torch.distributed.get_rank() + == _get_metrics_rank( + parallel_dims=self.parallel_dims, pp_schedule=self.pp_schedule + ) + ) + + def _update_peak_memory_summary( + self, device_mem_stats: DeviceMemStats, step: int + ) -> None: + if not self._should_write_peak_memory(): + return + + if self.peak_memory_summary is None: + self.peak_memory_summary = { + "max_reserved_gib": device_mem_stats.max_reserved_gib, + "max_reserved_step": step, + "max_active_gib": device_mem_stats.max_active_gib, + "max_active_step": step, + } + return + + if ( + device_mem_stats.max_reserved_gib + > self.peak_memory_summary["max_reserved_gib"] + ): + self.peak_memory_summary["max_reserved_gib"] = ( + device_mem_stats.max_reserved_gib + ) + self.peak_memory_summary["max_reserved_step"] = step + + if ( + device_mem_stats.max_active_gib + > self.peak_memory_summary["max_active_gib"] + ): + self.peak_memory_summary["max_active_gib"] = ( + device_mem_stats.max_active_gib + ) + self.peak_memory_summary["max_active_step"] = step + def should_log(self, step: int) -> bool: return step == 1 or step % self.config.log_freq == 0 @@ -496,6 +542,7 @@ def log( time_data_loading_pct = 100 * sum(self.data_loading_times) / time_delta device_mem_stats = self.device_memory_monitor.get_peak_stats() + self._update_peak_memory_summary(device_mem_stats, step) metrics = { "loss_metrics/global_avg_loss": global_avg_loss, @@ -545,6 +592,7 @@ def log_validation( time_delta = time.perf_counter() - self.time_last_log device_mem_stats = self.device_memory_monitor.get_peak_stats() + self._update_peak_memory_summary(device_mem_stats, step) # tokens per second per device, abbreviated as tps tps = self.ntokens_since_last_log / ( @@ -579,4 +627,7 @@ def log_validation( self.device_memory_monitor.reset_peak_stats() def close(self): + if self._should_write_peak_memory() and self.peak_memory_summary is not None: + with open(self.peak_memory_summary_path, "w") as f: + json.dump(self.peak_memory_summary, f, indent=2) self.logger.close() diff --git a/torchtitan/experiments/graph_trainer/tests/integration_tests.py b/torchtitan/experiments/graph_trainer/tests/integration_tests.py index ebd90bc60c..36ef08fdc8 100644 --- a/torchtitan/experiments/graph_trainer/tests/integration_tests.py +++ b/torchtitan/experiments/graph_trainer/tests/integration_tests.py @@ -547,6 +547,11 @@ def main(): help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)", ) parser.add_argument("--ngpu", default=8, type=int) + parser.add_argument( + "--collect_peak_memory", + action="store_true", + help="Collect peak reserved/active CUDA memory directly from the run.", + ) args = parser.parse_args() if not os.path.exists(args.output_dir):