From 6b74cf76cbaf521cd34633a572acb6abbbd124d8 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Tue, 11 Apr 2023 12:11:47 +0800 Subject: [PATCH] mp sync params & grads & opt states. (#51428) --- .../framework/distributed_strategy.proto | 8 + .../fleet/base/distributed_strategy.py | 6 + .../hybrid_parallel_optimizer.py | 81 +++++++++- .../fleet/hybrid_parallel_mp_model.py | 144 ++++++++++++++++++ 4 files changed, 238 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index b9055d38d38c5..de2e38c2f1165 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -50,11 +50,19 @@ message ShardingConfig { optional bool enable_tuning = 15 [ default = false ]; // incubate for auto parallel } +// for dygraph +message MpConfig { + optional bool sync_param= 1 [ default = false ]; + optional bool sync_grad= 2 [ default = false ]; + optional bool sync_moment= 3 [ default = false ]; +} + message HybridConfig { optional int32 dp_degree = 1 [ default = -1 ]; optional int32 mp_degree = 2 [ default = 1 ]; optional int32 pp_degree = 3 [ default = 1 ]; optional int32 sharding_degree = 4 [ default = 1 ]; + optional MpConfig mp_configs = 5; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 0f09440e4337c..86292a2d90e79 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -1696,6 +1696,12 @@ def hybrid_configs(self, configs): check_configs_key( self.strategy.hybrid_configs, hybrid_config, "hybrid_configs" ) + + if "mp_configs" in configs: + assign_configs_value( + self.strategy.hybrid_configs.mp_configs, configs["mp_configs"] + ) + configs.pop("mp_configs") assign_configs_value(self.strategy.hybrid_configs, configs) @property diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index 98604b8db3d8c..acd34f1b1d5b8 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. + import paddle from paddle import framework from paddle.autograd import no_grad +from paddle.distributed import fleet from paddle.framework import core from paddle.nn import ClipGradByGlobalNorm, clip @@ -292,6 +294,83 @@ def __init__(self, optimizer, hcg, strategy): self._inner_opt._grad_clip, hcg ) + def _filter_fn(self, param): + p_name = param.name + tar_param = ["embedding", "layer_norm", ".b_"] + if param.is_distributed is False: + for tar in tar_param: + if tar in p_name: + return True + return False + + def _step(self, parameters_list): + mp_group = self._hcg.get_model_parallel_group() + src_rank = self._hcg.get_model_parallel_group_src_rank() + params = None + mp_configs = None + + if mp_group.nranks > 1: + mp_configs = fleet.fleet._user_defined_strategy.hybrid_configs[ + "mp_configs" + ] + + if mp_configs and ( + mp_configs.sync_param + or mp_configs.sync_grad + or mp_configs.sync_moment + ): + params = sorted( + [p for p in parameters_list if self._filter_fn(p)], + key=lambda p: p.name, + ) + + if mp_group.nranks > 1 and mp_configs and mp_configs.sync_grad: + for p in params: + if p.grad is None: + continue + paddle.distributed.broadcast( + p.grad, src=src_rank, group=mp_group, sync_op=True + ) + + self._inner_opt.step() + + if mp_group.nranks > 1 and mp_configs and mp_configs.sync_param: + for p in params: + paddle.distributed.broadcast( + p, src=src_rank, group=mp_group, sync_op=True + ) + + if mp_group.nranks > 1 and mp_configs and mp_configs.sync_moment: + for p in params: + # support opt state of adam and adamw to broadcast now. + if isinstance( + self._inner_opt, + (paddle.optimizer.Adam, paddle.optimizer.AdamW), + ): + if ( + self._inner_opt._multi_precision + and p.name in self._master_weights + ): + paddle.distributed.broadcast( + self._inner_opt._master_weights[p.name], + src=src_rank, + group=mp_group, + sync_op=True, + ) + + moment1 = self._inner_opt._get_accumulator( + self._inner_opt._moment1_acc_str, p + ) + moment2 = self._inner_opt._get_accumulator( + self._inner_opt._moment2_acc_str, p + ) + paddle.distributed.broadcast( + moment1, src=src_rank, group=mp_group, sync_op=True + ) + paddle.distributed.broadcast( + moment2, src=src_rank, group=mp_group, sync_op=True + ) + @no_grad() @framework.dygraph_only def step(self): @@ -302,7 +381,7 @@ def step(self): if self._dp_enable: fused_allreduce_gradients(list(parameters_list), self._hcg) - self._inner_opt.step() + self._step(parameters_list) @no_grad() def minimize( diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_mp_model.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_mp_model.py index dec1eb949ddb8..26e740bfa6b79 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_mp_model.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_mp_model.py @@ -181,6 +181,150 @@ def forward(self, x): return x +class TestDistMPSyncTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + "mp_configs": { + "sync_param": False, + "sync_grad": False, + "sync_moment": False, + }, + } + fleet.init(is_collective=True, strategy=strategy) + + def build_model_optimizer_train( + self, + batchs, + fp16=False, + mp_sync_param=False, + mp_sync_grad=False, + mp_sync_moment=False, + ): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + mp_id = hcg.get_model_parallel_rank() + dp_id = hcg.get_data_parallel_rank() + rank_id = dist.get_rank() + paddle.seed(2023) + np.random.seed(2023) + random.seed(2023) + set_random_seed(1024, dp_id, rank_id) + + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model = SimpleMPNet( + vocab_size, + hidden_size, + inner_size, + output_size, + np_fc1, + np_fc2, + mp_id, + ) + optimizer = paddle.optimizer.AdamW( + learning_rate=0.1, parameters=model.parameters() + ) + + strategy = fleet.fleet._user_defined_strategy + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + "mp_configs": { + "sync_param": mp_sync_param, + "sync_grad": mp_sync_grad, + "sync_moment": mp_sync_moment, + }, + } + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + return self.train_batch(batchs, model, optimizer, fp16) + + def train_batch(self, batchs, model, optimizer, fp16=False): + losses = [] + if fp16: + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + scaler = fleet.distributed_scaler(scaler) + for batch in batchs: + with paddle.amp.auto_cast(enable=fp16, level='O1'): + output = model(batch) + loss = output.mean() + losses.append(loss.numpy()) + if fp16: + scaled = scaler.scale(loss) + scaled.backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + optimizer.clear_grad() + return losses + + def mp_sync_base( + self, mp_sync_param=False, mp_sync_grad=False, mp_sync_moment=False + ): + batchs = [] + for _ in range(5): + np_data = np.random.randint( + 0, + vocab_size, + ( + batch_size, + seq_length, + ), + ) + batchs.append(paddle.to_tensor(np_data)) + + losses = self.build_model_optimizer_train(batchs) + losses_sync = self.build_model_optimizer_train( + batchs, + mp_sync_param=mp_sync_param, + mp_sync_grad=mp_sync_grad, + mp_sync_moment=mp_sync_moment, + ) + + for i in range(len(losses)): + np.testing.assert_allclose(losses[i], losses_sync[i], rtol=1e-6) + + # test fp16 + losses_fp16 = self.build_model_optimizer_train(batchs, fp16=True) + losses_sync_fp16 = self.build_model_optimizer_train( + batchs, + fp16=True, + mp_sync_param=mp_sync_param, + mp_sync_grad=mp_sync_grad, + mp_sync_moment=mp_sync_moment, + ) + + for i in range(len(losses_fp16)): + np.testing.assert_allclose( + losses_fp16[i], losses_sync_fp16[i], rtol=1e-6 + ) + + def test_mp_sync_param(self): + self.mp_sync_base(mp_sync_param=True) + + def test_mp_sync_grad(self): + self.mp_sync_base(mp_sync_grad=True) + + def test_mp_sync_moment(self): + self.mp_sync_base(mp_sync_moment=True) + + def test_mp_sync_all(self): + self.mp_sync_base( + mp_sync_param=True, mp_sync_grad=True, mp_sync_moment=True + ) + + class TestDistMPTraning(unittest.TestCase): def setUp(self): strategy = fleet.DistributedStrategy()