Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multi-proc in to_json #2747

Merged
merged 19 commits into from
Sep 13, 2021
Merged

add multi-proc in to_json #2747

merged 19 commits into from
Sep 13, 2021

Conversation

bhavitvyamalik
Copy link
Contributor

@bhavitvyamalik bhavitvyamalik commented Aug 3, 2021

Closes #2663. I've tried adding multiprocessing in to_json. Here's some benchmarking I did to compare the timings of current version (say v1) and multi-proc version (say v2). I did this with cpu_count 4 (2015 Macbook Air)

  1. Dataset name: ascent_kb - 8.9M samples (all samples were used, reporting this for a single run)
    v1- ~225 seconds for converting whole dataset to json
    v2- ~200 seconds for converting whole dataset to json

  2. Dataset name: lama - 1.3M samples (all samples were used, reporting this for 2 runs)
    v1- ~26 seconds for converting whole dataset to json
    v2- ~23.6 seconds for converting whole dataset to json

I think it's safe to say that v2 is 10% faster as compared to v1. Timings may improve further with better configuration.

The only bottleneck I feel is writing to file from the output list. If we can improve that aspect then timings may improve further.

Let me know if any changes/improvements can be done in this @stas00, @lhoestq, @albertvillanova. @lhoestq even suggested to extend this work with other export methods as well like csv or parquet.

@stas00
Copy link
Contributor

stas00 commented Aug 3, 2021

Thank you for working on this, @bhavitvyamalik

10% is not solving the issue, we want 5-10x faster on a machine that has lots of resources, but limited processing time.

So let's benchmark it on an instance with many more cores, I can test with 12 on my dev box and 40 on JZ.

Could you please share the test I could run with both versions?

Should we also test the sharded version I shared in #2663 (comment) so optionally 3 versions to test.

@bhavitvyamalik
Copy link
Contributor Author

bhavitvyamalik commented Aug 4, 2021

Since I was facing OSError: [Errno 12] Cannot allocate memory in CircleCI tests, I've added num_proc option instead of always using full cpu_count. You can test both v1 and v2 through this branch (some redundancy needs to be removed).

Update: I was able to convert into json which took 50% less time as compared to v1 on ascent_kb dataset. Will post the benchmarking script with results here.

@bhavitvyamalik
Copy link
Contributor Author

bhavitvyamalik commented Aug 5, 2021

Here are the benchmarks with the current branch for both v1 and v2 (dataset: ascent_kb, 8.9M samples):

batch_size time (in sec) time (in sec)
num_proc = 1 num_proc = 4
10k 185.56 170.11
50k 175.79 86.84
100k 191.09 78.35
125k 198.28 90.89

Increasing the batch size on my machine helped in making v2 around 50% faster as compared to v1. Timings may vary depending on the machine. I'm including the benchmarking script as well. CircleCI errors are unrelated (something related to bertscore)

import time
from datasets import load_dataset
import pathlib
import os
from pathlib import Path
import shutil
import gc

batch_sizes = [10_000, 50_000, 100_000, 125_000]
num_procs = [1, 4]  # change this according to your machine

SAVE_LOC = "./new_dataset.json"

for batch in batch_sizes:
    for num in num_procs:
        dataset = load_dataset("ascent_kb")

        local_start = time.time()
        ans = dataset['train'].to_json(SAVE_LOC, batch_size=batch, num_proc=num)
        local_end = time.time() - local_start

        print(f"Time taken on {num} num_proc and {batch} batch_size: ", local_end)

        # remove that dataset and its contents from cache and newly generated json
        new_json = pathlib.Path(SAVE_LOC)
        new_json.unlink()

        try:
            shutil.rmtree(os.path.join(str(Path.home()), ".cache", "huggingface"))
        except OSError as e:
            print("Error: %s - %s." % (e.filename, e.strerror))

        gc.collect()

This will download the dataset in every iteration and run to_json. I didn't do multiple iterations here for to_json (for a specific batch_size and num_proc) and took average time as I found that v1 got faster after 1st iteration (maybe it's caching somewhere). Since you'll be doing this operation only once, I thought it'll be better to report how both v1 and v2 performed in single iteration only.

Important: Benchmarking script will delete the newly generated json and ~/.cache/huggingface/ after every iteration so that it doesn't end up using any cached data (just to be on a safe side)

@stas00
Copy link
Contributor

stas00 commented Aug 5, 2021

Thank you for sharing the benchmark, @bhavitvyamalik. Your results look promising.

But if I remember correctly the sharded version at #2663 (comment) was much faster. So we probably should compare to it as well? And if it's faster than at least document that manual sharding version?


That's a dangerous benchmark as it'd wipe out many other HF things. Why not wipe out:

~/.cache/huggingface/datasets/ascent_kb/

Running the benchmark now.

@stas00
Copy link
Contributor

stas00 commented Aug 5, 2021

Weird, I tried to adapt your benchmark to using shards and the program no longer works. It instead quickly uses up all available RAM and hangs. Has something changed recently in datasets? You can try:

import time
from datasets import load_dataset
import pathlib
import os
from pathlib import Path
import shutil
import gc
from multiprocessing import cpu_count, Process, Queue

batch_sizes = [10_000, 50_000, 100_000, 125_000]
num_procs = [1, 8]  # change this according to your machine

DATASET_NAME = ("ascent_kb")
num_shards = [1, 8]
for batch in batch_sizes:
    for shards in num_shards:
        dataset = load_dataset(DATASET_NAME)["train"]
        #print(dataset)

        def process_shard(idx):
            print(f"Sharding {idx}")
            ds_shard = dataset.shard(shards, idx, contiguous=True)
            # ds_shard = ds_shard.shuffle() # remove contiguous=True above if shuffling
            print(f"Saving {DATASET_NAME}-{idx}.jsonl")
            ds_shard.to_json(f"{DATASET_NAME}-{idx}.jsonl", orient="records", lines=True, force_ascii=False)

        local_start = time.time()
        queue = Queue()
        processes = [Process(target=process_shard, args=(idx,)) for idx in range(shards)]
        for p in processes:
            p.start()

        for p in processes:
            p.join()
        local_end = time.time() - local_start

        print(f"Time taken on {shards} shards and {batch} batch_size: ", local_end)

Just careful, so that it won't crash your compute environment. As it almost crashed mine.

@stas00
Copy link
Contributor

stas00 commented Aug 5, 2021

So this part seems to no longer work:

        dataset = load_dataset("ascent_kb")["train"]
        ds_shard = dataset.shard(1, 0, contiguous=True)
        ds_shard.to_json("ascent_kb-0.jsonl", orient="records", lines=True, force_ascii=False)

@bhavitvyamalik
Copy link
Contributor Author

If you are using to_json without any num_procor num_proc=1 then essentially it'll fall back to v1 only and I've kept it as it is (the tests were passing as well)

That's a dangerous benchmark as it'd wipe out many other HF things. Why not wipe out:

That's because some dataset related files were still left inside ~/.cache/huggingface/datasets folder. You can wipe off datasets folder inside your cache maybe

dataset = load_dataset("ascent_kb")["train"]
ds_shard = dataset.shard(1, 0, contiguous=True)
ds_shard.to_json("ascent_kb-0.jsonl", orient="records", lines=True, force_ascii=False)

I tried this lama dataset (1.3M) and it worked fine. Trying it with ascent_kb currently, will update it here.

@stas00
Copy link
Contributor

stas00 commented Aug 5, 2021

I don't think the issue has anything to do with your work, @bhavitvyamalik. I forgot to mention I tested to see the same problem with the latest datasets release.

Interesting, I tried your suggestion. This:

python -c 'import datasets; ds="lama"; dataset = datasets.load_dataset(ds)["train"]; \
dataset.shard(1, 0, contiguous=True).to_json(f"{ds}-0.jsonl", orient="records", lines=True, force_ascii=False)'

works fine and takes just a few GBs to complete.

this on the other hand blows up memory-wise:

python -c 'import datasets; ds="ascent_kb"; dataset = datasets.load_dataset(ds)["train"]; \
dataset.shard(1, 0, contiguous=True).to_json(f"{ds}-0.jsonl", orient="records", lines=True, force_ascii=False)'

and I have to kill it before it uses up all RAM. (I have 128GB of it, so it should be more than enough)

@stas00
Copy link
Contributor

stas00 commented Aug 5, 2021

That's because some dataset related files were still left inside ~/.cache/huggingface/datasets folder. You can wipe off datasets folder inside your cache maybe

I think recent datasets added a method that will print out the path for all the different components for a given dataset, I can't recall the name though. It was when we were discussing a janitor program to clear up space selectively.

@bhavitvyamalik
Copy link
Contributor Author

and I have to kill it before it uses up all RAM. (I have 128GB of it, so it should be more than enough)

Same thing just happened on my machine too. Memory leak somewhere maybe? Even if you were to load this dataset in your memory it shouldn't take more than 4GB. You were earlier doing this for oscar dataset. Is it working fine for that?

@stas00
Copy link
Contributor

stas00 commented Aug 6, 2021

Hmm, looks like datasets has changed and won't accept my currently cached oscar-en (crashes), so I'd rather not download 0.5TB again.

Were you able to reproduce the memory blow up with ascent_kb? It's should be a much quicker task to verify.

But yes, oscar worked just fine with .shard() which is what I used to process it fast.

@stas00
Copy link
Contributor

stas00 commented Aug 6, 2021

What I tried is:

HF_DATASETS_OFFLINE=1 HF_DATASETS_CACHE=cache python -c 'import datasets; ds="oscar"; \
dataset = datasets.load_dataset(ds, "unshuffled_deduplicated_en")["train"]; \
dataset.shard(1000000, 0, contiguous=True).to_json(f"{ds}-0.jsonl", orient="records", lines=True, force_ascii=False)'

and got:

Using the latest cached version of the module from /gpfswork/rech/six/commun/modules/datasets_modules/datasets/oscar/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d (last modified on Fri Aug  6 01:52:35 2021) since it couldn't be found locally at oscar/oscar.py or remotely (OfflineModeIsEnabled).
Reusing dataset oscar (cache/oscar/unshuffled_deduplicated_en/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d)
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/gpfswork/rech/six/commun/conda/stas/lib/python3.8/site-packages/datasets/load.py", line 755, in load_dataset
    ds = builder_instance.as_dataset(split=split, ignore_verifications=ignore_verifications, in_memory=keep_in_memory)
  File "/gpfswork/rech/six/commun/conda/stas/lib/python3.8/site-packages/datasets/builder.py", line 737, in as_dataset
    datasets = utils.map_nested(
  File "/gpfswork/rech/six/commun/conda/stas/lib/python3.8/site-packages/datasets/utils/py_utils.py", line 203, in map_nested
    mapped = [
  File "/gpfswork/rech/six/commun/conda/stas/lib/python3.8/site-packages/datasets/utils/py_utils.py", line 204, in <listcomp>
    _single_map_nested((function, obj, types, None, True)) for obj in tqdm(iterable, disable=disable_tqdm)
  File "/gpfswork/rech/six/commun/conda/stas/lib/python3.8/site-packages/datasets/utils/py_utils.py", line 142, in _single_map_nested
    return function(data_struct)
  File "/gpfswork/rech/six/commun/conda/stas/lib/python3.8/site-packages/datasets/builder.py", line 764, in _build_single_dataset
    ds = self._as_dataset(
  File "/gpfswork/rech/six/commun/conda/stas/lib/python3.8/site-packages/datasets/builder.py", line 834, in _as_dataset
    dataset_kwargs = ArrowReader(self._cache_dir, self.info).read(
  File "/gpfswork/rech/six/commun/conda/stas/lib/python3.8/site-packages/datasets/arrow_reader.py", line 217, in read
    return self.read_files(files=files, original_instructions=instructions, in_memory=in_memory)
  File "/gpfswork/rech/six/commun/conda/stas/lib/python3.8/site-packages/datasets/arrow_reader.py", line 238, in read_files
    pa_table = self._read_files(files, in_memory=in_memory)
  File "/gpfswork/rech/six/commun/conda/stas/lib/python3.8/site-packages/datasets/arrow_reader.py", line 173, in _read_files
    pa_table: Table = self._get_table_from_filename(f_dict, in_memory=in_memory)
  File "/gpfswork/rech/six/commun/conda/stas/lib/python3.8/site-packages/datasets/arrow_reader.py", line 308, in _get_table_from_filename
    table = ArrowReader.read_table(filename, in_memory=in_memory)
  File "/gpfswork/rech/six/commun/conda/stas/lib/python3.8/site-packages/datasets/arrow_reader.py", line 327, in read_table
    return table_cls.from_file(filename)
  File "/gpfswork/rech/six/commun/conda/stas/lib/python3.8/site-packages/datasets/table.py", line 450, in from_file
    table = _memory_mapped_arrow_table_from_file(filename)
  File "/gpfswork/rech/six/commun/conda/stas/lib/python3.8/site-packages/datasets/table.py", line 43, in _memory_mapped_arrow_table_from_file
    memory_mapped_stream = pa.memory_map(filename)
  File "pyarrow/io.pxi", line 782, in pyarrow.lib.memory_map
  File "pyarrow/io.pxi", line 743, in pyarrow.lib.MemoryMappedFile._open
  File "pyarrow/error.pxi", line 122, in pyarrow.lib.pyarrow_internal_check_status
  File "pyarrow/error.pxi", line 99, in pyarrow.lib.check_status
OSError: Memory mapping file failed: Cannot allocate memory

@bhavitvyamalik
Copy link
Contributor Author

Were you able to reproduce the memory blow up with ascent_kb? It's should be a much quicker task to verify.

Yes, this blows up memory-wise on my machine too.

I found that a similar error was posted on the forum on 5th March. Since you already knew how much time #2663 comment took, can you try benchmarking v1 and v2 for now maybe until we have a fix for this memory blow up?

@stas00
Copy link
Contributor

stas00 commented Aug 7, 2021

OK, so I benchmarked using "lama" though it's too small for this kind of test, since the sharding is much slower than one thread here.

Results: https://gist.github.com/stas00/dc1597a1e245c5915cfeefa0eee6902c

So sharding does really bad there, and your json over procs is doing great!

Any suggestions to a somewhat bigger dataset, but not too big? say 10 times of lama?

@bhavitvyamalik
Copy link
Contributor Author

bhavitvyamalik commented Aug 7, 2021

Looks great! I had a few questions/suggestions related to benchmark-datasets-to_json.py:

  1. You have used only 10_000 and 100_000 batch size. Including more batch sizes may help you find the perfect batch size for your machine and even give you some extra speed-up.
    For eg, I found load_dataset("cc100", lang="eu") with batch size 125_000 took less time as compared to batch size 100_000 (71.16 sec v/s 67.26 sec) since this dataset has 2 fields only ['id', 'text'], so that's why we can go for higher batch size here.

  2. Why have you used num_procs 1 and 4 only?

You can use:

  1. dataset = load_dataset("cc100", lang="af"). Even though it has only 2 fields but there are around 9.9 mil samples. (lama had around 1.3 mil samples)
  2. dataset = load_dataset("cc100", lang="eu") -> 16 mil samples. (if you want something more than 9.9 mil)
  3. dataset = load_dataset("neural_code_search", 'search_corpus') -> 4.7 mil samples

@stas00
Copy link
Contributor

stas00 commented Aug 10, 2021

Thank you, @bhavitvyamalik

My apologies, at the moment I have not found time to do more benchmark with the proposed other datasets. I will try to do it later, but I don't want it to hold your PR, it's definitely a great improvement based on the benchmarks I did run! And the comparison to sharded is really just of interest to me to see if it's on par or slower.

So if other reviewers are happy, this definitely looks like a great improvement to me and addresses the request I made in the first place.

Why have you used num_procs 1 and 4 only?

Oh, no particular reason, I was just comparing to 4 shards on my desktop. Typically it's sufficient to go from 1 to 2-4 to see whether the distributed approach is faster or not. Once hit larger numbers you often run into bottlenecks like IO, and then numbers can be less representative. I hope it makes sense.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool thanks !

I just have a few comments, especially one regarding memory when num_proc>1

src/datasets/io/json.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Show resolved Hide resolved
src/datasets/io/json.py Outdated Show resolved Hide resolved
@bhavitvyamalik
Copy link
Contributor Author

Tested it with a larger dataset (srwac) and memory utilisation remained constant with no swap memory used. @lhoestq should I also add test for the same? Last time I tried this, I got OSError: [Errno 12] Cannot allocate memory in CircleCI tests

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice @bhavitvyamalik this is awesome :)

Indeed we need to have a test for this.
Maybe you can avoid OSError: [Errno 12] Cannot allocate memory by using a smaller batch_size (maybe <1000) and a small number of processes (maybe 2) ?

src/datasets/io/json.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few nits.

src/datasets/io/json.py Show resolved Hide resolved
src/datasets/io/json.py Outdated Show resolved Hide resolved
src/datasets/io/json.py Outdated Show resolved Hide resolved
lhoestq and others added 2 commits September 10, 2021 14:49
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this feature @bhavitvyamalik :)

It looks all good !

@lhoestq lhoestq merged commit 4484e41 into huggingface:master Sep 13, 2021
@bhavitvyamalik bhavitvyamalik deleted the to_json branch October 19, 2021 18:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[to_json] add multi-proc sharding support
5 participants