Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional aggregation control #569

Merged
merged 8 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions intake_esm/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,35 +294,38 @@ def df(self) -> pd.DataFrame:
@property
def has_multiple_variable_assets(self) -> bool:
"""Return True if the catalog has multiple variable assets."""
return self.aggregation_control.variable_column_name in self.columns_with_iterables
if self.aggregation_control:
return self.aggregation_control.variable_column_name in self.columns_with_iterables
return False

def _cast_agg_columns_with_iterables(self) -> None:
"""Cast all agg_columns with iterables to tuple values so as
to avoid hashing issues (e.g. TypeError: unhashable type: 'list')
"""
columns = list(
self.columns_with_iterables.intersection(
set(map(lambda agg: agg.attribute_name, self.aggregation_control.aggregations))
if self.aggregation_control:
columns = list(
self.columns_with_iterables.intersection(
set(map(lambda agg: agg.attribute_name, self.aggregation_control.aggregations))
)
)
)
if columns:
self._df[columns] = self._df[columns].apply(tuple)
if columns:
self._df[columns] = self._df[columns].apply(tuple)

@property
def grouped(self) -> typing.Union[pd.core.groupby.DataFrameGroupBy, pd.DataFrame]:

if self.aggregation_control.groupby_attrs:
self.aggregation_control.groupby_attrs = list(
filter(
functools.partial(_allnan_or_nonan, self.df),
self.aggregation_control.groupby_attrs,
if self.aggregation_control:
if self.aggregation_control.groupby_attrs:
self.aggregation_control.groupby_attrs = list(
filter(
functools.partial(_allnan_or_nonan, self.df),
self.aggregation_control.groupby_attrs,
)
)
)

if self.aggregation_control.groupby_attrs and set(
self.aggregation_control.groupby_attrs
) != set(self.df.columns):
return self.df.groupby(self.aggregation_control.groupby_attrs)
if self.aggregation_control.groupby_attrs and set(
self.aggregation_control.groupby_attrs
) != set(self.df.columns):
return self.df.groupby(self.aggregation_control.groupby_attrs)
return self.df

def _construct_group_keys(self, sep: str = '.') -> dict[str, typing.Union[str, tuple[str]]]:
Expand Down
51 changes: 37 additions & 14 deletions intake_esm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def __init__(
self._validate_derivedcat()

def _validate_derivedcat(self) -> None:
if self.esmcat.aggregation_control is None and len(self.derivedcat):
raise ValueError(
'Variable derivation requires an `aggregation_control` to be specified in the catalog.'
)
for key, entry in self.derivedcat.items():
if self.esmcat.aggregation_control.variable_column_name not in entry.query.keys():
raise ValueError(
Expand Down Expand Up @@ -149,6 +153,10 @@ def keys_info(self) -> pd.DataFrame:

"""
results = self.esmcat._construct_group_keys(sep=self.sep)
if self.esmcat.aggregation_control and self.esmcat.aggregation_control.groupby_attrs:
pass
else:
pass
data = {
key: dict(zip(self.esmcat.aggregation_control.groupby_attrs, results[key]))
for key in results
Expand All @@ -167,7 +175,7 @@ def key_template(self) -> str:
str
string template used to create catalog entry keys
"""
if self.esmcat.aggregation_control.groupby_attrs:
if self.esmcat.aggregation_control and self.esmcat.aggregation_control.groupby_attrs:
return self.sep.join(self.esmcat.aggregation_control.groupby_attrs)
else:
return self.sep.join(self.esmcat.df.columns)
Expand Down Expand Up @@ -233,15 +241,21 @@ def __getitem__(self, key: str) -> ESMDataSource:
else:
records = grouped.get_group(internal_key).to_dict(orient='records')

if self.esmcat.aggregation_control:
variable_column_name = self.esmcat.aggregation_control.variable_column_name
aggregations = self.esmcat.aggregation_control.aggregations
else:
variable_column_name = None
aggregations = []
# Create a new entry
entry = ESMDataSource(
key=key,
records=records,
variable_column_name=self.esmcat.aggregation_control.variable_column_name,
variable_column_name=variable_column_name,
path_column_name=self.esmcat.assets.column_name,
data_format=self.esmcat.assets.format,
format_column_name=self.esmcat.assets.format_column_name,
aggregations=self.esmcat.aggregation_control.aggregations,
aggregations=aggregations,
intake_kwargs={'metadata': {}},
)
self._entries[key] = entry
Expand Down Expand Up @@ -366,7 +380,10 @@ def search(self, require_all_on: typing.Union[str, list[str]] = None, **query: t
# step 2: Search for entries required to derive variables in the derived catalogs
# This requires a bit of a hack i.e. the user has to specify the variable in the query
derivedcat_results = []
variables = query.pop(self.esmcat.aggregation_control.variable_column_name, None)
if self.esmcat.aggregation_control:
variables = query.pop(self.esmcat.aggregation_control.variable_column_name, None)
else:
variables = None
dependents = []
derived_cat_subset = {}
if variables:
Expand Down Expand Up @@ -488,19 +505,21 @@ def nunique(self) -> pd.Series:
dtype: int64
"""
nunique = self.esmcat.nunique()
nunique[f'derived_{self.esmcat.aggregation_control.variable_column_name}'] = len(
self.derivedcat.keys()
)
if self.esmcat.aggregation_control:
nunique[f'derived_{self.esmcat.aggregation_control.variable_column_name}'] = len(
self.derivedcat.keys()
)
return nunique

def unique(self) -> pd.Series:
"""Return unique values for given columns in the
catalog.
"""
unique = self.esmcat.unique()
unique[f'derived_{self.esmcat.aggregation_control.variable_column_name}'] = list(
self.derivedcat.keys()
)
if self.esmcat.aggregation_control:
unique[f'derived_{self.esmcat.aggregation_control.variable_column_name}'] = list(
self.derivedcat.keys()
)
return unique

@pydantic.validate_arguments
Expand Down Expand Up @@ -586,9 +605,13 @@ def to_dataset_dict(
return {}

if (
self.esmcat.aggregation_control.variable_column_name
in self.esmcat.aggregation_control.groupby_attrs
) and len(self.derivedcat) > 0:
self.esmcat.aggregation_control
and (
self.esmcat.aggregation_control.variable_column_name
in self.esmcat.aggregation_control.groupby_attrs
)
and len(self.derivedcat) > 0
):
raise NotImplementedError(
f'The `{self.esmcat.aggregation_control.variable_column_name}` column name is used as a groupby attribute: {self.esmcat.aggregation_control.groupby_attrs}. '
'This is not yet supported when computing derived variables.'
Expand Down Expand Up @@ -618,7 +641,7 @@ def to_dataset_dict(
storage_options=storage_options,
)

if aggregate is not None and not aggregate:
if aggregate is not None and not aggregate and self.esmcat.aggregation_control:
self = deepcopy(self)
self.esmcat.aggregation_control.groupby_attrs = []
if progressbar is not None:
Expand Down
10 changes: 5 additions & 5 deletions intake_esm/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _open_dataset(
ds = ds.set_coords(scalar_variables)
ds = ds[variables]
ds.attrs[OPTIONS['vars_key']] = variables
else:
elif varname:
ds.attrs[OPTIONS['vars_key']] = varname

ds = _expand_dims(expand_dims, ds)
Expand Down Expand Up @@ -126,11 +126,11 @@ def __init__(
self,
key: pydantic.StrictStr,
records: list[dict[str, typing.Any]],
variable_column_name: pydantic.StrictStr,
path_column_name: pydantic.StrictStr,
data_format: typing.Optional[DataFormat],
format_column_name: typing.Optional[pydantic.StrictStr],
*,
variable_column_name: typing.Optional[pydantic.StrictStr] = None,
aggregations: typing.Optional[list[Aggregation]] = None,
requested_variables: list[str] = None,
preprocess: typing.Callable = None,
Expand All @@ -148,12 +148,12 @@ def __init__(
records: list of dict
A list of records, each of which is a dictionary
mapping column names to values.
variable_column_name: str
The column name of the variable name.
path_column_name: str
The column name of the path.
data_format: DataFormat
The data format of the data.
variable_column_name: str, optional
The column name of the variable name.
aggregations: list of Aggregation, optional
A list of aggregations to apply to the data.
requested_variables: list of str, optional
Expand Down Expand Up @@ -220,7 +220,7 @@ def _open_dataset(self):
datasets = [
_open_dataset(
record[self.path_column_name],
record[self.variable_column_name],
record[self.variable_column_name] if self.variable_column_name else None,
xarray_open_kwargs=_get_xarray_open_kwargs(
record['_data_format_'], self.xarray_open_kwargs, self.storage_options
),
Expand Down
64 changes: 64 additions & 0 deletions tests/sample-catalogs/catalog-dict-records-noagg.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
{
"esmcat_version": "0.1.0",
"id": "aws-cesm1-le-noagg",
"description": "This is an ESM catalog for CESM1 Large Ensemble Zarr dataset publicly available on Amazon S3 (us-west-2 region), without any aggregation info.",
"catalog_dict": [
{
"component": "atm",
"frequency": "daily",
"experiment": "20C",
"variable": "FLNS",
"path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FLNS.zarr"
},
{
"component": "atm",
"frequency": "daily",
"experiment": "20C",
"variable": "FLNSC",
"path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FLNSC.zarr"
},
{
"component": "atm",
"frequency": "daily",
"experiment": "20C",
"variable": "FLUT",
"path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FLUT.zarr"
},
{
"component": "atm",
"frequency": "daily",
"experiment": "20C",
"variable": "FSNS",
"path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FSNS.zarr"
},
{
"component": "atm",
"frequency": "daily",
"experiment": "20C",
"variable": "FSNSC",
"path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FSNSC.zarr"
}
],
"attributes": [
{
"column_name": "component",
"vocabulary": ""
},
{
"column_name": "frequency",
"vocabulary": ""
},
{
"column_name": "experiment",
"vocabulary": ""
},
{
"column_name": "variable",
"vocabulary": ""
}
],
"assets": {
"column_name": "path",
"format": "zarr"
}
}
19 changes: 18 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def func_multivar(ds):
cdf_cat_sample_cmip6,
mixed_cat_sample_cmip6,
multi_variable_cat,
noagg_cat,
sample_df,
sample_esmcat_data,
zarr_cat_aws_cesm,
Expand All @@ -58,6 +59,7 @@ def func_multivar(ds):
'obj, sep, read_csv_kwargs',
[
(catalog_dict_records, '.', None),
(noagg_cat, '.', None),
(cdf_cat_sample_cmip6, '/', None),
(zarr_cat_aws_cesm, '.', None),
(zarr_cat_pangeo_cmip6, '*', None),
Expand Down Expand Up @@ -99,6 +101,18 @@ def func(ds):
intake.open_esm_datastore(catalog_dict_records, registry=registry)


def test_impossible_derivedcat():
registry = intake_esm.DerivedVariableRegistry()

@registry.register(variable='FOO', query={'variable': ['FLNS', 'FLUT']})
def func(ds):
ds['FOO'] = ds.FLNS + ds.FLUT
return ds

with pytest.raises(ValueError, match='Variable derivation requires an `aggregation_control`'):
intake.open_esm_datastore(noagg_cat, registry=registry)


@pytest.mark.parametrize(
'obj, sep, read_csv_kwargs',
[
Expand All @@ -107,6 +121,7 @@ def func(ds):
(cdf_cat_sample_cmip5, '.', None),
(cdf_cat_sample_cmip6, '*', None),
(catalog_dict_records, '.', None),
(noagg_cat, '.', None),
({'esmcat': sample_esmcat_data, 'df': sample_df}, '.', None),
],
)
Expand All @@ -116,7 +131,9 @@ def test_catalog_unique(obj, sep, read_csv_kwargs):
nuniques = cat.nunique()
assert isinstance(uniques, pd.Series)
assert isinstance(nuniques, pd.Series)
assert len(uniques.keys()) == len(cat.df.columns) + 1 # for derived_variable entry
assert len(uniques.keys()) == len(cat.df.columns) + (
0 if obj is noagg_cat else 1
) # for derived_variable entry


def test_catalog_contains():
Expand Down
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
cdf_cat_sample_cmip5 = os.path.join(here, 'sample-catalogs/cmip5-netcdf.json')
cdf_cat_sample_cesmle = os.path.join(here, 'sample-catalogs/cesm1-lens-netcdf.json')
catalog_dict_records = os.path.join(here, 'sample-catalogs/catalog-dict-records.json')
noagg_cat = os.path.join(here, 'sample-catalogs/catalog-dict-records-noagg.json')
zarr_cat_aws_cesm = (
'https://raw.githubusercontent.com/NCAR/cesm-lens-aws/master/intake-catalogs/aws-cesm1-le.json'
)
Expand Down