-
Notifications
You must be signed in to change notification settings - Fork 300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Checkpointer #871
Conversation
@mock.patch( | ||
"torchtitan.components.checkpoint.dcp.async_save", side_effect=fake_async_save | ||
) | ||
def test_async_save_calls_async_wait(self, *_): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this test memory leak as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, unfortunately. Memory leakage requires some more thorough test. I'm not sure if there is an easy way to test with unittest.
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM | ||
and self.staging | ||
self.keep_latest_k > 0 | ||
and dist.get_rank() == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my learning purpose, why do we only do the purge for dist.get_rank() == 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We assume that all ranks can access the same files. That's the assumption of DCP. If we let all ranks to purge, then the file systems will complain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entitled to review every detail of this PR -- so will leave it to others.
But it looks quite good to me.
torchtitan/components/checkpoint.py
Outdated
@@ -44,6 +50,8 @@ class AsyncMode(str, enum.Enum): | |||
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" | |||
|
|||
|
|||
# TODO: move this out from checkpoint.py and merge it with the trainer.py | |||
# We probably want to create a Trainer objecta. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# We probably want to create a Trainer objecta. | |
# We probably want to create a Trainer object. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I still didn't get why we need a Trainer (yet).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianyu-l , a Trainer class is not necessary but cleaner.
class Trainer(Stateful):
def __init__(self, job_config)-> None:
move_all_init_code in train.py here
def train(self) -> None:
training_loop
self.checkpoint.save(state={"trainer": self})
def state_dict(self) -> Dict[str, Any]:
return {"step": self.step, "log_steps": self.log_steps, ....}
def load_state_dict(self, sd) -> None:
self.step = sd["step"]
...
While we can simplify the original train() by splitting it to two separate functions and keeps TrainerState
, this one single trainer class is more natural, as we keep the states and methods in one class/object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, makes sense to me!
self.purge_thread.join() | ||
|
||
@torch.no_grad() | ||
def save(self, curr_step: int, force: bool = False) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: IIUC, force is only for unit test right? Might also mention this in the comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not, it is also used if the training is finished.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Reland #871 due to ghstack issues.
Reland pytorch#871 due to ghstack issues.
Stack from ghstack (oldest at bottom):
Several bugs fixes, refactors, and feature improvement for the next PR (integration with TorchFT)