Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support loading dataset from multiple zipped CSV data files #3021

Merged
merged 9 commits into from
Oct 6, 2021
16 changes: 16 additions & 0 deletions src/datasets/packaged_modules/csv/csv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# coding=utf-8
import glob
import os
from dataclasses import dataclass
from typing import List, Optional, Union

Expand All @@ -17,6 +19,16 @@
_PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS = ["encoding_errors", "on_bad_lines"]


def _iter_files(files):
for file in files:
if os.path.isfile(file):
yield file
else:
for subfile in glob.glob(os.path.join(file, "**", "*"), recursive=True):
if os.path.isfile(subfile):
yield subfile


@dataclass
class CsvConfig(datasets.BuilderConfig):
"""BuilderConfig for CSV."""
Expand Down Expand Up @@ -138,11 +150,15 @@ def _split_generators(self, dl_manager):
files = data_files
if isinstance(files, str):
files = [files]
if any(os.path.isdir(file) for file in files):
files = [file for file in _iter_files(files)]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
files = [files]
if any(os.path.isdir(file) for file in files):
files = [file for file in _iter_files(files)]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits

Expand Down
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import csv
import json
import lzma
import os
import textwrap

import pyarrow as pa
Expand Down Expand Up @@ -177,6 +178,10 @@ def xml_file(tmp_path_factory):
{"col_1": "2", "col_2": 2, "col_3": 2.0},
{"col_1": "3", "col_2": 3, "col_3": 3.0},
]
DATA2 = [
{"col_1": "4", "col_2": 4, "col_3": 4.0},
{"col_1": "5", "col_2": 5, "col_3": 5.0},
]
DATA_DICT_OF_LISTS = {
"col_1": ["0", "1", "2", "3"],
"col_2": [0, 1, 2, 3],
Expand Down Expand Up @@ -220,6 +225,17 @@ def csv_path(tmp_path_factory):
return path


@pytest.fixture(scope="session")
def csv2_path(tmp_path_factory):
path = str(tmp_path_factory.mktemp("data") / "dataset.csv")
with open(path, "w") as f:
writer = csv.DictWriter(f, fieldnames=["col_1", "col_2", "col_3"])
writer.writeheader()
for item in DATA:
writer.writerow(item)
return path


@pytest.fixture(scope="session")
def bz2_csv_path(csv_path, tmp_path_factory):
import bz2
Expand All @@ -233,6 +249,17 @@ def bz2_csv_path(csv_path, tmp_path_factory):
return path


@pytest.fixture(scope="session")
def zip_csv_path(csv_path, csv2_path, tmp_path_factory):
import zipfile

path = tmp_path_factory.mktemp("data") / "dataset.csv.zip"
with zipfile.ZipFile(path, "w") as f:
f.write(csv_path, arcname=os.path.basename(csv_path))
f.write(csv2_path, arcname=os.path.basename(csv2_path))
return path


@pytest.fixture(scope="session")
def parquet_path(tmp_path_factory):
path = str(tmp_path_factory.mktemp("data") / "dataset.parquet")
Expand Down
8 changes: 8 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,14 @@ def test_load_dataset_streaming_csv(path_extension, streaming, csv_path, bz2_csv
assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0}


def test_load_dataset_zip_csv(zip_csv_path):
data_files = str(zip_csv_path)
features = Features({"col_1": Value("string"), "col_2": Value("int32"), "col_3": Value("float32")})
ds = load_dataset("csv", split="train", data_files=data_files, features=features)
ds_item = next(iter(ds))
assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0}


def test_loading_from_the_datasets_hub():
with tempfile.TemporaryDirectory() as tmp_dir:
dataset = load_dataset(SAMPLE_DATASET_IDENTIFIER, cache_dir=tmp_dir)
Expand Down