diff --git a/unicore/distributed/bp.py b/unicore/distributed/bp.py new file mode 100644 index 0000000..68a6797 --- /dev/null +++ b/unicore/distributed/bp.py @@ -0,0 +1,73 @@ +import torch +import torch.distributed as dist +from torch._C._distributed_c10d import BroadcastOptions, AllreduceOptions, ReduceOp +from .comm_group import scg + +def broadcast(tensor, src): + """ broadcast tensor from src rank in bp group """ + if scg.get_bp_world_size() == 1: + return tensor + + assert src in [0, 1], "Branch Parallel is only support bp_degree=2 now!" + + group = scg.get_bp_group() + + opts = BroadcastOptions() + opts.rootRank = src + opts.rootTensor = 0 + work = group.broadcast([tensor], opts) + work.wait() + +def all_reduce(tensor): + """ allreduce a tensor in bp group """ + if scg.get_bp_world_size() == 1: + return tensor + + group = scg.get_bp_group() + + opts = AllreduceOptions() + opts.reduceOp = ReduceOp.SUM + + work = group.allreduce([tensor], opts) + work.wait() + + return tensor + +class SyncEvoformerResults(torch.autograd.Function): + """ A PyLayer Op broadcast gradient in backward stage """ + @staticmethod + def forward(ctx, outer, msa, pair, training): + broadcast(outer, 0) + if scg.get_bp_rank_in_group() == 1: + if training: + pair = pair + outer + else: + pair += outer + broadcast(pair, 1) + broadcast(msa, 0) + return msa.clone(), pair.clone() + + @staticmethod + def backward(ctx, *grad_output): + msa_grad = grad_output[0] + pair_grad = grad_output[1] + + if scg.get_bp_rank_in_group() == 0: + pair_grad = torch.zeros_like(pair_grad) + + outer_grad = pair_grad.clone() + broadcast(outer_grad, 1) + + return outer_grad.clone(), msa_grad.clone(), pair_grad.clone() + +def sync_evoformer_results(outer, msa, pair, training): + """ a warpper for boradcast gradient in backward stage """ + if scg.get_bp_world_size() == 1: + return msa, pair + + if torch.is_grad_enabled() and outer.requires_grad and msa.requires_grad and pair.requires_grad: + return msa, pair + + msa, pair = SyncEvoformerResults.apply(outer, msa, pair, training) + + return msa, pair diff --git a/unicore/distributed/comm_group.py b/unicore/distributed/comm_group.py new file mode 100644 index 0000000..ae6050c --- /dev/null +++ b/unicore/distributed/comm_group.py @@ -0,0 +1,173 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Communication group manager +""" + +import numpy as np +import torch.distributed as dist + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator) + +class SingletonCommunicationGroup(object): + """ A singleton communication group for bp, dap, ddp hybrid parallel. """ + def __init__(self): + self.initialized = False + + def init_group(self, bp_degree=1, dap_degree=1, dap_comm_sync=True): + """ init the hybrid parallel, it will auto calculate ddp_degree using bp_degree, dap_degree and world_size """ + assert self.initialized == False, "Communication group is already initialized!" + + # check valid config + world_size = dist.get_world_size() + rank = dist.get_rank() + inner_degree = bp_degree * dap_degree + ensure_divisibility(world_size, bp_degree) + ensure_divisibility(world_size, dap_degree) + ensure_divisibility(world_size, inner_degree) + + self.dp_degree = world_size // inner_degree + self.bp_degree = bp_degree + self.dap_degree = dap_degree + self.dap_comm_sync = dap_comm_sync + + arr = np.arange(0, world_size).reshape(self.dp_degree, self.dap_degree, self.bp_degree) + + # build bp group + bp_arr = arr.transpose((0, 1, 2)).reshape(-1, self.bp_degree) + for i in range(world_size // self.bp_degree): + ranks = bp_arr[i].tolist() + group = dist.new_group(ranks) + print('> bp ranks:', ranks, 'bp group:', group) + if rank in ranks: + self.bp_group = group + + # build dap group + dap_arr = arr.transpose((0, 2, 1)).reshape(-1, self.dap_degree) + for i in range(world_size // self.dap_degree): + ranks = dap_arr[i].tolist() + group = dist.new_group(ranks) + print('> dap ranks:', ranks, 'dap group:', group) + if rank in ranks: + self.dap_group = group + + # build dp group + dp_arr = arr.transpose((1, 2, 0)).reshape(-1, self.dp_degree) + for i in range(world_size // self.dp_degree): + ranks = dp_arr[i].tolist() + group = dist.new_group(ranks) + print('> dp ranks:', ranks, 'dp group:', group) + if rank in ranks: + self.dp_group = group + + self.initialized = True + if dist.get_rank() == 0: + print('> initialize branch parallel with size {}'.format(self.bp_degree)) + print('> initialize dynamic axial parallel with size {}'.format(self.dap_degree)) + print('> initialize data parallel with size {}'.format(self.dp_degree)) + + def dap_is_comm_sync(self): + """ get dap whether use sync or async communication """ + return self.dap_comm_sync + + def bp_is_initialized(self): + """ get bp commnication group whether is initialized """ + return self.initialized + + def dap_is_initialized(self): + """ get dap commnication group whether is initialized """ + return self.initialized + + def dp_is_initialized(self): + """ get dp commnication group whether is initialized """ + return self.initialized + + def is_initialized(self): + """ get hybird commnication group whether is initialized """ + return self.initialized + + def get_bp_group(self): + """ get bp commnication group """ + assert self.initialized == True, "bp group is not initialized!" + return self.bp_group + + def get_bp_rank(self): + """ get bp rank id in global group """ + if not self.initialized: + return 0 + return self.bp_group.rank + + def get_bp_rank_in_group(self): + """ get bp rank id in bp group """ + if not self.initialized: + return -1 + return dist.get_rank(self.bp_group) + + def get_bp_world_size(self): + """ get bp world size in bp group """ + if not self.initialized: + return 1 + return dist.get_world_size(self.bp_group) + + def get_dap_group(self): + """ get dap commnication group """ + assert self.initialized == True, "dap group is not initialized!" + return self.dap_group + + def get_dap_rank(self): + """ get dap rank id in global group """ + if not self.initialized: + return 0 + return self.dap_group.rank + + def get_dap_rank_in_group(self): + """ get dap rank id in dap group """ + if not self.initialized: + return -1 + return dist.get_rank(self.dap_group) + + def get_dap_world_size(self): + """ get dap world size in dap group """ + if not self.initialized: + return 1 + return dist.get_world_size(self.dap_group) + + def get_dp_group(self): + """ get ddp commnication group """ + assert self.initialized == True, "dp group is not initialized!" + return self.dp_group + + def get_dp_rank(self): + """ get ddp rank id in global group """ + if not self.initialized: + return 0 + return self.dp_group.rank + + def get_dp_rank_in_group(self): + """ get ddp rank id in ddp group """ + if not self.initialized: + return -1 + rank = dist.get_rank() + return dist.get_rank(self.dp_group) + + def get_dp_world_size(self): + """ get ddp world size in ddp group """ + if not self.initialized: + return 1 + return dist.get_world_size(self.dp_group) + +scg = SingletonCommunicationGroup() \ No newline at end of file diff --git a/unicore/distributed/utils.py b/unicore/distributed/utils.py index 4c26fc7..ff70400 100644 --- a/unicore/distributed/utils.py +++ b/unicore/distributed/utils.py @@ -21,6 +21,8 @@ import torch import torch.distributed as dist +from .comm_group import scg + logger = logging.getLogger(__name__) @@ -137,6 +139,9 @@ def distributed_init(args): if torch.cuda.is_available(): dist.all_reduce(torch.zeros(1).cuda()) + scg.init_group(bp_degree=args.bp_degree, dap_degree=1) + + args.dp_rank = scg.get_dp_rank_in_group() if torch.distributed.get_world_size() > 1 else 0 args.distributed_rank = torch.distributed.get_rank() if is_master(args): diff --git a/unicore/options.py b/unicore/options.py index c923a93..b5e8d26 100644 --- a/unicore/options.py +++ b/unicore/options.py @@ -294,6 +294,7 @@ def add_distributed_training_args(parser): help="number of GPUs in each node. An allreduce operation across GPUs in " "a node is very fast. Hence, we do allreduce across GPUs in a node, " "and gossip across different nodes") + group.add_argument('--bp-degree', default=1, type=int) # fmt: on return group diff --git a/unicore_cli/train.py b/unicore_cli/train.py index 0871ef6..bfbeb9c 100644 --- a/unicore_cli/train.py +++ b/unicore_cli/train.py @@ -49,6 +49,7 @@ def main(args) -> None: ), "Must specify batch size either with --batch-size" metrics.reset() + args.seed += args.dp_rank np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available():