Skip to content

Commit

Permalink
Fix load_dataset for data_files with protocols other than HF (#6862)
Browse files Browse the repository at this point in the history
* Issue 6598: load_dataset broken for data_files on s3

* Issue 6598: do not set DownloadConfig.storage_options[hf] for all protocols

* Issue 6598: ruff format

* Issue 6598: mock creds for new test, keep __post_init__ when token spec'd

* Issue 6598: call _prepare_path_and_storage_options for remote URLs via cached_path

* Do not set default 'hf' in DownloadConfig.storage_options

* Test cached_path for different protocols

* Mark test with require_moto

* Refactor test using fixture

* Make import local

---------

Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
  • Loading branch information
matstrand and albertvillanova committed Aug 13, 2024
1 parent efd4b69 commit 243dfd0
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 5 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
"jax>=0.3.14; sys_platform != 'win32'",
"jaxlib>=0.3.14; sys_platform != 'win32'",
"lz4",
"moto[server]",
"pyspark>=3.4", # https://issues.apache.org/jira/browse/SPARK-40991 fixed in 3.4.0
"py7zr",
"rarfile>=4.0",
Expand Down
2 changes: 0 additions & 2 deletions src/datasets/download/download_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ def __post_init__(self, use_auth_token):
FutureWarning,
)
self.token = use_auth_token
if "hf" not in self.storage_options:
self.storage_options["hf"] = {"token": self.token, "endpoint": config.HF_ENDPOINT}

def copy(self) -> "DownloadConfig":
return self.__class__(**{k: copy.deepcopy(v) for k, v in self.__dict__.items()})
Expand Down
5 changes: 4 additions & 1 deletion src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ def cached_path(

if is_remote_url(url_or_filename):
# URL, so get it from the cache (downloading if necessary)
url_or_filename, storage_options = _prepare_path_and_storage_options(
url_or_filename, download_config=download_config
)
output_path = get_from_cache(
url_or_filename,
cache_dir=cache_dir,
Expand All @@ -210,7 +213,7 @@ def cached_path(
max_retries=download_config.max_retries,
token=download_config.token,
ignore_url_params=download_config.ignore_url_params,
storage_options=download_config.storage_options,
storage_options=storage_options,
download_desc=download_config.download_desc,
disable_tqdm=download_config.disable_tqdm,
)
Expand Down
23 changes: 22 additions & 1 deletion tests/test_file_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import re
from pathlib import Path
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest
import zstandard as zstd
Expand Down Expand Up @@ -77,6 +77,27 @@ def tmpfs_file(tmpfs):
return FILE_PATH


@pytest.mark.parametrize("protocol", ["hf", "s3"])
def test_cached_path_protocols(protocol, monkeypatch, tmp_path):
# GH-6598: Test no TypeError: __init__() got an unexpected keyword argument 'hf'
mock_fsspec_head = MagicMock(return_value={})
mock_fsspec_get = MagicMock(return_value=None)
monkeypatch.setattr("datasets.utils.file_utils.fsspec_head", mock_fsspec_head)
monkeypatch.setattr("datasets.utils.file_utils.fsspec_get", mock_fsspec_get)
cache_dir = tmp_path / "cache"
storage_options = {} if protocol == "hf" else {"s3": {"anon": True}}
download_config = DownloadConfig(cache_dir=cache_dir, storage_options=storage_options)
urls = {"hf": "hf://datasets/org-name/ds-name@main/filename.ext", "s3": "s3://bucket-name/filename.ext"}
url = urls[protocol]
_ = cached_path(url, download_config=download_config)
assert True
for mock in [mock_fsspec_head, mock_fsspec_get]:
assert mock.called
assert mock.call_count == 1
assert mock.call_args.args[0] == url
assert list(mock.call_args.kwargs["storage_options"].keys()) == [protocol]


@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file):
input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path}
Expand Down
46 changes: 46 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
assert_arrow_memory_doesnt_increase,
assert_arrow_memory_increases,
offline,
require_moto,
require_not_windows,
require_pil,
require_sndfile,
Expand Down Expand Up @@ -1686,6 +1687,51 @@ def test_load_from_disk_with_default_in_memory(
_ = load_from_disk(dataset_path)


@pytest.fixture
def moto_server(monkeypatch):
from moto.server import ThreadedMotoServer

monkeypatch.setattr(
"os.environ",
{
"AWS_ENDPOINT_URL": "http://localhost:5000",
"AWS_DEFAULT_REGION": "us-east-1",
"AWS_ACCESS_KEY_ID": "FOO",
"AWS_SECRET_ACCESS_KEY": "BAR",
},
)
server = ThreadedMotoServer()
server.start()
try:
yield
finally:
server.stop()


@require_moto
def test_load_file_from_s3(moto_server):
# we need server mode here because of an aiobotocore incompatibility with moto.mock_aws
# (/~https://github.com/getmoto/moto/issues/6836)
import boto3

# Create a mock S3 bucket
bucket_name = "test-bucket"
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket=bucket_name)

# Upload a file to the mock bucket
key = "test-file.csv"
csv_data = "Island\nIsabela\nBaltra"

s3.put_object(Bucket=bucket_name, Key=key, Body=csv_data)

# Load the file from the mock bucket
ds = datasets.load_dataset("csv", data_files={"train": "s3://test-bucket/test-file.csv"})

# Check if the loaded content matches the original content
assert list(ds["train"]) == [{"Island": "Isabela"}, {"Island": "Baltra"}]


@pytest.mark.integration
def test_remote_data_files():
repo_id = "hf-internal-testing/raw_jsonl"
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def parse_flag_from_env(key, default=False):


require_faiss = pytest.mark.skipif(find_spec("faiss") is None or sys.platform == "win32", reason="test requires faiss")

require_moto = pytest.mark.skipif(find_spec("moto") is None, reason="test requires moto")
require_numpy1_on_windows = pytest.mark.skipif(
version.parse(importlib.metadata.version("numpy")) >= version.parse("2.0.0") and sys.platform == "win32",
reason="test requires numpy < 2.0 on windows",
Expand Down

0 comments on commit 243dfd0

Please sign in to comment.