From ef33c450f5cd44297ece12f1c50e8033305dd5fe Mon Sep 17 00:00:00 2001 From: suchenzang Date: Sun, 12 Mar 2023 17:32:47 -0400 Subject: [PATCH 1/5] fsdp wrap task --- metaseq/cli/train.py | 32 +++++++++++++++++++++++++------- metaseq/dataclass/configs.py | 4 ++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index 4076f30f2..2056eaf36 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -113,23 +113,41 @@ def main(cfg: DictConfig) -> None: logger.info(cfg) # Setup task, e.g., translation, language modeling, etc. - task = tasks.setup_task(cfg.task) + if cfg.distributed_training.task_ddp_backend == "fully_sharded": + # As the task is non-trainable, we switch flags to more optimized ones. + orig_memory_efficient_fp16 = cfg.distributed_training.memory_efficient_fp16 + orig_fp32_reduce_scatter = cfg.distributed_training.fp32_reduce_scatter + # Clobber memory_efficient_fp16 and fp32_reduce_scatter + cfg.distributed_training.memory_efficient_fp16 = cfg.distributed_training.fp16 + cfg.distributed_training.fp32_reduce_scatter = not cfg.distributed_training.fp16 + + with fsdp_enable_wrap( + cfg.distributed_training, + use_sharded_state=cfg.distributed_training.use_sharded_state, + ): + task = tasks.setup_task(cfg.task) + + # Reset memory_efficient_fp16 and fp32_reduce_scatter values. + cfg.distributed_training.memory_efficient_fp16 = orig_memory_efficient_fp16 + cfg.distributed_training.fp32_reduce_scatter = orig_fp32_reduce_scatter + else: + task = tasks.setup_task(cfg.task) + # Build model and criterion assert cfg.criterion, "Please specify criterion to train a model" - # Build model and criterion if cfg.distributed_training.ddp_backend == "fully_sharded": - extra = { - "use_sharded_state": cfg.distributed_training.use_sharded_state, - } - - with fsdp_enable_wrap(cfg.distributed_training, **extra): + with fsdp_enable_wrap( + cfg.distributed_training, + use_sharded_state=cfg.distributed_training.use_sharded_state, + ): model = fsdp_wrap( task.build_model(cfg.model), process_group=distributed_utils.get_data_parallel_group(), ) else: model = task.build_model(cfg.model) + # TODO[Susan]: FSDP on criterion? criterion = task.build_criterion(cfg.criterion) diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index 4aae1c9f5..c9a9f7465 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -247,6 +247,10 @@ class DistributedTrainingConfig(MetaseqDataclass): ddp_backend: DDP_BACKEND_CHOICES = field( default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"} ) + # Reference: https://github.com/facebookresearch/metaseq/pull/668 + task_ddp_backend: DDP_BACKEND_CHOICES = field( + default="", metadata={"help": "If set to fully_sharded, will fsdp wrap task."} + ) bucket_cap_mb: int = field( default=25, metadata={"help": "bucket size for reduction"} ) From e285e558873fa1059f3b71f318a63b33b7c781c7 Mon Sep 17 00:00:00 2001 From: suchenzang Date: Sun, 12 Mar 2023 17:58:45 -0400 Subject: [PATCH 2/5] bf16 cfg fix --- metaseq/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/metaseq/trainer.py b/metaseq/trainer.py index 394a13384..f4d175629 100644 --- a/metaseq/trainer.py +++ b/metaseq/trainer.py @@ -1212,11 +1212,12 @@ def _prepare_sample(self, sample, is_dummy=False): def lower_precision(t): """Converts a tensor to the desired dtype based on our cfg.""" if t.dtype is torch.float32: - if self.cfg.common.bf16 or self.cfg.bf16: + if self.cfg.common.bf16: return t.bfloat16() return t.half() return t + # TODO[Susan]: sample dict is full of int64 tensors - check this. if self.cfg.common.fp16: sample = utils.apply_to_sample(lower_precision, sample) From ab231f06670659afcf04f061d07f8a40811ce239 Mon Sep 17 00:00:00 2001 From: suchenzang Date: Sun, 12 Mar 2023 18:27:41 -0400 Subject: [PATCH 3/5] split out TASK_DDP_BACKEND_CHOICES --- metaseq/dataclass/configs.py | 6 ++++-- metaseq/dataclass/constants.py | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index c9a9f7465..e2843a798 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -13,6 +13,7 @@ from metaseq.dataclass.constants import ( DATASET_IMPL_CHOICES, DDP_BACKEND_CHOICES, + TASK_DDP_BACKEND_CHOICES, LOG_FORMAT_CHOICES, CLIP_GRAD_NORM_TYPE_CHOICES, ) @@ -248,8 +249,9 @@ class DistributedTrainingConfig(MetaseqDataclass): default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"} ) # Reference: https://github.com/facebookresearch/metaseq/pull/668 - task_ddp_backend: DDP_BACKEND_CHOICES = field( - default="", metadata={"help": "If set to fully_sharded, will fsdp wrap task."} + task_ddp_backend: TASK_DDP_BACKEND_CHOICES = field( + default="none", + metadata={"help": "If set to fully_sharded, will fsdp wrap task."}, ) bucket_cap_mb: int = field( default=25, metadata={"help": "bucket size for reduction"} diff --git a/metaseq/dataclass/constants.py b/metaseq/dataclass/constants.py index 401fb973c..b13e83974 100644 --- a/metaseq/dataclass/constants.py +++ b/metaseq/dataclass/constants.py @@ -42,6 +42,13 @@ def ChoiceEnum(choices: List[str]): "pytorch_ddp", ] ) + +TASK_DDP_BACKEND_CHOICES = ChoiceEnum( + [ + "none", # default + "fully_sharded", # FSDP wraps task. See https://github.com/facebookresearch/metaseq/pull/668/ + ] +) DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta"]) CLIP_GRAD_NORM_TYPE_CHOICES = ChoiceEnum(["l2", "inf"]) From a5a968ee0dc47a28c3a89d318c33034ce9771c24 Mon Sep 17 00:00:00 2001 From: suchenzang Date: Sun, 12 Mar 2023 19:13:53 -0400 Subject: [PATCH 4/5] add one more comment --- metaseq/cli/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index 2056eaf36..c131a270c 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -115,6 +115,7 @@ def main(cfg: DictConfig) -> None: # Setup task, e.g., translation, language modeling, etc. if cfg.distributed_training.task_ddp_backend == "fully_sharded": # As the task is non-trainable, we switch flags to more optimized ones. + # See https://github.com/facebookresearch/metaseq/pull/668 for when/why this was added. orig_memory_efficient_fp16 = cfg.distributed_training.memory_efficient_fp16 orig_fp32_reduce_scatter = cfg.distributed_training.fp32_reduce_scatter # Clobber memory_efficient_fp16 and fp32_reduce_scatter From e5af71d88146e0da6ebd89e8d088656e68aa8f75 Mon Sep 17 00:00:00 2001 From: suchenzang Date: Mon, 13 Mar 2023 05:20:23 -0700 Subject: [PATCH 5/5] remove extra is_moe arg --- metaseq/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/metaseq/trainer.py b/metaseq/trainer.py index f4d175629..faa888d11 100644 --- a/metaseq/trainer.py +++ b/metaseq/trainer.py @@ -254,7 +254,6 @@ def _build_ema(self): if self.is_fsdp: # Build FSDP model extra = { - "is_moe": getattr(self.cfg.model, "moe_freq", 0) > 0, "use_sharded_state": self.use_sharded_state, } with fsdp_enable_wrap(self.cfg.distributed_training, **extra):