Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IterableDataset: Unsupported ScalarType BFloat16 #7000

Closed
stoical07 opened this issue Jun 25, 2024 · 3 comments · Fixed by #7002
Closed

IterableDataset: Unsupported ScalarType BFloat16 #7000

stoical07 opened this issue Jun 25, 2024 · 3 comments · Fixed by #7002

Comments

@stoical07
Copy link

Describe the bug

IterableDataset.from_generator crashes when using BFloat16:

  File "/usr/local/lib/python3.11/site-packages/datasets/utils/_dill.py", line 169, in _save_torchTensor
    args = (obj.detach().cpu().numpy(),)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^

TypeError: Got unsupported ScalarType BFloat16

Steps to reproduce the bug

import torch
from datasets import IterableDataset


def demo(x):
    yield {"x": x}


x = torch.tensor([1.], dtype=torch.bfloat16)

dataset = IterableDataset.from_generator(
    demo,
    gen_kwargs=dict(x=x),
)

example = next(iter(dataset))
print(example)

Expected behavior

Code sample should print:

{'x': tensor([1.], dtype=torch.bfloat16)}

Environment info

datasets==2.20.0
torch==2.2.2
@stoical07
Copy link
Author

@lhoestq Thank you for merging #6607, but unfortunately the issue persists for IterableDataset 😔

@lhoestq
Copy link
Member

lhoestq commented Jun 25, 2024

Hi ! I opened #7002 to fix this bug

@stoical07
Copy link
Author

Amazing, thank you so much @lhoestq! 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants