Skip to content
This repository has been archived by the owner on Jun 11, 2024. It is now read-only.

Don't pass the massive xr.Dataset between processes! #22

Open
JackKelly opened this issue May 11, 2023 · 8 comments
Open

Don't pass the massive xr.Dataset between processes! #22

JackKelly opened this issue May 11, 2023 · 8 comments
Assignees
Labels
bug Something isn't working

Comments

@JackKelly
Copy link
Member

JackKelly commented May 11, 2023

Describe the bug

The code works at the moment. But it does something that's slow, CPU-intensive, and memory-intensive: it passes massive xr.Datasets between processes.

The main process does this:

    # Run the processes!
    with multiprocessing.Pool() as pool:
        for ds in pool.imap(load_grib_files, tasks):
            append_to_zarr(ds, destination_zarr_path)

Which requires every ds to be pickled and copied from the worker process to the main process through a pipe (which is very slow - multiple seconds - for large objects).

Experiment

This minimal example takes 6 seconds to run, and uses a trivially small amount of RAM and CPU. Each worker process creates a 160 MB array. But, crucially, each worker process doesn't pass that array back to the main process:

from multiprocessing.pool import Pool
from time import sleep
import numpy as np


# task executed in a worker process
def task(identifier: int):
    # generate an 160 MByte array of random values:
    arr = np.random.rand(1000, 1000, 20)
    # report a message
    print(f'Task {identifier} executing.', flush=True)
    # block for a (random) moment
    sleep(arr[0, 0, 0])
 
if __name__ == '__main__':
    # create and configure the process pool
    with Pool() as pool:
        # issue tasks to the process pool
        pool.imap(task, range(50))
        # shutdown the process pool
        pool.close()
        # wait for all issued task to complete
        pool.join()

(this code is adapted from here)

If we just append the line return arr at the end of the task function (so each worker process pickles the array and attempts to send it to the main process) then the script runs for 30 seconds using max CPU, and then consumes all the RAM on my desktop before crashing!

Expected behavior

I think the fix is simple: We just tell each worker process to save the dataset. I'm as sure as I can be that imap will still guarantee that the processes run in order, even if the processes take different amounts of time to complete.
UPDATE: I was wrong! imap runs tasks in arbitrary order, so we can't save to zarr in arbitrary order.

Additional context

The code used to use a "chain of locks"... but that proved unreliable and so the "chain of locks" were replaced with imap in commit 33330bf. Replacing the "chain of locks" with imap was definitely the right thing to do (much simpler code; much more stable!) We just need to make sure we don't pass massive datasets between processes 🙂 .

@JackKelly JackKelly added the bug Something isn't working label May 11, 2023
@JackKelly JackKelly self-assigned this May 11, 2023
@JackKelly
Copy link
Member Author

After chatting with @jacobbieker .... it turns out I was wrong! imap runs tasks in arbitrary order!

Evidence:

from multiprocessing.pool import Pool
from time import sleep

import numpy as np


# task executed in a worker process
def task(identifier: int):
    # generate an 160 MByte array of random values:
    rng = np.random.default_rng(seed=identifier)
    arr = rng.random((1000, 1000, 20))
    sleep_time_secs = arr[0, 0, 0] * 4
    print(f'Task {identifier} sleeping for {sleep_time_secs:.3f} secs...', flush=True)
    sleep(sleep_time_secs)
    print(f'Task {identifier} DONE!', flush=True)
    
 
if __name__ == '__main__':
    # create and configure the process pool
    with Pool() as pool:
        # issue tasks to the process pool
        pool.imap(task, range(50))
        # shutdown the process pool
        pool.close()
        # wait for all issued task to complete
        pool.join()

Produces this output:

(nwp) jack@jack-NUC:~/dev/ocf/nwp/scripts$ time python test_imap.py 
Task 2 sleeping for 1.046 secs...
Task 3 sleeping for 0.343 secs...
Task 1 sleeping for 2.047 secs...
Task 5 sleeping for 3.220 secs...
Task 7 sleeping for 2.500 secs...
Task 0 sleeping for 2.548 secs...
Task 4 sleeping for 3.772 secs...
Task 6 sleeping for 2.153 secs...
Task 3 DONE!
Task 8 sleeping for 1.308 secs...
Task 2 DONE!
Task 9 sleeping for 3.481 secs...
Task 8 DONE!
Task 10 sleeping for 3.824 secs...
Task 1 DONE!
Task 11 sleeping for 0.514 secs...
Task 6 DONE!
Task 12 sleeping for 1.003 secs...
Task 7 DONE!
Task 11 DONE!
...
Task 47 sleeping for 2.967 secs...
Task 48 sleeping for 1.551 secs...
Task 46 sleeping for 3.622 secs...
Task 40 DONE!
Task 49 sleeping for 1.451 secs...
Task 43 DONE!
Task 48 DONE!
Task 42 DONE!
Task 45 DONE!
Task 41 DONE!
Task 49 DONE!
Task 47 DONE!
Task 46 DONE!

real    0m16.123s
user    0m5.739s
sys     0m3.232s

@JackKelly
Copy link
Member Author

Two solutions spring to mind:

  1. Can we write to Zarr in arbitrary order? Maybe Zarr can do this out-of-the-box now? Or maybe we need to "lazily pre-allocate" the entire array first?
  2. Failing that, each worker process could write a netcdf file to disk, and the main process could load that netcdf file and write it to the zarr. Something like this:
    # Run the processes!
    with multiprocessing.Pool() as pool:
        for netcdf_filename in pool.imap(convert_grib_files_to_netcdf, tasks):
            append_netcdf_to_zarr(netcdf_filename, destination_zarr_path)

@JackKelly
Copy link
Member Author

Good: Option 1 (from the comment above) sounds viable. The xarray docs suggest that we can write to the Zarr in arbitrary order and in parallel if we first create the relevant zarr metadata. Some relevant quotes from the xarray docs:

you can use region to write to limited regions of existing arrays in an existing Zarr store. This is a good option for writing data in parallel from independent processes.
To scale this up to writing large datasets, the first step is creating an initial Zarr store without writing all of its array data.
...
Concurrent writes with region are safe as long as they modify distinct chunks in the underlying Zarr arrays (or use an appropriate lock).

@JackKelly
Copy link
Member Author

But, before making this change, I'll run some experiments with the code as is, to get a feel for whether this is even a problem!

@JackKelly
Copy link
Member Author

Converting two NWP init times (using Wholesale1 & Wholesale2) takes 54 seconds on my NUC, and very almost runs out of RAM.

Downcasting the dataset to float16 before passing the dataset from the worker process to the main process speeds it up to 41 seconds. Which does hint that there's considerable overhead to passing the object between workers.

@JackKelly
Copy link
Member Author

Not passing anything back to the main process (and hence not writing anything to disk) takes 32 seconds.

@JackKelly
Copy link
Member Author

JackKelly commented May 12, 2023

I've done some experiments using dataset.to_zarr(region=...)... it's looking very do-able (to write zarr chunks in arbitrary order, in parallel. After first constructing the metadata.) I think it could work something like this...

Each xr.Dataset will contain two DataArrays: The "UKV" data, and a "chunk_exists" DataArray: A 1D boolean array, with one chunk per element (so, yeah, the individual chunks will be tiny!) which just indicates which chunks have actually been written to disk completely. Why? Consider what happens if we write metadata saying we've got 1 year of NWP init times for 2022. But then the script crashes after only writing 4 arbitrary init time chunks to disk. When we re-run the script, it will see that the init_time coords extend to the end of 2022. So how will the script know that it hasn't finished converting all 2022 grib files to Zarr chunks? We could do something like ds["UKV"].isnull().max(dim=["variable", "step", "y", "x"]) but that will load all the Zarr chunks! We could write individual files to disk to indicate which chunks have been written. But it's tidier if we keep that data inside the Zarr (it should be easy to delete this data if needed).

In the main process, before launching the pool of workers:

  • Get the full list of NWP init times on disk
  • If the Zarr already exists then:
    • lazily load the zarr, and eagerly load the "chunk_exists" 1D bool array (from the same zarr).
    • init_times_to_load_from_grib = all_grib_init_times_on_disk - init_times_with_chunks_which_exist_in_zarr
    • If there are any init_times_to_load_from_grib:
      • if every init time to write already exist in the Zarr's "init_time" coordinates, then we don't have to touch the Zarr's metadata: We can just go ahead and concurrently convert GRIB data to Zarr chunks.
      • if new GRIB data exists which sits between the start and end of the Zarr's "init_time" coords, but for which no actual coords exist, then I think all we can do is ignore those GRIB files. And log a noisy warning. And, when the script finishes, tell users how many files were ignored. I don't think we can insert new coords inside existing coords. However, I think we can guarantee that this will never happen if we always append contiguous "init_time" coords. Such that the "init_time" coords are always contiguous (even after appending an arbitrary number of times). Then users code can look at the "chunk_exists" array to see where data doesn't exist.
      • if new GRIB data exists which extends beyond the end of the "init_time" coord in the Zarr, then we need to append to the Zarr's "init_time" metadata before we can start writing actual chunks (see below).

If we have to create new metadata or update existing metadata then, in the main process:

  • Create coords arrays. Let's work on the assumption that we'll always set the "init_time" coords to be contiguous (even after appending to the Zarr).
  • Create a dask.array for the NWP data with the required shape, dtype, and chunks. We can get the shape from the length of the coords.
  • Create a "chunk_exists" DataArray.
  • If the target Zarr already exists:
    • then:
      • create an xr.Dataset with the two dask arrays & coords (no need to worry about the attrs).
      • dataset.to_zarr(compute=False, append_dim="init_time")
    • else:
      • lazily open grib data for a single NWP init time to get the attrs.
      • create an xr.Dataset with the dask array, coords, and attrs.
      • dataset.to_zarr(compute=False) to write just the metadata.

When we actually write data to disk, we can use imap_unordered.

We can write actual chunks like this:

# The drop_vars is necessary otherwise Zarr will try to
# overwrite variable, step, y, and x coord arrays.
dataset.drop_vars(['variable', 'step', 'y', 'x']).to_zarr(
    "test_regions.zarr",
    region={"init_time": slice(10, 20)},  # integer index slice.
    )

@JackKelly
Copy link
Member Author

On second thoughts... This isn't a priority for me. Especially if I downsample the NWPs in the worker process before passing it to the Zarr-writing process.

The next task I plan to work on is down sampling the NWPs, ready for the National PV forecasting experiments.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant