-
Notifications
You must be signed in to change notification settings - Fork 300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dump memory snapshot to analyze OOMs #395
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
train_configs/debug_model.toml
Outdated
@@ -9,6 +9,7 @@ use_for_integration_test = true | |||
enable_profiling = true | |||
save_traces_folder = "profile_trace" | |||
profile_freq = 10 | |||
enable_memory_snapshot = false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
existing .toml
without enable_memory_snapshot
still works. enable_memory_snapshot
is optional with getattr(config.profiling, 'enable_memory_snapshot', False)
I am just adding it here so people can start toggle it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: we should put default option False
into config_manager, and remove this option in all the toml config files. Maybe only enable it to True
in debug_model.
@@ -15,6 +16,14 @@ | |||
# the number of warmup steps before the active step in each profiling cycle | |||
WARMUP = 3 | |||
|
|||
# how much memory allocation/free ops to record in memory snapshots | |||
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MEMORY_SNAPSHOT_MAX_ENTRIES
controls how large .pickle
can be. Right now it's 36MB
with open( | ||
f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb" | ||
) as output: | ||
pickle.dump(torch.cuda.memory._snapshot(), output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a threshold to control dumping the memory snapshot when the memory usage is larger than the threshold to avoid overwhelming data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean in MB
threashold? Right now it's bounded by number of free/allocate MEMORY_SNAPSHOT_MAX_ENTRIES. For MB
, I can google around
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
googled for MB
threshold but did not find something useful. Currently MEMORY_SNAPSHOT_MAX_ENTRIES=100000
conroled the file size to 36MB
. Let me know if this is still a blocker
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great addition to torchtitan! Had some comments on how to structure the configs.
Also, I wonder if it makes sense to have a very short tutorial on how to read/parse the output of memory profiler. Maybe extract part of this tutorial.
torchtitan/profiling.py
Outdated
# default memory snapshot folder | ||
ENABLE_MEMORY_SNAPSHOT_KEY = "enable_memory_snapshot" | ||
MEMORY_SNAPSHOT_FOLDER_KEY = "memory_snapshot_folder" | ||
MEMORY_SNAPSHOT_FOLDER_DEFAULT_VALUE = "memory_snapshot" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make these into configs. Please refer to how torch_profiler does this part, e.g. put into config_manager.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good to know config_manager.py. I will move deafults into config_manager
train_configs/debug_model.toml
Outdated
@@ -9,6 +9,7 @@ use_for_integration_test = true | |||
enable_profiling = true | |||
save_traces_folder = "profile_trace" | |||
profile_freq = 10 | |||
enable_memory_snapshot = false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: we should put default option False
into config_manager, and remove this option in all the toml config files. Maybe only enable it to True
in debug_model.
convert to draft now and will publish again after moving default into config_manager.py. But current version is good for benchmarking float8 + compile + fsdp2 on MAST |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The overhead from torch.profiler
is only around the steps where dumping actually happens (warmup steps + actual profiling steps). If _record_memory_history
is always enabled for entire training, there will constantly be overhead from this memory profiler.
In other words, torch.profiler
only profile one step per freq
steps, while MemoryProfiler
profiles every step and dump all freq
iterations per freq
steps. As a result, adjusting freq
only affects how often the snapshot are grouped into one pickle file. If we run a job 3000 steps, there will be snapshot for every step, regardless of freq
.
torchtitan/profiling.py
Outdated
if not exit_ctx and self.step_num % self.freq != 0: | ||
self.step_num += 1 | ||
return | ||
if not exit_ctx: | ||
curr_step = self.step_num | ||
self.step_num += 1 | ||
dir_name = f"iteration_{curr_step}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.profiler starts from step 0, whereas train.py starts from step 1. In order to make things work as expected, I suggest we do the following, so that if we set profile_freq=10
and run training for 10 steps, there will be memory snapshots for iteration_10
(similar to torch.profiler) and iteration_10_exit
. I've tested this offline.
if not exit_ctx and self.step_num % self.freq != 0: | |
self.step_num += 1 | |
return | |
if not exit_ctx: | |
curr_step = self.step_num | |
self.step_num += 1 | |
dir_name = f"iteration_{curr_step}" | |
self.step_num += 1 | |
if not exit_ctx and self.step_num % self.freq != 0: | |
return | |
if not exit_ctx: | |
curr_step = self.step_num | |
dir_name = f"iteration_{curr_step}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for pointing out the difference. updated accordingly
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great work!! thank you!
please address my nits before merge :)
@@ -9,6 +9,7 @@ use_for_integration_test = true | |||
enable_profiling = true | |||
save_traces_folder = "profile_trace" | |||
profile_freq = 10 | |||
enable_memory_snapshot = true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's put the folder here as well to be consistent and informative
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added save_memory_snapshot_folder
in .toml
torchtitan/config_manager.py
Outdated
help="Whether to dump memory snapshot", | ||
) | ||
self.parser.add_argument( | ||
"--profiling.memory_snapshot_folder", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: please rename it save_memory_snapshot_folder
to be consistent with save_traces_folder
and save_tb_folder
.
torchtitan/config_manager.py
Outdated
self.parser.add_argument( | ||
"--profiling.memory_snapshot_folder", | ||
type=str, | ||
default="memory_snapshots", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's call it memory_snapshot
thanks. will address feedback before merging |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
when setting `enable_memory_snapshot = true` in `.toml` * dump memory snapshots in case of OOMs. output folder is `memory_snapshot/iteration_x_exit` * dump regularly according to `profile_freq`. output folder is `memory_snapshot/iteration_x` * existing `.toml` works since `enable_memory_snapshot=False` by default snapshot is an example of the dump when OOM happens <img width="1640" alt="Screenshot 2024-06-12 at 9 26 53 PM" src="/~https://github.com/pytorch/torchtitan/assets/134637289/6420799c-ae68-4b35-b8bb-f5b6ab3dd053">
when setting `enable_memory_snapshot = true` in `.toml` * dump memory snapshots in case of OOMs. output folder is `memory_snapshot/iteration_x_exit` * dump regularly according to `profile_freq`. output folder is `memory_snapshot/iteration_x` * existing `.toml` works since `enable_memory_snapshot=False` by default snapshot is an example of the dump when OOM happens <img width="1640" alt="Screenshot 2024-06-12 at 9 26 53 PM" src="/~https://github.com/pytorch/torchtitan/assets/134637289/6420799c-ae68-4b35-b8bb-f5b6ab3dd053">
Add [memory_profiler](#395) to README, explain how to use memory profiler with `--profiling.enable_memory_snapshot` and `--profiling.save_memory_snapshot_folder`
when setting
enable_memory_snapshot = true
in.toml
memory_snapshot/iteration_x_exit
profile_freq
. output folder ismemory_snapshot/iteration_x
.toml
works sinceenable_memory_snapshot=False
by defaultsnapshot is an example of the dump when OOM happens