Skip to content

Commit

Permalink
Python: Add XArrayRasterSource
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaston committed Dec 18, 2023
1 parent a59191d commit c51455d
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 9 deletions.
1 change: 1 addition & 0 deletions python/src/exactextract/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
GDALRasterSource,
NumPyRasterSource,
RasterioRasterSource,
XArrayRasterSource,
)
from .writer import Writer, JSONWriter, GDALWriter
20 changes: 19 additions & 1 deletion python/src/exactextract/exact_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
JSONFeatureSource,
GeoPandasFeatureSource,
)
from .raster_source import RasterSource, GDALRasterSource, RasterioRasterSource
from .raster_source import RasterSource, GDALRasterSource, RasterioRasterSource, XArrayRasterSource
from .operation import Operation
from .processor import FeatureSequentialProcessor, RasterSequentialProcessor
from .writer import JSONWriter
Expand Down Expand Up @@ -62,6 +62,24 @@ def prep_raster(rast, band=None, name_root=None, names=None):
except ImportError:
pass

try:
import rioxarray
import xarray

if isinstance(rast, xarray.core.dataarray.DataArray):
if band:
return [XArrayRasterSource(rast, band)]
else:
if not names:
names = [f"{name_root}_{i+1}" for i in range(rast.rio.count)]
return [
XArrayRasterSource(rast, i+1, name=names[i])
for i in range(rast.rio.count)
]

except ImportError:
pass

raise Exception("Unhandled raster datatype")


Expand Down
71 changes: 71 additions & 0 deletions python/src/exactextract/raster_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import os
import pathlib

Expand Down Expand Up @@ -126,3 +127,73 @@ def read_window(self, x0, y0, nx, ny):
from rasterio.windows import Window

return self.ds.read(self.band_idx, window=Window(x0, y0, nx, ny))


class XArrayRasterSource(RasterSource):
def __init__(self, ds, band_idx=1, *, name=None):
super().__init__()

if isinstance(ds, (str, os.PathLike)):
import rioxarray
import xarray

ds = xarray.open_dataarray(ds)

self.ds = ds
if self.ds.rio.crs is None:
# Set a default CRS to prevent clip_box from
# complaining that we don't have one
self.ds.rio.set_crs('EPSG:4326', inplace=True)
self.band_idx = band_idx
self.band_dim = self._band_dim(self.ds)
self.bounds = self.ds.rio.bounds()

if name:
self.set_name(name)


@staticmethod
def _band_dim(ds):
dims = list(ds.dims)
dims.remove(ds.rio.x_dim)
dims.remove(ds.rio.y_dim)

if len(dims) == 0:
return None
elif len(dims) == 1:
return dims[0]
else:
raise Exception("Cannot handle >1 non-spatial dimension")


def res(self):
return tuple(abs(x) for x in self.ds.rio.resolution())


def extent(self):
return self.bounds


def nodata_value(self):
return self.ds.rio.nodata


def read_window(self, x0, y0, nx, ny):
lats = self.ds[self.ds.rio.y_dim]
flipped = bool(len(lats) > 1 and lats[1] > lats[0])

if flipped:
y0 = self.ds.rio.height - y0 - ny

selection = {}
if self.band_dim is not None:
selection[self.band_dim] = self.ds[self.band_dim][self.band_idx - 1]
selection[self.ds.rio.x_dim] = self.ds[self.ds.rio.x_dim][x0 : x0+nx]
selection[self.ds.rio.y_dim] = self.ds[self.ds.rio.y_dim][y0 : y0+ny]

ret = self.ds.sel(**selection).to_numpy()

if flipped:
ret = np.flipud(ret)

return ret
6 changes: 5 additions & 1 deletion python/tests/test_exact_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ def open_with_lib(fname, libname):
elif libname == "rasterio":
rasterio = pytest.importorskip("rasterio")
return rasterio.open(fname)
elif libname == "xarray":
rioxarray = pytest.importorskip("rioxarray")
xarray = pytest.importorskip("xarray")
return xarray.open_dataarray(fname)
elif libname == "ogr":
ogr = pytest.importorskip("osgeo.ogr")
return ogr.Open(fname)
Expand All @@ -331,7 +335,7 @@ def open_with_lib(fname, libname):
return gp.read_file(fname)


@pytest.mark.parametrize("rast_lib", ("gdal", "rasterio"))
@pytest.mark.parametrize("rast_lib", ("gdal", "rasterio", "xarray"))
@pytest.mark.parametrize("vec_lib", ("ogr", "fiona", "geopandas"))
@pytest.mark.parametrize(
"arr,expected",
Expand Down
33 changes: 26 additions & 7 deletions python/tests/test_raster_source.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,39 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import pytest

from exactextract import GDALRasterSource, RasterioRasterSource
from exactextract import GDALRasterSource, RasterioRasterSource, XArrayRasterSource


@pytest.fixture()
def global_half_degree(tmp_path):
from osgeo import gdal

fname = str(tmp_path / "test.tif")
fname = str(tmp_path / "test.nc")

drv = gdal.GetDriverByName("GTiff")
ds = drv.Create(fname, 720, 360)
nx = 720
ny = 360

drv = gdal.GetDriverByName("NetCDF")
ds = drv.Create(fname, nx, ny, eType=gdal.GDT_Int32)
gt = (-180.0, 0.5, 0.0, 90.0, 0.0, -0.5)
ds.SetGeoTransform(gt)
band = ds.GetRasterBand(1)
band.SetNoDataValue(6)

data = np.arange(nx * ny).reshape(ny, nx)
band.WriteArray(data)

ds = None

return fname


@pytest.mark.parametrize("Source", (GDALRasterSource, RasterioRasterSource))
@pytest.mark.parametrize(
"Source", (GDALRasterSource, RasterioRasterSource, XArrayRasterSource)
)
def test_gdal_raster(global_half_degree, Source):
try:
src = Source(global_half_degree, 1)
Expand All @@ -32,6 +42,15 @@ def test_gdal_raster(global_half_degree, Source):

assert src.res() == (0.50, 0.50)
assert src.extent() == pytest.approx((-180, -90, 180, 90))
assert src.nodata_value() == 6

assert src.read_window(0, 0, 10, 10).shape == (10, 10)
window = src.read_window(4, 5, 2, 3)

assert window.shape == (3, 2)
np.testing.assert_array_equal(
window.astype(np.float64),
np.array([[3604, 3605], [4324, 4325], [5044, 5045]], np.float64),
)

if Source != XArrayRasterSource:
assert src.nodata_value() == 6
assert window.dtype == np.int32

0 comments on commit c51455d

Please sign in to comment.