diff --git a/torchtitan/experiments/llm_trainer/INSTRUCTIONS.md b/torchtitan/experiments/llm_trainer/INSTRUCTIONS.md new file mode 100644 index 0000000000..a7fd6d7778 --- /dev/null +++ b/torchtitan/experiments/llm_trainer/INSTRUCTIONS.md @@ -0,0 +1,249 @@ +# LLM Trainer: Optimizing Flattened PyTorch Models + +## Overview + +You are tasked with optimizing a flattened PyTorch training step. The training +step (forward + loss + backward) has been traced into a straight-line sequence +of `torch.ops.aten.*` operations. Your job is to rewrite these operations to +run faster while maintaining convergence — the candidate's validation loss +after 100 training steps must be **less than or equal to** the baseline's. + +## Directory Layout + +Models are organized under `targets/` by **fingerprint** — a string that +encodes both the hardware and the parallelism configuration. You provide +the hardware label via `HARDWARE=` (e.g. `h100-sm90`), and the +tools automatically append non-trivial parallelism dimensions from the +config (e.g. `tp2`, `fsdp4`, `ep8`). + +Examples: +- Single GPU: `HARDWARE=h100-sm90` → `targets/h100-sm90/` +- TP=2, FSDP=4: `HARDWARE=h100-sm90` → `targets/h100-sm90_tp2_fsdp4/` +- TP=2, EP=8: `HARDWARE=h100-sm90` → `targets/h100-sm90_tp2_ep8/` + +``` +torchtitan/experiments/llm_trainer/ +└── targets/ + └── h100-sm90_tp2_fsdp4/ # Fingerprint = hardware + parallelism + ├── flattened_models/ # Original traced models (DO NOT EDIT) + │ ├── llama3_8b_rank0.py + │ ├── llama3_8b_rank0_meta.json + │ ├── llama3_8b_rank1.py + │ └── ... + ├── optimized_models/ # Current best version (auto-managed) + │ ├── llama3_8b_rank0.py + │ └── ... + └── candidate_models/ # YOUR workspace — put optimized code here + ├── llama3_8b_rank0.py + └── ... +``` + +- `flattened_models/` — The raw traced output. Never modify these files. + They serve as a reference for what the original model does. +- `optimized_models/` — The current best-performing version. Initially a + copy of the flattened model. Updated when a candidate is promoted. +- `candidate_models/` — Your working directory. Copy an optimized model here, + optimize it, then benchmark. + +## Workflow + +### Step 1: Generate the flattened model (one-time setup) + +```bash +NGPU=8 HARDWARE=h100-sm90 ./torchtitan/experiments/llm_trainer/run_flattener.sh \ + --module graph_trainer.llama3 \ + --config graph_trainer_llama3_debugmodel +``` + +This traces the model across GPUs via torchrun. The fingerprint directory +is auto-constructed from `HARDWARE` + the parallelism config. Each rank +writes its own files: +- `targets//flattened_models/_rank{i}.py` +- `targets//flattened_models/_rank{i}_meta.json` +- Copies both to `targets//optimized_models/` as the initial baseline + +### Step 2: Create your candidate + +```bash +# Check the printed fingerprint from step 1, then copy the optimized model +cp torchtitan/experiments/llm_trainer/targets/h100-sm90_tp2_fsdp4/optimized_models/llama3_8b_rank0.py \ + torchtitan/experiments/llm_trainer/targets/h100-sm90_tp2_fsdp4/candidate_models/llama3_8b_rank0.py +``` + +Now edit the candidate file with your optimizations. + +### Step 3: Benchmark + +```bash +NGPU=8 HARDWARE=h100-sm90 ./torchtitan/experiments/llm_trainer/run_benchmarker.sh \ + --module graph_trainer.llama3 \ + --config graph_trainer_llama3_debugmodel +``` + +This will: +1. Load both the optimized and candidate models +2. Run both for 100 training steps (SGD updates) from identical initial state +3. Check that the candidate's final loss is **<= the baseline's** +4. Measure execution time and MFU for both +5. Print a comparison report + +### Step 4: Promote (if candidate passes) + +```bash +NGPU=8 HARDWARE=h100-sm90 ./torchtitan/experiments/llm_trainer/run_benchmarker.sh \ + --promote \ + --module graph_trainer.llama3 \ + --config graph_trainer_llama3_debugmodel +``` + +`--promote` copies the candidate to `optimized_models/` **only if** the +candidate's final validation loss is <= the baseline's AND it is at least +1% faster on every benchmark run (default: 3 consecutive runs; override +with `--promote-runs N`). A comment is inserted at the top of the promoted +file recording the MFU and timestamp, making the optimization history +self-documenting. + +## Understanding the Model File + +The generated `.py` file contains a single `GraphModule` class with a +`forward` method. Here's the structure: + +```python +class GraphModule(torch.nn.Module): + def forward(self, arg0_1: "f32[32768, 256]", arg1_1: "f32[256, 256]", ...): + # Each line is one ATen operation + _to_copy = torch.ops.aten._to_copy.default(arg0_1, dtype=torch.float32) + embedding = torch.ops.aten.embedding.default(_to_copy, arg42_1) + view = torch.ops.aten.view.default(embedding, [256, 256]) + mm = torch.ops.aten.mm.default(view, arg1_1) + ... + return (loss, grad_0, grad_1, ...) +``` + +**Inputs** (arguments to `forward`): +- The first N arguments are model state (parameters and buffers), listed + in the docstring at the top of the file +- The remaining arguments are user data: input tokens, labels, positional + embeddings (e.g., `freqs_cis`), etc. + +**Outputs** (return tuple): +- `[0]` is the loss (scalar tensor) +- `[1..N]` are gradients for each trainable parameter + +## Optimization Rules + +### MUST follow: +1. **Validation loss <= baseline.** When both models are run for 100 training + steps (SGD updates from identical initial state), the candidate's final + loss must be less than or equal to the baseline's. Numerical approximations + are allowed as long as they do not hurt convergence. + +2. **Same function signature.** The `forward` method must accept the same + arguments in the same order with the same types. Do not add, remove, or + reorder arguments. + +3. **Same output count and order.** Return the same number of tensors in the + same order. + +4. **Keep it a single `GraphModule` class.** The benchmarker imports + `GraphModule` from the file. + +### CAN do: +- **Fuse operations.** Replace sequences of elementwise ops with fused + alternatives. E.g., replace separate `mul` + `add` with `addcmul`, or + use `F.silu` instead of `sigmoid` + `mul`. + +- **Use higher-level APIs.** Replace manual attention patterns with + `F.scaled_dot_product_attention`. Replace manual RMS norm with a + fused kernel. + +- **Eliminate redundant ops.** Remove no-op views, unnecessary copies, + identity slices, redundant type casts. + +- **Reorder independent ops.** If two operations don't depend on each + other, you can reorder them for better memory locality or to enable + kernel fusion. + +- **Use custom Triton kernels.** Write inline Triton kernels for hot + sequences of operations. Define them at module level and call from + `forward`. + +- **Add helper functions.** Define helper functions or classes in the + same file if it helps organize the code. + +### CANNOT do: +- Introduce changes that cause the final validation loss to exceed the baseline +- Change input/output signatures +- Import external packages not already available (torch, triton are fine) +- Remove or skip gradient computations +- **Modify files outside of `candidate_models/` (and helper files you + create alongside them).** All optimizations — including custom CUDA/Triton + kernels, helper functions, and binding code — must live within or next to + the candidate model files under `targets//candidate_models/`. + Never modify core torchtitan source, `flattened_models/`, or any other + files in the repository. The `optimized_models/` directory is managed + exclusively by the `--promote` workflow. + +## Common Options + +Both shell scripts require: + +``` +HARDWARE= Hardware label (e.g. h100-sm90). Combined with parallelism + config to form the fingerprint directory name. +NGPU= Number of GPUs (default: 1) +``` + +### Benchmarker-specific Options + +``` +--promote Auto-promote candidate if valid AND >=1% faster + on all benchmark runs +--promote-runs N Number of consecutive runs that must all pass (default: 3) +--num-model-warmup N Model warmup calls BEFORE validation check (default: 3) +--num-validation-steps N Training steps for validation (default: 100) +--validation-lr LR SGD learning rate for validation steps (default: 1e-3) +--num-warmup N Warmup iterations before timing (default: 5) +--num-bench N Benchmark iterations per run (default: 20) +``` + +### How Benchmarking Works + +The benchmarker runs each model through three phases: + +1. **Model warmup** (`--num-model-warmup`, default 3): The model is called + N times *before* the validation check or any timing. These calls are not + timed and not checked for correctness. This phase exists so that models + can initialize internal state — for example, populating caches, warming + up JIT compilers, or recording CUDA graphs. Your candidate model can do + arbitrary work during these calls (build lookup tables, capture graphs, + profile and specialize) as long as it produces correct outputs from call + N+1 onward. + +2. **Validation check** (`--num-validation-steps`, default 100): Both models + are run for N training steps from identical initial state. On each step + the model is called, the loss is recorded, and the state parameters are + updated via SGD (`param -= lr * grad`) using `--validation-lr` (default + 1e-3). The candidate passes if its final loss is <= the baseline's. + +3. **Benchmark loop** (`--num-warmup` + `--num-bench`): Additional warmup + iterations (not timed), then timed iterations measured with CUDA events. + The median time across `--num-bench` iterations is reported. + +Because model warmup calls are never timed or checked, you are free to use +compile-like transformations in your candidate. For example, you can cache +CPU-side scalar values during warmup and replay them in later calls, or +capture the entire forward pass as a CUDA graph during warmup and replay it +during benchmark. The only requirement is that the candidate's validation +loss does not exceed the baseline's after the configured number of steps. + +## Supported Models + +The flattener works with any model supported by graph_trainer. Currently: + +| Model | Module flag | Config flag | +|-------|-----------|-------------| +| Llama 3 (debug) | `graph_trainer.llama3` | `graph_trainer_llama3_debugmodel` | +| Llama 3 (8B) | `graph_trainer.llama3` | `graph_trainer_llama3_8b` | +| DeepSeek V3 (debug) | `graph_trainer.deepseek_v3` | `graph_trainer_deepseek_v3_debugmodel` | +| Qwen 3 (debug) | `graph_trainer.qwen3` | `graph_trainer_qwen3_debugmodel` | diff --git a/torchtitan/experiments/llm_trainer/__init__.py b/torchtitan/experiments/llm_trainer/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtitan/experiments/llm_trainer/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/experiments/llm_trainer/benchmarker.py b/torchtitan/experiments/llm_trainer/benchmarker.py new file mode 100644 index 0000000000..e539959bd7 --- /dev/null +++ b/torchtitan/experiments/llm_trainer/benchmarker.py @@ -0,0 +1,472 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Benchmarker: validates convergence and compares performance between the +current optimized model and a candidate model. + +Usage (via shell script): + NGPU=8 HARDWARE=h100-sm90 ./torchtitan/experiments/llm_trainer/run_benchmarker.sh \ + --module graph_trainer.llama3 \ + --config graph_trainer_llama3_debugmodel + + # Auto-promote if valid AND >= 1% faster: + NGPU=8 HARDWARE=h100-sm90 ./torchtitan/experiments/llm_trainer/run_benchmarker.sh \ + --promote \ + --module graph_trainer.llama3 \ + --config graph_trainer_llama3_debugmodel + +This will: + 1. Load metadata (input shapes, FLOPs info) from optimized_models/ + 2. Import both the optimized and candidate GraphModules + 3. Run both for N training steps (default 100) and compare final loss + 4. Benchmark execution time and compute MFU for both + 5. Print a comparison report + 6. With --promote: promote candidate if valid AND >= 1% faster +""" + +import argparse +import importlib.util +import json +import os +import sys +from datetime import datetime +from pathlib import Path +from statistics import median + +import torch +import torch.distributed as dist + +from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.tools import utils +from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.utils import get_peak_flops + + +def _build_fingerprint(hardware, parallel_dims): + """Build a directory fingerprint encoding hardware + parallelism. + + Only includes parallelism dimensions > 1 to keep the name concise. + E.g. "h100-sm90" with tp=2, dp_shard=4 -> "h100-sm90_tp2_fsdp4" + """ + parts = [hardware] + dims = [ + ("tp", parallel_dims.tp), + ("fsdp", parallel_dims.dp_shard), + ("dp", parallel_dims.dp_replicate), + ("pp", parallel_dims.pp), + ("cp", parallel_dims.cp), + ("ep", parallel_dims.ep), + ] + for name, val in dims: + if val > 1: + parts.append(f"{name}{val}") + return "_".join(parts) + + +_PROMOTE_SPEEDUP_THRESHOLD = 1.01 +_PROMOTE_RUNS_DEFAULT = 3 + + +def _parse_args(): + """Parse benchmarker-specific args, forwarding the rest to ConfigManager.""" + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--promote", action="store_true", default=False) + parser.add_argument( + "--promote-runs", + type=int, + default=_PROMOTE_RUNS_DEFAULT, + help=( + "Number of consecutive benchmark runs that must all show " + f">={(_PROMOTE_SPEEDUP_THRESHOLD - 1) * 100:.0f}%% speedup " + f"for promotion (default: {_PROMOTE_RUNS_DEFAULT})." + ), + ) + parser.add_argument("--num-warmup", type=int, default=5) + parser.add_argument("--num-bench", type=int, default=20) + parser.add_argument( + "--num-model-warmup", + type=int, + default=3, + help=( + "Number of warmup calls to each model BEFORE the validation " + "check and benchmarking. Allows models to initialize internal " + "state (e.g. CUDA graph capture). Default: 3." + ), + ) + parser.add_argument( + "--num-validation-steps", + type=int, + default=100, + help=( + "Number of training steps to run for validation. The candidate's " + "final loss must be <= the optimized model's final loss. Default: 100." + ), + ) + parser.add_argument( + "--validation-lr", + type=float, + default=1e-3, + help="Learning rate for SGD updates during validation steps. Default: 1e-3.", + ) + parser.add_argument( + "--hardware", + type=str, + required=True, + help="Hardware/software specialization namespace (e.g. h100-sm90).", + ) + known, remaining = parser.parse_known_args() + return known, remaining + + +def _load_metadata(meta_path): + """Load the metadata JSON generated by the flattener.""" + with open(meta_path) as f: + return json.load(f) + + +def _create_inputs_from_metadata(meta, device): + """Create deterministic random tensors matching the metadata specs.""" + torch.manual_seed(42) + inputs = [] + for spec in meta["input_specs"]: + if "dtype" in spec: + dtype = getattr(torch, spec["dtype"].replace("torch.", "")) + shape = spec["shape"] + if dtype.is_floating_point: + inputs.append(torch.randn(shape, dtype=dtype, device=device) * 0.01) + elif dtype in (torch.int32, torch.int64): + inputs.append(torch.randint(0, 1000, shape, dtype=dtype, device=device)) + else: + inputs.append(torch.zeros(shape, dtype=dtype, device=device)) + else: + inputs.append(spec["value"]) + return inputs + + +def _load_graph_module(filepath): + """Import a model file and return an instantiated GraphModule.""" + if not filepath.exists(): + return None + spec = importlib.util.spec_from_file_location("_bench_model", str(filepath)) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod.GraphModule() + + +def _run_validation_steps(model, inputs, num_state, num_steps, lr): + """Run model for num_steps training steps with SGD updates. + + Returns a list of per-step loss values. The state inputs (first + num_state entries) are cloned and updated with gradients each step; + data inputs are reused unchanged. + """ + state = [ + inp.clone() if isinstance(inp, torch.Tensor) else inp + for inp in inputs[:num_state] + ] + data = list(inputs[num_state:]) + + losses = [] + for step in range(num_steps): + current_inputs = list(state) + data + with torch.no_grad(): + outputs = model(*current_inputs) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + + losses.append(outputs[0].item()) + + grads = outputs[1:] + for i in range(min(len(grads), num_state)): + g = grads[i] + if ( + isinstance(g, torch.Tensor) + and isinstance(state[i], torch.Tensor) + and g.shape == state[i].shape + ): + state[i] = state[i] - lr * g + + return losses + + +def _benchmark_model(model, inputs, num_warmup, num_bench): + """Benchmark a model's execution time using CUDA events. + + Returns median time in milliseconds. + """ + for _ in range(num_warmup): + with torch.no_grad(): + model(*inputs) + torch.cuda.synchronize() + + times = [] + for _ in range(num_bench): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + with torch.no_grad(): + model(*inputs) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + return median(times) + + +def _promote_with_mfu_comment(candidate_path, optimized_path, cand_mfu, fingerprint): + """Copy candidate to optimized, inserting an MFU comment at the top.""" + code = candidate_path.read_text() + + mfu_str = f"{cand_mfu:.2f}%" if cand_mfu is not None else "N/A" + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M") + comment = f"# MFU: {mfu_str} (promoted {timestamp}, fingerprint: {fingerprint})\n" + + optimized_path.parent.mkdir(parents=True, exist_ok=True) + optimized_path.write_text(comment + code) + + +def main(): + init_logger() + + bench_args, remaining_args = _parse_args() + + from torchtitan.config import ConfigManager + + config_manager = ConfigManager() + sys.argv = [sys.argv[0]] + remaining_args + config = config_manager.parse_args(remaining_args) + + config.compile.mode = "aot_fx_trace" + + device_type = utils.device_type + device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") + torch.cuda.set_device(device) + + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + world_size = dist_utils.init_distributed( + config.comm, + enable_cpu_backend=config.training.enable_cpu_offload, + base_folder=config.dump_folder, + ) + + rank = dist.get_rank() + parallel_dims = ParallelDims.from_config(config.parallelism, world_size) + + model_spec = config.model_spec + if model_spec is None: + raise ValueError("model_spec must be set. Pass --module to specify the model.") + + model_name = model_spec.name.split("/")[-1] + output_name = f"{model_name}_{model_spec.flavor}" + fingerprint = _build_fingerprint(bench_args.hardware, parallel_dims) + + llm_trainer_dir = Path(__file__).parent + fingerprint_dir = llm_trainer_dir / "targets" / fingerprint + optimized_dir = fingerprint_dir / "optimized_models" + candidate_dir = fingerprint_dir / "candidate_models" + + model_filename = f"{output_name}_rank{rank}.py" + meta_filename = f"{output_name}_rank{rank}_meta.json" + + meta_path = optimized_dir / meta_filename + if not meta_path.exists(): + raise FileNotFoundError( + f"Metadata not found at {meta_path}. " + f"Run the flattener first to generate the baseline." + ) + + optimized_path = optimized_dir / model_filename + candidate_path = candidate_dir / model_filename + + if not optimized_path.exists(): + raise FileNotFoundError( + f"Optimized model not found at {optimized_path}. " + f"Run the flattener first." + ) + if not candidate_path.exists(): + raise FileNotFoundError( + f"Candidate model not found at {candidate_path}. " + f"Copy the optimized model to {candidate_dir}/ and optimize it." + ) + + meta = _load_metadata(meta_path) + + print(f"\n{'=' * 60}") + print(f" LLM Trainer Benchmarker") + print(f" Model: {output_name} (rank {rank})") + print(f" Fingerprint: {fingerprint}") + print(f"{'=' * 60}\n") + + print("Loading models...") + optimized_gm = _load_graph_module(optimized_path) + candidate_gm = _load_graph_module(candidate_path) + print(f" Optimized: {optimized_path}") + print(f" Candidate: {candidate_path}") + + print("\nCreating inputs...") + inputs = _create_inputs_from_metadata(meta, device) + num_tensor_inputs = sum(1 for x in inputs if isinstance(x, torch.Tensor)) + print(f" {len(inputs)} inputs ({num_tensor_inputs} tensors)") + + # Model warmup: let models initialize internal state (e.g. CUDA graph + # capture, JIT compilation, caching). These calls happen BEFORE the + # validation check so the check tests the "warmed up" code path. + if bench_args.num_model_warmup > 0: + print(f"\nModel warmup: {bench_args.num_model_warmup} calls per model...") + with torch.no_grad(): + for wi in range(bench_args.num_model_warmup): + optimized_gm(*inputs) + candidate_gm(*inputs) + torch.cuda.synchronize() + print(" Done.") + + num_state = meta["num_state_inputs"] + num_steps = bench_args.num_validation_steps + val_lr = bench_args.validation_lr + + print(f"\n--- Validation ({num_steps} training steps, lr={val_lr}) ---") + + print(" Running optimized model...") + opt_losses = _run_validation_steps( + optimized_gm, inputs, num_state, num_steps, val_lr + ) + + print(" Running candidate model...") + cand_losses = _run_validation_steps( + candidate_gm, inputs, num_state, num_steps, val_lr + ) + + print(f"\n {'Step':>6} {'Optimized Loss':>15} {'Candidate Loss':>15}") + print(f" {'----':>6} {'-' * 15} {'-' * 15}") + for step in range(num_steps): + if step == 0 or (step + 1) % 10 == 0: + print( + f" {step + 1:>6} {opt_losses[step]:>15.6f}" + f" {cand_losses[step]:>15.6f}" + ) + + valid = cand_losses[-1] <= opt_losses[-1] + if valid: + print( + f"\n RESULT: CANDIDATE VALID " + f"(final loss {cand_losses[-1]:.6f} <= {opt_losses[-1]:.6f})" + ) + else: + print( + f"\n RESULT: CANDIDATE FAILED " + f"(final loss {cand_losses[-1]:.6f} > {opt_losses[-1]:.6f})" + ) + + num_flops_per_token = meta.get("num_flops_per_token", 0) + seq_len = meta.get("seq_len", 1) + local_batch_size = meta.get("local_batch_size", 1) + tokens_per_step = local_batch_size * seq_len + device_name = torch.cuda.get_device_name(device) + gpu_peak_flops = get_peak_flops(device_name) + + num_bench_rounds = bench_args.promote_runs if bench_args.promote else 1 + pct = (_PROMOTE_SPEEDUP_THRESHOLD - 1) * 100 + speedups = [] + cand_mfu = None + + for run_idx in range(num_bench_rounds): + if num_bench_rounds > 1: + print( + f"\n--- Performance Benchmark (run {run_idx + 1}/{num_bench_rounds}) ---" + ) + else: + print(f"\n--- Performance Benchmark ---") + print( + f" Warmup: {bench_args.num_warmup} iters, " + f"Bench: {bench_args.num_bench} iters" + ) + + opt_time = _benchmark_model( + optimized_gm, inputs, bench_args.num_warmup, bench_args.num_bench + ) + cand_time = _benchmark_model( + candidate_gm, inputs, bench_args.num_warmup, bench_args.num_bench + ) + + speedup = opt_time / cand_time if cand_time > 0 else float("inf") + speedups.append(speedup) + + opt_tps = tokens_per_step / (opt_time / 1000.0) + cand_tps = tokens_per_step / (cand_time / 1000.0) + + if num_flops_per_token > 0 and gpu_peak_flops > 0: + opt_mfu = 100 * num_flops_per_token * opt_tps / gpu_peak_flops + cand_mfu = 100 * num_flops_per_token * cand_tps / gpu_peak_flops + else: + opt_mfu = None + cand_mfu = None + + print(f"\n {'Metric':<25} {'Optimized':>15} {'Candidate':>15}") + print(f" {'-' * 55}") + print(f" {'Time (ms)':<25} {opt_time:>15.3f} {cand_time:>15.3f}") + print(f" {'Tokens/sec':<25} {opt_tps:>15,.0f} {cand_tps:>15,.0f}") + if opt_mfu is not None and cand_mfu is not None: + print(f" {'MFU (%)':<25} {opt_mfu:>15.2f} {cand_mfu:>15.2f}") + print(f" {'Speedup':<25} {'':>15} {speedup:>15.3f}x") + + if speedup < _PROMOTE_SPEEDUP_THRESHOLD and bench_args.promote: + print( + f"\n Run {run_idx + 1}: speedup {speedup:.3f}x < {_PROMOTE_SPEEDUP_THRESHOLD:.2f}x threshold, stopping early" + ) + break + + all_passed = all(s >= _PROMOTE_SPEEDUP_THRESHOLD for s in speedups) + enough_runs = len(speedups) == num_bench_rounds + + print(f"\n--- Summary ---") + if not valid: + print(f" CANDIDATE FAILED: final loss exceeds baseline") + elif all_passed and enough_runs: + print( + f" CANDIDATE IS VALID AND >={pct:.0f}% FASTER " + f"({len(speedups)}/{num_bench_rounds} runs passed)" + ) + elif any(s >= _PROMOTE_SPEEDUP_THRESHOLD for s in speedups): + passed = sum(1 for s in speedups if s >= _PROMOTE_SPEEDUP_THRESHOLD) + print( + f" CANDIDATE IS VALID BUT INCONSISTENT " + f"({passed}/{num_bench_rounds} runs >={pct:.0f}% faster)" + ) + elif speedups[-1] >= 1.0: + print(f" CANDIDATE IS VALID BUT <{pct:.0f}% FASTER (not enough to promote)") + else: + print(f" CANDIDATE IS VALID BUT SLOWER") + + if bench_args.promote: + can_promote = valid and all_passed and enough_runs + if can_promote: + print(f"\n Promoting candidate to optimized...") + _promote_with_mfu_comment( + candidate_path, optimized_path, cand_mfu, fingerprint + ) + print(f" Done! {candidate_path.name} -> {optimized_path}") + elif not valid: + print( + f"\n Cannot promote: candidate final loss exceeds baseline. " + f"Fix convergence first." + ) + else: + print( + f"\n Cannot promote: candidate must be >={pct:.0f}% faster " + f"on all {num_bench_rounds} runs." + ) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/llm_trainer/flattener.py b/torchtitan/experiments/llm_trainer/flattener.py new file mode 100644 index 0000000000..a12db6d379 --- /dev/null +++ b/torchtitan/experiments/llm_trainer/flattener.py @@ -0,0 +1,550 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Flattener: traces a model's forward+backward training step and writes it +as a straight-line sequence of PyTorch ops in a standalone Python file. + +Usage (via shell script): + NGPU=8 HARDWARE=h100-sm90 ./torchtitan/experiments/llm_trainer/run_flattener.sh \ + --module graph_trainer.llama3 \ + --config graph_trainer_llama3_debugmodel + +Usage (direct torchrun): + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 \ + -m torchtitan.experiments.llm_trainer.flattener \ + --hardware h100-sm90 \ + --module graph_trainer.llama3 \ + --config graph_trainer_llama3_debugmodel + +This will: + 1. Build and parallelize the model across GPUs via torchrun + 2. Trace each rank's fwd+loss+bwd step via make_fx + 3. Write the traced graph to /flattened_models/_rank{i}.py + 4. Verify bitwise equivalence between in-memory graph and generated file + 5. Copy the verified file to /optimized_models/ as the initial baseline +""" + +import argparse +import contextlib +import importlib.util +import json +import os +import shutil +from pathlib import Path +from typing import Any + +import torch +import torch.distributed as dist + +from torchtitan.config import ConfigManager, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.experiments.graph_trainer.common_utils import ( + maybe_register_blockmask_pytree_node, +) +from torchtitan.experiments.graph_trainer.make_fx_tracer import trace_train_step +from torchtitan.experiments.graph_trainer.trainer import make_fwd_bwd_step +from torchtitan.models.common.decoder import Decoder +from torchtitan.tools import utils +from torchtitan.tools.logging import init_logger, logger + + +def _parse_flattener_args(): + """Parse flattener-specific args, forwarding the rest to ConfigManager.""" + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument( + "--hardware", + type=str, + required=True, + help="Hardware/software specialization namespace (e.g. h100-sm90).", + ) + known, remaining = parser.parse_known_args() + return known, remaining + + +def _build_fingerprint(hardware, parallel_dims): + """Build a directory fingerprint encoding hardware + parallelism. + + Only includes parallelism dimensions > 1 to keep the name concise. + E.g. "h100-sm90" with tp=2, dp_shard=4 -> "h100-sm90_tp2_fsdp4" + """ + parts = [hardware] + dims = [ + ("tp", parallel_dims.tp), + ("fsdp", parallel_dims.dp_shard), + ("dp", parallel_dims.dp_replicate), + ("pp", parallel_dims.pp), + ("cp", parallel_dims.cp), + ("ep", parallel_dims.ep), + ] + for name, val in dims: + if val > 1: + parts.append(f"{name}{val}") + return "_".join(parts) + + +def _derive_output_name(model_spec): + """Derive output filename stem from model spec. + + E.g. model_spec.name="graph_trainer/llama3", flavor="debugmodel" + -> "llama3_debugmodel" + """ + model_name = model_spec.name.split("/")[-1] + return f"{model_name}_{model_spec.flavor}" + + +def _bitwise_equal(a, b): + """Compare two tensors at the byte level (handles NaN correctly).""" + if a.shape != b.shape or a.dtype != b.dtype: + return False + a_flat = a.contiguous().reshape(-1) + b_flat = b.contiguous().reshape(-1) + if a_flat.numel() == 0: + return True + return torch.equal( + a_flat.view(torch.uint8), + b_flat.view(torch.uint8), + ) + + +def _create_real_inputs(example_inputs, device): + """Create real tensors matching FakeTensor shapes from tracing.""" + torch.manual_seed(42) + real = [] + for x in example_inputs: + if isinstance(x, torch.Tensor): + shape = tuple(int(s) for s in x.shape) + if x.is_floating_point(): + real.append(torch.randn(shape, dtype=x.dtype, device=device) * 0.01) + elif x.dtype in (torch.int32, torch.int64): + real.append(torch.randint(0, 1000, shape, dtype=x.dtype, device=device)) + else: + real.append(torch.zeros(shape, dtype=x.dtype, device=device)) + else: + real.append(x) + return real + + +def _load_graph_module(filepath): + """Import a generated model file and return an instantiated GraphModule.""" + spec = importlib.util.spec_from_file_location("_gen_model", str(filepath)) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod.GraphModule() + + +def _setup_model(config): + """Build and parallelize the model. + + Expects torch.distributed to already be initialized (via torchrun). + """ + device_type = utils.device_type + device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") + torch.cuda.set_device(device) + + world_size = dist_utils.init_distributed( + config.comm, + enable_cpu_backend=config.training.enable_cpu_offload, + base_folder=config.dump_folder, + ) + + parallel_dims = ParallelDims.from_config(config.parallelism, world_size) + + dist_utils.set_determinism( + parallel_dims, + device, + config.debug, + distinct_seed_mesh_dims=["pp"], + ) + + model_spec = config.model_spec + if model_spec is None: + raise ValueError( + "model_spec must be set. Pass --module to specify the model " + "(e.g. --module graph_trainer.llama3)." + ) + + model_config = model_spec.model + model_config.update_from_config(trainer_config=config) + + logger.info(f"Building {model_spec.name} {model_spec.flavor} on meta device") + with ( + torch.device("meta"), + utils.set_default_dtype(TORCH_DTYPE_MAP[config.training.dtype]), + ): + model = model_config.build() + + model.verify_module_protocol() + + compile_config = config.compile + + model = model_spec.parallelize_fn( + model, + parallel_dims=parallel_dims, + training=config.training, + model_converters=config.model_converters, + parallelism=config.parallelism, + compile_config=compile_config, + ac_config=config.activation_checkpoint, + dump_folder=config.dump_folder, + ) + + model.to_empty(device=device_type) + with torch.no_grad(): + model.init_weights(buffer_device=None) + model.train() + + logger.info("Model parallelized and materialized") + + tokenizer = config.tokenizer.build(tokenizer_path=config.hf_assets_path) + + return ( + model, + model_config, + model_spec, + compile_config, + parallel_dims, + device, + tokenizer, + ) + + +def _generate_model_file(gm, state_fqns, output_name, rank, output_dir): + """Write the traced GraphModule as a standalone Python file.""" + code = gm.print_readable( + print_output=False, + include_stride=True, + include_device=True, + ) + + state_lines = [] + for i, fqn in enumerate(state_fqns): + state_lines.append(f"# [{i}] {fqn}") + state_desc = "\n".join(state_lines) + + header = f'''""" +Flattened model: {output_name} (rank {rank}) + +Auto-generated by llm_trainer flattener. DO NOT EDIT this file directly. +To optimize: copy this to candidate_models/ and modify the copy. + +This file contains the complete forward+backward training step as a +straight-line sequence of PyTorch operations. All model parameters, +inputs, and intermediate activations flow through the forward() method +as explicit tensor arguments. + +State inputs (model parameters/buffers, in order): +{state_desc} + +Remaining inputs are user data (tokens, labels, positional embeddings, etc.) + +Outputs (in order): + [0] loss (scalar) + [1..N] gradients for each trainable parameter + +IMPORTANT: Any optimized version MUST produce a validation loss that is +<= the baseline's after 100 training steps. The benchmarker verifies this. +""" +from math import inf, nan +import torch +import torch.nn as nn +from torch import device, tensor + + +''' + + import re + + code = re.sub( + r"class\s+\S+\(torch\.nn\.Module\)", + "class GraphModule(torch.nn.Module)", + code, + count=1, + ) + + full_code = header + code + "\n" + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + filepath = output_dir / f"{output_name}_rank{rank}.py" + filepath.write_text(full_code) + return filepath + + +def _generate_metadata( + state_fqns, + example_inputs, + output_name, + rank, + num_flops_per_token, + model_param_count, + config, + parallel_dims, + output_dir, +): + """Save input specs and model info as JSON for the benchmarker.""" + specs = [] + for i, inp in enumerate(example_inputs): + if isinstance(inp, torch.Tensor): + specs.append( + { + "index": i, + "dtype": str(inp.dtype), + "shape": [int(s) for s in inp.shape], + } + ) + else: + try: + json.dumps(inp) + value = inp + except (TypeError, ValueError): + value = str(inp) + specs.append( + { + "index": i, + "type": type(inp).__name__, + "value": value, + } + ) + + metadata = { + "output_name": output_name, + "rank": rank, + "num_state_inputs": len(state_fqns), + "state_fqns": state_fqns, + "input_specs": specs, + "num_inputs": len(example_inputs), + "num_flops_per_token": num_flops_per_token, + "model_param_count": model_param_count, + "seq_len": config.training.seq_len, + "local_batch_size": config.training.local_batch_size, + "world_size": parallel_dims.world_size, + } + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + filepath = output_dir / f"{output_name}_rank{rank}_meta.json" + filepath.write_text(json.dumps(metadata, indent=2)) + return filepath + + +def _verify_equivalence(gm, example_inputs, model_filepath, device): + """Verify generated file produces bitwise identical results to in-memory graph. + + Creates random tensors matching the traced input shapes, runs both the + in-memory GraphModule and the file-loaded GraphModule, and compares + every output tensor at the byte level. + """ + real_inputs = _create_real_inputs(example_inputs, device) + + with torch.no_grad(): + ref_outputs = gm(*real_inputs) + + gen_gm = _load_graph_module(model_filepath) + with torch.no_grad(): + gen_outputs = gen_gm(*real_inputs) + + if isinstance(ref_outputs, torch.Tensor): + ref_outputs = (ref_outputs,) + if isinstance(gen_outputs, torch.Tensor): + gen_outputs = (gen_outputs,) + if isinstance(ref_outputs, list): + ref_outputs = tuple(ref_outputs) + if isinstance(gen_outputs, list): + gen_outputs = tuple(gen_outputs) + + if len(ref_outputs) != len(gen_outputs): + logger.error( + f"Output count mismatch: in-memory={len(ref_outputs)}, " + f"generated={len(gen_outputs)}" + ) + return False + + all_match = True + for i, (ref, gen) in enumerate(zip(ref_outputs, gen_outputs)): + if isinstance(ref, torch.Tensor) and isinstance(gen, torch.Tensor): + if _bitwise_equal(ref, gen): + if i == 0: + logger.info( + f" Output {i} (loss): MATCH " + f"(shape={list(ref.shape)}, dtype={ref.dtype})" + ) + else: + max_diff = (ref.float() - gen.float()).abs().max().item() + logger.error( + f" Output {i}: MISMATCH! " + f"(shape={list(ref.shape)}, dtype={ref.dtype}, " + f"max_diff={max_diff:.6e})" + ) + all_match = False + if all_match: + n = len(ref_outputs) + logger.info(f" All {n} outputs match (loss + {n - 1} gradients)") + + return all_match + + +def main(): + init_logger() + + flattener_args, remaining_args = _parse_flattener_args() + + config_manager = ConfigManager() + config = config_manager.parse_args(remaining_args) + + config.compile.mode = "aot_fx_trace" + config.debug.deterministic = True + if config.debug.seed is None: + config.debug.seed = 42 + + ( + model, + model_config, + model_spec, + compile_config, + parallel_dims, + device, + tokenizer, + ) = _setup_model(config) + + model_param_count, num_flops_per_token = model_config.get_nparams_and_flops( + model, config.training.seq_len + ) + logger.info( + f"Model params: {model_param_count:,}, " f"FLOPs/token: {num_flops_per_token:,}" + ) + + loss_fn = model_spec.build_loss_fn(compile_config, parallel_dims=parallel_dims) + fwd_bwd_fn = make_fwd_bwd_step(loss_fn) + + seq_len = config.training.seq_len + local_batch_size = config.training.local_batch_size + vocab_size = model_config.vocab_size + + dummy_inputs = torch.randint( + 0, vocab_size, (local_batch_size, seq_len), device=device + ) + dummy_labels = torch.randint( + 0, vocab_size, (local_batch_size, seq_len), device=device + ) + + global_batch_size = ( + local_batch_size + * parallel_dims.dp_shard + * parallel_dims.dp_replicate + * parallel_dims.cp + ) + dummy_global_valid_tokens = float(global_batch_size * seq_len) + extra_inputs: dict[str, torch.Tensor] = {} + extra_kwargs: dict[str, Any] = {} + + if isinstance(model_config, Decoder.Config): + layer = model_config.layers[0] + attn_config = layer.attention + else: + attn_config = None + mask_type = getattr(attn_config, "mask_type", "causal") + + if mask_type == "block_causal": + extra_kwargs["positions"] = ( + torch.arange(0, seq_len, dtype=torch.int32, device=device) + .unsqueeze(0) + .expand(local_batch_size, -1) + ) + elif parallel_dims.cp_enabled: + extra_kwargs["positions"] = torch.arange( + 0, seq_len, dtype=torch.int32, device=device + ).expand(local_batch_size, seq_len) + + inner_attention = getattr(attn_config, "inner_attention", None) + if inner_attention is not None: + from torchtitan.models.common.attention import FlexAttention, VarlenAttention + + if isinstance(inner_attention, (FlexAttention.Config, VarlenAttention.Config)): + extra_kwargs["attention_masks"] = model.get_attention_masks( + input_batch=dummy_inputs, + tokenizer=tokenizer, + extra_inputs=extra_inputs, + ) + + loss_parallel_enabled = ( + parallel_dims.tp_enabled and not config.parallelism.disable_loss_parallel + ) + loss_parallel_ctx = ( + torch.distributed.tensor.parallel.loss_parallel() + if loss_parallel_enabled + else contextlib.nullcontext() + ) + + maybe_register_blockmask_pytree_node() + + logger.info("Tracing fwd+loss+bwd via make_fx...") + with loss_parallel_ctx: + traced_result = trace_train_step(fwd_bwd_fn)( + model, + dummy_inputs, + dummy_labels, + dummy_global_valid_tokens, + extra_inputs, + extra_kwargs, + ) + + gm = traced_result.gm + num_nodes = len(list(gm.graph.nodes)) + logger.info( + f"Traced graph: {num_nodes} nodes, " + f"{len(traced_result.state_fqns)} state entries" + ) + + output_name = _derive_output_name(model_spec) + rank = dist.get_rank() + fingerprint = _build_fingerprint(flattener_args.hardware, parallel_dims) + llm_trainer_dir = Path(__file__).parent + fingerprint_dir = llm_trainer_dir / "targets" / fingerprint + logger.info(f"Fingerprint: {fingerprint}") + + flattened_dir = fingerprint_dir / "flattened_models" + model_file = _generate_model_file( + gm, traced_result.state_fqns, output_name, rank, flattened_dir + ) + logger.info(f"Wrote flattened model: {model_file}") + + meta_file = _generate_metadata( + traced_result.state_fqns, + traced_result.example_inputs, + output_name, + rank, + num_flops_per_token, + model_param_count, + config, + parallel_dims, + flattened_dir, + ) + logger.info(f"Wrote metadata: {meta_file}") + + logger.info("Verifying bitwise equivalence (in-memory graph vs generated file)...") + verified = _verify_equivalence(gm, traced_result.example_inputs, model_file, device) + + if not verified: + logger.error( + "Verification FAILED! The generated file does not match " + "the in-memory graph. This is a bug in the flattener." + ) + dist.destroy_process_group() + raise RuntimeError("Flattener verification failed") + + optimized_dir = fingerprint_dir / "optimized_models" + optimized_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(model_file, optimized_dir / model_file.name) + shutil.copy2(meta_file, optimized_dir / meta_file.name) + logger.info(f"Copied baseline to {optimized_dir}/") + + dist.destroy_process_group() + logger.info("Flattening complete!") + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/llm_trainer/run_benchmarker.sh b/torchtitan/experiments/llm_trainer/run_benchmarker.sh new file mode 100755 index 0000000000..22d385847f --- /dev/null +++ b/torchtitan/experiments/llm_trainer/run_benchmarker.sh @@ -0,0 +1,29 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# Usage: +# NGPU=8 HARDWARE=h100-sm90 ./torchtitan/experiments/llm_trainer/run_benchmarker.sh \ +# --module graph_trainer.llama3 \ +# --config graph_trainer_llama3_debugmodel +# +# # With promote: +# NGPU=8 HARDWARE=h100-sm90 ./torchtitan/experiments/llm_trainer/run_benchmarker.sh \ +# --promote \ +# --module graph_trainer.llama3 \ +# --config graph_trainer_llama3_debugmodel + +NGPU=${NGPU:-"1"} +HARDWARE=${HARDWARE:?'Set HARDWARE (e.g. HARDWARE=h100-sm90)'} +export LOG_RANK=${LOG_RANK:-0} + +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ + --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ + -m torchtitan.experiments.llm_trainer.benchmarker \ + --hardware "${HARDWARE}" \ + "$@" diff --git a/torchtitan/experiments/llm_trainer/run_flattener.sh b/torchtitan/experiments/llm_trainer/run_flattener.sh new file mode 100755 index 0000000000..0ce3f60bd8 --- /dev/null +++ b/torchtitan/experiments/llm_trainer/run_flattener.sh @@ -0,0 +1,23 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# Usage: +# NGPU=8 HARDWARE=h100-sm90 ./torchtitan/experiments/llm_trainer/run_flattener.sh \ +# --module graph_trainer.llama3 \ +# --config graph_trainer_llama3_debugmodel + +NGPU=${NGPU:-"1"} +HARDWARE=${HARDWARE:?'Set HARDWARE (e.g. HARDWARE=h100-sm90)'} +export LOG_RANK=${LOG_RANK:-0} + +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ + --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ + -m torchtitan.experiments.llm_trainer.flattener \ + --hardware "${HARDWARE}" \ + "$@"