Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrating to use native Pytorch AMP #2827

Merged
merged 85 commits into from
Jan 5, 2023
Merged

Migrating to use native Pytorch AMP #2827

merged 85 commits into from
Jan 5, 2023

Conversation

sjrl
Copy link
Contributor

@sjrl sjrl commented Jul 15, 2022

Related Issue(s): Issue #1512 Issue #1222

Proposed changes:
Migrating to Pytorch's native AMP https://pytorch.org/docs/stable/notes/amp_examples.html because it is much easier to use (no additional dependency on apex) and needs fewer code changes to support and it's recommended (https://discuss.pytorch.org/t/torch-cuda-amp-vs-nvidia-apex/74994).

Using the native AMP support in Pytorch requires using torch.cuda.amp.autocast and torch.cuda.amp.GradScaler together. These can easily be "turned off" so no autocasting or scaling occurs by passing the option enabled=False.

For example, the following code performs a standard training loop because we have passed the option enabled=False to both autocast and GradScaler

use_amp = False

# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

# Creates a GradScaler once at the beginning of training.
scaler = GradScaler(enabled=use_amp)

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # Runs the forward pass with autocasting.
        with autocast(enable=use_amp):
            output = model(input)
            loss = loss_fn(output, target)

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

And similarly it is easy to turn on AMP by passing enabled=True.

There are breaking changes:

  1. The input type for use_amp in the FARMReader.train method is now a type bool instead of type str.
  2. Currently this PR deprecates apex and does not try to support Pytorch AMP and Apex AMP.

Pre-flight checklist

  • I have read the contributors guidelines
  • If this is a code change, I added tests or updated existing ones
  • If this is a code change, I updated the docstrings

TODO

  • Checked that FARMReader.train works with use_amp=True (and False) on a single GPU
  • Test tutorial 9 DPR training in GPU environment with use_amp turned on (and off) and with grad_acc_steps
  • Add use_amp to trainer state dict so when restarting from a checkpoint AMP is set up as expected
  • Test multi-gpu training with AMP
    • Works with torch.nn.DistributedDataParallel (Link to docs)
    • Works with torch.nn.DataParallel. Tested on 4 GPUs (g4dn.12xlarge instance), confirmed usage using nvidia-smi. It appears the "apply autocast as part of your model’s forward method to ensure it’s enabled in side threads." statement only refers to when we use multiple GPUs per process. Docs here. From what I understand I believe we only use one GPU per process which is recommended.
  • Open PR in /~https://github.com/deepset-ai/haystack-website editing this file explaining AMP
    • I could imagine a section on that page where you briefly describe what AMP is, when to use it and how to use it. Perhaps also give readers an idea of how big a trade off in accuracy / speed it is.

@CLAassistant
Copy link

CLAassistant commented Jul 15, 2022

CLA assistant check
All committers have signed the CLA.

@sjrl sjrl changed the title Started making changes to use native Pytorch AMP Migrating to use native Pytorch AMP Jul 15, 2022
@sjrl sjrl requested a review from MichelBartels July 15, 2022 12:41
@sjrl sjrl requested a review from julian-risch July 15, 2022 13:18
@julian-risch
Copy link
Member

First of all, I support removing nvidia apex and adding pytorch amp. 👍 Doing quick research regarding this decision, it's what seems to be the preferred, future-proof way: https://discuss.pytorch.org/t/torch-cuda-amp-vs-nvidia-apex/74994/2
What we should once this PR is ready to be merged is a small benchmark. And we would also need to ensure that the documentation explains when to use this feature, how to use this feature and what to expect from it.

@sjrl
Copy link
Contributor Author

sjrl commented Jul 19, 2022

And we would also need to ensure that the documentation explains when to use this feature, how to use this feature and what to expect from it.

Where should this documentation be added?

@masci masci linked an issue Nov 28, 2022 that may be closed by this pull request
@julian-risch
Copy link
Member

Hi @sjrl we are planning a release in the next two weeks. Could this PR maybe make it into the new release? Did you have the chance to test multi-gpu training with AMP? In the todo list there is another open item "For torch.nn.DataParallel it looks like we would need to "apply autocast as part of your model’s forward method to ensure it’s enabled in side threads.""

@julian-risch
Copy link
Member

FYI: we might upgrade to torch 1.13.1 once it's released.

@sjrl
Copy link
Contributor Author

sjrl commented Dec 8, 2022

Hi @sjrl we are planning a release in the next two weeks. Could this PR maybe make it into the new release? Did you have the chance to test multi-gpu training with AMP? In the todo list there is another open item "For torch.nn.DataParallel it looks like we would need to "apply autocast as part of your model’s forward method to ensure it’s enabled in side threads.""

Hey @julian-risch. Sorry there was a small miscommunication. I have verified that amp works with torch.nn.DataParallel, but I have not verified that it works with torch.nn.DistributedDataParallel. I have updated the checklist in the top message to reflect this.

I haven't had time to test the multi-gpu training with torch.nn.DistributedDataParallel. I can try to get this done before the next release, but I am fairly busy at the moment so I'd also be happy to receive help here if you have the time.

@sjrl
Copy link
Contributor Author

sjrl commented Dec 8, 2022

Hey @julian-risch I first tried to get DistributedDataParallel to work today by turning on distributed training (without amp). Right now this option is hardcoded to be off since we are using the default value in the call to initialize_optimizer

def initialize_optimizer(
model: AdaptiveModel,
n_batches: int,
n_epochs: int,
device: torch.device,
learning_rate: float,
optimizer_opts: Optional[Dict[Any, Any]] = None,
schedule_opts: Optional[Dict[Any, Any]] = None,
distributed: bool = False,

And even trying to set this to true when using initialize_optimizer in the training loop

model, optimizer, lr_schedule = initialize_optimizer(

causes a multiprocessing error. So it looks like we would need to first fix the distributed training feature before confirming that amp works with it as well.

@julian-risch
Copy link
Member

So what's your opinion on the best way forward? I'd say we merge the changes that we have up to now and support just torch.nn.DataParallel but not torch.nn.DistributedDataParallel. Investigating the multiprocessing error and supporting torch.nn.DistributedDataParallel should then become the topic of a separate issue that we can add to the backlog.

@sjrl
Copy link
Contributor Author

sjrl commented Jan 4, 2023

@julian-risch Yes, I completely agree.

Copy link
Member

@julian-risch julian-risch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very good to me! 👍 Thanks for putting in the extra effort and adding a fast test!

@julian-risch julian-risch added this to the 1.13.0 milestone Jan 4, 2023
@sjrl sjrl merged commit e84fae2 into main Jan 5, 2023
@sjrl sjrl deleted the issue_1512 branch January 5, 2023 08:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Pytorch Native AMP support
5 participants