diff --git a/src/datasets/utils/_dill.py b/src/datasets/utils/_dill.py index 15578198a39..2dedf7f1fbc 100644 --- a/src/datasets/utils/_dill.py +++ b/src/datasets/utils/_dill.py @@ -162,11 +162,17 @@ def _save_torchTensor(pickler, obj): import torch # type: ignore # `torch.from_numpy` is not picklable in `torch>=1.11.0` - def create_torchTensor(np_array): - return torch.from_numpy(np_array) + def create_torchTensor(np_array, dtype=None): + tensor = torch.from_numpy(np_array) + if dtype: + tensor = tensor.type(torch.bfloat16) + return tensor log(pickler, f"To: {obj}") - args = (obj.detach().cpu().numpy(),) + if obj.dtype == torch.bfloat16: + args = (obj.detach().to(torch.float).cpu().numpy(), torch.bfloat16) + else: + args = (obj.detach().cpu().numpy(),) pickler.save_reduce(create_torchTensor, args, obj=obj) log(pickler, "# To")