Skip to content

Commit

Permalink
[AutoParallel] add release_gradients and comm_buffer_size_MB to strat…
Browse files Browse the repository at this point in the history
…egy (#69410)
  • Loading branch information
AndSonder authored Nov 18, 2024
1 parent b3ef967 commit 2d21046
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
15 changes: 8 additions & 7 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,13 +2244,14 @@ def __init__(
)
dist.fleet.init(is_collective=True)

if isinstance(optimizer, _ShardOptimizer) and use_pir_api():
shard_fn = optimizer._shard_fn
optimizer = optimizer._inner_opt
if isinstance(optimizer._shard_fn, ShardingStage1):
optimizer = ShardingOptimizerStage1(
optimizer, shard_fn, self._inner_strategy
)
if os.environ.get('FLAGS_enable_sharding_stage1_tensor_fusion', False):
if isinstance(optimizer, _ShardOptimizer) and use_pir_api():
shard_fn = optimizer._shard_fn
optimizer = optimizer._inner_opt
if isinstance(optimizer._shard_fn, ShardingStage1):
optimizer = ShardingOptimizerStage1(
optimizer, shard_fn, self._inner_strategy
)

self._engine = Engine(
layer, loss, optimizer, metrics, strategy=self._inner_strategy
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ class _AMPConfig(TypedDict, total=False): # noqa: PYI049
set_field_default_config(SHARDING, "partition_algor", "greedy_even")
set_field_default_config(SHARDING, "enable_tuning", False)
set_field_default_config(SHARDING, "tuning_range", [])
set_field_default_config(SHARDING, "release_gradients", False)
set_field_default_config(SHARDING, "comm_buffer_size_MB", -1)

if TYPE_CHECKING:

Expand Down
7 changes: 2 additions & 5 deletions python/paddle/distributed/auto_parallel/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,7 @@ def apply_gradients(self, params_grads):
self._place = paddle.base.libpaddle.Place()
self._place.set_place(place)

sharding_config = fleet.fleet._user_defined_strategy.hybrid_configs[
'sharding_configs'
]
comm_buffer_size_MB = sharding_config.comm_buffer_size_MB
comm_buffer_size_MB = self._strategy.sharding.comm_buffer_size_MB
parameters_dict = {}
grads_dict = {}
has_dist_param = False
Expand Down Expand Up @@ -227,7 +224,7 @@ def apply_gradients(self, params_grads):
// align[dtype]
)
align_size = align_size * self._sharding_degree
if not sharding_config.release_gradients:
if not self._strategy.sharding.release_gradients:
_, fused_grad = paddle._C_ops.coalesce_tensor_(
group_grad_list,
dtype,
Expand Down

0 comments on commit 2d21046

Please sign in to comment.