Skip to content

Commit

Permalink
Fix sliced ConcatenationTable pickling with mixed schemas vertically (#…
Browse files Browse the repository at this point in the history
…6715)

* fix ConcatenationTable pickling with mixed schemas vertically

* fix slicing

* fix tests
  • Loading branch information
lhoestq authored Mar 5, 2024
1 parent 1fe9483 commit 8247202
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
11 changes: 10 additions & 1 deletion src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,11 +1311,20 @@ def __init__(self, table: pa.Table, blocks: List[List[TableBlock]]):
)

def __getstate__(self):
return {"blocks": self.blocks}
return {"blocks": self.blocks, "schema": self.table.schema}

def __setstate__(self, state):
blocks = state["blocks"]
schema = state["schema"]
table = self._concat_blocks_horizontally_and_vertically(blocks)
if schema is not None and table.schema != schema:
# We fix the columns by concatenating with an empty table with the right columns
empty_table = pa.Table.from_batches([], schema=schema)
# we set promote=True to fill missing columns with null values
if config.PYARROW_VERSION.major < 14:
table = pa.concat_tables([table, empty_table], promote=True)
else:
table = pa.concat_tables([table, empty_table], promote_options="default")
ConcatenationTable.__init__(self, table, blocks=blocks)

@staticmethod
Expand Down
11 changes: 8 additions & 3 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def __reduce__(self):


class Unpicklable:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)

def __getstate__(self):
raise pickle.PicklingError()

Expand Down Expand Up @@ -811,6 +815,7 @@ def test_concatenate_pickle(self, in_memory):
Dataset.from_dict(data2, info=info2),
Dataset.from_dict(data3),
)
schema = dset1.data.schema
# mix from in-memory and on-disk datasets
dset1, dset2 = self._to(in_memory, tmp_dir, dset1, dset2)
dset3 = self._to(not in_memory, tmp_dir, dset3)
Expand All @@ -835,13 +840,13 @@ def test_concatenate_pickle(self, in_memory):
dset3 = dset3.rename_column("foo", "new_foo")
dset3 = dset3.remove_columns("new_foo")
if in_memory:
dset3._data.table = Unpicklable()
dset3._data.table = Unpicklable(schema=schema)
else:
dset1._data.table, dset2._data.table = Unpicklable(), Unpicklable()
dset1._data.table, dset2._data.table = Unpicklable(schema=schema), Unpicklable(schema=schema)
dset1, dset2, dset3 = (pickle.loads(pickle.dumps(d)) for d in (dset1, dset2, dset3))
with concatenate_datasets([dset3, dset2, dset1]) as dset_concat:
if not in_memory:
dset_concat._data.table = Unpicklable()
dset_concat._data.table = Unpicklable(schema=schema)
with pickle.loads(pickle.dumps(dset_concat)) as dset_concat:
self.assertTupleEqual((len(dset1), len(dset2), len(dset3)), (3, 3, 2))
self.assertEqual(len(dset_concat), len(dset1) + len(dset2) + len(dset3))
Expand Down
21 changes: 21 additions & 0 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,27 @@ def test_concatenation_table_slice(
assert isinstance(table, ConcatenationTable)


def test_concatenation_table_slice_mixed_schemas_vertically(arrow_file):
t1 = MemoryMappedTable.from_file(arrow_file)
t2 = InMemoryTable.from_pydict({"additional_column": ["foo"]})
expected = pa.table(
{
**{column: values + [None] for column, values in t1.to_pydict().items()},
"additional_column": [None] * len(t1) + ["foo"],
}
)
blocks = [[t1], [t2]]
table = ConcatenationTable.from_blocks(blocks)
assert table.to_pydict() == expected.to_pydict()
assert isinstance(table, ConcatenationTable)
reloaded = pickle.loads(pickle.dumps(table))
assert reloaded.to_pydict() == expected.to_pydict()
assert isinstance(reloaded, ConcatenationTable)
reloaded = pickle.loads(pickle.dumps(table.slice(1, 2)))
assert reloaded.to_pydict() == expected.slice(1, 2).to_pydict()
assert isinstance(reloaded, ConcatenationTable)


@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
def test_concatenation_table_filter(
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
Expand Down

0 comments on commit 8247202

Please sign in to comment.