Skip to content

Commit

Permalink
Allow polars as valid output type (#6762)
Browse files Browse the repository at this point in the history
* Add polaras in allowed processed inputs types in validate_function_output

* add polars as an updatable type

* add polars as a return type in tests

* Update tests/test_arrow_dataset.py

---------

Co-authored-by: Patrick Smyth <docs@huggingface.co>
Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 16, 2024
1 parent bececda commit 5f42139
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 15 deletions.
14 changes: 12 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3375,7 +3375,12 @@ class NumExamplesMismatchError(Exception):

def validate_function_output(processed_inputs, indices):
"""Validate output of the map function."""
if processed_inputs is not None and not isinstance(processed_inputs, (Mapping, pa.Table, pd.DataFrame)):
allowed_processed_inputs_types = (Mapping, pa.Table, pd.DataFrame)
if config.POLARS_AVAILABLE and "polars" in sys.modules:
import polars as pl

allowed_processed_inputs_types += (pl.DataFrame,)
if processed_inputs is not None and not isinstance(processed_inputs, allowed_processed_inputs_types):
raise TypeError(
f"Provided `function` which is applied to all elements of table returns a variable of type {type(processed_inputs)}. Make sure provided `function` returns a variable of type `dict` (or a pyarrow table) to update the dataset or `None` if you are only interested in side effects."
)
Expand Down Expand Up @@ -3434,7 +3439,12 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example
returned_lazy_dict = False
if update_data is None:
# Check if the function returns updated examples
update_data = isinstance(processed_inputs, (Mapping, pa.Table, pd.DataFrame))
updatable_types = (Mapping, pa.Table, pd.DataFrame)
if config.POLARS_AVAILABLE and "polars" in sys.modules:
import polars as pl

updatable_types += (pl.DataFrame,)
update_data = isinstance(processed_inputs, updatable_types)
validate_function_output(processed_inputs, indices)
if not update_data:
return None # Nothing to update, let's move on
Expand Down
101 changes: 88 additions & 13 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,10 @@ def test_flatten(self, in_memory):
with Dataset.from_dict(
{"a": [{"en": "the cat", "fr": ["le chat", "la chatte"], "de": "die katze"}] * 10, "foo": [1] * 10},
features=Features(
{"a": TranslationVariableLanguages(languages=["en", "fr", "de"]), "foo": Value("int64")}
{
"a": TranslationVariableLanguages(languages=["en", "fr", "de"]),
"foo": Value("int64"),
}
),
) as dset:
with self._to(in_memory, tmp_dir, dset) as dset:
Expand Down Expand Up @@ -1000,7 +1003,11 @@ def test_flatten_complex_image(self, in_memory):
self.assertDictEqual(
dset.features,
Features(
{"a.b.bytes": Value("binary"), "a.b.path": Value("string"), "foo": Value("int64")}
{
"a.b.bytes": Value("binary"),
"a.b.path": Value("string"),
"foo": Value("int64"),
}
),
)
self.assertNotEqual(dset._fingerprint, fingerprint)
Expand Down Expand Up @@ -1528,6 +1535,48 @@ def func_return_multi_row_pd_dataframe(x):
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
self.assertRaises(ValueError, dset.map, func_return_multi_row_pd_dataframe)

@require_polars
def test_map_return_pl_dataframe(self, in_memory):
import polars as pl

def func_return_single_row_pl_dataframe(x):
return pl.DataFrame({"id": [0], "text": ["a"]})

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
with dset.map(func_return_single_row_pl_dataframe) as dset_test:
self.assertEqual(len(dset_test), 30)
self.assertDictEqual(
dset_test.features,
Features({"id": Value("int64"), "text": Value("large_string")}),
)
self.assertEqual(dset_test[0]["id"], 0)
self.assertEqual(dset_test[0]["text"], "a")

# Batched
def func_return_single_row_pl_dataframe_batched(x):
batch_size = len(x[next(iter(x))])
return pl.DataFrame({"id": [0] * batch_size, "text": ["a"] * batch_size})

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
with dset.map(func_return_single_row_pl_dataframe_batched, batched=True) as dset_test:
self.assertEqual(len(dset_test), 30)
self.assertDictEqual(
dset_test.features,
Features({"id": Value("int64"), "text": Value("large_string")}),
)
self.assertEqual(dset_test[0]["id"], 0)
self.assertEqual(dset_test[0]["text"], "a")

# Error when returning a table with more than one row in the non-batched mode
def func_return_multi_row_pl_dataframe(x):
return pl.DataFrame({"id": [0, 1], "text": ["a", "b"]})

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
self.assertRaises(ValueError, dset.map, func_return_multi_row_pl_dataframe)

@require_numpy1_on_windows
@require_torch
def test_map_torch(self, in_memory):
Expand Down Expand Up @@ -1831,7 +1880,10 @@ def test_filter_caching(self, in_memory):

def test_keep_features_after_transform_specified(self, in_memory):
features = Features(
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
{
"tokens": Sequence(Value("string")),
"labels": Sequence(ClassLabel(names=["negative", "positive"])),
}
)

def invert_labels(x):
Expand All @@ -1849,7 +1901,10 @@ def invert_labels(x):

def test_keep_features_after_transform_unspecified(self, in_memory):
features = Features(
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
{
"tokens": Sequence(Value("string")),
"labels": Sequence(ClassLabel(names=["negative", "positive"])),
}
)

def invert_labels(x):
Expand All @@ -1867,7 +1922,10 @@ def invert_labels(x):

def test_keep_features_after_transform_to_file(self, in_memory):
features = Features(
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
{
"tokens": Sequence(Value("string")),
"labels": Sequence(ClassLabel(names=["negative", "positive"])),
}
)

def invert_labels(x):
Expand All @@ -1886,7 +1944,10 @@ def invert_labels(x):

def test_keep_features_after_transform_to_memory(self, in_memory):
features = Features(
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
{
"tokens": Sequence(Value("string")),
"labels": Sequence(ClassLabel(names=["negative", "positive"])),
}
)

def invert_labels(x):
Expand All @@ -1903,7 +1964,10 @@ def invert_labels(x):

def test_keep_features_after_loading_from_cache(self, in_memory):
features = Features(
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
{
"tokens": Sequence(Value("string")),
"labels": Sequence(ClassLabel(names=["negative", "positive"])),
}
)

def invert_labels(x):
Expand All @@ -1926,7 +1990,10 @@ def invert_labels(x):

def test_keep_features_with_new_features(self, in_memory):
features = Features(
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
{
"tokens": Sequence(Value("string")),
"labels": Sequence(ClassLabel(names=["negative", "positive"])),
}
)

def invert_labels(x):
Expand Down Expand Up @@ -3710,7 +3777,11 @@ def test_dataset_from_json_features(features, jsonl_path, tmp_path):

def test_dataset_from_json_with_class_label_feature(jsonl_str_path, tmp_path):
features = Features(
{"col_1": ClassLabel(names=["s0", "s1", "s2", "s3"]), "col_2": Value("int64"), "col_3": Value("float64")}
{
"col_1": ClassLabel(names=["s0", "s1", "s2", "s3"]),
"col_2": Value("int64"),
"col_3": Value("float64"),
}
)
cache_dir = tmp_path / "cache"
dataset = Dataset.from_json(jsonl_str_path, features=features, cache_dir=cache_dir)
Expand Down Expand Up @@ -4262,7 +4333,11 @@ def test_task_question_answering(self):
def test_task_summarization(self):
# Include a dummy extra column `dummy` to test we drop it correctly
features_before_cast = Features(
{"input_text": Value("string"), "input_summary": Value("string"), "dummy": Value("string")}
{
"input_text": Value("string"),
"input_summary": Value("string"),
"dummy": Value("string"),
}
)
features_after_cast = Features({"text": Value("string"), "summary": Value("string")})
task = Summarization(text_column="input_text", summary_column="input_summary")
Expand Down Expand Up @@ -4882,7 +4957,7 @@ def test_dataset_batch():
assert len(batch["id"]) == 3
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]
assert batch["text"] == [f"Text {3 * i}", f"Text {3 * i + 1}", f"Text {3 * i + 2}"]

# Check last partial batch
assert len(batches[3]["id"]) == 1
Expand All @@ -4899,7 +4974,7 @@ def test_dataset_batch():
assert len(batch["id"]) == 3
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]
assert batch["text"] == [f"Text {3 * i}", f"Text {3 * i + 1}", f"Text {3 * i + 2}"]

# Test with batch_size=4 (doesn't evenly divide dataset size)
batched_ds = ds.batch(batch_size=4, drop_last_batch=False)
Expand All @@ -4910,7 +4985,7 @@ def test_dataset_batch():
assert len(batch["id"]) == 4
assert len(batch["text"]) == 4
assert batch["id"] == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
assert batch["text"] == [f"Text {4*i}", f"Text {4*i+1}", f"Text {4*i+2}", f"Text {4*i+3}"]
assert batch["text"] == [f"Text {4 * i}", f"Text {4 * i + 1}", f"Text {4 * i + 2}", f"Text {4 * i + 3}"]

# Check last partial batch
assert len(batches[2]["id"]) == 2
Expand Down

0 comments on commit 5f42139

Please sign in to comment.