Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions metaseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,12 @@ class DatasetConfig(MetaseqDataclass):
"help": "if set, validate language model at the beginning of training or fine-tuning process"
},
)
log_training_trajectory: bool = field(
default=False,
metadata={
"help": "(InstructOPT specific) if set, evaluate and save checkpoints more frequently in early stage of training"
},
)
fixed_validation_seed: Optional[int] = field(
default=None, metadata={"help": "specified random seed for validation"}
)
Expand Down
8 changes: 8 additions & 0 deletions metaseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ def validate_and_save(
and num_updates >= cfg.dataset.validate_after_updates
and was_successful_step
)
or (
cfg.dataset.log_training_trajectory

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to save? Just validate?

@todpole3 todpole3 Dec 12, 2022

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To plot the trajectory for generation tasks as well (ROUGE); eval will be fast

and num_updates in [10, 20, 50, 100, 200, 500]
)
)
do_validate = (
(
Expand All @@ -432,6 +436,10 @@ def validate_and_save(
and was_successful_step
)
or (cfg.dataset.validate_at_beginning and num_updates == 0)
or (
cfg.dataset.log_training_trajectory
and num_updates in [0, 10, 20, 50, 100, 200, 500]
)
) and not cfg.dataset.disable_validation
valid_losses = [None]
if do_validate:
Expand Down