Skip to content

Commit

Permalink
mp sync params & grads & opt states. (#51428)
Browse files Browse the repository at this point in the history
  • Loading branch information
wuhuachaocoding authored Apr 11, 2023
1 parent f80a0fe commit 6b74cf7
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 1 deletion.
8 changes: 8 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,19 @@ message ShardingConfig {
optional bool enable_tuning = 15 [ default = false ]; // incubate for auto parallel
}

// for dygraph
message MpConfig {
optional bool sync_param= 1 [ default = false ];
optional bool sync_grad= 2 [ default = false ];
optional bool sync_moment= 3 [ default = false ];
}

message HybridConfig {
optional int32 dp_degree = 1 [ default = -1 ];
optional int32 mp_degree = 2 [ default = 1 ];
optional int32 pp_degree = 3 [ default = 1 ];
optional int32 sharding_degree = 4 [ default = 1 ];
optional MpConfig mp_configs = 5;
}

message AMPConfig {
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/distributed/fleet/base/distributed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,12 @@ def hybrid_configs(self, configs):
check_configs_key(
self.strategy.hybrid_configs, hybrid_config, "hybrid_configs"
)

if "mp_configs" in configs:
assign_configs_value(
self.strategy.hybrid_configs.mp_configs, configs["mp_configs"]
)
configs.pop("mp_configs")
assign_configs_value(self.strategy.hybrid_configs, configs)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import paddle
from paddle import framework
from paddle.autograd import no_grad
from paddle.distributed import fleet
from paddle.framework import core
from paddle.nn import ClipGradByGlobalNorm, clip

Expand Down Expand Up @@ -292,6 +294,83 @@ def __init__(self, optimizer, hcg, strategy):
self._inner_opt._grad_clip, hcg
)

def _filter_fn(self, param):
p_name = param.name
tar_param = ["embedding", "layer_norm", ".b_"]
if param.is_distributed is False:
for tar in tar_param:
if tar in p_name:
return True
return False

def _step(self, parameters_list):
mp_group = self._hcg.get_model_parallel_group()
src_rank = self._hcg.get_model_parallel_group_src_rank()
params = None
mp_configs = None

if mp_group.nranks > 1:
mp_configs = fleet.fleet._user_defined_strategy.hybrid_configs[
"mp_configs"
]

if mp_configs and (
mp_configs.sync_param
or mp_configs.sync_grad
or mp_configs.sync_moment
):
params = sorted(
[p for p in parameters_list if self._filter_fn(p)],
key=lambda p: p.name,
)

if mp_group.nranks > 1 and mp_configs and mp_configs.sync_grad:
for p in params:
if p.grad is None:
continue
paddle.distributed.broadcast(
p.grad, src=src_rank, group=mp_group, sync_op=True
)

self._inner_opt.step()

if mp_group.nranks > 1 and mp_configs and mp_configs.sync_param:
for p in params:
paddle.distributed.broadcast(
p, src=src_rank, group=mp_group, sync_op=True
)

if mp_group.nranks > 1 and mp_configs and mp_configs.sync_moment:
for p in params:
# support opt state of adam and adamw to broadcast now.
if isinstance(
self._inner_opt,
(paddle.optimizer.Adam, paddle.optimizer.AdamW),
):
if (
self._inner_opt._multi_precision
and p.name in self._master_weights
):
paddle.distributed.broadcast(
self._inner_opt._master_weights[p.name],
src=src_rank,
group=mp_group,
sync_op=True,
)

moment1 = self._inner_opt._get_accumulator(
self._inner_opt._moment1_acc_str, p
)
moment2 = self._inner_opt._get_accumulator(
self._inner_opt._moment2_acc_str, p
)
paddle.distributed.broadcast(
moment1, src=src_rank, group=mp_group, sync_op=True
)
paddle.distributed.broadcast(
moment2, src=src_rank, group=mp_group, sync_op=True
)

@no_grad()
@framework.dygraph_only
def step(self):
Expand All @@ -302,7 +381,7 @@ def step(self):
if self._dp_enable:
fused_allreduce_gradients(list(parameters_list), self._hcg)

self._inner_opt.step()
self._step(parameters_list)

@no_grad()
def minimize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,150 @@ def forward(self, x):
return x


class TestDistMPSyncTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
self.data_parallel_size = 1
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
"mp_configs": {
"sync_param": False,
"sync_grad": False,
"sync_moment": False,
},
}
fleet.init(is_collective=True, strategy=strategy)

def build_model_optimizer_train(
self,
batchs,
fp16=False,
mp_sync_param=False,
mp_sync_grad=False,
mp_sync_moment=False,
):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
mp_id = hcg.get_model_parallel_rank()
dp_id = hcg.get_data_parallel_rank()
rank_id = dist.get_rank()
paddle.seed(2023)
np.random.seed(2023)
random.seed(2023)
set_random_seed(1024, dp_id, rank_id)

np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))

model = SimpleMPNet(
vocab_size,
hidden_size,
inner_size,
output_size,
np_fc1,
np_fc2,
mp_id,
)
optimizer = paddle.optimizer.AdamW(
learning_rate=0.1, parameters=model.parameters()
)

strategy = fleet.fleet._user_defined_strategy
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
"mp_configs": {
"sync_param": mp_sync_param,
"sync_grad": mp_sync_grad,
"sync_moment": mp_sync_moment,
},
}

model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
return self.train_batch(batchs, model, optimizer, fp16)

def train_batch(self, batchs, model, optimizer, fp16=False):
losses = []
if fp16:
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
scaler = fleet.distributed_scaler(scaler)
for batch in batchs:
with paddle.amp.auto_cast(enable=fp16, level='O1'):
output = model(batch)
loss = output.mean()
losses.append(loss.numpy())
if fp16:
scaled = scaler.scale(loss)
scaled.backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
optimizer.clear_grad()
return losses

def mp_sync_base(
self, mp_sync_param=False, mp_sync_grad=False, mp_sync_moment=False
):
batchs = []
for _ in range(5):
np_data = np.random.randint(
0,
vocab_size,
(
batch_size,
seq_length,
),
)
batchs.append(paddle.to_tensor(np_data))

losses = self.build_model_optimizer_train(batchs)
losses_sync = self.build_model_optimizer_train(
batchs,
mp_sync_param=mp_sync_param,
mp_sync_grad=mp_sync_grad,
mp_sync_moment=mp_sync_moment,
)

for i in range(len(losses)):
np.testing.assert_allclose(losses[i], losses_sync[i], rtol=1e-6)

# test fp16
losses_fp16 = self.build_model_optimizer_train(batchs, fp16=True)
losses_sync_fp16 = self.build_model_optimizer_train(
batchs,
fp16=True,
mp_sync_param=mp_sync_param,
mp_sync_grad=mp_sync_grad,
mp_sync_moment=mp_sync_moment,
)

for i in range(len(losses_fp16)):
np.testing.assert_allclose(
losses_fp16[i], losses_sync_fp16[i], rtol=1e-6
)

def test_mp_sync_param(self):
self.mp_sync_base(mp_sync_param=True)

def test_mp_sync_grad(self):
self.mp_sync_base(mp_sync_grad=True)

def test_mp_sync_moment(self):
self.mp_sync_base(mp_sync_moment=True)

def test_mp_sync_all(self):
self.mp_sync_base(
mp_sync_param=True, mp_sync_grad=True, mp_sync_moment=True
)


class TestDistMPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
Expand Down

0 comments on commit 6b74cf7

Please sign in to comment.