Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix wrong type hints in data_files #6910

Merged
merged 7 commits into from
May 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from .utils.py_utils import glob_pattern_to_regex, string_to_dict


SingleOriginMetadata = Union[Tuple[str, str], Tuple[str], Tuple[()]]


SANITIZED_DEFAULT_SPLIT = str(Split.TRAIN)


Expand Down Expand Up @@ -361,6 +364,7 @@ def resolve_pattern(
base_path (str): Base path to use when resolving relative paths.
allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions).
For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"]
download_config ([`DownloadConfig`], *optional*): Specific download configuration parameters.
Returns:
List[str]: List of paths or URLs to the local or remote files that match the patterns.
"""
Expand Down Expand Up @@ -516,17 +520,17 @@ def get_metadata_patterns(
def _get_single_origin_metadata(
data_file: str,
download_config: Optional[DownloadConfig] = None,
) -> Tuple[str]:
) -> SingleOriginMetadata:
data_file, storage_options = _prepare_path_and_storage_options(data_file, download_config=download_config)
fs, *_ = url_to_fs(data_file, **storage_options)
if isinstance(fs, HfFileSystem):
resolved_path = fs.resolve_path(data_file)
return (resolved_path.repo_id, resolved_path.revision)
return resolved_path.repo_id, resolved_path.revision
elif isinstance(fs, HTTPFileSystem) and data_file.startswith(config.HF_ENDPOINT):
hffs = HfFileSystem(endpoint=config.HF_ENDPOINT, token=download_config.token)
data_file = "hf://" + data_file[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1)
resolved_path = hffs.resolve_path(data_file)
return (resolved_path.repo_id, resolved_path.revision)
return resolved_path.repo_id, resolved_path.revision
info = fs.info(data_file)
# s3fs uses "ETag", gcsfs uses "etag", and for local we simply check mtime
for key in ["ETag", "etag", "mtime"]:
Expand All @@ -539,7 +543,7 @@ def _get_origin_metadata(
data_files: List[str],
download_config: Optional[DownloadConfig] = None,
max_workers: Optional[int] = None,
) -> Tuple[str]:
) -> List[SingleOriginMetadata]:
max_workers = max_workers if max_workers is not None else config.HF_DATASETS_MULTITHREADING_MAX_WORKERS
return thread_map(
partial(_get_single_origin_metadata, download_config=download_config),
Expand All @@ -555,11 +559,11 @@ def _get_origin_metadata(
class DataFilesList(List[str]):
"""
List of data files (absolute local paths or URLs).
It has two construction methods given the user's data files patterns :
It has two construction methods given the user's data files patterns:
- ``from_hf_repo``: resolve patterns inside a dataset repository
- ``from_local_or_remote``: resolve patterns from a local path

Moreover DataFilesList has an additional attribute ``origin_metadata``.
Moreover, DataFilesList has an additional attribute ``origin_metadata``.
It can store:
- the last modified time of local files
- ETag of remote files
Expand All @@ -570,11 +574,11 @@ class DataFilesList(List[str]):
This is useful for caching Dataset objects that are obtained from a list of data files.
"""

def __init__(self, data_files: List[str], origin_metadata: List[Tuple[str]]):
def __init__(self, data_files: List[str], origin_metadata: List[SingleOriginMetadata]) -> None:
super().__init__(data_files)
self.origin_metadata = origin_metadata

def __add__(self, other):
def __add__(self, other: "DataFilesList") -> "DataFilesList":
return DataFilesList([*self, *other], self.origin_metadata + other.origin_metadata)

@classmethod
Expand Down Expand Up @@ -646,9 +650,9 @@ class DataFilesDict(Dict[str, DataFilesList]):
- ``from_hf_repo``: resolve patterns inside a dataset repository
- ``from_local_or_remote``: resolve patterns from a local path

Moreover each list is a DataFilesList. It is possible to hash the dictionary
Moreover, each list is a DataFilesList. It is possible to hash the dictionary
and get a different hash if and only if at least one file changed.
For more info, see ``DataFilesList``.
For more info, see [`DataFilesList`].

This is useful for caching Dataset objects that are obtained from a list of data files.

Expand All @@ -666,14 +670,14 @@ def from_local_or_remote(
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
DataFilesList.from_local_or_remote(
patterns_for_key
if isinstance(patterns_for_key, DataFilesList)
else DataFilesList.from_local_or_remote(
patterns_for_key,
base_path=base_path,
allowed_extensions=allowed_extensions,
download_config=download_config,
)
if not isinstance(patterns_for_key, DataFilesList)
else patterns_for_key
)
return out

Expand All @@ -689,15 +693,15 @@ def from_hf_repo(
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
DataFilesList.from_hf_repo(
patterns_for_key
if isinstance(patterns_for_key, DataFilesList)
else DataFilesList.from_hf_repo(
patterns_for_key,
dataset_info=dataset_info,
base_path=base_path,
allowed_extensions=allowed_extensions,
download_config=download_config,
)
if not isinstance(patterns_for_key, DataFilesList)
else patterns_for_key
)
return out

Expand All @@ -712,14 +716,14 @@ def from_patterns(
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
DataFilesList.from_patterns(
patterns_for_key
if isinstance(patterns_for_key, DataFilesList)
else DataFilesList.from_patterns(
patterns_for_key,
base_path=base_path,
allowed_extensions=allowed_extensions,
download_config=download_config,
)
if not isinstance(patterns_for_key, DataFilesList)
else patterns_for_key
)
return out

Expand Down Expand Up @@ -751,7 +755,7 @@ def __add__(self, other):
@classmethod
def from_patterns(
cls, patterns: List[str], allowed_extensions: Optional[List[str]] = None
) -> "DataFilesPatternsDict":
) -> "DataFilesPatternsList":
return cls(patterns, [allowed_extensions] * len(patterns))

def resolve(
Expand All @@ -777,7 +781,7 @@ def resolve(
origin_metadata = _get_origin_metadata(data_files, download_config=download_config)
return DataFilesList(data_files, origin_metadata)

def filter_extensions(self, extensions: List[str]) -> "DataFilesList":
def filter_extensions(self, extensions: List[str]) -> "DataFilesPatternsList":
return DataFilesPatternsList(
self, [allowed_extensions + extensions for allowed_extensions in self.allowed_extensions]
)
Expand All @@ -795,12 +799,12 @@ def from_patterns(
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
DataFilesPatternsList.from_patterns(
patterns_for_key
if isinstance(patterns_for_key, DataFilesPatternsList)
else DataFilesPatternsList.from_patterns(
patterns_for_key,
allowed_extensions=allowed_extensions,
)
if not isinstance(patterns_for_key, DataFilesPatternsList)
else patterns_for_key
)
return out

Expand Down
Loading