Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fit sharding optimization for auto parallel llama #8021

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@
The paddle sequence parallel strategy. It can reduce the GPU memory of activation to 1/sep, and it is orthogonal to
data parallel, sharding stage1, tensor parallel and pipeline parallel strategy.
)
data_parallel_config (`str`, *optional*)(
Some additional configs which affect data parallel performance, we provide some option to config it.
following config is support:
enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now.
tensor_parallel_config (`str`, *optional*)(
Some additional configs which affect model parallel performance, we provide some option to config it.
following config is support:
Expand Down Expand Up @@ -571,6 +575,16 @@
)
},
)
data_parallel_config: str = field(
default="",
metadata={
"help": (
"Some additional configs which affect data parallel performance, we provide some option to config it."
"following config is support:\n"
"enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now. \n"
)
},
)
tensor_parallel_config: str = field(
default="",
metadata={
Expand Down Expand Up @@ -951,6 +965,7 @@
# TODO use paddle.distributed.is_initialized() after paddle 2.4rc
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized():
strategy = fleet.DistributedStrategy()
assert self.data_parallel_config == "", "data_parallle_config is not supported in hybrid parallel"

Check warning on line 968 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L968

Added line #L968 was not covered by tests
if self.pipeline_parallel_degree > 1:
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
for x in pipeline_parallel_config:
Expand Down Expand Up @@ -1165,6 +1180,17 @@
warnings.warn("`offload` is not supported NOW!")

strategy = fleet.auto.Strategy()
if self.data_parallel_degree > 1:
data_parallel_config = set(self.data_parallel_config.split(" "))
for x in data_parallel_config:
if len(x) > 0:
if x not in ["enable_allreduce_avg_in_gradinent_scale"]:
raise ValueError(

Check warning on line 1188 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1183-L1188

Added lines #L1183 - L1188 were not covered by tests
f"Found unknown data parallel config {x}, accpet config is enable_allreduce_avg_in_gradinent_scale."
)
if "enable_allreduce_avg_in_gradinent_scale" in data_parallel_config:
strategy.gradient_scale_using_allreduce_avg = True

Check warning on line 1192 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1191-L1192

Added lines #L1191 - L1192 were not covered by tests

# navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1
if self.pipeline_parallel_degree > 1 and self.gradient_accumulation_steps > 1:
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
Expand Down Expand Up @@ -1254,9 +1280,9 @@
for x in sharding_parallel_config:
if len(x) > 0:
if x not in [
# "enable_stage1_tensor_fusion",
# "enable_stage1_overlap",
# "enable_stage2_overlap",
"enable_stage1_tensor_fusion",
"enable_stage1_overlap",
"enable_stage2_overlap",
]:
raise ValueError(
f"Found unknown pipeline mode config {x}, " f"accpet config is reduce_overlap."
Expand All @@ -1266,7 +1292,10 @@
"enable_stage1_overlap" in sharding_parallel_config
or "enable_stage2_overlap" in sharding_parallel_config
):
sharding.reduce_overlap = True
sharding.enable_overlap = True

Check warning on line 1295 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1295

Added line #L1295 was not covered by tests

if "enable_stage1_tensor_fusion" in sharding_parallel_config:
sharding.grad_bucket_size_numel = 210355872

Check warning on line 1298 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1297-L1298

Added lines #L1297 - L1298 were not covered by tests

if self.bf16 or self.fp16:
amp = strategy.amp
Expand Down
Loading