Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dygraph]Add dygraph sharding stage3 #38052

Merged
merged 6 commits into from
Jan 14, 2022

Conversation

Baibaifan
Copy link
Contributor

@Baibaifan Baibaifan commented Dec 10, 2021

PR types

New features

PR changes

Others

Describe

Add dygraph sharding stage3

import paddle
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3

fleet.init(is_collective=True)
group = paddle.distributed.new_group([0, 1])

# wrap model
model = model_class(...)
model = ShardingStage3(model, optimizer=optimizer, group=group)

# use optimizer as normal
img, label = data
label.stop_gradient = True
img.stop_gradient = True
out = model(img)

loss = paddle.nn.functional.cross_entropy(input=out, label=label)
loss.backward()
optimizer.step()
optimizer.clear_grad()
# Get all parameter from parameter slice
model.get_all_parameters()

stage3 and DP fp16 O2 GPT 117M
e0eb800b66a79fdb763fe60de82e309c

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Baibaifan Baibaifan force-pushed the dygraph_sharding_stage3 branch from 3af862c to f3523ab Compare December 10, 2021 16:21
@Baibaifan Baibaifan force-pushed the dygraph_sharding_stage3 branch from f3523ab to d715ccc Compare December 10, 2021 16:43
@Baibaifan Baibaifan force-pushed the dygraph_sharding_stage3 branch 2 times, most recently from 8a7b0b1 to 1b7ae8e Compare December 13, 2021 03:50
@Baibaifan Baibaifan force-pushed the dygraph_sharding_stage3 branch from 1b7ae8e to b95678c Compare December 14, 2021 07:47
@Baibaifan Baibaifan force-pushed the dygraph_sharding_stage3 branch from b95678c to 36fc5ba Compare December 17, 2021 14:32
@Baibaifan Baibaifan changed the title Add dygraph sharding stage3 [Dygraph]Add dygraph sharding stage3 Dec 20, 2021
@Baibaifan Baibaifan force-pushed the dygraph_sharding_stage3 branch from 80a78b0 to bb72283 Compare December 21, 2021 13:53
@Baibaifan Baibaifan force-pushed the dygraph_sharding_stage3 branch from bb72283 to b94f685 Compare December 22, 2021 07:50
@Baibaifan Baibaifan force-pushed the dygraph_sharding_stage3 branch from b94f685 to 73f82b7 Compare December 23, 2021 14:37
@paddle-bot-old
Copy link

Sorry to inform you that 73f82b7's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@Baibaifan Baibaifan force-pushed the dygraph_sharding_stage3 branch from 656781a to 2cfc15e Compare January 11, 2022 13:37
@Baibaifan Baibaifan force-pushed the dygraph_sharding_stage3 branch from 2cfc15e to 5b70cb5 Compare January 12, 2022 03:16
@Baibaifan Baibaifan closed this Jan 12, 2022
@Baibaifan Baibaifan reopened this Jan 12, 2022
@Baibaifan Baibaifan force-pushed the dygraph_sharding_stage3 branch from 351866c to da970c5 Compare January 12, 2022 08:37
Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for set_tests_properties(test_dygraph_sharding_stage3 PROPERTIES TIMEOUT 120)

Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Baibaifan Baibaifan merged commit 4c77a90 into PaddlePaddle:develop Jan 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants