Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low bit optimizers quality #744

Closed
tsengalb99 opened this issue Aug 24, 2024 · 16 comments
Closed

Low bit optimizers quality #744

tsengalb99 opened this issue Aug 24, 2024 · 16 comments

Comments

@tsengalb99
Copy link

I saw that the quality numbers in /~https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim changed recently. Was there a bug in the AO low bit optimizer implementation before?

@msaroufim
Copy link
Member

There was a bug with how we handled LR schedulers that make things go faster now #736

@tsengalb99
Copy link
Author

So nothing functionally changed and the optimizers are faster -> you were able to test on a larger model?

@msaroufim
Copy link
Member

Largest model we tried was on the order of 500M parameters, for 1B+ try it out and feel free share loss curves

@tsengalb99
Copy link
Author

Did you test pretraining or just finetuning only? Thanks

@msaroufim
Copy link
Member

msaroufim commented Aug 24, 2024

This was a finetuning benchmark

Benchmark script for fine-tuning a timm model on resisc45 dataset is available at benchmarks/benchmark_low_bit_adam.py

We were hoping to expand to more fine-tuning benchmarks /~https://github.com/pytorch/torchtune here

And we're doing with fully quantized training here /~https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training

If you're doing research in this space feel free to reach out, the focus on the team now is more on infra and making sure things run fast

@tsengalb99
Copy link
Author

I tried using the AO 8 bit optimizer in nanoGPT by swapping out AdamW with Adam8bit here and keeping the embedding and non-decayed parameters (layernorms etc) in a regular torch AdamW instance. This strategy seems to work with the lpmm codebase, where I get close to FP32 performance, but with AO I get severely degraded performance. Is there anything special I need to do with the AO low bit optimizers? Also, I wasn't able to get the AO 4 bit optimizer to work out of the box. I had to disable the compile call here or compile would complain about the input to lerp being negative. I'm using torch 2.4.0 if that matters.

@msaroufim
Copy link
Member

msaroufim commented Aug 25, 2024

Yeah for sure try to use ao and PyTorch nightlies. In the meantime @gau-nernst might make sense to do convergence benchmarks with at least llama 8B

@tsengalb99 do you mind sharing more details on what you're observing too? What size of model? Loss curves?

@gau-nernst
Copy link
Collaborator

@tsengalb99 Yes, having more details about your training would be great to debug the issue.

Input to lerp being negative sounds like something is wrong with the training. I think we only use lerp here

new_exp_avg = exp_avg.lerp(grad, 1 - beta1)
new_exp_avg_sq = exp_avg_sq.lerp(grad.square(), 1 - beta2)

Would you mind trying bnb AdamW8bit too? Our implementation should match bnb's one exactly. Also, I notice you are using ao Adam8bit. For a fair comparison, shouldn't AdamW8bit be used instead?

@msaroufim Regarding llama-8B, are you thinking pre-training or fine-tuning? Should be a drop-in replacement for torchtune and torchtitan. I think @awgu ran a small test before?

@msaroufim
Copy link
Member

I was thinking of prioritizing larger scale finetuning experiments first and pretraining more of a hail mary

@tsengalb99
Copy link
Author

Input to lerp being negative sounds like something is wrong with the training. I think we only use lerp here

Yes, that's where the error is being thrown

Adam8bit. For a fair comparison, shouldn't AdamW8bit be used instead?

Ah, I totally forgot to put AdamW and not Adam. That's probably the issue - thanks!

@msaroufim
Copy link
Member

@tsengalb99 is this still an issue?

@tsengalb99
Copy link
Author

tsengalb99 commented Aug 27, 2024 via email

@gau-nernst
Copy link
Collaborator

gau-nernst commented Aug 27, 2024

@tsengalb99 Was your error with Adam4bit like this?

  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/meta_utils.py", line 1038, in view_from_base
    fake_t = t.view_func(base, symint_visitor_fn, tensor_visitor_fn)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/meta_utils.py", line 967, in symint_visitor_fn
    symbol = shape_env.create_symbol(s, sym_source)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/experimental/recording.py", line 245, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3454, in create_symbol
    assert not (positive and val < 0), f"positive set for negative value: {val}"
AssertionError: positive set for negative value: -1

from user code:
   File "/opt/conda/lib/python3.11/site-packages/torchao/prototype/low_bit_optim/adamw.py", line 140, in single_param_adamw
    new_exp_avg = exp_avg.lerp(grad, 1 - beta1)

I'm seeing this error with PyTorch 2.4. Seems like a torch.compile issue.

@msaroufim Turns out our CI skip 4-bit optimizer for 2.4

if optim_name.endswith("4bit") and not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("4-bit Adam requires PyTorch > 2.4")

I think I initially set this line to TORCH_VERSION_AFTER_2_4, which will NOT run on 2.4, and then you changed to TORCH_VERSION_AT_LEAST_2_5 during the refactor PR.

@SunMarc Did you manage to successfully run 4-bit Adam when you implemented huggingface/transformers#31865? It seems like you added a test but it didn't run in CI since torchao was not installed? The only way to run 4-bit optim right now is to use PyTorch nightly 🌚.

As of now, HF trainer is still guarding against PyTorch>=2.3

/~https://github.com/huggingface/transformers/blob/d1f39c484d8347aa7b3170ea250a1e8f3bdfdf31/src/transformers/trainer.py#L1482-L1486

I will update the doc and make the test clearer. If I have time, maybe I try to make it work for 2.4 also. If our 4-bit Adam doesn't work for latest stable release, feel like it will hinder people trying it out.

@SunMarc
Copy link

SunMarc commented Aug 27, 2024

@SunMarc Did you manage to successfully run 4-bit Adam when you implemented huggingface/transformers#31865? It seems like you added a test but it didn't run in CI since torchao was not installed? The only way to run 4-bit optim right now is to use PyTorch nightly 🌚.

Indeed, i'm facing the same issue as you without pytorch nightly. When I opened this PR, I still needed to install torchao nightly for 4-bit optimizer and this might have worked with prior version of torch. I will update the req to let the user know that he needs to install torch nightly

@tsengalb99
Copy link
Author

Yep, that's the issue I'm seeing.

@gau-nernst
Copy link
Collaborator

Sorry for the regression! I have fixed it in #755. Now 4-bit optim works in PyTorch 2.3 again now. It slips through our tests previously, and I have updated the tests to make sure this won't happen again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants