From 8247202a7ed1c3164c88f8f183513c5f003aa2af Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Tue, 5 Mar 2024 12:17:04 +0100 Subject: [PATCH] Fix sliced ConcatenationTable pickling with mixed schemas vertically (#6715) * fix ConcatenationTable pickling with mixed schemas vertically * fix slicing * fix tests --- src/datasets/table.py | 11 ++++++++++- tests/test_arrow_dataset.py | 11 ++++++++--- tests/test_table.py | 21 +++++++++++++++++++++ 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/datasets/table.py b/src/datasets/table.py index 1811bd1e18a..366f8bc185d 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -1311,11 +1311,20 @@ def __init__(self, table: pa.Table, blocks: List[List[TableBlock]]): ) def __getstate__(self): - return {"blocks": self.blocks} + return {"blocks": self.blocks, "schema": self.table.schema} def __setstate__(self, state): blocks = state["blocks"] + schema = state["schema"] table = self._concat_blocks_horizontally_and_vertically(blocks) + if schema is not None and table.schema != schema: + # We fix the columns by concatenating with an empty table with the right columns + empty_table = pa.Table.from_batches([], schema=schema) + # we set promote=True to fill missing columns with null values + if config.PYARROW_VERSION.major < 14: + table = pa.concat_tables([table, empty_table], promote=True) + else: + table = pa.concat_tables([table, empty_table], promote_options="default") ConcatenationTable.__init__(self, table, blocks=blocks) @staticmethod diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index bc2c8846fd0..b101c4f5715 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -73,6 +73,10 @@ def __reduce__(self): class Unpicklable: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + def __getstate__(self): raise pickle.PicklingError() @@ -811,6 +815,7 @@ def test_concatenate_pickle(self, in_memory): Dataset.from_dict(data2, info=info2), Dataset.from_dict(data3), ) + schema = dset1.data.schema # mix from in-memory and on-disk datasets dset1, dset2 = self._to(in_memory, tmp_dir, dset1, dset2) dset3 = self._to(not in_memory, tmp_dir, dset3) @@ -835,13 +840,13 @@ def test_concatenate_pickle(self, in_memory): dset3 = dset3.rename_column("foo", "new_foo") dset3 = dset3.remove_columns("new_foo") if in_memory: - dset3._data.table = Unpicklable() + dset3._data.table = Unpicklable(schema=schema) else: - dset1._data.table, dset2._data.table = Unpicklable(), Unpicklable() + dset1._data.table, dset2._data.table = Unpicklable(schema=schema), Unpicklable(schema=schema) dset1, dset2, dset3 = (pickle.loads(pickle.dumps(d)) for d in (dset1, dset2, dset3)) with concatenate_datasets([dset3, dset2, dset1]) as dset_concat: if not in_memory: - dset_concat._data.table = Unpicklable() + dset_concat._data.table = Unpicklable(schema=schema) with pickle.loads(pickle.dumps(dset_concat)) as dset_concat: self.assertTupleEqual((len(dset1), len(dset2), len(dset3)), (3, 3, 2)) self.assertEqual(len(dset_concat), len(dset1) + len(dset2) + len(dset3)) diff --git a/tests/test_table.py b/tests/test_table.py index 7db58037245..9ac6b7ee7b9 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -787,6 +787,27 @@ def test_concatenation_table_slice( assert isinstance(table, ConcatenationTable) +def test_concatenation_table_slice_mixed_schemas_vertically(arrow_file): + t1 = MemoryMappedTable.from_file(arrow_file) + t2 = InMemoryTable.from_pydict({"additional_column": ["foo"]}) + expected = pa.table( + { + **{column: values + [None] for column, values in t1.to_pydict().items()}, + "additional_column": [None] * len(t1) + ["foo"], + } + ) + blocks = [[t1], [t2]] + table = ConcatenationTable.from_blocks(blocks) + assert table.to_pydict() == expected.to_pydict() + assert isinstance(table, ConcatenationTable) + reloaded = pickle.loads(pickle.dumps(table)) + assert reloaded.to_pydict() == expected.to_pydict() + assert isinstance(reloaded, ConcatenationTable) + reloaded = pickle.loads(pickle.dumps(table.slice(1, 2))) + assert reloaded.to_pydict() == expected.slice(1, 2).to_pydict() + assert isinstance(reloaded, ConcatenationTable) + + @pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"]) def test_concatenation_table_filter( blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks