Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions metaseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,16 +547,16 @@ def load_model_ensemble_and_task(
f"!!! cfg does not exist in state keys = {state.keys()} !!!"
)

# Load 175B model trained on megatron (model parallel) branch
# "cfg.common.model_parallel_size == 1" checks if model parallel is
# enabled at load time. If it's not, fall back to non-MP
# transformer code path.
if (
getattr(cfg.model, "arch", None) == "transformer_lm_megatron"
and cfg.common.model_parallel_size == 1
):
cfg.model.arch = "transformer_lm_gpt"
cfg.model._name = "transformer_lm_gpt"
# # Load 175B model trained on megatron (model parallel) branch
Comment thread
suchenzang marked this conversation as resolved.
# # "cfg.common.model_parallel_size == 1" checks if model parallel is
# # enabled at load time. If it's not, fall back to non-MP
# # transformer code path.
# if (
# getattr(cfg.model, "arch", None) == "transformer_lm_megatron"
# and cfg.common.model_parallel_size == 1
# ):
# cfg.model.arch = "transformer_lm_gpt"
# cfg.model._name = "transformer_lm_gpt"

# We now copy embed_tokens over to output_proj (if its missing) for all arches (only OPT here so far).
oproj_key = "decoder.output_projection.weight"
Expand Down
14 changes: 11 additions & 3 deletions metaseq/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,32 @@ def main(cfg: DictConfig) -> None:
# Print args
logger.info(cfg)

# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(cfg.task)

assert cfg.criterion, "Please specify criterion to train a model"

# Build model and criterion
if cfg.distributed_training.ddp_backend == "fully_sharded":
extra = {
"use_sharded_state": cfg.distributed_training.use_sharded_state,
}
memory_efficient_fp16 = cfg.distributed_training.memory_efficient_fp16
fp32_reduce_scatter = cfg.distributed_training.fp32_reduce_scatter
cfg.distributed_training.memory_efficient_fp16 = cfg.distributed_training.fp16
cfg.distributed_training.fp32_reduce_scatter = not cfg.distributed_training.fp16
Comment thread
suchenzang marked this conversation as resolved.
with fsdp_enable_wrap(cfg.distributed_training, **extra):
# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(cfg.task)
cfg.distributed_training.memory_efficient_fp16 = memory_efficient_fp16
cfg.distributed_training.fp32_reduce_scatter = fp32_reduce_scatter

with fsdp_enable_wrap(cfg.distributed_training, **extra):
model = fsdp_wrap(
task.build_model(cfg.model),
process_group=distributed_utils.get_data_parallel_group(),
)
else:
task = tasks.setup_task(cfg.task)
model = task.build_model(cfg.model)

# TODO[Susan]: FSDP on criterion?
criterion = task.build_criterion(cfg.criterion)

Expand Down
3 changes: 1 addition & 2 deletions metaseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def _build_ema(self):
if self.is_fsdp:
# Build FSDP model
extra = {
"is_moe": getattr(self.cfg.model, "moe_freq", 0) > 0,
"use_sharded_state": self.use_sharded_state,
}
with fsdp_enable_wrap(self.cfg.distributed_training, **extra):
Expand Down Expand Up @@ -1212,7 +1211,7 @@ def _prepare_sample(self, sample, is_dummy=False):
def lower_precision(t):
"""Converts a tensor to the desired dtype based on our cfg."""
if t.dtype is torch.float32:
if self.cfg.common.bf16 or self.cfg.bf16:
if self.cfg.common.bf16:
return t.bfloat16()
return t.half()
return t
Expand Down
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ def do_setup(package_data):
"torch",
"tqdm",
"typing_extensions",
"einops",
"webdataset==0.1.103",
"matplotlib",
"pytorchvideo==0.1.5",
"wandb",
Comment thread
suchenzang marked this conversation as resolved.
"albumentations",
"dalle_pytorch",
],
dependency_links=dependency_links,
packages=find_packages(
Expand Down