Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
dc55c3f
- support seq_len > 2048 (4096 and 8192)
urielsinger May 15, 2023
1a63cb5
old/new tokens conversion
urielsinger May 22, 2023
6fba607
fsdp double wrap disable
urielsinger May 22, 2023
3c40198
- local symlink
urielsinger May 27, 2023
d464d96
- local symlink
urielsinger May 27, 2023
f37dcc2
Merge remote-tracking branch 'origin/main'
urielsinger Jun 4, 2023
56f8b54
pr fix
urielsinger Jun 5, 2023
c5a58b6
revert "force_distributed=True"
urielsinger Jun 5, 2023
f0b5275
fixed
adampolyak Jun 7, 2023
d54eca8
Merge remote-tracking branch 'origin/main'
urielsinger Jun 11, 2023
d30921d
improve free port finding for single node dist init
adampolyak Jun 11, 2023
865c4b3
Merge remote-tracking branch 'origin/cm3_seq_len'
urielsinger Jun 13, 2023
75b74e9
- pytorch FSDP support
urielsinger Jun 21, 2023
61f8792
fix bug
urielsinger Jun 21, 2023
4a677a4
fix bug
urielsinger Jun 25, 2023
af57884
back to fairscale
urielsinger Jul 4, 2023
59556ee
back to fairscale
urielsinger Jul 4, 2023
82d4c77
fix delete_old_checkpoint_files
urielsinger Jul 9, 2023
a72de97
stop training when loss_scale reached minimum
urielsinger Jul 10, 2023
00df75a
stop training when loss_scale reached minimum
urielsinger Jul 10, 2023
bc68a84
add validate_on_first_step support
urielsinger Aug 3, 2023
d5f50e8
fix for single files
adampolyak Nov 3, 2023
a42d648
add no_c10d support
urielsinger Nov 22, 2023
2c484fa
Merge remote-tracking branch 'origin/cm3_seq_len' into cm3_seq_len
urielsinger Nov 22, 2023
7e7b5e3
criterion fsdp
urielsinger Dec 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion metaseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffi
if not end_of_epoch and cfg.keep_last_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
cfg.save_dir, pattern=r"checkpoint_(\d+){}\.pt".format(suffix)
)
for old_chk in checkpoints[cfg.keep_last_updates :]:
if os.path.lexists(old_chk):
Expand Down
39 changes: 32 additions & 7 deletions metaseq/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from metaseq.data import iterators, data_utils
from metaseq.data.plasma_utils import PlasmaStore
from metaseq.dataclass.utils import convert_namespace_to_omegaconf
from metaseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils
from metaseq.distributed import fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel, utils as distributed_utils
from metaseq.file_io import PathManager
from metaseq.logging import meters, metrics, progress_bar
from metaseq.trainer import Trainer
Expand Down Expand Up @@ -144,15 +144,36 @@ def main(cfg: DictConfig) -> None:
cfg.distributed_training,
use_sharded_state=cfg.distributed_training.use_sharded_state,
):
model = fsdp_wrap(
task.build_model(cfg.model),
process_group=distributed_utils.get_data_parallel_group(),
)
model = task.build_model(cfg.model)
if not isinstance(model, FullyShardedDataParallel):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, this is for loading up consolidated model for training?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.
I added support to change the MP size during job lunch, and for that I need to wrap it in FullyShardedDataParallel inside the build_model.
As I don't want to double wrap it, I needed to add this if..

model = fsdp_wrap(
model,
process_group=distributed_utils.get_data_parallel_group(),
)
else:
model = task.build_model(cfg.model)

# TODO[Susan]: FSDP on criterion?
criterion = task.build_criterion(cfg.criterion)
if cfg.distributed_training.criterion_ddp_backend == "fully_sharded":
# As the task is non-trainable, we switch flags to more optimized ones.
# See https://github.com/facebookresearch/metaseq/pull/668 for when/why this was added.
orig_memory_efficient_fp16 = cfg.distributed_training.memory_efficient_fp16
orig_fp32_reduce_scatter = cfg.distributed_training.fp32_reduce_scatter
# Clobber memory_efficient_fp16 and fp32_reduce_scatter
cfg.distributed_training.memory_efficient_fp16 = cfg.distributed_training.fp16
cfg.distributed_training.fp32_reduce_scatter = not cfg.distributed_training.fp16

with fsdp_enable_wrap(
cfg.distributed_training,
use_sharded_state=cfg.distributed_training.use_sharded_state,
):
criterion = task.build_criterion(cfg.criterion)

# Reset memory_efficient_fp16 and fp32_reduce_scatter values.
cfg.distributed_training.memory_efficient_fp16 = orig_memory_efficient_fp16
cfg.distributed_training.fp32_reduce_scatter = orig_fp32_reduce_scatter
else:
criterion = task.build_criterion(cfg.criterion)


logger.info(model)
logger.info("task: {}".format(task.__class__.__name__))
Expand Down Expand Up @@ -483,6 +504,10 @@ def validate_and_save(
and num_updates % cfg.dataset.validate_interval_updates == 0
and was_successful_step
)
or (
num_updates == cfg.dataset.validate_on_first_step
and was_successful_step
)
) and not cfg.dataset.disable_validation

# Save checkpoint before validating.
Expand Down
54 changes: 35 additions & 19 deletions metaseq/data/cm3_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are changes to the cm3 objectives that i landed in scaling_racm3 correct?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.


import numpy as np
import random
import torch

from typing import List, Optional, Tuple
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
to_skip=0,
permute_documents=True,
source_target=False,
percent_full_document_rotation: float = 0.0
):
super().__init__(
dataset,
Expand Down Expand Up @@ -106,6 +108,7 @@ def __init__(
self.sentinel_fixed = self.sentinel_method == "fixed"
self.allow_rotation_across_eod = allow_rotation_across_eod
self.eod = eod
self.percent_full_document_rotation = percent_full_document_rotation

def get_sentinel(self, i):
return self.sentinel_tokens[i]
Expand Down Expand Up @@ -139,7 +142,8 @@ def sentinel_targets(self, document: torch.Tensor, spans: List[Tuple[int, int]])
index = index + size + 1
return target

def get_spans_to_mask(self, document_length: int) -> List[Tuple[int, int]]:
def get_spans_to_mask(self, document_length: int, document_boundaries: List[Tuple[int, int]]) -> List[
Tuple[int, int]]:
# Ok, we do not use a budget here but instead
# our goal is to sample from ~ U[0,1] in the case of len(sentinel_tokens) = 1
# If len(sentinel_tokens) > 1 we try to find len(sentinel_tokens) non intersecting spans
Expand All @@ -156,18 +160,23 @@ def get_spans_to_mask(self, document_length: int) -> List[Tuple[int, int]]:
if len_sentinel_tokens == 0:
return None
if len_sentinel_tokens == 1:
if np.random.random() < self.percent_full_document_rotation:
return [random.choice(document_boundaries)]

start, end = np.random.uniform(size=2)
if end < start:
start, end = end, start
# round down
start = int(start * document_length)
start = max(1, int(start * document_length))
# round up
end = int(end * document_length + 0.5)
if start == end:
return None
else:
assert start < end
return [(start, end)]
if len_sentinel_tokens < len(document_boundaries) and np.random.random() < self.percent_full_document_rotation:
return random.sample(document_boundaries, len_sentinel_tokens)

# Let's implement the general case. We will create len(self.sentinel_tokens) ** 2 possible candidates
# And we will filter one by one to insure no intersections. If we can't find anything then so be it.
Expand Down Expand Up @@ -200,24 +209,31 @@ def get_document_boundaries(self, item: torch.Tensor):
boundaries = boundaries + [item.size(0)]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is get_document_boundaries() robust to the case that there is no break tokens

spans = []
for i in range(1, len(boundaries)):
spans.append((boundaries[i - 1], boundaries[i]))
spans.append((boundaries[i - 1] + 1, boundaries[i]))
return spans

def cm3_shuffle(self, item):
assert len(item) > 0
document_boundaries = self.get_document_boundaries(item)
spans = self.get_spans_to_mask(len(item), document_boundaries)
if not self.allow_rotation_across_eod and spans is not None:
spans = adjust_spans(spans, document_boundaries)
if spans is None:
return item
else:
spans = self.get_ordered_spans(spans)
causal_source = self.sentinel_masking(item, spans)
causal_masked = self.sentinel_targets(item, spans)

total_count = len(causal_source) + len(causal_masked)
total_diff = total_count - self.tokens_per_sample
total_causal_length = len(causal_source) - total_diff
return torch.cat([
causal_source[:total_causal_length],
causal_masked
])[: self.tokens_per_sample] # EOSS tokens can add just enough tokens to get off by 1-2.

def __iter__(self):
for packed_item in super().__iter__():
item = packed_item["block"]
assert len(item) > 0
spans = self.get_spans_to_mask(len(item))
if not self.allow_rotation_across_eod:
document_boundaries = self.get_document_boundaries(item)
spans = adjust_spans(spans, document_boundaries)
if spans is None:
yield packed_item
else:
spans = self.get_ordered_spans(spans)
causal_source = self.sentinel_masking(item, spans)
causal_masked = self.sentinel_targets(item, spans)
packed_item["block"] = torch.cat([causal_source, causal_masked])[
: self.tokens_per_sample
] # EOSS tokens can add just enough tokens to get off by 1-2.
yield packed_item
packed_item["block"] = self.cm3_shuffle(packed_item["block"])
yield packed_item
3 changes: 3 additions & 0 deletions metaseq/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path

try:
from collections.abc import Iterable
Expand Down Expand Up @@ -279,6 +280,8 @@ def _find_extra_valid_paths(dataset_path: str) -> set:
for sub_dir in paths:
if "://" in sub_dir:
continue
if not Path(sub_dir).is_dir():
continue
contents = PathManager.ls(sub_dir)
valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
all_valid_paths |= {os.path.basename(p) for p in valid_paths}
Expand Down
7 changes: 7 additions & 0 deletions metaseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ class DistributedTrainingConfig(MetaseqDataclass):
default="none",
metadata={"help": "If set to fully_sharded, will fsdp wrap task."},
)
criterion_ddp_backend: TASK_DDP_BACKEND_CHOICES = field(
default="none",
metadata={"help": "If set to fully_sharded, will fsdp wrap task."},
)
bucket_cap_mb: int = field(
default=25, metadata={"help": "bucket size for reduction"}
)
Expand Down Expand Up @@ -375,6 +379,9 @@ class DatasetConfig(MetaseqDataclass):
validate_interval_updates: int = field(
default=0, metadata={"help": "validate every N updates"}
)
validate_on_first_step: int = field(
default=-1, metadata={"help": "validate on first step. default not to validate."}
)
validate_after_updates: int = field(
default=0, metadata={"help": "dont validate until reaching this many updates"}
)
Expand Down
2 changes: 2 additions & 0 deletions metaseq/dataclass/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def ChoiceEnum(choices: List[str]):
"c10d", # alias for pytorch_ddp
"fully_sharded", # FullyShardedDataParallel from fairscale
"pytorch_ddp",
"no_c10d",
"legacy_ddp",
]
)

Expand Down
2 changes: 2 additions & 0 deletions metaseq/distributed/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def fsdp_enable_wrap(
cfg: DistributedTrainingConfig, use_sharded_state: bool = False, **kwargs
):
try:
# from torch.distributed.fsdp.wrap import enable_wrap
# from torch.distributed.fsdp import MixedPrecision
from fairscale.nn import enable_wrap
except ImportError:
raise ImportError(
Expand Down
Loading