Skip to content

Commit

Permalink
Restrict to latitude range first to concatenate fewer latitudes if ne…
Browse files Browse the repository at this point in the history
…eded
  • Loading branch information
NoraLoose committed Nov 4, 2024
1 parent 9445437 commit 4f5e4a7
Showing 1 changed file with 31 additions and 20 deletions.
51 changes: 31 additions & 20 deletions roms_tools/setup/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,11 +553,13 @@ def check_if_global(self, ds) -> bool:

return is_global

def concatenate_longitudes(self, end="upper", verbose=False):
def concatenate_longitudes(self, ds, end="upper", verbose=False):
"""Concatenates fields in dataset twice along the longitude dimension.
Parameters
----------
ds: xr.Dataset
The dataset to be concatenated. The longitude dimension must be present in this dataset.
end : str, optional
Specifies which end to shift the longitudes.
Options are:
Expand All @@ -571,14 +573,13 @@ def concatenate_longitudes(self, end="upper", verbose=False):
Returns
-------
None
The method updates the internal dataset `self.ds` with the concatenated fields
along the longitude dimension.
ds_concatenated : xr.Dataset
The concatenated dataset.
"""

if verbose:
start_time = time.time()
ds = self.ds

ds_concatenated = xr.Dataset()

lon = ds[self.dim_names["longitude"]]
Expand All @@ -601,8 +602,6 @@ def concatenate_longitudes(self, end="upper", verbose=False):
[lon_minus360, lon, lon_plus360], dim=self.dim_names["longitude"]
)

ds_concatenated[self.dim_names["longitude"]] = lon_concatenated

for var in ds.data_vars:
if self.dim_names["longitude"] in ds[var].dims:
field = ds[var]
Expand All @@ -625,12 +624,15 @@ def concatenate_longitudes(self, end="upper", verbose=False):
else:
ds_concatenated[var] = ds[var]

object.__setattr__(self, "ds", ds_concatenated)
ds_concatenated[self.dim_names["longitude"]] = lon_concatenated

if verbose:
print(
f"Concatenating the data along the longitude dimension: {time.time() - start_time:.3f} seconds"
)

return ds_concatenated

def post_process(self):
"""Placeholder method to be overridden by subclasses for dataset post-
processing.
Expand Down Expand Up @@ -686,7 +688,13 @@ def choose_subdomain(

margin = self.resolution * buffer_points

lon = self.ds[self.dim_names["longitude"]]
# Select the subdomain in latitude direction (so that we have to concatenate fewer latitudes below if concatenation is necessary)
subdomain = self.ds.sel(
**{
self.dim_names["latitude"]: slice(lat_min - margin, lat_max + margin),
}
)
lon = subdomain[self.dim_names["longitude"]]

if self.is_global:
# Concatenate only if necessary
Expand All @@ -695,19 +703,25 @@ def choose_subdomain(
if (lon_min - margin > (lon + 360).min()) and (
lon_max + margin < (lon + 360).max()
):
self.ds[self.dim_names["longitude"]] = lon + 360
lon = self.ds[self.dim_names["longitude"]]
subdomain[self.dim_names["longitude"]] = lon + 360
lon = subdomain[self.dim_names["longitude"]]
else:
self.concatenate_longitudes(end="upper", verbose=verbose)
subdomain = self.concatenate_longitudes(
subdomain, end="upper", verbose=verbose
)
lon = subdomain[self.dim_names["longitude"]]
if lon_min - margin < lon.min():
# See if shifting by -360 degrees helps
if (lon_min - margin > (lon - 360).min()) and (
lon_max + margin < (lon - 360).max()
):
self.ds[self.dim_names["longitude"]] = lon - 360
lon = self.ds[self.dim_names["longitude"]]
subdomain[self.dim_names["longitude"]] = lon - 360
lon = subdomain[self.dim_names["longitude"]]
else:
self.concatenate_longitudes(end="lower", verbose=verbose)
subdomain = self.concatenate_longitudes(
subdomain, end="lower", verbose=verbose
)
lon = subdomain[self.dim_names["longitude"]]

else:
# Adjust longitude range if needed to match the expected range
Expand All @@ -730,12 +744,9 @@ def choose_subdomain(
if lon_min - margin < 0:
lon_min += 360
lon_max += 360

# Select the subdomain

subdomain = self.ds.sel(
# Select the subdomain in longitude direction
subdomain = subdomain.sel(
**{
self.dim_names["latitude"]: slice(lat_min - margin, lat_max + margin),
self.dim_names["longitude"]: slice(lon_min - margin, lon_max + margin),
}
)
Expand Down

0 comments on commit 4f5e4a7

Please sign in to comment.