Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch method to Dataset class #7064

Merged
merged 8 commits into from
Jul 25, 2024

Conversation

lappemic
Copy link
Contributor

This PR introduces a new batch method to the Dataset class, aligning its functionality with the IterableDataset.batch() method (implemented in #7054). The implementation uses as well the existing map method for efficient batching of examples.

Key changes:

  • Add batch method to Dataset class in arrow_dataset.py
  • Utilize map method for batching

Closes #7063

Once the approach is approved, i will create the tests and update the documentation.

@lhoestq
Copy link
Member

lhoestq commented Jul 23, 2024

Looks good to me ! :)

you might want to add the map num_proc argument as well, for people who want to make it run faster

@lappemic lappemic marked this pull request as ready for review July 24, 2024 06:14
@lappemic
Copy link
Contributor Author

Thanks for the feedback @lhoestq! The last commits include:

  • Adding the num_proc parameter to batch
  • Adding tests similar to the one done for IterableDataset.batch()
  • Updated the documentation -> I think they are actually misplaced in the Stream page. But could not find a better place atm. Where would you put this documentation?

WDYT?

@lhoestq
Copy link
Member

lhoestq commented Jul 24, 2024

You can put the documentation in process.mdx :)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@lappemic lappemic force-pushed the 7063-add-batch-method-to-Dataset branch from af3d739 to 7b02d5f Compare July 25, 2024 09:23
@lappemic
Copy link
Contributor Author

I reset the head to the commit before I added the Dataset.batch() documentation to stream.mdx and instead added the documentation to process.mdx.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks ! the CI failures are unrelated to your PR

docs/source/process.mdx Outdated Show resolved Hide resolved
docs/source/process.mdx Outdated Show resolved Hide resolved
docs/source/process.mdx Outdated Show resolved Hide resolved
@lhoestq lhoestq merged commit 9c98e06 into huggingface:main Jul 25, 2024
2 of 14 checks passed
@lappemic lappemic deleted the 7063-add-batch-method-to-Dataset branch July 25, 2024 13:47
Copy link

Show benchmarks

PyArrow==8.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.005736 / 0.011353 (-0.005617) 0.003959 / 0.011008 (-0.007049) 0.063259 / 0.038508 (0.024751) 0.030705 / 0.023109 (0.007596) 0.245706 / 0.275898 (-0.030192) 0.278766 / 0.323480 (-0.044714) 0.003354 / 0.007986 (-0.004632) 0.004246 / 0.004328 (-0.000082) 0.049346 / 0.004250 (0.045095) 0.046439 / 0.037052 (0.009386) 0.257930 / 0.258489 (-0.000559) 0.295562 / 0.293841 (0.001722) 0.030529 / 0.128546 (-0.098017) 0.012465 / 0.075646 (-0.063182) 0.205595 / 0.419271 (-0.213677) 0.036319 / 0.043533 (-0.007214) 0.243872 / 0.255139 (-0.011267) 0.275834 / 0.283200 (-0.007366) 0.020330 / 0.141683 (-0.121353) 1.108337 / 1.452155 (-0.343817) 1.150406 / 1.492716 (-0.342310)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.113498 / 0.018006 (0.095491) 0.306654 / 0.000490 (0.306164) 0.000238 / 0.000200 (0.000038) 0.000043 / 0.000054 (-0.000012)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.019092 / 0.037411 (-0.018319) 0.063180 / 0.014526 (0.048654) 0.078244 / 0.176557 (-0.098313) 0.126106 / 0.737135 (-0.611030) 0.078651 / 0.296338 (-0.217687)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.284132 / 0.215209 (0.068923) 2.781250 / 2.077655 (0.703595) 1.471864 / 1.504120 (-0.032256) 1.354661 / 1.541195 (-0.186534) 1.362839 / 1.468490 (-0.105651) 0.719126 / 4.584777 (-3.865651) 2.396969 / 3.745712 (-1.348743) 2.987924 / 5.269862 (-2.281938) 1.910555 / 4.565676 (-2.655121) 0.078612 / 0.424275 (-0.345663) 0.005170 / 0.007607 (-0.002437) 0.333876 / 0.226044 (0.107832) 3.298340 / 2.268929 (1.029412) 1.853332 / 55.444624 (-53.591292) 1.551919 / 6.876477 (-5.324557) 1.585677 / 2.142072 (-0.556395) 0.802487 / 4.805227 (-4.002741) 0.134828 / 6.500664 (-6.365837) 0.041966 / 0.075469 (-0.033503)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 0.992277 / 1.841788 (-0.849511) 11.626887 / 8.074308 (3.552578) 9.715623 / 10.191392 (-0.475769) 0.140306 / 0.680424 (-0.540117) 0.014528 / 0.534201 (-0.519673) 0.306247 / 0.579283 (-0.273036) 0.263067 / 0.434364 (-0.171297) 0.342325 / 0.540337 (-0.198013) 0.432299 / 1.386936 (-0.954637)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.006004 / 0.011353 (-0.005349) 0.003890 / 0.011008 (-0.007118) 0.050408 / 0.038508 (0.011900) 0.031880 / 0.023109 (0.008771) 0.273114 / 0.275898 (-0.002784) 0.296653 / 0.323480 (-0.026826) 0.004569 / 0.007986 (-0.003416) 0.002831 / 0.004328 (-0.001497) 0.050032 / 0.004250 (0.045782) 0.040468 / 0.037052 (0.003415) 0.284718 / 0.258489 (0.026229) 0.321754 / 0.293841 (0.027913) 0.033863 / 0.128546 (-0.094684) 0.012183 / 0.075646 (-0.063463) 0.060805 / 0.419271 (-0.358466) 0.034919 / 0.043533 (-0.008614) 0.274354 / 0.255139 (0.019215) 0.293477 / 0.283200 (0.010277) 0.019418 / 0.141683 (-0.122265) 1.151571 / 1.452155 (-0.300584) 1.217174 / 1.492716 (-0.275542)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.097326 / 0.018006 (0.079320) 0.316277 / 0.000490 (0.315787) 0.000225 / 0.000200 (0.000025) 0.000045 / 0.000054 (-0.000009)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.022932 / 0.037411 (-0.014479) 0.077455 / 0.014526 (0.062929) 0.088949 / 0.176557 (-0.087608) 0.129447 / 0.737135 (-0.607688) 0.093705 / 0.296338 (-0.202634)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.303918 / 0.215209 (0.088709) 2.973866 / 2.077655 (0.896211) 1.593165 / 1.504120 (0.089045) 1.465312 / 1.541195 (-0.075883) 1.484503 / 1.468490 (0.016013) 0.731849 / 4.584777 (-3.852928) 0.953337 / 3.745712 (-2.792375) 2.887815 / 5.269862 (-2.382047) 1.923618 / 4.565676 (-2.642058) 0.080073 / 0.424275 (-0.344202) 0.005460 / 0.007607 (-0.002148) 0.359876 / 0.226044 (0.133832) 3.532251 / 2.268929 (1.263323) 1.987778 / 55.444624 (-53.456846) 1.685572 / 6.876477 (-5.190905) 1.827141 / 2.142072 (-0.314932) 0.815953 / 4.805227 (-3.989274) 0.136698 / 6.500664 (-6.363967) 0.042185 / 0.075469 (-0.033285)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.032508 / 1.841788 (-0.809280) 12.526918 / 8.074308 (4.452610) 10.202942 / 10.191392 (0.011550) 0.145920 / 0.680424 (-0.534504) 0.015643 / 0.534201 (-0.518558) 0.300465 / 0.579283 (-0.278818) 0.126786 / 0.434364 (-0.307578) 0.342885 / 0.540337 (-0.197453) 0.438139 / 1.386936 (-0.948797)

albertvillanova pushed a commit that referenced this pull request Aug 13, 2024
* feat: add `batch` method to `Dataset` class

* feat: add `num_proc` arg from `map` to `batch`

* test: add test for `Dataset.batch()

* style: formatting...

* docs: move `Dataset.batch()`documentation to `process.mdx`

* docs: add `numb_proc` to docs

* Apply suggestions from code review

---------

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
albertvillanova pushed a commit that referenced this pull request Aug 13, 2024
* feat: add `batch` method to `Dataset` class

* feat: add `num_proc` arg from `map` to `batch`

* test: add test for `Dataset.batch()

* style: formatting...

* docs: move `Dataset.batch()`documentation to `process.mdx`

* docs: add `numb_proc` to docs

* Apply suggestions from code review

---------

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
albertvillanova pushed a commit that referenced this pull request Aug 14, 2024
* feat: add `batch` method to `Dataset` class

* feat: add `num_proc` arg from `map` to `batch`

* test: add test for `Dataset.batch()

* style: formatting...

* docs: move `Dataset.batch()`documentation to `process.mdx`

* docs: add `numb_proc` to docs

* Apply suggestions from code review

---------

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add batch method to Dataset
3 participants