Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make metrics logging work for pipeline parallelism #383

Merged
merged 3 commits into from
Jun 4, 2024

Conversation

wconstab
Copy link
Contributor

@wconstab wconstab commented Jun 3, 2024

Stack from ghstack (oldest at bottom):

Avoid complicating the ux and leave the status quo of 2 user-selectable
behaviors:

  • log from rank 0 (the default)
  • log from all ranks (not the default)

Modify the meaning of 'log from rank 0' to log from rank 0 in
non-pipeline parallel runs, and log from the local rank 0 within the
last pipeline-parallel stage group if pp is enabled. (note: earlier
pipeline stages still produce some metrics like mfu/memory, but do not
compute loss.)

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 3, 2024
Avoid complicating the ux and leave the status quo of 2 user-selectable
behaviors:
 - log from rank 0 (the default)
 - log from all ranks (not the default)

Modify the meaning of 'log from rank 0' to log from rank 0 in
non-pipeline parallel runs, and log from the local rank 0 within the
last pipeline-parallel stage group if pp is enabled.  (note: earlier
pipeline stages still produce some metrics like mfu/memory, but do not
compute loss.)

ghstack-source-id: 10c6d13f6820642995a09585b12aa5f7da0b2165
Pull Request resolved: #383
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 3, 2024
train.py Outdated
metric_logger = build_metric_logger(job_config)
if parallel_dims.pp_enabled:
pp_size = pp_mesh.size()
metrics_log_rank = int((world_mesh.size() // pp_size) * (pp_size - 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this rank computation assume that PP is outermost? If so, should we assert/check for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, it does. I wonder if there is a better place to do the assert. I'll add it here for now.

I also don't like to do it this way, i might want to propose adding a device-mesh API to help do a calculation like this more robustly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i moved this into a util and added an assert.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 4, 2024
Avoid complicating the ux and leave the status quo of 2 user-selectable
behaviors:
 - log from rank 0 (the default)
 - log from all ranks (not the default)

Modify the meaning of 'log from rank 0' to log from rank 0 in
non-pipeline parallel runs, and log from the local rank 0 within the
last pipeline-parallel stage group if pp is enabled.  (note: earlier
pipeline stages still produce some metrics like mfu/memory, but do not
compute loss.)

ghstack-source-id: 740e06906e079e366cc6966b549cc3e1fcebcc69
Pull Request resolved: #383
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jun 4, 2024
Avoid complicating the ux and leave the status quo of 2 user-selectable
behaviors:
 - log from rank 0 (the default)
 - log from all ranks (not the default)

Modify the meaning of 'log from rank 0' to log from rank 0 in
non-pipeline parallel runs, and log from the local rank 0 within the
last pipeline-parallel stage group if pp is enabled.  (note: earlier
pipeline stages still produce some metrics like mfu/memory, but do not
compute loss.)

ghstack-source-id: 7f60d1045f240327ae41ade3a353aff19d2f289a
Pull Request resolved: #383
@wconstab wconstab requested a review from wanchaol June 4, 2024 16:43
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need to extend DeviceMesh to make calculating a specific rank easier.

@wconstab
Copy link
Contributor Author

wconstab commented Jun 4, 2024

We probably need to extend DeviceMesh to make calculating a specific rank easier.

yes, my thoughts exactly. I would like to discuss this offline. I couldn't quickly think of what the best API proposal for devicemesh would be so i went this route instead.

@wconstab wconstab merged commit 44a0046 into gh/wconstab/34/base Jun 4, 2024
4 checks passed
wconstab added a commit that referenced this pull request Jun 4, 2024
Avoid complicating the ux and leave the status quo of 2 user-selectable
behaviors:
 - log from rank 0 (the default)
 - log from all ranks (not the default)

Modify the meaning of 'log from rank 0' to log from rank 0 in
non-pipeline parallel runs, and log from the local rank 0 within the
last pipeline-parallel stage group if pp is enabled.  (note: earlier
pipeline stages still produce some metrics like mfu/memory, but do not
compute loss.)

ghstack-source-id: 7f60d1045f240327ae41ade3a353aff19d2f289a
Pull Request resolved: #383
@wconstab wconstab deleted the gh/wconstab/34/head branch June 4, 2024 18:03
tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
Avoid complicating the ux and leave the status quo of 2 user-selectable
behaviors:
 - log from rank 0 (the default)
 - log from all ranks (not the default)

Modify the meaning of 'log from rank 0' to log from rank 0 in
non-pipeline parallel runs, and log from the local rank 0 within the
last pipeline-parallel stage group if pp is enabled.  (note: earlier
pipeline stages still produce some metrics like mfu/memory, but do not
compute loss.)

ghstack-source-id: 7f60d1045f240327ae41ade3a353aff19d2f289a
Pull Request resolved: pytorch#383
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
Avoid complicating the ux and leave the status quo of 2 user-selectable
behaviors:
 - log from rank 0 (the default)
 - log from all ranks (not the default)

Modify the meaning of 'log from rank 0' to log from rank 0 in
non-pipeline parallel runs, and log from the local rank 0 within the
last pipeline-parallel stage group if pp is enabled.  (note: earlier
pipeline stages still produce some metrics like mfu/memory, but do not
compute loss.)

ghstack-source-id: 7f60d1045f240327ae41ade3a353aff19d2f289a
Pull Request resolved: pytorch#383
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants