[Refactor] Improve the performance of temporal group averaging#689
[Refactor] Improve the performance of temporal group averaging#689tomvothecoder merged 8 commits intomainfrom
Conversation
7594df4 to
0d56ed5
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #689 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 15 15
Lines 1544 1546 +2
=========================================
+ Hits 1544 1546 +2 ☔ View full report in Codecov by Sentry. |
Replace `.load()` with `.astype("timedelta64[ns"])` for clarity
tomvothecoder
left a comment
There was a problem hiding this comment.
My initial self-review. The GH Actions build is passing.
| # 5. Calculate the departures for the data variable. | ||
| # ---------------------------------------------------------------------- | ||
| # This step allows us to perform xarray's grouped arithmetic to | ||
| # calculate departures. | ||
| dv_obs = ds_obs[data_var].copy() | ||
| self._labeled_time = self._label_time_coords(dv_obs[self.dim]) | ||
| dv_obs_grouped = self._group_data(dv_obs) | ||
|
|
||
| # 5. Align time dimension names using the labeled time dimension name. | ||
| # ---------------------------------------------------------------------- | ||
| # The climatology's time dimension is renamed to the labeled time | ||
| # dimension in step #4 above (e.g., "time" -> "season"). xarray requires | ||
| # dimension names to be aligned to perform grouped arithmetic, which we | ||
| # use for calculating departures in step #5. Otherwise, this error is | ||
| # raised: "`ValueError: incompatible dimensions for a grouped binary | ||
| # operation: the group variable '<FREQ ARG>' is not a dimension on the | ||
| # other argument`". | ||
| dv_climo = ds_climo[data_var] | ||
| dv_climo = dv_climo.rename({self.dim: self._labeled_time.name}) | ||
|
|
||
| # 6. Calculate the departures for the data variable. | ||
| # ---------------------------------------------------------------------- | ||
| # departures = observation - climatology | ||
| with xr.set_options(keep_attrs=True): | ||
| dv_departs = dv_obs_grouped - dv_climo | ||
| dv_departs = self._add_operation_attrs(dv_departs) | ||
| ds_obs[data_var] = dv_departs | ||
| ds_departs = self._calculate_departures(ds_obs, ds_climo, data_var) |
There was a problem hiding this comment.
Refactored this block of code into self._calculate_departures() for readability.
| self._labeled_time = self._label_time_coords(dv[self.dim]) | ||
| dv = dv.assign_coords({self.dim: self._labeled_time}) |
There was a problem hiding this comment.
Address bottleneck #1 from PR description.
replace time coords with labeled time coords directly for grouping, rather than adding labeled time coords as auxiliary coords on the time dimension (which slows things down in Xarray for some reason, need to ask Xarray forum)
| # warning please use the scalar types `np.float64`, or string notation.` | ||
| if isinstance(time_lengths.data, Array): | ||
| time_lengths.load() | ||
| time_lengths = time_lengths.astype("timedelta64[ns]") |
There was a problem hiding this comment.
Address bottleneck #2 from PR description
| dv = dv.assign_coords({self.dim: self._labeled_time}) | ||
| dv_gb = dv.groupby(self.dim) |
There was a problem hiding this comment.
Address bottleneck #1 from PR description
replace time coords with labeled time coords directly for grouping, rather than adding labeled time coords as auxiliary coords on the time dimension (which slows things down in Xarray for some reason, need to ask Xarray forum)
| time_grouped = xr.DataArray( | ||
| name="_".join(df_dt_components.columns), | ||
| name=self.dim, | ||
| data=dt_objects, | ||
| coords={self.dim: time_coords[self.dim]}, | ||
| coords={self.dim: dt_objects}, | ||
| dims=[self.dim], | ||
| attrs=time_coords[self.dim].attrs, | ||
| ) |
There was a problem hiding this comment.
Address bottleneck #1 from PR description
| if self._mode in ["group_average", "climatology"]: | ||
| self._weights = self._weights.rename({self.dim: f"{self.dim}_original"}) | ||
| # Only keep the original time coordinates, not the ones labeled | ||
| # by group. | ||
| self._weights = self._weights.drop_vars(self._labeled_time.name) | ||
| weights = self._weights.assign_coords({self.dim: self._dataset[self.dim]}) | ||
| weights = weights.rename({self.dim: f"{self.dim}_original"}) | ||
|
|
||
| ds[self._weights.name] = self._weights | ||
| ds[weights.name] = weights |
There was a problem hiding this comment.
Reassign the original, unlabeled time coordinates back to the weights xr.DataArray and then rename it to "time_original" to avoid conflicting the the labeled time coordinates (now called "time").
| dv_departs = dv_departs.assign_coords({self.dim: ds_obs[self.dim]}) | ||
| ds_departs[data_var] = dv_departs |
There was a problem hiding this comment.
Reassign the grouped, unlabeled time coordinates back to the final departures time coordinates (since the labeled, grouped time coordinates sometimes removes the year of the time coordinates).
|
Hi @chengzhuzhang, this PR is ready for review. After refactoring, I managed to cut down the runtime as following:
I also performed a regression test using the same e3sm_diags dataset between Benchmarking Script# %%
import xarray as xr
import xcdat as xc
### 1. Using temporal.climatology from xcdat
file_path = "/global/cfs/cdirs/e3sm/e3sm_diags/postprocessed_e3sm_v2_data_for_e3sm_diags/20221103.v2.LR.amip.NGD_v3atm.chrysalis/arm-diags-data/PRECT_sgpc1_198501_201412.nc"
ds = xc.open_dataset(file_path)
branch = "dev"
# %%
# 1. Calculate annual climatology
# -------------------------------
ds_annual_cycle = ds.temporal.climatology("PRECT", "month", keep_weights=True)
ds_annual_cycle.to_netcdf(f"temporal_climatology_{branch}.nc")
"""
main
--------------------------
CPU times: user 33 s, sys: 2.41 s, total: 35.4 s
Wall time: 35.4 s
refactor/688-temp-api-perf
--------------------------
CPU times: user 5.85 s, sys: 2.88 s, total: 8.72 s
Wall time: 8.78 s
"""
# %%
# 2. Calculate annual departures
# ------------------------------
ds_annual_cycle_anom = ds.temporal.departures("PRECT", "month", keep_weights=True)
ds_annual_cycle_anom.to_netcdf(f"temporal_departures_{branch}.nc")
"""
main
--------------------------
CPU times: user 1min 9s, sys: 4.8 s, total: 1min 14s
Wall time: 1min 14s
refactor/688-temp-api-perf
--------------------------
CPU times: user 11.6 s, sys: 4.32 s, total: 15.9 s
Wall time: 15.9 s
"""
# %%
# 3. Calculate monthly group averages
# -----------------------------------
ds_annual_avg = ds.temporal.group_average("PRECT", "month", keep_weights=True)
ds_annual_avg.to_netcdf(f"temporal_group_average_{branch}.nc")
"""
main
--------------------------
CPU times: user 33.5 s, sys: 2.27 s, total: 35.8 s
Wall time: 35.9 s
refactor/688-temp-api-perf
--------------------------
CPU times: user 5.59 s, sys: 2.06 s, total: 7.65 s
Wall time: 7.65 s
"""Regression testing scriptimport glob
import xarray as xr
# Get the filepaths for the dev and main branches
dev_filepaths = sorted(glob.glob("qa/issue-688/dev/*.nc"))
main_filepaths = sorted(glob.glob("qa/issue-688/main/*.nc"))
for fp, mp in zip(dev_filepaths, main_filepaths):
print(f"Comparing {fp} and {mp}")
# Load the datasets
dev_ds = xr.open_dataset(fp)
main_ds = xr.open_dataset(mp)
# Compare the datasets
try:
xr.testing.assert_identical(dev_ds, main_ds)
except AssertionError as e:
print(f"Datasets are not identical: {e}")
else:
print("Datasets are identical")Next step
|
| weights: xr.DataArray = grouped_time_lengths / grouped_time_lengths.sum() | ||
| weights.name = f"{self.dim}_wts" | ||
|
|
||
| # Validate the sum of weights for each group is 1.0. |
There was a problem hiding this comment.
It seems to be a good feature to have to check if the sum matches. But if it de-gradates the performance a lot, we can exclude it. maybe this check can be just implemented in testing (if it is not included yet). Also the _get_weights description needs to be updated to reflect that sum is no longer validated.
There was a problem hiding this comment.
We should expect the logic of _get_weights() to be correct, so this assertion should not be necessary at runtime (especially with the performance hit).
I like your suggestion of making it a unit test instead. I will push a commit with this change soon.
xcdat/temporal.py
Outdated
| if weighted and keep_weights: | ||
| self._weights = ds_climo.time_wts | ||
| ds_obs = self._keep_weights(ds_obs) | ||
| if keep_weights: |
There was a problem hiding this comment.
I notice this if statement changed from if weighted and keep_weights, should it be kept the same?
There was a problem hiding this comment.
Thank you for catching this. I reverted the conditional.
chengzhuzhang
left a comment
There was a problem hiding this comment.
Hi Tom, Thank you for the PR! I think it looks great, just have minor comments for you to consider.
- Check if sum of each weight group equals 1.0 - Update `_get_weights()` docs to remove validation portion
Description
TODO:
_get_weights(), loading time lengths into memory is slow (lines) -- replace with casting to"timedelta64[ns]"thenfloat64_get_weights(), performing validation to check the sums of weights for each group adds up to 1 is slow (lines) -- remove this unnecessary assertionIdentify performance optimizations -- I don't think this is necessary right nowgroupbywith vs. withoutfloxpackagemainChecklist
If applicable: