-
Notifications
You must be signed in to change notification settings - Fork 300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make metrics logging work for pipeline parallelism #383
Conversation
Avoid complicating the ux and leave the status quo of 2 user-selectable behaviors: - log from rank 0 (the default) - log from all ranks (not the default) Modify the meaning of 'log from rank 0' to log from rank 0 in non-pipeline parallel runs, and log from the local rank 0 within the last pipeline-parallel stage group if pp is enabled. (note: earlier pipeline stages still produce some metrics like mfu/memory, but do not compute loss.) ghstack-source-id: 10c6d13f6820642995a09585b12aa5f7da0b2165 Pull Request resolved: #383
train.py
Outdated
metric_logger = build_metric_logger(job_config) | ||
if parallel_dims.pp_enabled: | ||
pp_size = pp_mesh.size() | ||
metrics_log_rank = int((world_mesh.size() // pp_size) * (pp_size - 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this rank computation assume that PP is outermost? If so, should we assert/check for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, it does. I wonder if there is a better place to do the assert. I'll add it here for now.
I also don't like to do it this way, i might want to propose adding a device-mesh API to help do a calculation like this more robustly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i moved this into a util and added an assert.
Avoid complicating the ux and leave the status quo of 2 user-selectable behaviors: - log from rank 0 (the default) - log from all ranks (not the default) Modify the meaning of 'log from rank 0' to log from rank 0 in non-pipeline parallel runs, and log from the local rank 0 within the last pipeline-parallel stage group if pp is enabled. (note: earlier pipeline stages still produce some metrics like mfu/memory, but do not compute loss.) ghstack-source-id: 740e06906e079e366cc6966b549cc3e1fcebcc69 Pull Request resolved: #383
Avoid complicating the ux and leave the status quo of 2 user-selectable behaviors: - log from rank 0 (the default) - log from all ranks (not the default) Modify the meaning of 'log from rank 0' to log from rank 0 in non-pipeline parallel runs, and log from the local rank 0 within the last pipeline-parallel stage group if pp is enabled. (note: earlier pipeline stages still produce some metrics like mfu/memory, but do not compute loss.) ghstack-source-id: 7f60d1045f240327ae41ade3a353aff19d2f289a Pull Request resolved: #383
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably need to extend DeviceMesh to make calculating a specific rank easier.
yes, my thoughts exactly. I would like to discuss this offline. I couldn't quickly think of what the best API proposal for devicemesh would be so i went this route instead. |
Avoid complicating the ux and leave the status quo of 2 user-selectable behaviors: - log from rank 0 (the default) - log from all ranks (not the default) Modify the meaning of 'log from rank 0' to log from rank 0 in non-pipeline parallel runs, and log from the local rank 0 within the last pipeline-parallel stage group if pp is enabled. (note: earlier pipeline stages still produce some metrics like mfu/memory, but do not compute loss.) ghstack-source-id: 7f60d1045f240327ae41ade3a353aff19d2f289a Pull Request resolved: #383
Avoid complicating the ux and leave the status quo of 2 user-selectable behaviors: - log from rank 0 (the default) - log from all ranks (not the default) Modify the meaning of 'log from rank 0' to log from rank 0 in non-pipeline parallel runs, and log from the local rank 0 within the last pipeline-parallel stage group if pp is enabled. (note: earlier pipeline stages still produce some metrics like mfu/memory, but do not compute loss.) ghstack-source-id: 7f60d1045f240327ae41ade3a353aff19d2f289a Pull Request resolved: pytorch#383
Avoid complicating the ux and leave the status quo of 2 user-selectable behaviors: - log from rank 0 (the default) - log from all ranks (not the default) Modify the meaning of 'log from rank 0' to log from rank 0 in non-pipeline parallel runs, and log from the local rank 0 within the last pipeline-parallel stage group if pp is enabled. (note: earlier pipeline stages still produce some metrics like mfu/memory, but do not compute loss.) ghstack-source-id: 7f60d1045f240327ae41ade3a353aff19d2f289a Pull Request resolved: pytorch#383
Stack from ghstack (oldest at bottom):
Avoid complicating the ux and leave the status quo of 2 user-selectable
behaviors:
Modify the meaning of 'log from rank 0' to log from rank 0 in
non-pipeline parallel runs, and log from the local rank 0 within the
last pipeline-parallel stage group if pp is enabled. (note: earlier
pipeline stages still produce some metrics like mfu/memory, but do not
compute loss.)