Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions metaseq/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,42 @@ def main(cfg: DictConfig) -> None:
logger.info(cfg)

# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(cfg.task)
if cfg.distributed_training.task_ddp_backend == "fully_sharded":
# As the task is non-trainable, we switch flags to more optimized ones.
# See https://github.com/facebookresearch/metaseq/pull/668 for when/why this was added.
orig_memory_efficient_fp16 = cfg.distributed_training.memory_efficient_fp16
orig_fp32_reduce_scatter = cfg.distributed_training.fp32_reduce_scatter
# Clobber memory_efficient_fp16 and fp32_reduce_scatter
cfg.distributed_training.memory_efficient_fp16 = cfg.distributed_training.fp16
cfg.distributed_training.fp32_reduce_scatter = not cfg.distributed_training.fp16

with fsdp_enable_wrap(
cfg.distributed_training,
use_sharded_state=cfg.distributed_training.use_sharded_state,
):
task = tasks.setup_task(cfg.task)

# Reset memory_efficient_fp16 and fp32_reduce_scatter values.
cfg.distributed_training.memory_efficient_fp16 = orig_memory_efficient_fp16
cfg.distributed_training.fp32_reduce_scatter = orig_fp32_reduce_scatter
else:
task = tasks.setup_task(cfg.task)

# Build model and criterion
assert cfg.criterion, "Please specify criterion to train a model"

# Build model and criterion
if cfg.distributed_training.ddp_backend == "fully_sharded":
extra = {
"use_sharded_state": cfg.distributed_training.use_sharded_state,
}

with fsdp_enable_wrap(cfg.distributed_training, **extra):
with fsdp_enable_wrap(
cfg.distributed_training,
use_sharded_state=cfg.distributed_training.use_sharded_state,
):
model = fsdp_wrap(
task.build_model(cfg.model),
process_group=distributed_utils.get_data_parallel_group(),
)
else:
model = task.build_model(cfg.model)

# TODO[Susan]: FSDP on criterion?
criterion = task.build_criterion(cfg.criterion)

Expand Down
6 changes: 6 additions & 0 deletions metaseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from metaseq.dataclass.constants import (
DATASET_IMPL_CHOICES,
DDP_BACKEND_CHOICES,
TASK_DDP_BACKEND_CHOICES,
LOG_FORMAT_CHOICES,
CLIP_GRAD_NORM_TYPE_CHOICES,
)
Expand Down Expand Up @@ -247,6 +248,11 @@ class DistributedTrainingConfig(MetaseqDataclass):
ddp_backend: DDP_BACKEND_CHOICES = field(
default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"}
)
# Reference: https://github.com/facebookresearch/metaseq/pull/668
task_ddp_backend: TASK_DDP_BACKEND_CHOICES = field(
default="none",
metadata={"help": "If set to fully_sharded, will fsdp wrap task."},
)
bucket_cap_mb: int = field(
default=25, metadata={"help": "bucket size for reduction"}
)
Expand Down
7 changes: 7 additions & 0 deletions metaseq/dataclass/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def ChoiceEnum(choices: List[str]):
"pytorch_ddp",
]
)

TASK_DDP_BACKEND_CHOICES = ChoiceEnum(
[
"none", # default
"fully_sharded", # FSDP wraps task. See https://github.com/facebookresearch/metaseq/pull/668/
]
)
DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta"])
CLIP_GRAD_NORM_TYPE_CHOICES = ChoiceEnum(["l2", "inf"])

Expand Down
4 changes: 2 additions & 2 deletions metaseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def _build_ema(self):
if self.is_fsdp:
# Build FSDP model
extra = {
"is_moe": getattr(self.cfg.model, "moe_freq", 0) > 0,
"use_sharded_state": self.use_sharded_state,
}
with fsdp_enable_wrap(self.cfg.distributed_training, **extra):
Expand Down Expand Up @@ -1212,11 +1211,12 @@ def _prepare_sample(self, sample, is_dummy=False):
def lower_precision(t):
Comment thread
suchenzang marked this conversation as resolved.
"""Converts a tensor to the desired dtype based on our cfg."""
if t.dtype is torch.float32:
if self.cfg.common.bf16 or self.cfg.bf16:
if self.cfg.common.bf16:
return t.bfloat16()
return t.half()
return t

# TODO[Susan]: sample dict is full of int64 tensors - check this.
if self.cfg.common.fp16:
sample = utils.apply_to_sample(lower_precision, sample)

Expand Down