From a30a66aac1b80930cdf23cdf17d4df3ad4d7e30f Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Thu, 25 Jul 2024 19:03:28 +0530 Subject: [PATCH] feat: add unit test to load data in both arrow formats Signed-off-by: Mehant Kammakomati --- tests/packaged_modules/test_arrow.py | 47 +++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/tests/packaged_modules/test_arrow.py b/tests/packaged_modules/test_arrow.py index dda3720efe3..55c07c9672f 100644 --- a/tests/packaged_modules/test_arrow.py +++ b/tests/packaged_modules/test_arrow.py @@ -1,8 +1,53 @@ +import pyarrow as pa import pytest from datasets.builder import InvalidConfigName from datasets.data_files import DataFilesList -from datasets.packaged_modules.arrow.arrow import ArrowConfig +from datasets.packaged_modules.arrow.arrow import Arrow, ArrowConfig + + +@pytest.fixture +def arrow_file_streaming_format(tmp_path): + filename = tmp_path / "stream.arrow" + testdata = [[1, 1, 1], [0, 100, 6], [1, 90, 900]] + + schema = pa.schema([pa.field("input_ids", pa.list_(pa.int32()))]) + array = pa.array(testdata, type=pa.list_(pa.int32())) + table = pa.Table.from_arrays([array], schema=schema) + with open(filename, "wb") as f: + with pa.ipc.new_stream(f, schema) as writer: + writer.write_table(table) + return str(filename) + + +@pytest.fixture +def arrow_file_file_format(tmp_path): + filename = tmp_path / "file.arrow" + testdata = [[1, 1, 1], [0, 100, 6], [1, 90, 900]] + + schema = pa.schema([pa.field("input_ids", pa.list_(pa.int32()))]) + array = pa.array(testdata, type=pa.list_(pa.int32())) + table = pa.Table.from_arrays([array], schema=schema) + with open(filename, "wb") as f: + with pa.ipc.new_file(f, schema) as writer: + writer.write_table(table) + return str(filename) + + +@pytest.mark.parametrize( + "file_fixture, config_kwargs", + [ + ("arrow_file_streaming_format", {}), + ("arrow_file_file_format", {}), + ], +) +def test_arrow_generate_tables(file_fixture, config_kwargs, request): + arrow = Arrow(**config_kwargs) + generator = arrow._generate_tables([[request.getfixturevalue(file_fixture)]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + expected = {"input_ids": [[1, 1, 1], [0, 100, 6], [1, 90, 900]]} + assert pa_table.to_pydict() == expected def test_config_raises_when_invalid_name() -> None: