Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Checkpointer #871

Merged
merged 13 commits into from
Feb 27, 2025
Merged

Refactor Checkpointer #871

merged 13 commits into from
Feb 27, 2025

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Feb 20, 2025

Stack from ghstack (oldest at bottom):

Several bugs fixes, refactors, and feature improvement for the next PR (integration with TorchFT)

  1. Code refactor for better readability
  2. Remove the time based checkpoint condition, this is not used and can cause deadlocks when integrating with TorchFT. This will also make code simpler.
  3. Fixes a async_with_pinned_memory bug.
  4. The original keep_last_k implementation can cause exceptions in certain case and is also slow. Fixes the bugs and use a thread to purge checkpoints.

[ghstack-poisoned]
This was referenced Feb 20, 2025
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 20, 2025
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@fegin fegin requested review from d4l3k, tianyu-l and fduwjj February 25, 2025 07:55
@mock.patch(
"torchtitan.components.checkpoint.dcp.async_save", side_effect=fake_async_save
)
def test_async_save_calls_async_wait(self, *_):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this test memory leak as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, unfortunately. Memory leakage requires some more thorough test. I'm not sure if there is an easy way to test with unittest.

and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
and self.staging
self.keep_latest_k > 0
and dist.get_rank() == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my learning purpose, why do we only do the purge for dist.get_rank() == 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We assume that all ranks can access the same files. That's the assumption of DCP. If we let all ranks to purge, then the file systems will complain.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entitled to review every detail of this PR -- so will leave it to others.
But it looks quite good to me.

@@ -44,6 +50,8 @@ class AsyncMode(str, enum.Enum):
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"


# TODO: move this out from checkpoint.py and merge it with the trainer.py
# We probably want to create a Trainer objecta.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# We probably want to create a Trainer objecta.
# We probably want to create a Trainer object.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I still didn't get why we need a Trainer (yet).

Copy link
Contributor Author

@fegin fegin Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l , a Trainer class is not necessary but cleaner.

class Trainer(Stateful):
      def __init__(self, job_config)-> None:
              move_all_init_code in train.py here
              
      def train(self) -> None:
             training_loop
             self.checkpoint.save(state={"trainer": self})
      
      def state_dict(self) -> Dict[str, Any]:
            return {"step": self.step, "log_steps": self.log_steps, ....}
            
       def load_state_dict(self, sd) -> None:
            self.step = sd["step"]
            ...

While we can simplify the original train() by splitting it to two separate functions and keeps TrainerState, this one single trainer class is more natural, as we keep the states and methods in one class/object.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense to me!

self.purge_thread.join()

@torch.no_grad()
def save(self, curr_step: int, force: bool = False) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: IIUC, force is only for unit test right? Might also mention this in the comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not, it is also used if the training is finished.

Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@fegin fegin merged commit 29caadc into gh/fegin/13/base Feb 27, 2025
6 checks passed
fegin added a commit that referenced this pull request Feb 27, 2025
Reland #871 due to ghstack
issues.
K-H-Ismail pushed a commit to K-H-Ismail/torchtitan that referenced this pull request Feb 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants