diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 07eab262b..3715b3f02 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -97,12 +97,56 @@ def save_checkpoint( async_callback_fn=async_callback_fn if save_to_NFS else None, ) + for cp in checkpoints[1:]: + assert PathManager.copy(checkpoints[0], cp, overwrite=True), f"Failed to copy {checkpoints[0]} to {cp}" + write_timer.stop() logger.info( f"Saved checkpoint {checkpoints[0]} (epoch {epoch} @ {updates} updates) " f"(writing took {write_timer.sum} seconds)" ) + delete_old_checkpoint_files(cfg, end_of_epoch, suffix, trainer.is_data_parallel_master) + + +def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffix: str, is_data_parallel_master: bool): + if not end_of_epoch and cfg.keep_last_updates > 0: + # remove old checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) + ) + for old_chk in checkpoints[cfg.keep_last_updates:]: + if os.path.lexists(old_chk): + os.remove(old_chk) + + if cfg.keep_last_epochs > 0: + # remove old epoch checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) + ) + for old_chk in checkpoints[cfg.keep_last_epochs:]: + if os.path.lexists(old_chk): + os.remove(old_chk) + + +def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): + """Retrieves all checkpoints found in `path` directory. + + Checkpoints are identified by matching filename to the specified pattern. If + the pattern contains groups, the result will be sorted by the first group in + descending order. + """ + pt_regexp = re.compile(pattern) + files = os.listdir(path) + + entries = [] + for i, f in enumerate(files): + m = pt_regexp.fullmatch(f) + if m is not None: + idx = float(m.group(1)) if len(m.groups()) > 0 else i + entries.append((idx, m.group(0))) + return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] + def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): """ @@ -547,16 +591,16 @@ def load_model_ensemble_and_task( f"!!! cfg does not exist in state keys = {state.keys()} !!!" ) - # Load 175B model trained on megatron (model parallel) branch - # "cfg.common.model_parallel_size == 1" checks if model parallel is - # enabled at load time. If it's not, fall back to non-MP - # transformer code path. - if ( - getattr(cfg.model, "arch", None) == "transformer_lm_megatron" - and cfg.common.model_parallel_size == 1 - ): - cfg.model.arch = "transformer_lm_gpt" - cfg.model._name = "transformer_lm_gpt" + # # Load 175B model trained on megatron (model parallel) branch + # # "cfg.common.model_parallel_size == 1" checks if model parallel is + # # enabled at load time. If it's not, fall back to non-MP + # # transformer code path. + # if ( + # getattr(cfg.model, "arch", None) == "transformer_lm_megatron" + # and cfg.common.model_parallel_size == 1 + # ): + # cfg.model.arch = "transformer_lm_gpt" + # cfg.model._name = "transformer_lm_gpt" # We now copy embed_tokens over to output_proj (if its missing) for all arches (only OPT here so far). oproj_key = "decoder.output_projection.weight" diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index 4076f30f2..b3d4574cc 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -112,17 +112,25 @@ def main(cfg: DictConfig) -> None: # Print args logger.info(cfg) - # Setup task, e.g., translation, language modeling, etc. - task = tasks.setup_task(cfg.task) - assert cfg.criterion, "Please specify criterion to train a model" - # Build model and criterion - if cfg.distributed_training.ddp_backend == "fully_sharded": - extra = { - "use_sharded_state": cfg.distributed_training.use_sharded_state, - } + # Build task, model and criterion + extra = {"use_sharded_state": cfg.distributed_training.use_sharded_state,} + if cfg.distributed_training.task_ddp_backend == "fully_sharded": + # As the task is non-trainable, we witch flags to more optimized ones. + memory_efficient_fp16 = cfg.distributed_training.memory_efficient_fp16 + fp32_reduce_scatter = cfg.distributed_training.fp32_reduce_scatter + cfg.distributed_training.memory_efficient_fp16 = cfg.distributed_training.fp16 + cfg.distributed_training.fp32_reduce_scatter = not cfg.distributed_training.fp16 + with fsdp_enable_wrap(cfg.distributed_training, **extra): + # Setup task, e.g., translation, language modeling, etc. + task = tasks.setup_task(cfg.task) + cfg.distributed_training.memory_efficient_fp16 = memory_efficient_fp16 + cfg.distributed_training.fp32_reduce_scatter = fp32_reduce_scatter + else: + task = tasks.setup_task(cfg.task) + if cfg.distributed_training.ddp_backend == "fully_sharded": with fsdp_enable_wrap(cfg.distributed_training, **extra): model = fsdp_wrap( task.build_model(cfg.model), @@ -130,6 +138,7 @@ def main(cfg: DictConfig) -> None: ) else: model = task.build_model(cfg.model) + # TODO[Susan]: FSDP on criterion? criterion = task.build_criterion(cfg.criterion) diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index 4aae1c9f5..35641b9a5 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -247,6 +247,9 @@ class DistributedTrainingConfig(MetaseqDataclass): ddp_backend: DDP_BACKEND_CHOICES = field( default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"} ) + task_ddp_backend: DDP_BACKEND_CHOICES = field( + default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend for task"} + ) bucket_cap_mb: int = field( default=25, metadata={"help": "bucket size for reduction"} ) diff --git a/metaseq/trainer.py b/metaseq/trainer.py index 394a13384..05b0d9807 100644 --- a/metaseq/trainer.py +++ b/metaseq/trainer.py @@ -254,7 +254,6 @@ def _build_ema(self): if self.is_fsdp: # Build FSDP model extra = { - "is_moe": getattr(self.cfg.model, "moe_freq", 0) > 0, "use_sharded_state": self.use_sharded_state, } with fsdp_enable_wrap(self.cfg.distributed_training, **extra): @@ -1212,7 +1211,7 @@ def _prepare_sample(self, sample, is_dummy=False): def lower_precision(t): """Converts a tensor to the desired dtype based on our cfg.""" if t.dtype is torch.float32: - if self.cfg.common.bf16 or self.cfg.bf16: + if self.cfg.common.bf16: return t.bfloat16() return t.half() return t diff --git a/setup.py b/setup.py index 8cc013e7c..d33e435b0 100644 --- a/setup.py +++ b/setup.py @@ -169,6 +169,13 @@ def do_setup(package_data): "torch", "tqdm", "typing_extensions", + "einops", + "webdataset==0.1.103", + "matplotlib", + "pytorchvideo==0.1.5", + "wandb", + "albumentations", + "dalle_pytorch", ], dependency_links=dependency_links, packages=find_packages(