From 7b7df9080751a1530197a8a22973697d5f85c877 Mon Sep 17 00:00:00 2001 From: Victoria Lin Date: Sun, 11 Dec 2022 02:22:07 -0800 Subject: [PATCH 1/7] add log_training_trajectory option --- metaseq/dataclass/configs.py | 6 ++++++ metaseq_cli/train.py | 8 ++++++++ 2 files changed, 14 insertions(+) diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index c618d540e..41ff4ff4e 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -385,6 +385,12 @@ class DatasetConfig(MetaseqDataclass): "help": "if set, validate language model at the beginning of training or fine-tuning process" }, ) + log_training_trajectory: bool = field( + default=False, + metadata={ + "help": "(InstructOPT specific) if set, evaluate and save checkpoints more frequently in early stage of training" + }, + ) fixed_validation_seed: Optional[int] = field( default=None, metadata={"help": "specified random seed for validation"} ) diff --git a/metaseq_cli/train.py b/metaseq_cli/train.py index 721df05f7..97cd7a8cc 100644 --- a/metaseq_cli/train.py +++ b/metaseq_cli/train.py @@ -406,6 +406,10 @@ def validate_and_save( and num_updates >= cfg.dataset.validate_after_updates and was_successful_step ) + or ( + cfg.dataset.log_traininig_trajectory + and num_updates in [10, 20, 50, 100, 200, 500] + ) ) do_validate = ( ( @@ -420,6 +424,10 @@ def validate_and_save( and was_successful_step ) or (cfg.dataset.validate_at_beginning and num_updates == 0) + or ( + cfg.dataset.log_traininig_trajectory + and num_updates in [0, 10, 20, 50, 100, 200, 500] + ) ) and not cfg.dataset.disable_validation valid_losses = [None] if do_validate: From 5cd7183835aa969e033b4d0fe85ff856f61ec272 Mon Sep 17 00:00:00 2001 From: Victoria Lin Date: Sun, 11 Dec 2022 02:23:58 -0800 Subject: [PATCH 2/7] fix typo --- metaseq_cli/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metaseq_cli/train.py b/metaseq_cli/train.py index 97cd7a8cc..9e6e77cb2 100644 --- a/metaseq_cli/train.py +++ b/metaseq_cli/train.py @@ -407,7 +407,7 @@ def validate_and_save( and was_successful_step ) or ( - cfg.dataset.log_traininig_trajectory + cfg.dataset.log_training_trajectory and num_updates in [10, 20, 50, 100, 200, 500] ) ) @@ -425,7 +425,7 @@ def validate_and_save( ) or (cfg.dataset.validate_at_beginning and num_updates == 0) or ( - cfg.dataset.log_traininig_trajectory + cfg.dataset.log_training_trajectory and num_updates in [0, 10, 20, 50, 100, 200, 500] ) ) and not cfg.dataset.disable_validation From e0e726ec9c74b5ebbf346fc25be3a9040aefc13e Mon Sep 17 00:00:00 2001 From: Victoria Lin Date: Sun, 11 Dec 2022 02:22:07 -0800 Subject: [PATCH 3/7] add log_training_trajectory option --- metaseq/dataclass/configs.py | 6 ++++++ metaseq_cli/train.py | 8 ++++++++ 2 files changed, 14 insertions(+) diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index c618d540e..41ff4ff4e 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -385,6 +385,12 @@ class DatasetConfig(MetaseqDataclass): "help": "if set, validate language model at the beginning of training or fine-tuning process" }, ) + log_training_trajectory: bool = field( + default=False, + metadata={ + "help": "(InstructOPT specific) if set, evaluate and save checkpoints more frequently in early stage of training" + }, + ) fixed_validation_seed: Optional[int] = field( default=None, metadata={"help": "specified random seed for validation"} ) diff --git a/metaseq_cli/train.py b/metaseq_cli/train.py index 3669c2b7e..3509b374b 100644 --- a/metaseq_cli/train.py +++ b/metaseq_cli/train.py @@ -418,6 +418,10 @@ def validate_and_save( and num_updates >= cfg.dataset.validate_after_updates and was_successful_step ) + or ( + cfg.dataset.log_traininig_trajectory + and num_updates in [10, 20, 50, 100, 200, 500] + ) ) do_validate = ( ( @@ -432,6 +436,10 @@ def validate_and_save( and was_successful_step ) or (cfg.dataset.validate_at_beginning and num_updates == 0) + or ( + cfg.dataset.log_traininig_trajectory + and num_updates in [0, 10, 20, 50, 100, 200, 500] + ) ) and not cfg.dataset.disable_validation valid_losses = [None] if do_validate: From 54b3e87e1db68f72bd7793d15d42f141bf5a51a3 Mon Sep 17 00:00:00 2001 From: Victoria Lin Date: Sun, 11 Dec 2022 02:23:58 -0800 Subject: [PATCH 4/7] fix typo --- metaseq_cli/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metaseq_cli/train.py b/metaseq_cli/train.py index 3509b374b..a0fcdd33f 100644 --- a/metaseq_cli/train.py +++ b/metaseq_cli/train.py @@ -419,7 +419,7 @@ def validate_and_save( and was_successful_step ) or ( - cfg.dataset.log_traininig_trajectory + cfg.dataset.log_training_trajectory and num_updates in [10, 20, 50, 100, 200, 500] ) ) @@ -437,7 +437,7 @@ def validate_and_save( ) or (cfg.dataset.validate_at_beginning and num_updates == 0) or ( - cfg.dataset.log_traininig_trajectory + cfg.dataset.log_training_trajectory and num_updates in [0, 10, 20, 50, 100, 200, 500] ) ) and not cfg.dataset.disable_validation From 9b0e2ae4ab4b844eeaea8fbf419006c7389b335b Mon Sep 17 00:00:00 2001 From: Victoria Lin Date: Mon, 12 Dec 2022 01:40:17 -0800 Subject: [PATCH 5/7] hack training trajectory checkpoint saving --- metaseq/checkpoint_utils.py | 72 +++++++++++++++++++++++++++---------- metaseq_cli/train.py | 5 ++- 2 files changed, 57 insertions(+), 20 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 472fedd8b..98289f368 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -33,6 +33,7 @@ def save_checkpoint( trainer, epoch_itr, val_loss, + log_training_trajectory=False, training_finished=False, async_callback_fn=None, ): @@ -72,10 +73,11 @@ def is_better(a, b): checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 ) - checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( - not end_of_epoch - and cfg.save_interval_updates > 0 - and updates % cfg.save_interval_updates == 0 + checkpoint_conds[ + "checkpoint_{}_{}{}.pt".format(epoch, updates, suffix) + ] = not end_of_epoch and ( + (cfg.save_interval_updates > 0 and updates % cfg.save_interval_updates == 0) + or (log_training_trajectory and updates in [10, 20, 50, 100, 200, 500]) ) checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = ( val_loss is not None @@ -182,12 +184,18 @@ def _delete_old_checkpoint_files( def verify_shards(cfg, dir=None, checkpoint_name=None, distributed_world_size=None): # verifies that all the shards of the checkpoint are present checkpoint_name = checkpoint_name.replace(".pt", "") - num_gpus = DistributedTrainingConfig.distributed_world_size if not distributed_world_size else distributed_world_size + num_gpus = ( + DistributedTrainingConfig.distributed_world_size + if not distributed_world_size + else distributed_world_size + ) num_shards = 0 for file in os.listdir(dir): if file.startswith(checkpoint_name): num_shards += 1 - logger.info(f'dir {dir} checkpoint_name {checkpoint_name} num_shards {num_shards} num_gpus {num_gpus}') + logger.info( + f"dir {dir} checkpoint_name {checkpoint_name} num_shards {num_shards} num_gpus {num_gpus}" + ) return num_shards == num_gpus @@ -202,7 +210,12 @@ def get_last_good_checkpoint(cfg, distributed_world_size): unique_checkpoints = set(unique_checkpoints) unique_checkpoints = sorted(unique_checkpoints, reverse=True) for checkpoint in unique_checkpoints: - if verify_shards(cfg, dir=cfg.save_dir, checkpoint_name=checkpoint, distributed_world_size=distributed_world_size): + if verify_shards( + cfg, + dir=cfg.save_dir, + checkpoint_name=checkpoint, + distributed_world_size=distributed_world_size, + ): return checkpoint # no good checkpoints available, first launch return None @@ -239,7 +252,10 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): ) first_launch = not PathManager.exists(checkpoint_path_to_load) if not first_launch and not verify_shards( - cfg, dir=cfg.save_dir, checkpoint_name=cfg.restore_file, distributed_world_size=distributed_world_size + cfg, + dir=cfg.save_dir, + checkpoint_name=cfg.restore_file, + distributed_world_size=distributed_world_size, ): # checkpoint_last is corrupted best_checkpoint = get_last_good_checkpoint(cfg, distributed_world_size) @@ -251,13 +267,18 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): else: first_launch = True - elif first_launch and get_last_good_checkpoint(cfg, distributed_world_size) is not None: + elif ( + first_launch + and get_last_good_checkpoint(cfg, distributed_world_size) is not None + ): # possible past checkpoint to load from cfg.restore_file = os.path.join( - cfg.save_dir, get_last_good_checkpoint(cfg, distributed_world_size) + ".pt" + cfg.save_dir, + get_last_good_checkpoint(cfg, distributed_world_size) + ".pt", ) checkpoint_path_to_load = os.path.join( - cfg.save_dir, get_last_good_checkpoint(cfg, distributed_world_size) + suffix + ".pt" + cfg.save_dir, + get_last_good_checkpoint(cfg, distributed_world_size) + suffix + ".pt", ) first_launch = False @@ -288,7 +309,12 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): # --restore-file was passed checkpoint_name = cfg.restore_file.split("/")[-1] dir = cfg.restore_file.replace("/" + checkpoint_name, "") - if verify_shards(cfg, dir=dir, checkpoint_name=checkpoint_name, distributed_world_size=distributed_world_size): + if verify_shards( + cfg, + dir=dir, + checkpoint_name=checkpoint_name, + distributed_world_size=distributed_world_size, + ): # when the checkpoint passed by the user is not corrupted checkpoint_path_to_load = cfg.restore_file.replace(".pt", suffix + ".pt") else: @@ -298,7 +324,10 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): if get_last_good_checkpoint(cfg, distributed_world_size) is not None: # checkpoint passed by user is corrupted but there is a good checkpoint to fall back on checkpoint_path_to_load = os.path.join( - cfg.save_dir, get_last_good_checkpoint(cfg, distributed_world_size) + suffix + ".pt" + cfg.save_dir, + get_last_good_checkpoint(cfg, distributed_world_size) + + suffix + + ".pt", ) else: first_launch = True @@ -309,20 +338,25 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): else: checkpoint_name = cfg.restore_file.split("/")[-1] dir = cfg.restore_file.replace("/" + checkpoint_name, "") - if verify_shards(cfg, dir=dir, checkpoint_name=checkpoint_name, distributed_world_size=distributed_world_size): + if verify_shards( + cfg, + dir=dir, + checkpoint_name=checkpoint_name, + distributed_world_size=distributed_world_size, + ): checkpoint_path_to_load = cfg.restore_file else: logger.warning("Passed checkpoint is corrupted or does not exist") first_launch = True checkpoint_path_to_load = cfg.restore_file - # Commenting this because it fails when checkpoint_last isnt saved i.e. + # Commenting this because it fails when checkpoint_last isnt saved i.e. # when we save per-interval checkpoints # if cfg.restore_file != default_restore_file and cfg.finetune_from_model: - # raise ValueError( - # "--finetune-from-model and --restore-file (non-default value) " - # "can not be specified together: " + str(cfg) - # ) + # raise ValueError( + # "--finetune-from-model and --restore-file (non-default value) " + # "can not be specified together: " + str(cfg) + # ) # Azure logic try: diff --git a/metaseq_cli/train.py b/metaseq_cli/train.py index a0fcdd33f..da2946f04 100644 --- a/metaseq_cli/train.py +++ b/metaseq_cli/train.py @@ -37,7 +37,9 @@ from metaseq.model_parallel.megatron_trainer import MegatronTrainer from metaseq.trainer import Trainer from metaseq.tasks.streaming_language_modeling import StreamingLanguageModelingTask -from metaseq.tasks.streaming_finetune_language_modeling import StreamingFinetuneLanguageModelingTask +from metaseq.tasks.streaming_finetune_language_modeling import ( + StreamingFinetuneLanguageModelingTask, +) logging.basicConfig( @@ -454,6 +456,7 @@ def validate_and_save( trainer, epoch_itr, valid_losses[0], + log_training_trajectory=cfg.dataset.validate_after_updates, training_finished=should_stop, async_callback_fn=functools.partial(post_checkpoint_callback, cfg) if cfg.checkpoint.cloud_upload_path From d659c00be17b3185edea88218d5eadd7b202be54 Mon Sep 17 00:00:00 2001 From: Victoria Lin Date: Mon, 12 Dec 2022 01:41:25 -0800 Subject: [PATCH 6/7] hack training trajectory checkpoint saving --- metaseq/checkpoint_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 98289f368..30fa17cee 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -1,3 +1,4 @@ +# fmt: off # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. # # This source code is licensed under the MIT license found in the From 22dca46079f026b87d8de0a9cfdcbc4680bd83fd Mon Sep 17 00:00:00 2001 From: Victoria Lin Date: Mon, 12 Dec 2022 01:43:38 -0800 Subject: [PATCH 7/7] reduce black formatting changes --- metaseq/checkpoint_utils.py | 70 ++++++++++--------------------------- 1 file changed, 18 insertions(+), 52 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 30fa17cee..eeecf638e 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -74,11 +74,9 @@ def is_better(a, b): checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 ) - checkpoint_conds[ - "checkpoint_{}_{}{}.pt".format(epoch, updates, suffix) - ] = not end_of_epoch and ( - (cfg.save_interval_updates > 0 and updates % cfg.save_interval_updates == 0) - or (log_training_trajectory and updates in [10, 20, 50, 100, 200, 500]) + checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( + not end_of_epoch + and ((cfg.save_interval_updates > 0 and updates % cfg.save_interval_updates == 0) or (log_training_trajectory and updates in [10, 20, 50, 100, 200, 500])) ) checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = ( val_loss is not None @@ -185,18 +183,12 @@ def _delete_old_checkpoint_files( def verify_shards(cfg, dir=None, checkpoint_name=None, distributed_world_size=None): # verifies that all the shards of the checkpoint are present checkpoint_name = checkpoint_name.replace(".pt", "") - num_gpus = ( - DistributedTrainingConfig.distributed_world_size - if not distributed_world_size - else distributed_world_size - ) + num_gpus = DistributedTrainingConfig.distributed_world_size if not distributed_world_size else distributed_world_size num_shards = 0 for file in os.listdir(dir): if file.startswith(checkpoint_name): num_shards += 1 - logger.info( - f"dir {dir} checkpoint_name {checkpoint_name} num_shards {num_shards} num_gpus {num_gpus}" - ) + logger.info(f'dir {dir} checkpoint_name {checkpoint_name} num_shards {num_shards} num_gpus {num_gpus}') return num_shards == num_gpus @@ -211,12 +203,7 @@ def get_last_good_checkpoint(cfg, distributed_world_size): unique_checkpoints = set(unique_checkpoints) unique_checkpoints = sorted(unique_checkpoints, reverse=True) for checkpoint in unique_checkpoints: - if verify_shards( - cfg, - dir=cfg.save_dir, - checkpoint_name=checkpoint, - distributed_world_size=distributed_world_size, - ): + if verify_shards(cfg, dir=cfg.save_dir, checkpoint_name=checkpoint, distributed_world_size=distributed_world_size): return checkpoint # no good checkpoints available, first launch return None @@ -253,10 +240,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): ) first_launch = not PathManager.exists(checkpoint_path_to_load) if not first_launch and not verify_shards( - cfg, - dir=cfg.save_dir, - checkpoint_name=cfg.restore_file, - distributed_world_size=distributed_world_size, + cfg, dir=cfg.save_dir, checkpoint_name=cfg.restore_file, distributed_world_size=distributed_world_size ): # checkpoint_last is corrupted best_checkpoint = get_last_good_checkpoint(cfg, distributed_world_size) @@ -268,18 +252,13 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): else: first_launch = True - elif ( - first_launch - and get_last_good_checkpoint(cfg, distributed_world_size) is not None - ): + elif first_launch and get_last_good_checkpoint(cfg, distributed_world_size) is not None: # possible past checkpoint to load from cfg.restore_file = os.path.join( - cfg.save_dir, - get_last_good_checkpoint(cfg, distributed_world_size) + ".pt", + cfg.save_dir, get_last_good_checkpoint(cfg, distributed_world_size) + ".pt" ) checkpoint_path_to_load = os.path.join( - cfg.save_dir, - get_last_good_checkpoint(cfg, distributed_world_size) + suffix + ".pt", + cfg.save_dir, get_last_good_checkpoint(cfg, distributed_world_size) + suffix + ".pt" ) first_launch = False @@ -310,12 +289,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): # --restore-file was passed checkpoint_name = cfg.restore_file.split("/")[-1] dir = cfg.restore_file.replace("/" + checkpoint_name, "") - if verify_shards( - cfg, - dir=dir, - checkpoint_name=checkpoint_name, - distributed_world_size=distributed_world_size, - ): + if verify_shards(cfg, dir=dir, checkpoint_name=checkpoint_name, distributed_world_size=distributed_world_size): # when the checkpoint passed by the user is not corrupted checkpoint_path_to_load = cfg.restore_file.replace(".pt", suffix + ".pt") else: @@ -325,10 +299,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): if get_last_good_checkpoint(cfg, distributed_world_size) is not None: # checkpoint passed by user is corrupted but there is a good checkpoint to fall back on checkpoint_path_to_load = os.path.join( - cfg.save_dir, - get_last_good_checkpoint(cfg, distributed_world_size) - + suffix - + ".pt", + cfg.save_dir, get_last_good_checkpoint(cfg, distributed_world_size) + suffix + ".pt" ) else: first_launch = True @@ -339,25 +310,20 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): else: checkpoint_name = cfg.restore_file.split("/")[-1] dir = cfg.restore_file.replace("/" + checkpoint_name, "") - if verify_shards( - cfg, - dir=dir, - checkpoint_name=checkpoint_name, - distributed_world_size=distributed_world_size, - ): + if verify_shards(cfg, dir=dir, checkpoint_name=checkpoint_name, distributed_world_size=distributed_world_size): checkpoint_path_to_load = cfg.restore_file else: logger.warning("Passed checkpoint is corrupted or does not exist") first_launch = True checkpoint_path_to_load = cfg.restore_file - # Commenting this because it fails when checkpoint_last isnt saved i.e. + # Commenting this because it fails when checkpoint_last isnt saved i.e. # when we save per-interval checkpoints # if cfg.restore_file != default_restore_file and cfg.finetune_from_model: - # raise ValueError( - # "--finetune-from-model and --restore-file (non-default value) " - # "can not be specified together: " + str(cfg) - # ) + # raise ValueError( + # "--finetune-from-model and --restore-file (non-default value) " + # "can not be specified together: " + str(cfg) + # ) # Azure logic try: