Skip to content

Load target embeddings without full model init#11

Open
zhangxin81 wants to merge 1 commit into
deepseek-ai:mainfrom
zhangxin81:feat/load-target-weights-only
Open

Load target embeddings without full model init#11
zhangxin81 wants to merge 1 commit into
deepseek-ai:mainfrom
zhangxin81:feat/load-target-weights-only

Conversation

@zhangxin81

Copy link
Copy Markdown

Summary

  • Avoid constructing the full target AutoModelForCausalLM during trainer startup.
  • Add a lightweight checkpoint reader that loads only target input embeddings and lm_head weights from HF checkpoints.
  • Reuse the same initialization helper from both base/DSpark trainers and Eagle3 trainers.
  • Support sharded/single-file safetensors checkpoints, PyTorch .bin fallbacks, common Qwen/Gemma key prefixes, and tied embedding checkpoints without lm_head.weight.

GPU memory impact

Previously, trainer startup instantiated the entire target model only to read get_input_embeddings() and get_output_embeddings(). In environments where that temporary target model is materialized on CUDA, the peak GPU memory includes the full target checkpoint's BF16 parameter footprint. This PR keeps target checkpoint tensors on CPU and copies only into the already-existing draft embedding/head parameters, so the extra target-model GPU peak is effectively removed.

Estimated peak GPU memory avoided per rank for the current training configs:

Target model Old transient target-model footprint New extra target-model GPU footprint Estimated peak saving
Qwen/Qwen3-4B ~7.49 GiB ~0 GiB ~7.49 GiB
Qwen/Qwen3-8B ~15.26 GiB ~0 GiB ~15.26 GiB
Qwen/Qwen3-14B ~27.51 GiB ~0 GiB ~27.51 GiB
google/gemma-4-12B-it ~22.28 GiB ~0 GiB ~22.28 GiB

Notes:

  • Qwen numbers come from model.safetensors.index.json.metadata.total_size for the public HF checkpoints, converted from bytes to GiB.
  • Gemma4 uses the HF safetensors BF16 parameter count (11,959,730,224) × 2 bytes, converted to GiB.
  • The draft model's own frozen embed_tokens/lm_head parameters are unchanged; this only removes the temporary full target model allocation.
  • If a particular Transformers setup already keeps from_pretrained() fully on CPU, then the GPU saving is smaller, but this still avoids constructing/loading all unneeded target layers and reduces CPU memory/init overhead.

Testing

  • PYTHONPYCACHEPREFIX=/private/tmp/deepspec-pycache python3 -m py_compile deepspec/trainer/base_trainer.py deepspec/trainer/eagle3_trainer.py deepspec/utils/target_weights.py
  • git diff --check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant