Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,10 @@ def dcp_load(
its own model definition and safetensors format.
"""

logger.info(f"[DCP LOAD] Starting DCP load from: {checkpoint_id}")
logger.info(f"[DCP LOAD] from_hf={from_hf}, from_quantized={from_quantized}")
logger.info(f"[DCP LOAD] State dict keys before load: {list(state_dict.keys())}")

if from_hf:
assert (
self.sd_adapter is not None
Expand All @@ -448,20 +452,56 @@ def dcp_load(
checkpoint_id, from_quantized
)

logger.info(f"[DCP LOAD] Loading from HF format with storage reader")
dcp.load(
hf_state_dict,
storage_reader=hf_storage_reader,
)

state_dict = self.sd_adapter.from_hf(hf_state_dict)
logger.info(f"[DCP LOAD] Converted from HF format, loading into model")
self.states[MODEL].load_state_dict(state_dict)
logger.info(f"[DCP LOAD] Model state loaded successfully from HF")
else:
logger.info(f"[DCP LOAD] Loading from native DCP format")
dcp.load(state_dict, checkpoint_id=checkpoint_id)
logger.info(f"[DCP LOAD] DCP load completed")

# TODO: Since we flatten the model states in state_dict, we need to
# manually call load_state_dict() for the model. Need to fix this.
if MODEL in self.states:
logger.info(f"[DCP LOAD] Loading model state dict")
self.states[MODEL].load_state_dict(state_dict)
logger.info(f"[DCP LOAD] Model state loaded successfully")

# Log actual parameter values after loading
logger.info(f"[DCP LOAD] ===== Verifying loaded parameter values =====")
if MODEL in self.states:
model_wrapper = self.states[MODEL]
for idx, model_part in enumerate(model_wrapper.model):
logger.info(f"[DCP LOAD] Model part {idx} parameters:")
for name, param in model_part.named_parameters():
logger.info(f" - {name}:")
logger.info(f" param: {param}")
# Show first few values
if isinstance(param, torch.distributed.tensor.DTensor):
flat_data = param.data.to_local().flatten()
else:
flat_data = param.flatten()
logger.info(f" first 10 values: {flat_data[:10].tolist()}")
logger.info(f"[DCP LOAD] Model part {idx} total parameters: {sum(p.numel() for p in model_part.parameters()):,}")

# Log dataloader state if loaded
if DATALOADER in self.states and DATALOADER in state_dict:
logger.info(f"[DCP LOAD] Dataloader state check:")
dl_state = self.states[DATALOADER].state_dict()
logger.info(f" Dataloader state keys: {list(dl_state.keys())}")
if hasattr(self.states[DATALOADER], '_rank_id'):
rank_id = self.states[DATALOADER]._rank_id
logger.info(f" Expected rank_id: {rank_id}")
logger.info(f" Rank_id in state: {rank_id in dl_state}")

logger.info(f"[DCP LOAD] ===== Load verification complete =====")

@torch.no_grad()
def save(self, curr_step: int, last_step: bool = False) -> None:
Expand Down Expand Up @@ -619,8 +659,10 @@ def load(self, step: int = -1) -> bool:
)

logger.info(f"Loading the checkpoint from {checkpoint_id}.")
logger.info(f"[CHECKPOINT LOAD] model_only={model_only}, from_hf={from_hf}")
begin = time.monotonic()
states = self._states_to_load(model_only)
logger.info(f"[CHECKPOINT LOAD] States to load: {list(states.keys())}")
self.dcp_load(
states,
checkpoint_id=checkpoint_id,
Expand Down
16 changes: 14 additions & 2 deletions torchtitan/components/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,31 @@ def state_dict(self) -> dict[str, Any]:
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
# State being empty is valid.
if not state_dict:
logger.info(f"[DATALOADER] load_state_dict called with empty state_dict for dp_rank={self.dp_rank}")
return

logger.info(f"[DATALOADER] load_state_dict called for dp_rank={self.dp_rank}")
logger.info(f"[DATALOADER] Expected rank_id: {self._rank_id}")
logger.info(f"[DATALOADER] state_dict keys: {list(state_dict.keys())}")
logger.info(f"[DATALOADER] world_size in state: {state_dict.get('world_size', 'NOT FOUND')}, current: {self.dp_world_size}")

if self._rank_id not in state_dict:
logger.warning(
f"DataLoader state is empty for dp rank {self.dp_rank}, "
"expected key {self._rank_id}"
f"expected key {self._rank_id}. Available keys: {list(state_dict.keys())}"
)
return

assert self.dp_world_size == state_dict["world_size"], (
"dp_degree is inconsistent before and after checkpoint, "
"dataloader resharding is not supported yet."
)

logger.info(f"[DATALOADER] Loading state for {self._rank_id}")
unpickled_state = pickle.loads(state_dict[self._rank_id])
logger.info(f"[DATALOADER] Unpickled state keys: {list(unpickled_state.keys()) if isinstance(unpickled_state, dict) else type(unpickled_state)}")

# We don't have to use pickle as DCP will serialize the state_dict. However, we have to
# keep this for backward compatibility.
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
super().load_state_dict(unpickled_state)
logger.info(f"[DATALOADER] State loaded successfully for dp_rank={self.dp_rank}")
2 changes: 1 addition & 1 deletion torchtitan/experiments/transformers_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class TitanDenseModelArgs:
"debugmodel": HFTransformerModelArgs(
titan_dense_args=TitanDenseModelArgs(
dim=256,
n_layers=6,
n_layers=2,
n_heads=16,
n_kv_heads=16,
),
Expand Down
60 changes: 60 additions & 0 deletions torchtitan/experiments/transformers_backend/aa.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# aa.sh
torchrun \
--nproc_per_node 1 \
--nnodes 1 \
--rdzv_endpoint localhost:0 \
--rdzv_backend c10d \
--max_restarts 0 \
--role rank \
--local_ranks_filter 0 \
--tee 3 \
-m torchtitan.train \
--checkpoint.enable \
--checkpoint.initial_load_path debug_local_results/meta-llama/Llama-3.2-1B/debugmodel/seed_checkpoint/checkpoint/step-0 \
--training.seed 42 \
--training.deterministic \
--training.steps 2 \
--job.custom_config_module=torchtitan.experiments.transformers_backend.job_config \
--job.config_file debug_local_results/meta-llama/Llama-3.2-1B/debugmodel/fsdp1_tp1_cp1_pp1/config.toml \
2>&1 | tee log_baseline.txt

# Rank 0
CUDA_VISIBLE_DEVICES=0 torchrun \
--nproc_per_node 1 \
--nnodes 2 \
--node_rank 0 \
--rdzv_endpoint localhost:29500 \
--rdzv_backend c10d \
--max_restarts 0 \
--role rank \
--tee 3 \
-m torchtitan.train \
--checkpoint.enable \
--checkpoint.initial_load_path debug_local_results/meta-llama/Llama-3.2-1B/debugmodel/seed_checkpoint/checkpoint/step-0 \
--training.seed 42 \
--training.deterministic \
--training.steps 2 \
--job.custom_config_module=torchtitan.experiments.transformers_backend.job_config \
--job.config_file debug_local_results/meta-llama/Llama-3.2-1B/debugmodel/fsdp1_tp1_cp1_pp2/config.toml \
2>&1 | tee log_pp2_rank0.txt &


# rank 1
CUDA_VISIBLE_DEVICES=1 torchrun \
--nproc_per_node 1 \
--nnodes 2 \
--node_rank 1 \
--rdzv_endpoint localhost:29500 \
--rdzv_backend c10d \
--max_restarts 0 \
--role rank \
--tee 3 \
-m torchtitan.train \
--checkpoint.enable \
--checkpoint.initial_load_path debug_local_results/meta-llama/Llama-3.2-1B/debugmodel/seed_checkpoint/checkpoint/step-0 \
--training.seed 42 \
--training.steps 2 \
--training.deterministic \
--job.custom_config_module=torchtitan.experiments.transformers_backend.job_config \
--job.config_file debug_local_results/meta-llama/Llama-3.2-1B/debugmodel/fsdp1_tp1_cp1_pp2/config.toml \
2>&1 | tee log_pp2_rank1.txt &
35 changes: 34 additions & 1 deletion torchtitan/experiments/transformers_backend/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,31 @@
from .args import HFTransformerModelArgs


class SlicableModuleDict(nn.ModuleDict):
"""
A ModuleDict that supports slicing like ModuleList.
Keys are expected to be string representations of integers (e.g., "0", "1", "2").
"""

def __getitem__(self, key):
if isinstance(key, slice):
# Handle slicing: convert slice to list of keys
keys = sorted(self.keys(), key=lambda x: int(x) if x.isdigit() else float('inf'))
sliced_keys = keys[key]
# Return a new SlicableModuleDict with the sliced items
return SlicableModuleDict({k: self[k] for k in sliced_keys})
return super().__getitem__(key)

def __iter__(self):
# Iterate over values in sorted order by key (as integers)
keys = sorted(self.keys(), key=lambda x: int(x) if x.isdigit() else float('inf'))
for key in keys:
yield self[key]

def __len__(self):
return len(self._modules)


class HFTransformerModel(nn.Module):
def __init__(self, model_args: HFTransformerModelArgs):
super().__init__()
Expand Down Expand Up @@ -76,7 +101,15 @@ def __init__(self, model_args: HFTransformerModelArgs):
self.max_seq_len = model_args.max_seq_len
self.cp_mesh = None

for layer in self.model.model.layers:
# Convert ModuleList to ModuleDict to preserve original indices
# This ensures state dict keys match checkpoint keys
if isinstance(self.model.model.layers, nn.ModuleList):
self.model.model.layers = SlicableModuleDict({
str(i): layer
for i, layer in enumerate(self.model.model.layers)
})

for layer in self.model.model.layers.values():
layer.moe_enabled = False

def set_cp_mesh(self, mesh):
Expand Down
53 changes: 50 additions & 3 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import lovely_tensors as lt; lt.monkey_patch()
import importlib
import os
import time
Expand Down Expand Up @@ -32,6 +32,7 @@
maybe_enable_profiling,
)

from torchtitan.utils.test_utils import debug_structure_param

class Trainer(torch.distributed.checkpoint.stateful.Stateful):
# core configs
Expand Down Expand Up @@ -151,11 +152,12 @@ def __init__(self, job_config: JobConfig):
utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]),
):
model = self.train_spec.model_cls(model_args)

# Build the collection of model converters. No-op if `model.converters` empty
model_converters = build_model_converters(job_config, parallel_dims)
model_converters.convert(model)

debug_structure_param(model)

# metrics logging
build_metrics_processor_fn = (
build_metrics_processor
Expand Down Expand Up @@ -382,11 +384,29 @@ def batch_generator(
"""Returns an iterator that processes batches from the data iterator."""
device_type = utils.device_type
data_iterator = iter(data_iterable)


# Log dataloader state at start
logger.info(f"[BATCH GEN] Creating batch generator")
if hasattr(data_iterable, 'state_dict'):
dl_state = data_iterable.state_dict()
logger.info(f"[BATCH GEN] Dataloader state keys: {list(dl_state.keys())}")

batch_count = 0
while True:
data_load_start = time.perf_counter()
try:
batch = next(data_iterator)
batch_count += 1

# Log first few batches
if batch_count <= 3:
logger.info(f"[BATCH GEN] Fetched batch {batch_count}")
input_dict, labels = batch
if 'input' in input_dict:
inp = input_dict['input']
logger.info(f" Input: shape={inp.shape}, is_zeros={torch.all(inp==0).item()}, first_val={inp.flatten()[0].item()}")
logger.info(f" Labels: shape={labels.shape}, first_val={labels.flatten()[0].item()}")

except StopIteration as ex:
# If data runs out during gradient accumulation, that
# entire step will not be executed.
Expand Down Expand Up @@ -509,6 +529,33 @@ def train_step(
# entire step will not be executed.
for _microbatch in range(self.gradient_accumulation_steps):
input_dict, labels = next(data_iterator)

# Log first few batches to verify data
if self.step == 1 and _microbatch < 2:
logger.info(f"[BATCH DATA] ===== Step {self.step}, Microbatch {_microbatch} =====")
logger.info(f"[BATCH DATA] input_dict keys: {list(input_dict.keys())}")

if 'input' in input_dict:
inp = input_dict['input']
logger.info(f"[BATCH DATA] Input tensor:")
logger.info(f"inp: {inp}")
logger.info(f" is all zeros: {torch.all(inp == 0).item()}")
logger.info(f" is all same value: {torch.all(inp == inp.flatten()[0]).item()}")
logger.info(f" unique values count: {torch.unique(inp).numel()}")
# Show actual token IDs (first 20)
logger.info(f" first 20 token IDs: {inp.flatten()[:20].tolist()}")
# Show last 20 token IDs
logger.info(f" last 20 token IDs: {inp.flatten()[-20:].tolist()}")
# Check if it matches expected data
if torch.all(inp == 0):
logger.warning(f"[BATCH DATA] ⚠️ INPUT IS ALL ZEROS - THIS IS THE ISSUE!")

logger.info(f"[BATCH DATA] Labels tensor:")
logger.info(f" labels: {labels}")
logger.info(f" first 20 labels: {labels.flatten()[:20].tolist()}")
logger.info(f" last 20 labels: {labels.flatten()[-20:].tolist()}")
logger.info(f" is all same value: {torch.all(labels == labels.flatten()[0]).item()}")

loss = self.forward_backward_step(input_dict, labels)
accumulated_losses.append(loss.detach())

Expand Down