Skip to content

Commit

Permalink
feat: add unit test to load data in both arrow formats
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
  • Loading branch information
kmehant committed Jul 25, 2024
1 parent d8f8428 commit bd6546c
Showing 1 changed file with 46 additions and 1 deletion.
47 changes: 46 additions & 1 deletion tests/packaged_modules/test_arrow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,53 @@
import pyarrow as pa
import pytest

from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.arrow.arrow import ArrowConfig
from datasets.packaged_modules.arrow.arrow import Arrow, ArrowConfig


@pytest.fixture
def arrow_file_streaming_format(tmp_path):
filename = tmp_path / "stream.arrow"
testdata = [[1, 1, 1], [0, 100, 6], [1, 90, 900]]

schema = pa.schema([pa.field("input_ids", pa.list_(pa.int32()))])
array = pa.array(testdata, type=pa.list_(pa.int32()))
table = pa.Table.from_arrays([array], schema=schema)
with open(filename, "wb") as f:
with pa.ipc.new_stream(f, schema) as writer:
writer.write_table(table)
return str(filename)


@pytest.fixture
def arrow_file_file_format(tmp_path):
filename = tmp_path / "file.arrow"
testdata = [[1, 1, 1], [0, 100, 6], [1, 90, 900]]

schema = pa.schema([pa.field("input_ids", pa.list_(pa.int32()))])
array = pa.array(testdata, type=pa.list_(pa.int32()))
table = pa.Table.from_arrays([array], schema=schema)
with open(filename, "wb") as f:
with pa.ipc.new_file(f, schema) as writer:
writer.write_table(table)
return str(filename)


@pytest.mark.parametrize(
"file_fixture, config_kwargs",
[
("arrow_file_streaming_format", {}),
("arrow_file_file_format", {}),
],
)
def test_arrow_generate_tables(file_fixture, config_kwargs, request):
arrow = Arrow(**config_kwargs)
generator = arrow._generate_tables([[request.getfixturevalue(file_fixture)]])
pa_table = pa.concat_tables([table for _, table in generator])

expected = {"input_ids": [[1, 1, 1], [0, 100, 6], [1, 90, 900]]}
assert pa_table.to_pydict() == expected


def test_config_raises_when_invalid_name() -> None:
Expand Down

0 comments on commit bd6546c

Please sign in to comment.