Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions deepspec/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer

from deepspec.data import CacheDataset, validate_train_cache
from deepspec.data.cuda_prefetcher import CUDAPrefetcher
Expand All @@ -29,6 +29,7 @@
)
import deepspec.utils.training_logger as training_logger
from deepspec.utils.hfai_suspend import SuspendController
from deepspec.utils.target_weights import load_target_embeddings_and_head


_PRECISION_DTYPES = {
Expand Down Expand Up @@ -257,22 +258,23 @@ def build_models(self):
)
draft_model = draft_model.to(device=self.device, dtype=self.precision_dtype)

self._initialize_embeddings_and_head_from_target(draft_model, model_args)
return draft_model, tokenizer

def _initialize_embeddings_and_head_from_target(self, draft_model, model_args):
# Training only uses the target checkpoint to initialize frozen draft
# embeddings and lm_head weights.
target_model = AutoModelForCausalLM.from_pretrained(
target_embed_tokens, target_lm_head = load_target_embeddings_and_head(
model_args.target_model_name_or_path,
embed_shape=draft_model.embed_tokens.weight.shape,
lm_head_shape=draft_model.lm_head.weight.shape,
dtype=self.precision_dtype,
).to(device="cpu").eval()
target_embed_tokens = target_model.get_input_embeddings()
target_lm_head = target_model.get_output_embeddings()
assert (target_lm_head is not None) and (target_embed_tokens is not None)
)
draft_model.initialize_embeddings_and_head(
embed_tokens=target_embed_tokens,
lm_head=target_lm_head,
freeze=True,
)
del target_model
return draft_model, tokenizer

def _build_draft_model(self, *, target_config, model_args):
raise NotImplementedError
Expand Down
17 changes: 2 additions & 15 deletions deepspec/trainer/eagle3_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer

from deepspec.data import CacheCollator
from deepspec.modeling.eagle3.gemma4 import Gemma4Eagle3Model
Expand Down Expand Up @@ -32,23 +32,10 @@ def build_models(self):
)
draft_model = draft_model.to(device=self.device, dtype=self.precision_dtype)

target_model = AutoModelForCausalLM.from_pretrained(
model_args.target_model_name_or_path,
dtype=self.precision_dtype,
).to(device="cpu").eval()
target_embed_tokens = target_model.get_input_embeddings()
target_lm_head = target_model.get_output_embeddings()
assert (target_lm_head is not None) and (target_embed_tokens is not None)

# The draft head and norm stay frozen / target-independent to match
# the DSpark setup: head is not trained and norm is not inherited.
draft_model.initialize_embeddings_and_head(
embed_tokens=target_embed_tokens,
lm_head=target_lm_head,
freeze=True,
)
self._initialize_embeddings_and_head_from_target(draft_model, model_args)

del target_model
return draft_model, tokenizer

def _build_draft_model(self, *, target_config, model_args):
Expand Down
210 changes: 210 additions & 0 deletions deepspec/utils/target_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from __future__ import annotations

import json
import os
from collections.abc import Iterable

import torch
from torch import nn


_SAFETENSORS_INDEX = "model.safetensors.index.json"
_SAFETENSORS_FILE = "model.safetensors"
_PYTORCH_INDEX = "pytorch_model.bin.index.json"
_PYTORCH_FILE = "pytorch_model.bin"

_EMBED_SUFFIXES = (
"model.embed_tokens.weight",
"language_model.model.embed_tokens.weight",
"language_model.embed_tokens.weight",
"text_model.embed_tokens.weight",
"embed_tokens.weight",
)
_LM_HEAD_SUFFIXES = (
"lm_head.weight",
"language_model.lm_head.weight",
"text_model.lm_head.weight",
)


class _WeightOnlyModule(nn.Module):
def __init__(self, weight: torch.Tensor) -> None:
super().__init__()
self.weight = nn.Parameter(weight, requires_grad=False)


def load_target_embeddings_and_head(
model_name_or_path: str,
*,
embed_shape: torch.Size,
lm_head_shape: torch.Size,
dtype: torch.dtype,
) -> tuple[nn.Module, nn.Module]:
"""Load only target input embeddings and lm_head from a HF checkpoint."""

reader = _CheckpointReader(model_name_or_path)
embed_weight = reader.find_tensor(
suffixes=_EMBED_SUFFIXES,
expected_shape=tuple(embed_shape),
description="input embeddings",
)
lm_head_weight = reader.find_tensor(
suffixes=_LM_HEAD_SUFFIXES,
expected_shape=tuple(lm_head_shape),
description="lm_head",
required=False,
)
if lm_head_weight is None:
assert tuple(embed_weight.shape) == tuple(lm_head_shape), (
"Target checkpoint does not contain lm_head.weight and input "
"embeddings cannot be reused as a tied lm_head because shapes differ: "
f"embed={tuple(embed_weight.shape)}, lm_head={tuple(lm_head_shape)}"
)
lm_head_weight = embed_weight

embed_weight = embed_weight.detach().cpu().to(dtype=dtype).contiguous()
lm_head_weight = lm_head_weight.detach().cpu().to(dtype=dtype).contiguous()
return _WeightOnlyModule(embed_weight), _WeightOnlyModule(lm_head_weight)


class _CheckpointReader:
def __init__(self, model_name_or_path: str) -> None:
self.model_name_or_path = str(model_name_or_path)
self.local_dir = (
os.path.abspath(os.path.expanduser(self.model_name_or_path))
if os.path.isdir(os.path.expanduser(self.model_name_or_path))
else None
)
self.index_path = self._resolve_file(_SAFETENSORS_INDEX)
self.is_safetensors = True
if self.index_path is not None:
self.weight_map = self._read_weight_map(self.index_path)
return

self.single_file = self._resolve_file(_SAFETENSORS_FILE)
if self.single_file is not None:
self.weight_map = None
return

self.index_path = self._resolve_file(_PYTORCH_INDEX)
self.is_safetensors = False
if self.index_path is not None:
self.weight_map = self._read_weight_map(self.index_path)
return

self.single_file = self._resolve_file(_PYTORCH_FILE)
if self.single_file is not None:
self.weight_map = None
self._state_dict_cache = {}
return

raise FileNotFoundError(
"Could not find a Hugging Face checkpoint in "
f"{self.model_name_or_path!r}. Expected one of: "
f"{_SAFETENSORS_INDEX}, {_SAFETENSORS_FILE}, "
f"{_PYTORCH_INDEX}, {_PYTORCH_FILE}."
)

def find_tensor(
self,
*,
suffixes: Iterable[str],
expected_shape: tuple[int, ...],
description: str,
required: bool = True,
) -> torch.Tensor | None:
candidates = self._candidate_keys(suffixes)
for key in candidates:
tensor = self._get_tensor(key)
if tuple(tensor.shape) == expected_shape:
return tensor
if not required:
return None
raise KeyError(
f"Could not find target {description} with shape {expected_shape} in "
f"{self.model_name_or_path!r}. Tried keys: {candidates}"
)

def _candidate_keys(self, suffixes: Iterable[str]) -> list[str]:
suffixes = tuple(suffixes)
keys = list(self.weight_map) if self.weight_map is not None else self._keys()
candidates = [
key
for key in keys
if any(key == suffix or key.endswith(f".{suffix}") for suffix in suffixes)
]
return sorted(candidates, key=lambda key: self._key_priority(key, suffixes))

@staticmethod
def _key_priority(key: str, suffixes: tuple[str, ...]) -> tuple[int, int, str]:
for index, suffix in enumerate(suffixes):
if key == suffix or key.endswith(f".{suffix}"):
return index, len(key), key
return len(suffixes), len(key), key

@staticmethod
def _read_weight_map(index_path: str) -> dict[str, str]:
with open(index_path, encoding="utf-8") as handle:
index = json.load(handle)
return dict(index["weight_map"])

def _resolve_file(self, filename: str) -> str | None:
if self.local_dir is not None:
path = os.path.join(self.local_dir, filename)
return path if os.path.exists(path) else None

from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, LocalEntryNotFoundError

try:
return hf_hub_download(self.model_name_or_path, filename)
except (EntryNotFoundError, LocalEntryNotFoundError, FileNotFoundError):
return None

def _resolve_shard(self, shard_name: str) -> str:
if self.local_dir is not None:
return os.path.join(self.local_dir, shard_name)

from huggingface_hub import hf_hub_download

return hf_hub_download(self.model_name_or_path, shard_name)

def _keys(self) -> list[str]:
if self.is_safetensors:
from safetensors import safe_open

with safe_open(self.single_file, framework="pt", device="cpu") as handle:
return list(handle.keys())
return list(self._load_state_dict(self.single_file))

def _get_tensor(self, key: str) -> torch.Tensor:
checkpoint_file = (
self._resolve_shard(self.weight_map[key])
if self.weight_map is not None
else self.single_file
)
if self.is_safetensors:
from safetensors import safe_open

with safe_open(checkpoint_file, framework="pt", device="cpu") as handle:
return handle.get_tensor(key)
return self._load_state_dict(checkpoint_file)[key]

def _load_state_dict(self, checkpoint_file: str) -> dict[str, torch.Tensor]:
if not hasattr(self, "_state_dict_cache"):
self._state_dict_cache = {}
if checkpoint_file not in self._state_dict_cache:
try:
state_dict = torch.load(
checkpoint_file,
map_location="cpu",
weights_only=True,
)
except TypeError:
state_dict = torch.load(checkpoint_file, map_location="cpu")
if "state_dict" in state_dict and isinstance(
state_dict["state_dict"], dict
):
state_dict = state_dict["state_dict"]
self._state_dict_cache[checkpoint_file] = state_dict
return self._state_dict_cache[checkpoint_file]