Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded"
sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded"

python -m torchtitan.experiments.graph_trainer.tests.integration_tests --test_suite graph_trainer_default --gpu_arch_type cuda $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8
python -m torchtitan.experiments.graph_trainer.tests.integration_tests --test_suite graph_trainer_default --gpu_arch_type cuda $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 --collect_peak_memory


# Run the numerics unit tests (dense models only; MoE tests run in the H100 workflow)
Expand Down
29 changes: 29 additions & 0 deletions tests/integration_tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
# LICENSE file in the root directory of this source tree.

import argparse
import json
import os
import subprocess
import time
from pathlib import Path

from torchtitan.tools.logging import logger

Expand Down Expand Up @@ -40,6 +42,8 @@ def run_single_test(
output_dir: str,
module: str | None = None,
config: str | None = None,
*,
collect_peak_memory: bool = False,
):
# run_test supports sequence of tests.
test_name = test_flavor.test_name
Expand All @@ -48,6 +52,7 @@ def run_single_test(
all_ranks = ",".join(map(str, range(test_flavor.ngpu)))

for idx, override_arg in enumerate(test_flavor.override_args):
peak_memory_path = Path(output_dir) / test_name / f"peak_memory_{idx}.json"
cmd = ""
if module is not None:
cmd += f"MODULE={module} "
Expand All @@ -59,6 +64,9 @@ def run_single_test(
cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd

cmd += " " + dump_folder_arg
if collect_peak_memory:
cmd = f'TORCHTITAN_PEAK_MEMORY_JSON="{peak_memory_path}" ' + cmd
cmd += " --metrics.log_freq=1"
if override_arg:
cmd += " " + " ".join(override_arg)
logger.info(
Expand Down Expand Up @@ -89,6 +97,19 @@ def run_single_test(
f"Command: {cmd}\n"
f"stderr: {result.stderr}\n"
)
if collect_peak_memory:
if not peak_memory_path.exists():
raise FileNotFoundError(
f"Peak memory summary not found: {peak_memory_path}"
)
peak_memory = json.loads(peak_memory_path.read_text())
logger.info(
f"Peak memory for {test_name}[{idx}]: "
f"reserved={peak_memory['max_reserved_gib']:.3f} GiB "
f"at step {peak_memory['max_reserved_step']}, "
f"active={peak_memory['max_active_gib']:.3f} GiB "
f"at step {peak_memory['max_active_step']}"
)


def run_tests(
Expand All @@ -98,6 +119,7 @@ def run_tests(
config=None,
):
"""Run all integration tests to test the core features of TorchTitan"""
collect_peak_memory = getattr(args, "collect_peak_memory", False)
exclude_set = set()
if hasattr(args, "exclude") and args.exclude:
exclude_set = {name.strip() for name in args.exclude.split(",")}
Expand Down Expand Up @@ -132,6 +154,7 @@ def run_tests(
args.output_dir,
module,
config,
collect_peak_memory=collect_peak_memory,
)
except Exception as e:
logger.error(str(e))
Expand Down Expand Up @@ -202,6 +225,12 @@ def main():
default=None,
help="Comma-separated list of test names to skip",
)
parser.add_argument(
"--collect_peak_memory",
default=False,
action="store_true",
help="Collect peak reserved/active CUDA memory directly from the run.",
)
args = parser.parse_args()

if not os.path.exists(args.output_dir):
Expand Down
51 changes: 51 additions & 0 deletions torchtitan/components/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
import os
import time
from collections import namedtuple
Expand Down Expand Up @@ -353,6 +354,9 @@ def __init__(
self.data_loading_times = []
self.time_last_log = time.perf_counter()
self.device_memory_monitor.reset_peak_stats()
self.peak_memory_summary_path = os.getenv("TORCHTITAN_PEAK_MEMORY_JSON")
self.pp_schedule = pp_schedule
self.peak_memory_summary: dict[str, float | int] | None = None

self.has_quantization = has_quantization

Expand All @@ -362,6 +366,48 @@ def __init__(
self.lr_schedulers = None
self.model_parts = None

def _should_write_peak_memory(self) -> bool:
return self.peak_memory_summary_path is not None and (
not torch.distributed.is_initialized()
or torch.distributed.get_rank()
== _get_metrics_rank(
parallel_dims=self.parallel_dims, pp_schedule=self.pp_schedule
)
)

def _update_peak_memory_summary(
self, device_mem_stats: DeviceMemStats, step: int
) -> None:
if not self._should_write_peak_memory():
return

if self.peak_memory_summary is None:
self.peak_memory_summary = {
"max_reserved_gib": device_mem_stats.max_reserved_gib,
"max_reserved_step": step,
"max_active_gib": device_mem_stats.max_active_gib,
"max_active_step": step,
}
return

if (
device_mem_stats.max_reserved_gib
> self.peak_memory_summary["max_reserved_gib"]
):
self.peak_memory_summary["max_reserved_gib"] = (
device_mem_stats.max_reserved_gib
)
self.peak_memory_summary["max_reserved_step"] = step

if (
device_mem_stats.max_active_gib
> self.peak_memory_summary["max_active_gib"]
):
self.peak_memory_summary["max_active_gib"] = (
device_mem_stats.max_active_gib
)
self.peak_memory_summary["max_active_step"] = step

def should_log(self, step: int) -> bool:
return step == 1 or step % self.config.log_freq == 0

Expand Down Expand Up @@ -496,6 +542,7 @@ def log(
time_data_loading_pct = 100 * sum(self.data_loading_times) / time_delta

device_mem_stats = self.device_memory_monitor.get_peak_stats()
self._update_peak_memory_summary(device_mem_stats, step)

metrics = {
"loss_metrics/global_avg_loss": global_avg_loss,
Expand Down Expand Up @@ -545,6 +592,7 @@ def log_validation(
time_delta = time.perf_counter() - self.time_last_log

device_mem_stats = self.device_memory_monitor.get_peak_stats()
self._update_peak_memory_summary(device_mem_stats, step)

# tokens per second per device, abbreviated as tps
tps = self.ntokens_since_last_log / (
Expand Down Expand Up @@ -579,4 +627,7 @@ def log_validation(
self.device_memory_monitor.reset_peak_stats()

def close(self):
if self._should_write_peak_memory() and self.peak_memory_summary is not None:
with open(self.peak_memory_summary_path, "w") as f:
json.dump(self.peak_memory_summary, f, indent=2)
self.logger.close()
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,11 @@ def main():
help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)",
)
parser.add_argument("--ngpu", default=8, type=int)
parser.add_argument(
"--collect_peak_memory",
action="store_true",
help="Collect peak reserved/active CUDA memory directly from the run.",
)
args = parser.parse_args()

if not os.path.exists(args.output_dir):
Expand Down
Loading