From b33d5d231afe8d2488c5cae907565fd1df4a2217 Mon Sep 17 00:00:00 2001 From: Uriel Singer Date: Fri, 3 Mar 2023 15:34:56 +0200 Subject: [PATCH 1/4] working pipe --- metaseq/checkpoint_utils.py | 20 ++++++++++---------- metaseq/trainer.py | 3 +-- setup.py | 7 +++++++ 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 07eab262b..e79cd35f8 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -547,16 +547,16 @@ def load_model_ensemble_and_task( f"!!! cfg does not exist in state keys = {state.keys()} !!!" ) - # Load 175B model trained on megatron (model parallel) branch - # "cfg.common.model_parallel_size == 1" checks if model parallel is - # enabled at load time. If it's not, fall back to non-MP - # transformer code path. - if ( - getattr(cfg.model, "arch", None) == "transformer_lm_megatron" - and cfg.common.model_parallel_size == 1 - ): - cfg.model.arch = "transformer_lm_gpt" - cfg.model._name = "transformer_lm_gpt" + # # Load 175B model trained on megatron (model parallel) branch + # # "cfg.common.model_parallel_size == 1" checks if model parallel is + # # enabled at load time. If it's not, fall back to non-MP + # # transformer code path. + # if ( + # getattr(cfg.model, "arch", None) == "transformer_lm_megatron" + # and cfg.common.model_parallel_size == 1 + # ): + # cfg.model.arch = "transformer_lm_gpt" + # cfg.model._name = "transformer_lm_gpt" # We now copy embed_tokens over to output_proj (if its missing) for all arches (only OPT here so far). oproj_key = "decoder.output_projection.weight" diff --git a/metaseq/trainer.py b/metaseq/trainer.py index 394a13384..05b0d9807 100644 --- a/metaseq/trainer.py +++ b/metaseq/trainer.py @@ -254,7 +254,6 @@ def _build_ema(self): if self.is_fsdp: # Build FSDP model extra = { - "is_moe": getattr(self.cfg.model, "moe_freq", 0) > 0, "use_sharded_state": self.use_sharded_state, } with fsdp_enable_wrap(self.cfg.distributed_training, **extra): @@ -1212,7 +1211,7 @@ def _prepare_sample(self, sample, is_dummy=False): def lower_precision(t): """Converts a tensor to the desired dtype based on our cfg.""" if t.dtype is torch.float32: - if self.cfg.common.bf16 or self.cfg.bf16: + if self.cfg.common.bf16: return t.bfloat16() return t.half() return t diff --git a/setup.py b/setup.py index 8cc013e7c..d33e435b0 100644 --- a/setup.py +++ b/setup.py @@ -169,6 +169,13 @@ def do_setup(package_data): "torch", "tqdm", "typing_extensions", + "einops", + "webdataset==0.1.103", + "matplotlib", + "pytorchvideo==0.1.5", + "wandb", + "albumentations", + "dalle_pytorch", ], dependency_links=dependency_links, packages=find_packages( From c49c4a032140c24be1fe91e27a5b840280d48f21 Mon Sep 17 00:00:00 2001 From: Uriel Singer Date: Sat, 4 Mar 2023 21:44:42 +0200 Subject: [PATCH 2/4] fsdp_enable_wrap on entire task --- metaseq/cli/train.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index 4076f30f2..5f756687f 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -112,9 +112,6 @@ def main(cfg: DictConfig) -> None: # Print args logger.info(cfg) - # Setup task, e.g., translation, language modeling, etc. - task = tasks.setup_task(cfg.task) - assert cfg.criterion, "Please specify criterion to train a model" # Build model and criterion @@ -122,6 +119,15 @@ def main(cfg: DictConfig) -> None: extra = { "use_sharded_state": cfg.distributed_training.use_sharded_state, } + memory_efficient_fp16 = cfg.distributed_training.memory_efficient_fp16 + fp32_reduce_scatter = cfg.distributed_training.fp32_reduce_scatter + cfg.distributed_training.memory_efficient_fp16 = cfg.distributed_training.fp16 + cfg.distributed_training.fp32_reduce_scatter = not cfg.distributed_training.fp16 + with fsdp_enable_wrap(cfg.distributed_training, **extra): + # Setup task, e.g., translation, language modeling, etc. + task = tasks.setup_task(cfg.task) + cfg.distributed_training.memory_efficient_fp16 = memory_efficient_fp16 + cfg.distributed_training.fp32_reduce_scatter = fp32_reduce_scatter with fsdp_enable_wrap(cfg.distributed_training, **extra): model = fsdp_wrap( @@ -129,7 +135,9 @@ def main(cfg: DictConfig) -> None: process_group=distributed_utils.get_data_parallel_group(), ) else: + task = tasks.setup_task(cfg.task) model = task.build_model(cfg.model) + # TODO[Susan]: FSDP on criterion? criterion = task.build_criterion(cfg.criterion) From 0768bf34c76740ed8cb5530441e1cdcc07a60f21 Mon Sep 17 00:00:00 2001 From: Uriel Singer Date: Sun, 12 Mar 2023 09:25:40 +0200 Subject: [PATCH 3/4] - save checkpoint_last - keep last k ckpts - added cfg.distributed_training.task_ddp_backend --- metaseq/checkpoint_utils.py | 44 ++++++++++++++++++++++++++++++++++++ metaseq/cli/train.py | 12 +++++----- metaseq/dataclass/configs.py | 3 +++ 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index e79cd35f8..3715b3f02 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -97,12 +97,56 @@ def save_checkpoint( async_callback_fn=async_callback_fn if save_to_NFS else None, ) + for cp in checkpoints[1:]: + assert PathManager.copy(checkpoints[0], cp, overwrite=True), f"Failed to copy {checkpoints[0]} to {cp}" + write_timer.stop() logger.info( f"Saved checkpoint {checkpoints[0]} (epoch {epoch} @ {updates} updates) " f"(writing took {write_timer.sum} seconds)" ) + delete_old_checkpoint_files(cfg, end_of_epoch, suffix, trainer.is_data_parallel_master) + + +def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffix: str, is_data_parallel_master: bool): + if not end_of_epoch and cfg.keep_last_updates > 0: + # remove old checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) + ) + for old_chk in checkpoints[cfg.keep_last_updates:]: + if os.path.lexists(old_chk): + os.remove(old_chk) + + if cfg.keep_last_epochs > 0: + # remove old epoch checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) + ) + for old_chk in checkpoints[cfg.keep_last_epochs:]: + if os.path.lexists(old_chk): + os.remove(old_chk) + + +def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): + """Retrieves all checkpoints found in `path` directory. + + Checkpoints are identified by matching filename to the specified pattern. If + the pattern contains groups, the result will be sorted by the first group in + descending order. + """ + pt_regexp = re.compile(pattern) + files = os.listdir(path) + + entries = [] + for i, f in enumerate(files): + m = pt_regexp.fullmatch(f) + if m is not None: + idx = float(m.group(1)) if len(m.groups()) > 0 else i + entries.append((idx, m.group(0))) + return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] + def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): """ diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index 5f756687f..c277299f1 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -114,11 +114,9 @@ def main(cfg: DictConfig) -> None: assert cfg.criterion, "Please specify criterion to train a model" - # Build model and criterion - if cfg.distributed_training.ddp_backend == "fully_sharded": - extra = { - "use_sharded_state": cfg.distributed_training.use_sharded_state, - } + # Build task, model and criterion + extra = {"use_sharded_state": cfg.distributed_training.use_sharded_state,} + if cfg.distributed_training.task_ddp_backend == "fully_sharded": memory_efficient_fp16 = cfg.distributed_training.memory_efficient_fp16 fp32_reduce_scatter = cfg.distributed_training.fp32_reduce_scatter cfg.distributed_training.memory_efficient_fp16 = cfg.distributed_training.fp16 @@ -128,14 +126,16 @@ def main(cfg: DictConfig) -> None: task = tasks.setup_task(cfg.task) cfg.distributed_training.memory_efficient_fp16 = memory_efficient_fp16 cfg.distributed_training.fp32_reduce_scatter = fp32_reduce_scatter + else: + task = tasks.setup_task(cfg.task) + if cfg.distributed_training.ddp_backend == "fully_sharded": with fsdp_enable_wrap(cfg.distributed_training, **extra): model = fsdp_wrap( task.build_model(cfg.model), process_group=distributed_utils.get_data_parallel_group(), ) else: - task = tasks.setup_task(cfg.task) model = task.build_model(cfg.model) # TODO[Susan]: FSDP on criterion? diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index 4aae1c9f5..35641b9a5 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -247,6 +247,9 @@ class DistributedTrainingConfig(MetaseqDataclass): ddp_backend: DDP_BACKEND_CHOICES = field( default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"} ) + task_ddp_backend: DDP_BACKEND_CHOICES = field( + default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend for task"} + ) bucket_cap_mb: int = field( default=25, metadata={"help": "bucket size for reduction"} ) From 20fb304095d6307d4eeefc7114cf72b57b4c5a8a Mon Sep 17 00:00:00 2001 From: Uriel Singer Date: Sun, 12 Mar 2023 09:40:43 +0200 Subject: [PATCH 4/4] added comment for task fsdp wrapper --- metaseq/cli/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index c277299f1..b3d4574cc 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -117,6 +117,7 @@ def main(cfg: DictConfig) -> None: # Build task, model and criterion extra = {"use_sharded_state": cfg.distributed_training.use_sharded_state,} if cfg.distributed_training.task_ddp_backend == "fully_sharded": + # As the task is non-trainable, we witch flags to more optimized ones. memory_efficient_fp16 = cfg.distributed_training.memory_efficient_fp16 fp32_reduce_scatter = cfg.distributed_training.fp32_reduce_scatter cfg.distributed_training.memory_efficient_fp16 = cfg.distributed_training.fp16