From e52c28a9db33168d2fa96a11411949dc7f4cf8e4 Mon Sep 17 00:00:00 2001 From: 3outeille Date: Wed, 5 Nov 2025 16:07:31 +0000 Subject: [PATCH 1/2] seems like the bug comes from loading weights in PP which differs if you use ModuleList vs ModuleDIct --- torchtitan/components/checkpoint.py | 42 +++++++++++++ torchtitan/components/dataloader.py | 16 ++++- .../transformers_backend/__init__.py | 2 +- .../experiments/transformers_backend/aa.sh | 60 +++++++++++++++++++ torchtitan/train.py | 53 +++++++++++++++- 5 files changed, 167 insertions(+), 6 deletions(-) create mode 100644 torchtitan/experiments/transformers_backend/aa.sh diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index d4c5416aa2..d3c612aa15 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -439,6 +439,10 @@ def dcp_load( its own model definition and safetensors format. """ + logger.info(f"[DCP LOAD] Starting DCP load from: {checkpoint_id}") + logger.info(f"[DCP LOAD] from_hf={from_hf}, from_quantized={from_quantized}") + logger.info(f"[DCP LOAD] State dict keys before load: {list(state_dict.keys())}") + if from_hf: assert ( self.sd_adapter is not None @@ -448,20 +452,56 @@ def dcp_load( checkpoint_id, from_quantized ) + logger.info(f"[DCP LOAD] Loading from HF format with storage reader") dcp.load( hf_state_dict, storage_reader=hf_storage_reader, ) state_dict = self.sd_adapter.from_hf(hf_state_dict) + logger.info(f"[DCP LOAD] Converted from HF format, loading into model") self.states[MODEL].load_state_dict(state_dict) + logger.info(f"[DCP LOAD] Model state loaded successfully from HF") else: + logger.info(f"[DCP LOAD] Loading from native DCP format") dcp.load(state_dict, checkpoint_id=checkpoint_id) + logger.info(f"[DCP LOAD] DCP load completed") # TODO: Since we flatten the model states in state_dict, we need to # manually call load_state_dict() for the model. Need to fix this. if MODEL in self.states: + logger.info(f"[DCP LOAD] Loading model state dict") self.states[MODEL].load_state_dict(state_dict) + logger.info(f"[DCP LOAD] Model state loaded successfully") + + # Log actual parameter values after loading + logger.info(f"[DCP LOAD] ===== Verifying loaded parameter values =====") + if MODEL in self.states: + model_wrapper = self.states[MODEL] + for idx, model_part in enumerate(model_wrapper.model): + logger.info(f"[DCP LOAD] Model part {idx} parameters:") + for name, param in model_part.named_parameters(): + logger.info(f" - {name}:") + logger.info(f" shape: {param.shape}, dtype: {param.dtype}, device: {param.device}") + logger.info(f" mean: {param.data.mean().item():.6f}, std: {param.data.std().item():.6f}") + logger.info(f" min: {param.data.min().item():.6f}, max: {param.data.max().item():.6f}") + # Show first few values + flat_data = param.data.flatten() + logger.info(f" first 10 values: {flat_data[:10].tolist()}") + logger.info(f" is all zeros: {torch.all(param.data == 0).item()}") + logger.info(f"[DCP LOAD] Model part {idx} total parameters: {sum(p.numel() for p in model_part.parameters()):,}") + + # Log dataloader state if loaded + if DATALOADER in self.states and DATALOADER in state_dict: + logger.info(f"[DCP LOAD] Dataloader state check:") + dl_state = self.states[DATALOADER].state_dict() + logger.info(f" Dataloader state keys: {list(dl_state.keys())}") + if hasattr(self.states[DATALOADER], '_rank_id'): + rank_id = self.states[DATALOADER]._rank_id + logger.info(f" Expected rank_id: {rank_id}") + logger.info(f" Rank_id in state: {rank_id in dl_state}") + + logger.info(f"[DCP LOAD] ===== Load verification complete =====") @torch.no_grad() def save(self, curr_step: int, last_step: bool = False) -> None: @@ -619,8 +659,10 @@ def load(self, step: int = -1) -> bool: ) logger.info(f"Loading the checkpoint from {checkpoint_id}.") + logger.info(f"[CHECKPOINT LOAD] model_only={model_only}, from_hf={from_hf}") begin = time.monotonic() states = self._states_to_load(model_only) + logger.info(f"[CHECKPOINT LOAD] States to load: {list(states.keys())}") self.dcp_load( states, checkpoint_id=checkpoint_id, diff --git a/torchtitan/components/dataloader.py b/torchtitan/components/dataloader.py index 071af84d54..c4578b3a2e 100644 --- a/torchtitan/components/dataloader.py +++ b/torchtitan/components/dataloader.py @@ -86,12 +86,18 @@ def state_dict(self) -> dict[str, Any]: def load_state_dict(self, state_dict: dict[str, Any]) -> None: # State being empty is valid. if not state_dict: + logger.info(f"[DATALOADER] load_state_dict called with empty state_dict for dp_rank={self.dp_rank}") return + logger.info(f"[DATALOADER] load_state_dict called for dp_rank={self.dp_rank}") + logger.info(f"[DATALOADER] Expected rank_id: {self._rank_id}") + logger.info(f"[DATALOADER] state_dict keys: {list(state_dict.keys())}") + logger.info(f"[DATALOADER] world_size in state: {state_dict.get('world_size', 'NOT FOUND')}, current: {self.dp_world_size}") + if self._rank_id not in state_dict: logger.warning( f"DataLoader state is empty for dp rank {self.dp_rank}, " - "expected key {self._rank_id}" + f"expected key {self._rank_id}. Available keys: {list(state_dict.keys())}" ) return @@ -99,6 +105,12 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: "dp_degree is inconsistent before and after checkpoint, " "dataloader resharding is not supported yet." ) + + logger.info(f"[DATALOADER] Loading state for {self._rank_id}") + unpickled_state = pickle.loads(state_dict[self._rank_id]) + logger.info(f"[DATALOADER] Unpickled state keys: {list(unpickled_state.keys()) if isinstance(unpickled_state, dict) else type(unpickled_state)}") + # We don't have to use pickle as DCP will serialize the state_dict. However, we have to # keep this for backward compatibility. - super().load_state_dict(pickle.loads(state_dict[self._rank_id])) + super().load_state_dict(unpickled_state) + logger.info(f"[DATALOADER] State loaded successfully for dp_rank={self.dp_rank}") diff --git a/torchtitan/experiments/transformers_backend/__init__.py b/torchtitan/experiments/transformers_backend/__init__.py index b72b77760c..c54b656d0a 100644 --- a/torchtitan/experiments/transformers_backend/__init__.py +++ b/torchtitan/experiments/transformers_backend/__init__.py @@ -47,7 +47,7 @@ class TitanDenseModelArgs: "debugmodel": HFTransformerModelArgs( titan_dense_args=TitanDenseModelArgs( dim=256, - n_layers=6, + n_layers=2, n_heads=16, n_kv_heads=16, ), diff --git a/torchtitan/experiments/transformers_backend/aa.sh b/torchtitan/experiments/transformers_backend/aa.sh new file mode 100644 index 0000000000..971e679c33 --- /dev/null +++ b/torchtitan/experiments/transformers_backend/aa.sh @@ -0,0 +1,60 @@ +# aa.sh +torchrun \ + --nproc_per_node 1 \ + --nnodes 1 \ + --rdzv_endpoint localhost:0 \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --role rank \ + --local_ranks_filter 0 \ + --tee 3 \ + -m torchtitan.train \ + --checkpoint.enable \ + --checkpoint.initial_load_path debug_local_results/meta-llama/Llama-3.2-1B/debugmodel/seed_checkpoint/checkpoint/step-0 \ + --training.seed 42 \ + --training.deterministic \ + --training.steps 2 \ + --job.custom_config_module=torchtitan.experiments.transformers_backend.job_config \ + --job.config_file debug_local_results/meta-llama/Llama-3.2-1B/debugmodel/fsdp1_tp1_cp1_pp1/config.toml \ + 2>&1 | tee log_baseline.txt + +# Rank 0 +CUDA_VISIBLE_DEVICES=0 torchrun \ + --nproc_per_node 1 \ + --nnodes 2 \ + --node_rank 0 \ + --rdzv_endpoint localhost:29500 \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --role rank \ + --tee 3 \ + -m torchtitan.train \ + --checkpoint.enable \ + --checkpoint.initial_load_path debug_local_results/meta-llama/Llama-3.2-1B/debugmodel/seed_checkpoint/checkpoint/step-0 \ + --training.seed 42 \ + --training.deterministic \ + --training.steps 2 \ + --job.custom_config_module=torchtitan.experiments.transformers_backend.job_config \ + --job.config_file debug_local_results/meta-llama/Llama-3.2-1B/debugmodel/fsdp1_tp1_cp1_pp2/config.toml \ + 2>&1 | tee log_pp2_rank0.txt & + + +# rank 1 +CUDA_VISIBLE_DEVICES=1 torchrun \ + --nproc_per_node 1 \ + --nnodes 2 \ + --node_rank 1 \ + --rdzv_endpoint localhost:29500 \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --role rank \ + --tee 3 \ + -m torchtitan.train \ + --checkpoint.enable \ + --checkpoint.initial_load_path debug_local_results/meta-llama/Llama-3.2-1B/debugmodel/seed_checkpoint/checkpoint/step-0 \ + --training.seed 42 \ + --training.steps 2 \ + --training.deterministic \ + --job.custom_config_module=torchtitan.experiments.transformers_backend.job_config \ + --job.config_file debug_local_results/meta-llama/Llama-3.2-1B/debugmodel/fsdp1_tp1_cp1_pp2/config.toml \ + 2>&1 | tee log_pp2_rank1.txt & \ No newline at end of file diff --git a/torchtitan/train.py b/torchtitan/train.py index d4de8bc5d4..ed82e22a04 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - +import lovely_tensors as lt; lt.monkey_patch() import importlib import os import time @@ -32,6 +32,7 @@ maybe_enable_profiling, ) +from torchtitan.utils.test_utils import debug_structure_param class Trainer(torch.distributed.checkpoint.stateful.Stateful): # core configs @@ -151,11 +152,12 @@ def __init__(self, job_config: JobConfig): utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), ): model = self.train_spec.model_cls(model_args) - # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) model_converters.convert(model) + debug_structure_param(model) + # metrics logging build_metrics_processor_fn = ( build_metrics_processor @@ -382,11 +384,29 @@ def batch_generator( """Returns an iterator that processes batches from the data iterator.""" device_type = utils.device_type data_iterator = iter(data_iterable) - + + # Log dataloader state at start + logger.info(f"[BATCH GEN] Creating batch generator") + if hasattr(data_iterable, 'state_dict'): + dl_state = data_iterable.state_dict() + logger.info(f"[BATCH GEN] Dataloader state keys: {list(dl_state.keys())}") + + batch_count = 0 while True: data_load_start = time.perf_counter() try: batch = next(data_iterator) + batch_count += 1 + + # Log first few batches + if batch_count <= 3: + logger.info(f"[BATCH GEN] Fetched batch {batch_count}") + input_dict, labels = batch + if 'input' in input_dict: + inp = input_dict['input'] + logger.info(f" Input: shape={inp.shape}, is_zeros={torch.all(inp==0).item()}, first_val={inp.flatten()[0].item()}") + logger.info(f" Labels: shape={labels.shape}, first_val={labels.flatten()[0].item()}") + except StopIteration as ex: # If data runs out during gradient accumulation, that # entire step will not be executed. @@ -509,6 +529,33 @@ def train_step( # entire step will not be executed. for _microbatch in range(self.gradient_accumulation_steps): input_dict, labels = next(data_iterator) + + # Log first few batches to verify data + if self.step == 1 and _microbatch < 2: + logger.info(f"[BATCH DATA] ===== Step {self.step}, Microbatch {_microbatch} =====") + logger.info(f"[BATCH DATA] input_dict keys: {list(input_dict.keys())}") + + if 'input' in input_dict: + inp = input_dict['input'] + logger.info(f"[BATCH DATA] Input tensor:") + logger.info(f"inp: {inp}") + logger.info(f" is all zeros: {torch.all(inp == 0).item()}") + logger.info(f" is all same value: {torch.all(inp == inp.flatten()[0]).item()}") + logger.info(f" unique values count: {torch.unique(inp).numel()}") + # Show actual token IDs (first 20) + logger.info(f" first 20 token IDs: {inp.flatten()[:20].tolist()}") + # Show last 20 token IDs + logger.info(f" last 20 token IDs: {inp.flatten()[-20:].tolist()}") + # Check if it matches expected data + if torch.all(inp == 0): + logger.warning(f"[BATCH DATA] ⚠️ INPUT IS ALL ZEROS - THIS IS THE ISSUE!") + + logger.info(f"[BATCH DATA] Labels tensor:") + logger.info(f" labels: {labels}") + logger.info(f" first 20 labels: {labels.flatten()[:20].tolist()}") + logger.info(f" last 20 labels: {labels.flatten()[-20:].tolist()}") + logger.info(f" is all same value: {torch.all(labels == labels.flatten()[0]).item()}") + loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) From b0e6efbbe81bc61e82214bb1773196ef9c08e2dd Mon Sep 17 00:00:00 2001 From: 3outeille Date: Thu, 6 Nov 2025 16:14:24 +0000 Subject: [PATCH 2/2] issue when loading weight due to use of ModuleList. Now use of ModuleDict --- torchtitan/components/checkpoint.py | 10 +++--- .../transformers_backend/model/model.py | 35 ++++++++++++++++++- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index d3c612aa15..5c84df6029 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -482,13 +482,13 @@ def dcp_load( logger.info(f"[DCP LOAD] Model part {idx} parameters:") for name, param in model_part.named_parameters(): logger.info(f" - {name}:") - logger.info(f" shape: {param.shape}, dtype: {param.dtype}, device: {param.device}") - logger.info(f" mean: {param.data.mean().item():.6f}, std: {param.data.std().item():.6f}") - logger.info(f" min: {param.data.min().item():.6f}, max: {param.data.max().item():.6f}") + logger.info(f" param: {param}") # Show first few values - flat_data = param.data.flatten() + if isinstance(param, torch.distributed.tensor.DTensor): + flat_data = param.data.to_local().flatten() + else: + flat_data = param.flatten() logger.info(f" first 10 values: {flat_data[:10].tolist()}") - logger.info(f" is all zeros: {torch.all(param.data == 0).item()}") logger.info(f"[DCP LOAD] Model part {idx} total parameters: {sum(p.numel() for p in model_part.parameters()):,}") # Log dataloader state if loaded diff --git a/torchtitan/experiments/transformers_backend/model/model.py b/torchtitan/experiments/transformers_backend/model/model.py index 8041e54f70..21bee85fc8 100644 --- a/torchtitan/experiments/transformers_backend/model/model.py +++ b/torchtitan/experiments/transformers_backend/model/model.py @@ -17,6 +17,31 @@ from .args import HFTransformerModelArgs +class SlicableModuleDict(nn.ModuleDict): + """ + A ModuleDict that supports slicing like ModuleList. + Keys are expected to be string representations of integers (e.g., "0", "1", "2"). + """ + + def __getitem__(self, key): + if isinstance(key, slice): + # Handle slicing: convert slice to list of keys + keys = sorted(self.keys(), key=lambda x: int(x) if x.isdigit() else float('inf')) + sliced_keys = keys[key] + # Return a new SlicableModuleDict with the sliced items + return SlicableModuleDict({k: self[k] for k in sliced_keys}) + return super().__getitem__(key) + + def __iter__(self): + # Iterate over values in sorted order by key (as integers) + keys = sorted(self.keys(), key=lambda x: int(x) if x.isdigit() else float('inf')) + for key in keys: + yield self[key] + + def __len__(self): + return len(self._modules) + + class HFTransformerModel(nn.Module): def __init__(self, model_args: HFTransformerModelArgs): super().__init__() @@ -76,7 +101,15 @@ def __init__(self, model_args: HFTransformerModelArgs): self.max_seq_len = model_args.max_seq_len self.cp_mesh = None - for layer in self.model.model.layers: + # Convert ModuleList to ModuleDict to preserve original indices + # This ensures state dict keys match checkpoint keys + if isinstance(self.model.model.layers, nn.ModuleList): + self.model.model.layers = SlicableModuleDict({ + str(i): layer + for i, layer in enumerate(self.model.model.layers) + }) + + for layer in self.model.model.layers.values(): layer.moe_enabled = False def set_cp_mesh(self, mesh):