diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index 4076f30f2..c131a270c 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -113,23 +113,42 @@ def main(cfg: DictConfig) -> None: logger.info(cfg) # Setup task, e.g., translation, language modeling, etc. - task = tasks.setup_task(cfg.task) + if cfg.distributed_training.task_ddp_backend == "fully_sharded": + # As the task is non-trainable, we switch flags to more optimized ones. + # See https://github.com/facebookresearch/metaseq/pull/668 for when/why this was added. + orig_memory_efficient_fp16 = cfg.distributed_training.memory_efficient_fp16 + orig_fp32_reduce_scatter = cfg.distributed_training.fp32_reduce_scatter + # Clobber memory_efficient_fp16 and fp32_reduce_scatter + cfg.distributed_training.memory_efficient_fp16 = cfg.distributed_training.fp16 + cfg.distributed_training.fp32_reduce_scatter = not cfg.distributed_training.fp16 + + with fsdp_enable_wrap( + cfg.distributed_training, + use_sharded_state=cfg.distributed_training.use_sharded_state, + ): + task = tasks.setup_task(cfg.task) + + # Reset memory_efficient_fp16 and fp32_reduce_scatter values. + cfg.distributed_training.memory_efficient_fp16 = orig_memory_efficient_fp16 + cfg.distributed_training.fp32_reduce_scatter = orig_fp32_reduce_scatter + else: + task = tasks.setup_task(cfg.task) + # Build model and criterion assert cfg.criterion, "Please specify criterion to train a model" - # Build model and criterion if cfg.distributed_training.ddp_backend == "fully_sharded": - extra = { - "use_sharded_state": cfg.distributed_training.use_sharded_state, - } - - with fsdp_enable_wrap(cfg.distributed_training, **extra): + with fsdp_enable_wrap( + cfg.distributed_training, + use_sharded_state=cfg.distributed_training.use_sharded_state, + ): model = fsdp_wrap( task.build_model(cfg.model), process_group=distributed_utils.get_data_parallel_group(), ) else: model = task.build_model(cfg.model) + # TODO[Susan]: FSDP on criterion? criterion = task.build_criterion(cfg.criterion) diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index 4aae1c9f5..e2843a798 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -13,6 +13,7 @@ from metaseq.dataclass.constants import ( DATASET_IMPL_CHOICES, DDP_BACKEND_CHOICES, + TASK_DDP_BACKEND_CHOICES, LOG_FORMAT_CHOICES, CLIP_GRAD_NORM_TYPE_CHOICES, ) @@ -247,6 +248,11 @@ class DistributedTrainingConfig(MetaseqDataclass): ddp_backend: DDP_BACKEND_CHOICES = field( default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"} ) + # Reference: https://github.com/facebookresearch/metaseq/pull/668 + task_ddp_backend: TASK_DDP_BACKEND_CHOICES = field( + default="none", + metadata={"help": "If set to fully_sharded, will fsdp wrap task."}, + ) bucket_cap_mb: int = field( default=25, metadata={"help": "bucket size for reduction"} ) diff --git a/metaseq/dataclass/constants.py b/metaseq/dataclass/constants.py index 401fb973c..b13e83974 100644 --- a/metaseq/dataclass/constants.py +++ b/metaseq/dataclass/constants.py @@ -42,6 +42,13 @@ def ChoiceEnum(choices: List[str]): "pytorch_ddp", ] ) + +TASK_DDP_BACKEND_CHOICES = ChoiceEnum( + [ + "none", # default + "fully_sharded", # FSDP wraps task. See https://github.com/facebookresearch/metaseq/pull/668/ + ] +) DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta"]) CLIP_GRAD_NORM_TYPE_CHOICES = ChoiceEnum(["l2", "inf"]) diff --git a/metaseq/trainer.py b/metaseq/trainer.py index 394a13384..faa888d11 100644 --- a/metaseq/trainer.py +++ b/metaseq/trainer.py @@ -254,7 +254,6 @@ def _build_ema(self): if self.is_fsdp: # Build FSDP model extra = { - "is_moe": getattr(self.cfg.model, "moe_freq", 0) > 0, "use_sharded_state": self.use_sharded_state, } with fsdp_enable_wrap(self.cfg.distributed_training, **extra): @@ -1212,11 +1211,12 @@ def _prepare_sample(self, sample, is_dummy=False): def lower_precision(t): """Converts a tensor to the desired dtype based on our cfg.""" if t.dtype is torch.float32: - if self.cfg.common.bf16 or self.cfg.bf16: + if self.cfg.common.bf16: return t.bfloat16() return t.half() return t + # TODO[Susan]: sample dict is full of int64 tensors - check this. if self.cfg.common.fp16: sample = utils.apply_to_sample(lower_precision, sample)