-
Notifications
You must be signed in to change notification settings - Fork 300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support looped PP schedules in torchtitan #358
Conversation
ghstack-source-id: 39a1559ba3ecf1c7c8b2704151ca2781bfe0001b Pull Request resolved: #358
ghstack-source-id: 527a6f22d3c0955e527ac34167a00023deab6981 Pull Request resolved: #358
ghstack-source-id: db6559fe5a5d2b338bd27553d3d1b66a6c64d3b9 Pull Request resolved: #358
ghstack-source-id: 94567ac8c62948a130e7d062c8d66f3c34f5ff7f Pull Request resolved: #358
- refactor some per-model logic into helper functions ghstack-source-id: 4741d494bdb61cd28f7bf5ad91094f0c174f88c2 Pull Request resolved: #358
- refactor some per-model logic into helper functions ghstack-source-id: 4fcd38adafe9926799366c4c868219d47f7bc03c Pull Request resolved: #358
- refactor some per-model logic into helper functions ghstack-source-id: a7768287ed2d31272b07ac9f3601b6e23e90c710 Pull Request resolved: #358
- refactor some per-model logic into helper functions ghstack-source-id: a7768287ed2d31272b07ac9f3601b6e23e90c710 Pull Request resolved: #358
- refactor some per-model logic into helper functions ghstack-source-id: c40342e4d577a044d4094ef766de16ba496ab835 Pull Request resolved: #358
torchtitan/checkpoint.py
Outdated
"optimizer": OptimizerWrapper(model_parts, optimizers), | ||
# TODO(whc) flatten lr_schedulers using a wrapper and somehow handle resharding? | ||
# or store one per key and explicitly dont support resharding? | ||
# "lr_scheduler": lr_scheduler, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i have to fix this part. i think @fegin had a suggestion for a workaround that would support resharding. i am not sure if i should do this or just let it be saved in a way that would break for resharding and don't care about resharding for now.
- refactor some per-model logic into helper functions ghstack-source-id: 1d313526b76b7ba76376d82d39171b75294fd831 Pull Request resolved: #358
- refactor some per-model logic into helper functions ghstack-source-id: 2f0b57f3cbfb2d27f37850d09a92d64e5b7fbc87 Pull Request resolved: #358
- refactor some per-model logic into helper functions ghstack-source-id: 049327e0eb74dd0f1e8a6ccd8f1e7391ed4c339b Pull Request resolved: #358
- refactor some per-model logic into helper functions ghstack-source-id: d9cd4b2de66ff263b68db13f717f3f597cbd6d0d Pull Request resolved: #358
- refactor some per-model logic into helper functions ghstack-source-id: f0f106158e366922573d91e1e11ca278d900f136 Pull Request resolved: #358
- refactor some per-model logic into helper functions ghstack-source-id: 64587ca052b8107ef86112a64891a4bab54b7f27 Pull Request resolved: #358
- refactor some per-model logic into helper functions ghstack-source-id: 8ce18913aff539ec8ca102383663448c69fa6632 Pull Request resolved: #358
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening | ||
support described in (1). | ||
|
||
3. LR schedulers also index model states like optimizers and would need to be flattened properly to support |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did it this way because i thought if we are not supporting resharding of lr_scheduler, then i may as well save each one. If i save each one, then at load time i have a form of assertion provided for me by dcp- if the runtime loading the checkpoint has the same number of ranks, they will match up and load OK. if not, they would throw an error.
I could switch back to the version where I save only one copy, but then i have to do some validation up front. Do you think this way is better?
@fegin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are not supporting resharding, then this implementation is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pipelining part looks good to me. Left two minor comments.
@@ -26,7 +33,7 @@ def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn): | |||
n_microbatches = job_config.experimental.pipeline_parallel_degree | |||
|
|||
return schedule_class( | |||
stage, | |||
stages if looped_schedule else stages[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the init function of ScheduleInterleaved1F1B takes a list of _PipelineStageBase
if we only pass stages[0], will it cause ScheduleInterleaved1F1B fail ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, the point of this code is that the schedule class could be a simple schedule or a looped schedule.
If it's a looped schedule, like Interleaved1F1B, then we must have 'stages'.
if its a simple schedule, then we must have just one stage, so 'stages[0]' is appropriate.
But we should never have Interleaved1F1B(stages[0]).
- refactor some per-model logic into helper functions ghstack-source-id: a2376627e2864deeb9e4fbf49cecd0990bc434ea Pull Request resolved: #358
- refactor some per-model logic into helper functions ghstack-source-id: a2376627e2864deeb9e4fbf49cecd0990bc434ea Pull Request resolved: pytorch#358
- refactor some per-model logic into helper functions ghstack-source-id: a2376627e2864deeb9e4fbf49cecd0990bc434ea Pull Request resolved: #358
- refactor some per-model logic into helper functions ghstack-source-id: a2376627e2864deeb9e4fbf49cecd0990bc434ea Pull Request resolved: pytorch#358
Stack from ghstack (oldest at bottom):