Skip to content

Commit

Permalink
Merge pull request #76 from ks905383/diag_fig_test
Browse files Browse the repository at this point in the history
Create tests for diag_fig() robustness, make auxfuncs work with silent option
  • Loading branch information
ks905383 authored Aug 21, 2024
2 parents 738941a + c18ede4 commit 5d9ac33
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.12.noxesmf"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.12.noxesmf","3.12.plotfuncs"]
steps:
- uses: actions/checkout@v4
- name: Create conda environment
Expand Down
28 changes: 28 additions & 0 deletions ci/environment-py3.12.plotfuncs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: test_env_xagg_38
channels:
- conda-forge
dependencies:
- python=3.12
############## These will have to be adjusted to your specific project
- numpy
- scipy
- xarray
- pandas
- netcdf4
- geopandas >= 0.12.0
- shapely
- xesmf >= 0.7.0 # These versions and explicit loads are to fix an issue in xesmf's call to cf_xarray (possibly through esmpy)
- cf_xarray >= 0.5.1
- esmf >= 8.1.0
- esmpy >= 8.1.0
- pytables
- cartopy
- matplotlib
- cmocean
##############
- pytest
- pip:
- codecov
- pytest-cov
- coverage[toml]
# - tables >= 3.7.0 # For exporting hd5 files through wm.to_file() (3.6.0 may have issues)
47 changes: 13 additions & 34 deletions docs/source/notebooks/full_run.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "broken-labor",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -237,36 +237,7 @@
"id": "a0ceaf52-bc04-4630-a017-8157b24ce9d0",
"metadata": {},
"source": [
"Let's verify if the aggregation was successful. The `weightmap` class can produce diagnostic figures that show a given polygon + the grid cells of the original raster dataset that overlap it. (This feature is still a bit experimental and finicky, and as of v0.3.2.0 needs a little bit of manual processing) "
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d136f646-836b-4c0a-8d5c-e84e5e25a38b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"adjusting grid... (this may happen because only a subset of pixels were used for aggregation for efficiency - i.e. [subset_bbox=True] in xa.pixel_overlaps())\n",
"grid adjustment successful\n"
]
}
],
"source": [
"# Load `subset_find()`, which allows you to find one grid within another\n",
"from xagg.auxfuncs import subset_find\n",
"\n",
"# weightmap.diag_fig() takes two required arguments: some information about\n",
"# a grid, and either the polygons of the raster grid, or the raster grid\n",
"# itself to calculate the polygons. \n",
"\n",
"# Let's get the raster grid.\n",
"# To match the internal indexing of `weightmap`, we need to subset the `ds`\n",
"# TODO: move this step internally to `weightmap.diag_fig()`\n",
"grid_polygon_info = subset_find(ds,weightmap.source_grid)"
"Let's verify if the aggregation was successful. The `weightmap` class can produce diagnostic figures that show a given polygon + the grid cells of the original raster dataset that overlap it."
]
},
{
Expand All @@ -290,7 +261,15 @@
"# Create diagnostic figure of the polygon with index 50 in `gdf` (in this \n",
"# case, a county in Montana). You can verify this is the 50th row in `gdf`\n",
"# by printing `gdf.loc[50]`. \n",
"weightmap.diag_fig(50,grid_polygon_info)"
"weightmap.diag_fig(50,ds)"
]
},
{
"cell_type": "markdown",
"id": "075f03b0-f099-44fe-a157-8b1825849bb6",
"metadata": {},
"source": [
"This diagnostic figure shows the weight used for each grid, relative to the grid cell with the largest overlap. For this county in Montana, the middle grid cell has the largest area overlap, followed by the grid cell to the left. The bottom right grid cell only barely touches the county and therefore has a much lower relative weight in the aggregation calculation. "
]
},
{
Expand Down Expand Up @@ -334,7 +313,7 @@
"source": [
"# Let's use the \"NAME\" (aka, county name) column to plot a \n",
"# diagnostic plot of Los Angeles county\n",
"weightmap.diag_fig({'NAME':'Los Angeles'},grid_polygon_info)"
"weightmap.diag_fig({'NAME':'Los Angeles'},ds)"
]
},
{
Expand Down Expand Up @@ -1138,7 +1117,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.12.1"
}
},
"nbformat": 4,
Expand Down
61 changes: 61 additions & 0 deletions tests/test_diag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
import pandas as pd
import numpy as np
import xarray as xr
import geopandas as gpd
from matplotlib import pyplot as plt

from xagg.core import aggregate
from xagg.wrappers import pixel_overlaps

try:
from cartopy import crs as ccrs
from matplotlib import pyplot as pyplot
import cmocean
_has_plotpckgs=True
except ImportError:
# To be able to test the rest with environments without xesmf
_has_plotpckgs=False

##### diag_fig() tests #####

# Load sample data
ds = xr.open_dataset('data/climate_data/tas_Amon_CCSM4_rcp85_monthavg_20700101-20991231.nc')
gdf = gpd.read_file('data/geo_data/UScounties.shp')
# Subset manually to not go through slow subset_find in every test
ds = ds.sel(lat=slice(30,55),lon=slice(360-130,360-65))

# Calculate overlaps
wm = pixel_overlaps(ds,gdf)

if _has_plotpckgs:
def test_diag_fig_noerror():
# Test whether error occurs when calling diag_fig()
# With some random location
wm.diag_fig(50,ds)

def test_diag_fig_isfig():
# Test to make sure a figure, axis is returned
fig,ax = wm.diag_fig(50,ds)

assert isinstance(fig,plt.Figure)
assert isinstance(ax,plt.Axes)

def test_diag_fig_subsetbypolyid():
# Test to make sure the right location is returned
# when using an integer poly index
fig,ax = wm.diag_fig(50,ds)

assert ax.get_title() == 'Poly #50: Sanders; Montana; 30; 089; 30089'

def test_diag_fig_subsetbyname():
# Test to make sure the right location is returned
# when using a column dictionary
fig,ax=wm.diag_fig({'NAME':'Los Angeles'},ds)

assert ax.get_title() == 'Poly #2384: Los Angeles; California; 06; 037; 06037'
else:
def test_diag_fig_noimport():
# Should raise ImportError in the no-plot environment
with pytest.raises(ImportError):
wm.diag_fig(50,ds)
10 changes: 8 additions & 2 deletions xagg/auxfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
import os
import re
from .options import get_options

def normalize(a,drop_na = False):
""" Normalizes the vector `a`
Expand Down Expand Up @@ -187,7 +188,7 @@ def fix_ds(ds,var_cipher = {'latitude':{'latitude':'lat','longitude':'lon'},
def get_bnds(ds,wrap_around_thresh='dynamic',
break_window_width=3,
break_thresh_x=2,
silent=False):
silent=None):

""" Builds vectors of lat/lon bounds if not present in `ds`
Expand Down Expand Up @@ -239,6 +240,9 @@ def get_bnds(ds,wrap_around_thresh='dynamic',
already existed, or with new variables "lat_bnds" and "lon_bnds"
if not.
"""
if silent is None:
silent = get_options()['silent']

#----------- Setup -----------
if (type(wrap_around_thresh) == str) and (wrap_around_thresh != 'dynamic'):
raise ValueError('`wrap_around_thresh` must either be numeric or the string "dynamic"; instead, it is '+str(wrap_around_thresh)+'.')
Expand Down Expand Up @@ -378,7 +382,7 @@ def get_bnds(ds,wrap_around_thresh='dynamic',
return ds


def subset_find(ds0,ds1,silent=False):
def subset_find(ds0,ds1,silent=None):
""" Finds the grid of `ds1` in `ds0`, and subsets `ds0` to the grid in `ds1`
Parameters
Expand All @@ -402,6 +406,8 @@ def subset_find(ds0,ds1,silent=False):
The input `ds0`, subset to the locations in `ds1`.
"""
if silent is None:
silent = get_options()['silent']

if 'loc' not in ds0.sizes:
ds0 = ds0.stack(loc = ('lat','lon'))
Expand Down
12 changes: 9 additions & 3 deletions xagg/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import warnings
import os
import re
from .options import get_options
from .options import get_options,set_options
from .auxfuncs import subset_find

try:
import cartopy
Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(self,agg,source_grid,geometry,overlap_da=None,weights='nowghts'):
self.weights = weights
self.overlap_da = overlap_da

def diag_fig(self,poly_id,ds):
def diag_fig(self,poly_id,ds,fig=None,ax=None):
""" Create a diagnostic figure showing overlap between pixels and a given polygon
See `xagg.diag.diag_fig()` for more info.
Expand All @@ -50,9 +51,14 @@ def diag_fig(self,poly_id,ds):
except ImportError:
raise ImportError('`wm.diag_fig()` separately requires `cartopy`, `matplotlib`, and `cmocean` to function; make sure these are installed first.')

# Adjust grids between the input ds and the weightmap grid (in case subset to
# bbox was used)
with set_options(silent=True):
ds = subset_find(ds,self.source_grid)

# Plot diagnostic figure
diag_fig(self,poly_id,ds)
fig,ax=diag_fig(self,poly_id,ds,fig=fig,ax=ax)
return fig,ax

def to_file(self,fn,overwrite=False):
""" Save a copy of the weightmap, to avoid recalculating it
Expand Down
2 changes: 2 additions & 0 deletions xagg/diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,5 @@ def diag_fig(wm,poly_id,pix_overlap_info,
# Gridlines
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=['x','y','bottom','left'],
linewidth=1, color='gray', alpha=0.5, linestyle=':')

return fig,ax

0 comments on commit 5d9ac33

Please sign in to comment.