Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support looped PP schedules in torchtitan #358

Merged
merged 20 commits into from
Jun 21, 2024

Conversation

wconstab
Copy link
Contributor

@wconstab wconstab commented May 23, 2024

Stack from ghstack (oldest at bottom):


  • refactor some per-model logic into helper functions

[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 23, 2024
ghstack-source-id: 39a1559ba3ecf1c7c8b2704151ca2781bfe0001b
Pull Request resolved: #358
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 23, 2024
[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 23, 2024
ghstack-source-id: 527a6f22d3c0955e527ac34167a00023deab6981
Pull Request resolved: #358
[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 23, 2024
ghstack-source-id: db6559fe5a5d2b338bd27553d3d1b66a6c64d3b9
Pull Request resolved: #358
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 3, 2024
ghstack-source-id: 94567ac8c62948a130e7d062c8d66f3c34f5ff7f
Pull Request resolved: #358
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 11, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: 4741d494bdb61cd28f7bf5ad91094f0c174f88c2
Pull Request resolved: #358
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 12, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: 4fcd38adafe9926799366c4c868219d47f7bc03c
Pull Request resolved: #358
@wconstab wconstab changed the title add todos mocking changes for looped PP support Support Looped PP schedules Jun 13, 2024
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 13, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: a7768287ed2d31272b07ac9f3601b6e23e90c710
Pull Request resolved: #358
@wconstab wconstab changed the title Support Looped PP schedules Add train loop support for looped PP schedules Jun 13, 2024
wconstab added a commit that referenced this pull request Jun 13, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: a7768287ed2d31272b07ac9f3601b6e23e90c710
Pull Request resolved: #358
@wconstab wconstab changed the title Add train loop support for looped PP schedules Add support for looped PP schedules Jun 14, 2024
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 14, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: c40342e4d577a044d4094ef766de16ba496ab835
Pull Request resolved: #358
"optimizer": OptimizerWrapper(model_parts, optimizers),
# TODO(whc) flatten lr_schedulers using a wrapper and somehow handle resharding?
# or store one per key and explicitly dont support resharding?
# "lr_scheduler": lr_scheduler,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have to fix this part. i think @fegin had a suggestion for a workaround that would support resharding. i am not sure if i should do this or just let it be saved in a way that would break for resharding and don't care about resharding for now.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 14, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: 1d313526b76b7ba76376d82d39171b75294fd831
Pull Request resolved: #358
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 14, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: 2f0b57f3cbfb2d27f37850d09a92d64e5b7fbc87
Pull Request resolved: #358
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 14, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: 049327e0eb74dd0f1e8a6ccd8f1e7391ed4c339b
Pull Request resolved: #358
[ghstack-poisoned]
@wconstab wconstab changed the title Add support for looped PP schedules Add train loop support for looped PP schedules Jun 15, 2024
wconstab added a commit that referenced this pull request Jun 15, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: d9cd4b2de66ff263b68db13f717f3f597cbd6d0d
Pull Request resolved: #358
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 15, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: f0f106158e366922573d91e1e11ca278d900f136
Pull Request resolved: #358
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 17, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: 64587ca052b8107ef86112a64891a4bab54b7f27
Pull Request resolved: #358
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 17, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: 8ce18913aff539ec8ca102383663448c69fa6632
Pull Request resolved: #358
wconstab added 3 commits June 17, 2024 15:07
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
support described in (1).

3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RE:
image

I did it this way because i thought if we are not supporting resharding of lr_scheduler, then i may as well save each one. If i save each one, then at load time i have a form of assertion provided for me by dcp- if the runtime loading the checkpoint has the same number of ranks, they will match up and load OK. if not, they would throw an error.

I could switch back to the version where I save only one copy, but then i have to do some validation up front. Do you think this way is better?
@fegin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are not supporting resharding, then this implementation is better.

Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pipelining part looks good to me. Left two minor comments.

@@ -26,7 +33,7 @@ def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn):
n_microbatches = job_config.experimental.pipeline_parallel_degree

return schedule_class(
stage,
stages if looped_schedule else stages[0],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the init function of ScheduleInterleaved1F1B takes a list of _PipelineStageBase
if we only pass stages[0], will it cause ScheduleInterleaved1F1B fail ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, the point of this code is that the schedule class could be a simple schedule or a looped schedule.

If it's a looped schedule, like Interleaved1F1B, then we must have 'stages'.

if its a simple schedule, then we must have just one stage, so 'stages[0]' is appropriate.

But we should never have Interleaved1F1B(stages[0]).

wconstab added 2 commits June 20, 2024 14:27
[ghstack-poisoned]
[ghstack-poisoned]
@wconstab wconstab merged commit 6fc2045 into gh/wconstab/28/base Jun 21, 2024
5 checks passed
wconstab added a commit that referenced this pull request Jun 21, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: a2376627e2864deeb9e4fbf49cecd0990bc434ea
Pull Request resolved: #358
@wconstab wconstab deleted the gh/wconstab/28/head branch June 21, 2024 16:40
@wconstab wconstab changed the title Add train loop support for looped PP schedules Support looped PP schedules in torchtitan Jun 25, 2024
tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: a2376627e2864deeb9e4fbf49cecd0990bc434ea
Pull Request resolved: pytorch#358
tianyu-l pushed a commit that referenced this pull request Aug 16, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: a2376627e2864deeb9e4fbf49cecd0990bc434ea
Pull Request resolved: #358
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
- refactor some per-model logic into helper functions

ghstack-source-id: a2376627e2864deeb9e4fbf49cecd0990bc434ea
Pull Request resolved: pytorch#358
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants