Skip to content

Commit

Permalink
Merge pull request #880 from pnuu/feature-zarr-resample-luts
Browse files Browse the repository at this point in the history
 Replace Numpy files with zarr for resampling LUT caching
  • Loading branch information
pnuu authored Sep 6, 2019
2 parents ffc21c2 + 365d761 commit ff5ed76
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 113 deletions.
185 changes: 120 additions & 65 deletions satpy/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
#
# You should have received a copy of the GNU General Public License along with
# satpy. If not, see <http://www.gnu.org/licenses/>.
"""Satpy provides multiple resampling algorithms for resampling geolocated
"""Satpy resampling module.
Satpy provides multiple resampling algorithms for resampling geolocated
data to uniform projected grids. The easiest way to perform resampling in
Satpy is through the :class:`~satpy.scene.Scene` object's
:meth:`~satpy.scene.Scene.resample` method. Additional utility functions are
Expand Down Expand Up @@ -144,11 +146,19 @@
LOG = getLogger(__name__)

CACHE_SIZE = 10
NN_COORDINATES = {'valid_input_index': ('y1', 'x1'),
'valid_output_index': ('y2', 'x2'),
'index_array': ('y2', 'x2', 'z2')}
BIL_COORDINATES = {'bilinear_s': ('x1', ),
'bilinear_t': ('x1', ),
'valid_input_index': ('x2', ),
'index_array': ('x1', 'n')}

resamplers_cache = WeakValueDictionary()


def hash_dict(the_dict, the_hash=None):
"""Calculate a hash for a dictionary."""
if the_hash is None:
the_hash = hashlib.sha1()
the_hash.update(json.dumps(the_dict, sort_keys=True).encode('utf-8'))
Expand Down Expand Up @@ -321,7 +331,6 @@ def __init__(self, source_geo_def, target_geo_def):
Geolocation definition for the area to resample data to.
"""

self.source_geo_def = source_geo_def
self.target_geo_def = target_geo_def

Expand Down Expand Up @@ -394,12 +403,13 @@ def resample(self, data, cache_dir=None, mask_area=None, **kwargs):
cache_id = self.precompute(cache_dir=cache_dir, **kwargs)
return self.compute(data, cache_id=cache_id, **kwargs)

def _create_cache_filename(self, cache_dir=None, **kwargs):
"""Create filename for the cached resampling parameters"""
def _create_cache_filename(self, cache_dir=None, prefix='',
fmt='.zarr', **kwargs):
"""Create filename for the cached resampling parameters."""
cache_dir = cache_dir or '.'
hash_str = self.get_hash(**kwargs)

return os.path.join(cache_dir, 'resample_lut-' + hash_str + '.npz')
return os.path.join(cache_dir, prefix + hash_str + fmt)


class KDTreeResampler(BaseResampler):
Expand Down Expand Up @@ -427,6 +437,7 @@ class KDTreeResampler(BaseResampler):
"""

def __init__(self, source_geo_def, target_geo_def):
"""Init KDTreeResampler."""
super(KDTreeResampler, self).__init__(source_geo_def, target_geo_def)
self.resampler = None
self._index_caches = {}
Expand Down Expand Up @@ -479,61 +490,96 @@ def precompute(self, mask=None, radius_of_influence=None, epsilon=0,
self.resampler.get_neighbour_info(mask=mask)
self.save_neighbour_info(cache_dir, mask=mask, **kwargs)

def _apply_cached_indexes(self, cached_indexes, persist=False):
"""Reassign various resampler index attributes."""
# cacheable_dict = {}
for elt in ['valid_input_index', 'valid_output_index',
'index_array', 'distance_array']:
val = cached_indexes[elt]
if isinstance(val, tuple):
val = cached_indexes[elt][0]
elif isinstance(val, np.ndarray):
val = da.from_array(val, chunks=CHUNK_SIZE)
elif persist and isinstance(val, da.Array):
cached_indexes[elt] = val = val.persist()
setattr(self.resampler, elt, val)
def _apply_cached_index(self, val, idx_name, persist=False):
"""Reassign resampler index attributes."""
if isinstance(val, np.ndarray):
val = da.from_array(val, chunks=CHUNK_SIZE)
elif persist and isinstance(val, da.Array):
val = val.persist()
setattr(self.resampler, idx_name, val)
return val

def _check_numpy_cache(self, cache_dir, mask=None,
**kwargs):
"""Check if there's Numpy cache file and convert it to zarr."""
fname_np = self._create_cache_filename(cache_dir,
prefix='resample_lut-',
mask=mask, fmt='.npz',
**kwargs)
fname_zarr = self._create_cache_filename(cache_dir, prefix='nn_lut-',
mask=mask, fmt='.zarr',
**kwargs)

if os.path.exists(fname_np) and not os.path.exists(fname_zarr):
import warnings
warnings.warn("Using Numpy files as resampling cache is "
"deprecated.")
LOG.warning("Converting resampling LUT from .npz to .zarr")
zarr_out = xr.Dataset()
with np.load(fname_np, 'r') as fid:
for idx_name, coord in NN_COORDINATES.items():
zarr_out[idx_name] = (coord, fid[idx_name])

# Write indices to Zarr file
zarr_out.to_zarr(fname_zarr)

def load_neighbour_info(self, cache_dir, mask=None, **kwargs):
"""Read index arrays from either the in-memory or disk cache."""
mask_name = getattr(mask, 'name', None)
filename = self._create_cache_filename(cache_dir,
cached = {}
self._check_numpy_cache(cache_dir, mask=mask_name, **kwargs)

filename = self._create_cache_filename(cache_dir, prefix='nn_lut-',
mask=mask_name, **kwargs)
if kwargs.get('mask') in self._index_caches:
self._apply_cached_indexes(self._index_caches[kwargs.get('mask')])
elif cache_dir:
cache = np.load(filename, mmap_mode='r', allow_pickle=True)
# copy the dict so we can modify it's keys
new_cache = dict(cache.items())
cache.close()
self._apply_cached_indexes(new_cache) # modifies cache dict in-place
self._index_caches[mask_name] = new_cache
else:
raise IOError
for idx_name in NN_COORDINATES.keys():
if mask_name in self._index_caches:
cached[idx_name] = self._apply_cached_index(
self._index_caches[mask_name][idx_name], idx_name)
elif cache_dir:
try:
cache = da.from_zarr(filename, idx_name)
if idx_name == 'valid_input_index':
# valid input index array needs to be boolean
cache = cache.astype(np.bool)
except ValueError:
raise IOError
cache = self._apply_cached_index(cache, idx_name)
cached[idx_name] = cache
else:
raise IOError
self._index_caches[mask_name] = cached

def save_neighbour_info(self, cache_dir, mask=None, **kwargs):
"""Cache resampler's index arrays if there is a cache dir."""
if cache_dir:
mask_name = getattr(mask, 'name', None)
cache = self._read_resampler_attrs()
filename = self._create_cache_filename(
cache_dir, mask=mask_name, **kwargs)
cache_dir, prefix='nn_lut-', mask=mask_name, **kwargs)
LOG.info('Saving kd_tree neighbour info to %s', filename)
cache = self._read_resampler_attrs()
# update the cache in place with persisted dask arrays
self._apply_cached_indexes(cache, persist=True)
zarr_out = xr.Dataset()
for idx_name, coord in NN_COORDINATES.items():
# update the cache in place with persisted dask arrays
cache[idx_name] = self._apply_cached_index(cache[idx_name],
idx_name,
persist=True)
zarr_out[idx_name] = (coord, cache[idx_name])

# Write indices to Zarr file
zarr_out.to_zarr(filename)

self._index_caches[mask_name] = cache
np.savez(filename, **cache)

def _read_resampler_attrs(self):
"""Read certain attributes from the resampler for caching."""
return {attr_name: getattr(self.resampler, attr_name)
for attr_name in [
'valid_input_index', 'valid_output_index',
'index_array', 'distance_array']}
for attr_name in NN_COORDINATES.keys()}

def compute(self, data, weight_funcs=None, fill_value=np.nan,
with_uncert=False, **kwargs):
"""Resample data."""
del kwargs
LOG.debug("Resampling " + str(data.name))
LOG.debug("Resampling %s", str(data.name))
res = self.resampler.get_sample_from_neighbour_info(data, fill_value)
return update_resampled_coords(data, res, self.target_geo_def)

Expand Down Expand Up @@ -583,6 +629,7 @@ class EWAResampler(BaseResampler):
"""

def __init__(self, source_geo_def, target_geo_def):
"""Init EWAResampler."""
super(EWAResampler, self).__init__(source_geo_def, target_geo_def)
self.cache = {}

Expand All @@ -599,7 +646,7 @@ def resample(self, *args, **kwargs):
return super(EWAResampler, self).resample(*args, **kwargs)

def _call_ll2cr(self, lons, lats, target_geo_def, swath_usage=0):
"""Wrapper around ll2cr for handling dask delayed calls better."""
"""Wrap ll2cr() for handling dask delayed calls better."""
new_src = SwathDefinition(lons, lats)

swath_points_in_grid, cols, rows = ll2cr(new_src, target_geo_def)
Expand Down Expand Up @@ -664,7 +711,7 @@ def precompute(self, cache_dir=None, swath_usage=0, **kwargs):

def _call_fornav(self, cols, rows, target_geo_def, data,
grid_coverage=0, **kwargs):
"""Wrapper to run fornav as a dask delayed."""
"""Wrap fornav() to run as a dask delayed."""
num_valid_points, res = fornav(cols, rows, target_geo_def,
data, **kwargs)

Expand Down Expand Up @@ -734,23 +781,21 @@ def compute(self, data, cache_id=None, fill_value=0, weight_count=10000,


class BilinearResampler(BaseResampler):

"""Resample using bilinear."""

def __init__(self, source_geo_def, target_geo_def):
"""Init BilinearResampler."""
super(BilinearResampler, self).__init__(source_geo_def, target_geo_def)
self.resampler = None

def precompute(self, mask=None, radius_of_influence=50000, epsilon=0,
reduce_data=True, nprocs=1,
cache_dir=False, **kwargs):
reduce_data=True, cache_dir=False, **kwargs):
"""Create bilinear coefficients and store them for later use.
Note: The `mask` keyword should be provided if geolocation may be valid
where data points are invalid. This defaults to the `mask` attribute of
the `data` numpy masked array passed to the `resample` method.
"""

del kwargs

if self.resampler is None:
Expand All @@ -771,37 +816,44 @@ def precompute(self, mask=None, radius_of_influence=50000, epsilon=0,
self.save_bil_info(cache_dir, **kwargs)

def load_bil_info(self, cache_dir, **kwargs):

"""Load bilinear resampling info from cache directory."""
if cache_dir:
filename = self._create_cache_filename(cache_dir,
prefix='resample_lut_bil_',
prefix='bil_lut-',
**kwargs)
cache = np.load(filename)
for elt in ['bilinear_s', 'bilinear_t', 'valid_input_index',
'index_array']:
if isinstance(cache[elt], tuple):
setattr(self.resampler, elt, cache[elt][0])
else:
setattr(self.resampler, elt, cache[elt])
cache.close()
for val in BIL_COORDINATES.keys():
try:
cache = da.from_zarr(filename, val)
if val == 'valid_input_index':
# valid input index array needs to be boolean
cache = cache.astype(np.bool)
# Compute the cache arrays
cache = cache.compute()
except ValueError:
raise IOError
setattr(self.resampler, val, cache)

else:
raise IOError

def save_bil_info(self, cache_dir, **kwargs):
"""Save bilinear resampling info to cache directory."""
if cache_dir:
filename = self._create_cache_filename(cache_dir,
prefix='resample_lut_bil_',
prefix='bil_lut-',
**kwargs)
LOG.info('Saving kd_tree neighbour info to %s', filename)
cache = {'bilinear_s': self.resampler.bilinear_s,
'bilinear_t': self.resampler.bilinear_t,
'valid_input_index': self.resampler.valid_input_index,
'index_array': self.resampler.index_array}

np.savez(filename, **cache)
LOG.info('Saving BIL neighbour info to %s', filename)
zarr_out = xr.Dataset()
for idx_name, coord in BIL_COORDINATES.items():
var = getattr(self.resampler, idx_name)
if isinstance(var, np.ndarray):
var = da.from_array(var, chunks=CHUNK_SIZE)
var = var.rechunk(CHUNK_SIZE)
zarr_out[idx_name] = (coord, var)
zarr_out.to_zarr(filename)

def compute(self, data, fill_value=None, **kwargs):
"""Resample the given data using bilinear interpolation"""
"""Resample the given data using bilinear interpolation."""
del kwargs

if fill_value is None:
Expand Down Expand Up @@ -838,6 +890,7 @@ class NativeResampler(BaseResampler):
"""

def resample(self, data, cache_dir=None, mask_area=False, **kwargs):
"""Run NativeResampler."""
# use 'mask_area' with a default of False. It wouldn't do anything.
return super(NativeResampler, self).resample(data,
cache_dir=cache_dir,
Expand All @@ -846,7 +899,7 @@ def resample(self, data, cache_dir=None, mask_area=False, **kwargs):

@staticmethod
def aggregate(d, y_size, x_size):
"""Average every 4 elements (2x2) in a 2D array"""
"""Average every 4 elements (2x2) in a 2D array."""
if d.ndim != 2:
# we can't guarantee what blocks we are getting and how
# it should be reshaped to do the averaging.
Expand All @@ -867,6 +920,7 @@ def aggregate(d, y_size, x_size):

@classmethod
def expand_reduce(cls, d_arr, repeats):
"""Expand reduce."""
if not isinstance(d_arr, da.Array):
d_arr = da.from_array(d_arr, chunks=CHUNK_SIZE)
if all(x == 1 for x in repeats.values()):
Expand Down Expand Up @@ -900,6 +954,7 @@ def _calc_chunks(c, c_size):
"directions")

def compute(self, data, expand=True, **kwargs):
"""Resample data with NativeResampler."""
if isinstance(self.target_geo_def, (list, tuple)):
# find the highest/lowest area among the provided
test_func = max if expand else min
Expand Down
Loading

0 comments on commit ff5ed76

Please sign in to comment.