Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding to_tf_dataset method #2731

Merged
merged 46 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
92cad15
Rebase onto master
Rocketknight1 Jul 29, 2021
74b5bad
Support multiple label_cols, replaced tokenizer with collate_fn, supp…
Rocketknight1 Jul 30, 2021
97917bc
Standardize int and float dtypes to keep TF happy
Rocketknight1 Jul 30, 2021
4eb79f5
Add a prefetch buffer for improved performance
Rocketknight1 Jul 30, 2021
bed394a
TF dataset is actually kinda performant now!
Rocketknight1 Aug 4, 2021
ea525a2
TF dataset is actually kinda performant now!
Rocketknight1 Aug 4, 2021
d3a8140
Style pass
Rocketknight1 Aug 4, 2021
3ce6dc4
Helpful error message if my code gets caught off-guard by unexpected …
Rocketknight1 Aug 4, 2021
67c0657
Style pass
Rocketknight1 Aug 4, 2021
2963f0a
Added drop_remainder argument, removed pad_to
Rocketknight1 Aug 5, 2021
7f11d76
Correct shape signatures when we're not dropping the remainder
Rocketknight1 Aug 5, 2021
bbf6197
Style pass
Rocketknight1 Aug 5, 2021
f902bde
Support ClassLabel columns too!
Rocketknight1 Aug 5, 2021
990f150
Re-enable `tf.ragged` by avoiding `tf.ragged.constant` unless absolut…
Rocketknight1 Aug 16, 2021
fa06206
Style pass
Rocketknight1 Aug 16, 2021
29415cd
Adding a comment to explain myself in tf_formatter.py
Rocketknight1 Aug 26, 2021
ca93c34
Fixes for shuffling and the case where the collator adds new columns
Rocketknight1 Aug 26, 2021
d78cd50
Style pass
Rocketknight1 Aug 26, 2021
0bf0050
Ensuring we respect TF dtype args
Rocketknight1 Aug 26, 2021
6c91fc7
Style pass
Rocketknight1 Aug 26, 2021
1954862
Updating tests
Rocketknight1 Aug 31, 2021
7f2a8f1
Updating tests
Rocketknight1 Aug 31, 2021
6eef188
Fixing things so they work in TF2.6
Rocketknight1 Sep 2, 2021
a63dfb9
Style pass
Rocketknight1 Sep 2, 2021
d7048a4
Correctly set output shapes - fixes a whole lot of issues
Rocketknight1 Sep 2, 2021
56ea08f
Fix an embarrassing regression bug
Rocketknight1 Sep 7, 2021
2ddf7c6
Style pass
Rocketknight1 Sep 7, 2021
ddfda69
Added `config.TF_AVAILABLE` checks and dict literals
Rocketknight1 Sep 8, 2021
c87d47e
Handling for special cases around label/labels and very nested dtypes
Rocketknight1 Sep 9, 2021
e7d1ce8
Fix for accidentally shuffling even when flag was False
Rocketknight1 Sep 10, 2021
48045fb
Adding dummy labels by default
Rocketknight1 Sep 14, 2021
ec4f7d4
Adding docstrings and type hints
Rocketknight1 Sep 15, 2021
88e9f1e
Style pass
Rocketknight1 Sep 15, 2021
a7b4574
Add tests, bugfix to handling scalar columns
Rocketknight1 Sep 15, 2021
b35267d
Style pass
Rocketknight1 Sep 15, 2021
6273d73
Fix to `numpy_pad`
Rocketknight1 Sep 15, 2021
4ff6d2e
Replace assertion with more robust syntax
Rocketknight1 Sep 15, 2021
589c575
Add cleanup deletion of tf_dataset in tests
Rocketknight1 Sep 15, 2021
d70fe94
Rebasing onto Master
Rocketknight1 Sep 15, 2021
a189740
Fixes for the new approach
Rocketknight1 Sep 15, 2021
c8f251b
Force dtype to ensure Windows compatibility
Rocketknight1 Sep 15, 2021
f1f8888
Fixing things because I am bad at merging
Rocketknight1 Sep 15, 2021
ef9a7bb
Fix issues with passing a mutable list to columns argument
Rocketknight1 Sep 16, 2021
b8523e4
Update src/datasets/arrow_dataset.py
lhoestq Sep 16, 2021
46c2507
Merge branch 'master' into tf_dataset_conversion
Rocketknight1 Sep 16, 2021
397bcb7
Fix unused import
Rocketknight1 Sep 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 252 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from . import config, utils
from .arrow_reader import ArrowReader
from .arrow_writer import ArrowWriter, OptimizedTypedSequence
from .features import ClassLabel, Features, Value
from .features import ClassLabel, Features, Sequence, Value, _ArrayXD
from .filesystems import extract_path_from_uri, is_remote_filesystem
from .fingerprint import (
fingerprint_transform,
Expand Down Expand Up @@ -159,6 +159,252 @@ def version(self):
return self._info.version


class TensorflowDatasetMixIn:
def __init__(self):
pass

@staticmethod
def _get_output_signature(dataset, cols_to_retain, test_batch, batch_size):
if config.TF_AVAILABLE:
import tensorflow as tf
else:
raise ImportError("Called a Tensorflow-specific function but could not import it!")

signatures = {}
for column, col_feature in dataset.features.items():
if column not in cols_to_retain:
continue
dtype_feature = col_feature
while hasattr(dtype_feature, "feature"): # Descend this godforsaken nested rabbit hole as long as it takes
dtype_feature = dtype_feature.feature
dtype_str = dtype_feature.dtype
if dtype_str.startswith("int") or dtype_str.startswith("uint"):
dtype = tf.int64
elif dtype_str.startswith("float"):
dtype = tf.float32
else:
raise ValueError(f"Could not convert datatype {dtype_str} in column {column}!")

shape = []
shape_feature = col_feature
while not isinstance(shape_feature, (Value, ClassLabel)):
if isinstance(shape_feature, _ArrayXD):
shape.extend(list(shape_feature.shape))
break
elif isinstance(shape_feature, Sequence):
shape.insert(0, shape_feature.length)
shape_feature = shape_feature.feature
else:
raise ValueError(
f"Couldn't parse feature {column} with type {type(col_feature)}! "
"This may indicate a column was included with an unusual datatype "
"that we were unable to process correctly. "
"If you're getting this error with one of our datasets, and you're "
"sure the column should be convertable to tf.Tensor, please "
"file an issue at github.com/huggingface/datasets and tag "
"@rocketknight1!"
)
shape = [batch_size] + shape
shape = [dim if dim != -1 else None for dim in shape]

signatures[column] = tf.TensorSpec(shape=shape, dtype=dtype)

# Catching columns added by the collate_fn, such as MLM labels
for column, tensor in test_batch.items():
if column in signatures:
continue
if column.startswith("label"):
if "input_ids" in signatures and test_batch[column].shape == test_batch["input_ids"].shape:
shape = signatures["input_ids"].shape
else:
# If this doesn't look like LM labels that got added by the collate_fn, let's not say anything
# about the dimensions we're unsure of
shape = [batch_size] + [None for dim in tensor.shape.as_list()[1:]]
else:
# If this doesn't look like LM labels that got added by the collate_fn, let's not say anything
# about the dimensions we're unsure of
shape = [batch_size] + [None for dim in tensor.shape.as_list()[1:]]
signatures[column] = tf.TensorSpec(shape=shape, dtype=tensor.dtype)
return signatures

def to_tf_dataset(
self,
columns: Union[str, List[str]],
batch_size: int,
shuffle: bool,
drop_remainder: bool = None,
collate_fn: Callable = None,
collate_fn_args: Dict[str, Any] = None,
label_cols: Union[str, List[str]] = None,
dummy_labels: bool = True,
prefetch: bool = True,
):
"""Create a tf.data.Dataset from the underlying Dataset. This tf.data.Dataset will load and collate batches from
the Dataset, and is suitable for passing to methods like model.fit() or model.predict().

Args:
columns (:obj:`List[str]` or :obj:`str`): Dataset column(s) to load in the tf.data.Dataset. In general,
only columns that the model can use as input should be included here (numeric data only).
batch_size (:obj:`int`): Size of batches to load from the dataset.
shuffle(:obj:`bool`): Shuffle the dataset order when loading. Recommended True for training, False for
validation/evaluation.
drop_remainder(:obj:`bool`, default ``None``): Drop the last incomplete batch when loading. If not provided,
defaults to the same setting as shuffle.
collate_fn(:obj:`Callable`): A function or callable object (such as a `DataCollator`) that will collate
lists of samples into a batch.
collate_fn_args (:obj:`Dict`, optional): An optional `dict` of keyword arguments to be passed to the
`collate_fn`.
label_cols (:obj:`List[str]` or :obj:`str`, default ``None``): Dataset column(s) to load as
labels. Note that many models compute loss internally rather than letting Keras do it, in which case it is
not necessary to actually pass the labels here, as long as they're in the input `columns`.
dummy_labels (:obj:`bool`, default ``True``): If no `label_cols` are set, output an array of "dummy" labels
with each batch. This setting ensures that Keras `fit()` or `train_on_batch()` does not get confused
by the missing labels.
prefetch (:obj:`bool`, default ``True``): Whether to run the dataloader in a separate thread and maintain
a small buffer of batches for training. Improves performance by allowing data to be loaded in the
background while the model is training.
"""
if config.TF_AVAILABLE:
import tensorflow as tf
else:
raise ImportError("Called a Tensorflow-specific function but could not import it!")

if collate_fn_args is None:
collate_fn_args = {}

if label_cols is None:
label_cols = []
elif isinstance(label_cols, str):
label_cols = [label_cols]
elif len(set(label_cols)) < len(label_cols):
raise ValueError("List of label_cols contains duplicates!")
if not columns:
raise ValueError("Need to specify at least one column!")
elif isinstance(columns, str):
columns = [columns]
elif len(set(columns)) < len(columns):
raise ValueError("List of columns contains duplicates!")
if label_cols is not None:
cols_to_retain = list(set(columns + label_cols))
else:
cols_to_retain = columns
# Special casing when the dataset has 'label' and the model expects 'labels' and the collator fixes it up for us
if "labels" in cols_to_retain and "labels" not in self.features and "label" in self.features:
cols_to_retain[cols_to_retain.index("labels")] = "label"
for col in cols_to_retain:
if col not in self.features:
raise ValueError(f"Couldn't find column {col} in dataset!")
if drop_remainder is None:
# We assume that if you're shuffling it's the train set, so we drop the remainder unless told not to
drop_remainder = shuffle
self.set_format("numpy", columns=cols_to_retain)

def numpy_pad(data):
try:
# When this is finally fully removed, remove this line
# Alternatively, find a more elegant way to do this whole thing
np.warnings.filterwarnings("error", category=np.VisibleDeprecationWarning)
data = np.array(data)
if data.dtype == np.object:
raise AssertionError # Do it this way so that the assert doesn't get optimized out
return data
except (np.VisibleDeprecationWarning, AssertionError):
pass
# Get lengths of each row of data
lens = np.array([len(i) for i in data])

# Mask of valid places in each row
mask = np.arange(lens.max()) < lens[:, None]

# Setup output array and put elements from data into masked positions
out = np.zeros(mask.shape, dtype=np.array(data[0]).dtype)
out[mask] = np.concatenate(data)
return out

def np_get_batch(indices):
batch = self[indices]
out_batch = []
if collate_fn is not None:
actual_size = len(list(batch.values())[0]) # Get the length of one of the arrays, assume all same
# Our collators expect a list of dicts, not a dict of lists/arrays, so we invert
batch = [{key: value[i] for key, value in batch.items()} for i in range(actual_size)]
batch = collate_fn(batch, **collate_fn_args)
# Special casing when the dataset has 'label' and the model
# expects 'labels' and the collator fixes it up for us
if "label" in cols_to_retain and "label" not in batch and "labels" in batch:
cols_to_retain[cols_to_retain.index("label")] = "labels"
for key in cols_to_retain:
# In case the collate_fn returns something strange
array = np.array(batch[key])
cast_dtype = np.int64 if np.issubdtype(array.dtype, np.integer) else np.float32
array = array.astype(cast_dtype)
out_batch.append(array)
else:
for key in cols_to_retain:
array = batch[key]
array = numpy_pad(array)
cast_dtype = np.int64 if np.issubdtype(array.dtype, np.integer) else np.float32
array = array.astype(cast_dtype)
Comment on lines +346 to +347
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work for string types or nested types ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've had some success with nested dtypes (in multiple choice datasets). This does fail on string types though - the tf.data.Dataset is intended to be passed straight to a model, so the assumption was that everything coming out of it would be convertable to a tf.Tensor. We could possibly make strings work in this context, though - but I'd need to think about a more generic approach to building the dataset and doing shape inference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok ! Maybe we can mention this in the docstring ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just mentioned that numeric data only are expected in the docstring :)

out_batch.append(array)
return [tf.convert_to_tensor(arr) for arr in out_batch]

test_batch = np_get_batch(np.arange(batch_size))

@tf.function(input_signature=[tf.TensorSpec(None, tf.int64)])
def fetch_function(indices):
output = tf.numpy_function(
np_get_batch, inp=[indices], Tout=[tf.dtypes.as_dtype(arr.dtype) for arr in test_batch]
)
return {key: output[i] for i, key in enumerate(cols_to_retain)}

test_batch_dict = {key: test_batch[i] for i, key in enumerate(cols_to_retain)}
output_signature = self._get_output_signature(
self, cols_to_retain, test_batch_dict, batch_size=batch_size if drop_remainder else None
)

def ensure_shapes(input_dict):
return {key: tf.ensure_shape(val, output_signature[key].shape) for key, val in input_dict.items()}

tf_dataset = tf.data.Dataset.from_tensor_slices(np.arange(len(self), dtype=np.int64))

if shuffle:
tf_dataset = tf_dataset.shuffle(len(self))

tf_dataset = tf_dataset.batch(batch_size, drop_remainder=drop_remainder).map(fetch_function).map(ensure_shapes)

if label_cols:

def split_features_and_labels(input_batch):
features = {key: tensor for key, tensor in input_batch.items() if key in columns}
labels = {key: tensor for key, tensor in input_batch.items() if key in label_cols}
if len(features) == 1:
features = list(features.values())[0]
if len(labels) == 1:
labels = list(labels.values())[0]
return features, labels

tf_dataset = tf_dataset.map(split_features_and_labels)

elif len(columns) == 1:
tf_dataset = tf_dataset.map(lambda x: list(x.values())[0])

if dummy_labels and not label_cols:
print(
"Warning: No label_cols specified - adding some dummy labels to ensure fit() works correctly. If you "
"only want to use this dataset with predict() or custom training loops, you can disable this "
"behaviour by setting dummy_labels to False."
)

def add_dummy_labels(input_batch):
return input_batch, tf.zeros(tf.shape(input_batch[columns[0]])[0])

tf_dataset = tf_dataset.map(add_dummy_labels)

if prefetch:
tf_dataset = tf_dataset.prefetch(tf.data.AUTOTUNE)
return tf_dataset


class DatasetTransformationNotAllowedError(Exception):
pass

Expand Down Expand Up @@ -238,7 +484,7 @@ class NonExistentDatasetError(Exception):
pass


class Dataset(DatasetInfoMixin, IndexableMixin):
class Dataset(DatasetInfoMixin, IndexableMixin, TensorflowDatasetMixIn):
"""A Dataset backed by an Arrow table."""

def __init__(
Expand Down Expand Up @@ -1355,12 +1601,16 @@ def set_format(
# Check filter column
if isinstance(columns, str):
columns = [columns]
if isinstance(columns, tuple):
columns = list(columns)
if columns is not None and any(col not in self._data.column_names for col in columns):
raise ValueError(
"Columns {} not in the dataset. Current columns in the dataset: {}".format(
list(filter(lambda col: col not in self._data.column_names, columns)), self._data.column_names
)
)
if columns is not None:
columns = columns.copy() # Ensures modifications made to the list after this call don't cause bugs

self._format_type = type
self._format_kwargs = format_kwargs
Expand Down
32 changes: 26 additions & 6 deletions src/datasets/formatting/tf_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,33 @@ def __init__(self, **tf_tensor_kwargs):
def _tensorize(self, value):
import tensorflow as tf

default_dtype = {}
if np.issubdtype(value.dtype, np.integer):
default_dtype = {"dtype": tf.int64}
elif np.issubdtype(value.dtype, np.floating):
default_dtype = {"dtype": tf.float32}
if "dtype" not in self.tf_tensor_kwargs:
if np.issubdtype(value.dtype, np.integer):
np_dtype = np.int64
tf_dtype = tf.int64
default_dtype = {"dtype": tf_dtype}
elif np.issubdtype(value.dtype, np.floating):
np_dtype = np.float32
tf_dtype = tf.float32
default_dtype = {"dtype": tf_dtype}
else:
np_dtype = None
tf_dtype = None
default_dtype = {}
else:
tf_dtype = self.tf_tensor_kwargs["dtype"]
np_dtype = tf_dtype.as_numpy_dtype
default_dtype = {}

return tf.ragged.constant(value, **{**default_dtype, **self.tf_tensor_kwargs})
# Saving the most expensive methods for last
try:
return tf.convert_to_tensor(value, dtype=tf_dtype)
except ValueError:
try:
return tf.ragged.stack([np.array(subarr, dtype=np_dtype) for subarr in value])
except ValueError:
# tf.ragged.constant is orders of magnitude slower than tf.ragged.stack
return tf.ragged.constant(value, **{**default_dtype, **self.tf_tensor_kwargs})

def _recursive_tensorize(self, data_struct: dict):
# support for nested types like struct of list of struct
Expand Down
19 changes: 17 additions & 2 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1816,8 +1816,8 @@ def test_format_vectors(self, in_memory):
self.assertIsInstance(dset[0][col], (tf.Tensor, tf.RaggedTensor))
self.assertIsInstance(dset[:2][col], (tf.Tensor, tf.RaggedTensor))
self.assertIsInstance(dset[col], (tf.Tensor, tf.RaggedTensor))
self.assertEqual(tuple(dset[:2]["vec"].shape), (2, None))
self.assertEqual(tuple(dset["vec"][:2].shape), (2, None))
self.assertEqual(tuple(dset[:2]["vec"].shape), (2, 3))
self.assertEqual(tuple(dset["vec"][:2].shape), (2, 3))

dset.set_format("numpy")
self.assertIsNotNone(dset[0])
Expand Down Expand Up @@ -1997,6 +1997,21 @@ def test_with_transform(self, in_memory):
self.assertNotEqual(dset.format, dset2.format)
self.assertNotEqual(dset._fingerprint, dset2._fingerprint)

@require_tf
def test_tf_dataset_conversion(self, in_memory):
with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
tf_dataset = dset.to_tf_dataset(columns="col_3", batch_size=4, shuffle=False, dummy_labels=False)
batch = next(iter(tf_dataset))
self.assertEqual(batch.shape.as_list(), [4, 4])
self.assertEqual(batch.dtype.name, "int64")
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
tf_dataset = dset.to_tf_dataset(columns="col_1", batch_size=4, shuffle=False, dummy_labels=False)
batch = next(iter(tf_dataset))
self.assertEqual(batch.shape.as_list(), [4])
self.assertEqual(batch.dtype.name, "int64")
del tf_dataset # For correct cleanup


class MiscellaneousDatasetTest(TestCase):
def test_from_pandas(self):
Expand Down
16 changes: 8 additions & 8 deletions tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,20 +164,20 @@ def test_tf_formatter(self):
pa_table = self._create_dummy_table()
formatter = TFFormatter()
row = formatter.format_row(pa_table)
tf.debugging.assert_equal(row["a"], tf.ragged.constant(_COL_A, dtype=tf.int64)[0])
tf.debugging.assert_equal(row["b"], tf.ragged.constant(_COL_B, dtype=tf.string)[0])
tf.debugging.assert_equal(row["c"], tf.ragged.constant(_COL_C, dtype=tf.float32)[0])
tf.debugging.assert_equal(row["a"], tf.convert_to_tensor(_COL_A, dtype=tf.int64)[0])
tf.debugging.assert_equal(row["b"], tf.convert_to_tensor(_COL_B, dtype=tf.string)[0])
tf.debugging.assert_equal(row["c"], tf.convert_to_tensor(_COL_C, dtype=tf.float32)[0])
col = formatter.format_column(pa_table)
tf.debugging.assert_equal(col, tf.ragged.constant(_COL_A, dtype=tf.int64))
batch = formatter.format_batch(pa_table)
tf.debugging.assert_equal(batch["a"], tf.ragged.constant(_COL_A, dtype=tf.int64))
tf.debugging.assert_equal(batch["b"], tf.ragged.constant(_COL_B, dtype=tf.string))
self.assertIsInstance(batch["c"], tf.RaggedTensor)
tf.debugging.assert_equal(batch["a"], tf.convert_to_tensor(_COL_A, dtype=tf.int64))
tf.debugging.assert_equal(batch["b"], tf.convert_to_tensor(_COL_B, dtype=tf.string))
self.assertIsInstance(batch["c"], tf.Tensor)
self.assertEqual(batch["c"].dtype, tf.float32)
tf.debugging.assert_equal(
batch["c"].bounding_shape(), tf.ragged.constant(_COL_C, dtype=tf.float32).bounding_shape()
batch["c"].shape.as_list(), tf.convert_to_tensor(_COL_C, dtype=tf.float32).shape.as_list()
)
tf.debugging.assert_equal(batch["c"].flat_values, tf.ragged.constant(_COL_C, dtype=tf.float32).flat_values)
tf.debugging.assert_equal(tf.convert_to_tensor(batch["c"]), tf.convert_to_tensor(_COL_C, dtype=tf.float32))

@require_tf
def test_tf_formatter_np_array_kwargs(self):
Expand Down