-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[speedup] Use indices mappings instead of deepcopy for all the samples reordering methods #513
Conversation
writer_batch_size=writer_batch_size, | ||
verbose=verbose, | ||
) | ||
|
||
def export( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this method without modification to keep all the samples re-ordering/selection methods (select
, sort
, shuffle
, shard
, train_test_split
) in the same part of the file. Sorry for that.
@@ -1419,8 +1482,8 @@ def train_test_split( | |||
generator: Optional[np.random.Generator] = None, | |||
keep_in_memory: bool = False, | |||
load_from_cache_file: bool = True, | |||
train_cache_file_name: Optional[str] = None, | |||
test_cache_file_name: Optional[str] = None, | |||
train_indices_cache_file_name: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit long but I think it's important that the user does not mistake this cache for the dataset table cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool !
Things will be so much faster :)
a few comments:
@@ -998,7 +1090,7 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F | |||
|
|||
def filter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we use the indices mapping for filter too ?
cache_file_name: Optional[str] = None, | ||
writer_batch_size: Optional[int] = 1000, | ||
reader_batch_size: Optional[int] = 1000, | ||
features: Optional[Features] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why we have features here ?
I agree that it can be done for free but I didn't expect to see that here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Free lunch :-)
@@ -1549,28 +1612,34 @@ def train_test_split( | |||
"seed": seed, | |||
"keep_in_memory": keep_in_memory, | |||
"load_from_cache_file": load_from_cache_file, | |||
"train_cache_file_name": train_cache_file_name, | |||
"test_cache_file_name": test_cache_file_name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As in shuffle
you probably need to add "length": len(self)
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about that because I feel like it's handled by the hashes on the indices and data files.
I add some tests on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you're right, it's taken into account in previous_files_string
@@ -1589,16 +1658,14 @@ def train_test_split( | |||
train_split = self.select( | |||
indices=train_indices, | |||
keep_in_memory=keep_in_memory, | |||
load_from_cache_file=load_from_cache_file, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe keep load_from_cache_file
here as well
@@ -1611,8 +1678,7 @@ def shard( | |||
index: int, | |||
contiguous: bool = False, | |||
keep_in_memory: bool = False, | |||
load_from_cache_file: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can keep load_from_cache_file
here and also pass it to .select()
Ok I fixed |
Ok, adding some benchmarks for map/filters and then I'll merge |
Warning from pytorch that we should maybe consider at some point @lhoestq:
|
@@ -717,7 +717,7 @@ def _map_indices(self, indices: Union[int, slice, pa.Array, Iterable]): | |||
|
|||
# We can do a slice | |||
if array_indices is None: | |||
return self._indices.column(0).slice(array_indices[0], array_indices[1] - array_indices[0]) | |||
return self._indices.column(0).slice(slice_indices[0], slice_indices[1] - slice_indices[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch !
Not sure why we have that, it's probably linked to zero copy from arrow to numpy |
Use an indices mapping instead of rewriting the dataset for all the samples re-ordering/selection methods (
select
,sort
,shuffle
,shard
,train_test_split
).Added a
flatten_indices
method which copy the dataset to a new table to remove the indices mapping with tests.All the samples re-ordering/selection methods should be a lot faster. The downside is that iterating on very large batch of the dataset might be a little slower when we have changed the order of the samples since with in these case we use
pyarrow.Table.take
instead ofpyarrow.Table.slice
. There is no free lunch but the speed of iterating over the dataset is rarely the bottleneck.Backward breaking change: the
cache_file_name
argument in all the samples re-ordering/selection methods (select
,sort
,shuffle
,shard
,train_test_split
) is now calledindices_cache_file_name
on purpose to make it explicit to the user that this caching file is used for caching the indices mapping and not the dataset itself.