Skip to content

Commit

Permalink
feat: support non streamable arrow file binary format
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
  • Loading branch information
kmehant committed Jul 9, 2024
1 parent 689447f commit 2e3af68
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions src/datasets/packaged_modules/arrow/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ def _split_generators(self, dl_manager):
if self.info.features is None:
for file in itertools.chain.from_iterable(files):
with open(file, "rb") as f:
self.info.features = datasets.Features.from_arrow_schema(pa.ipc.open_stream(f).schema)
try:
reader = pa.ipc.open_stream(f)
except pa.lib.ArrowInvalid:
reader = pa.ipc.open_file(f)
self.info.features = datasets.Features.from_arrow_schema(reader.schema)
break
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits
Expand All @@ -57,14 +61,19 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:

def _generate_tables(self, files):
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
with open(file, "rb") as f:
try:
for batch_idx, record_batch in enumerate(pa.ipc.open_stream(f)):
pa_table = pa.Table.from_batches([record_batch])
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield f"{file_idx}_{batch_idx}", self._cast_table(pa_table)
except ValueError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise
try:
with open(file, "rb") as f:
try:
batches = pa.ipc.open_stream(f)
except pa.lib.ArrowInvalid:
reader = pa.ipc.open_file(f)
batches = [reader.get_batch(i) for i in range(reader.num_record_batches)]
for batch_idx, record_batch in enumerate(batches):
pa_table = pa.Table.from_batches([record_batch])
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield f"{file_idx}_{batch_idx}", self._cast_table(pa_table)
except ValueError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise

0 comments on commit 2e3af68

Please sign in to comment.