Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create tests for diag_fig() robustness, make auxfuncs work with silent option #76

Merged
merged 4 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.12.noxesmf"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.12.noxesmf","3.12.plotfuncs"]
steps:
- uses: actions/checkout@v4
- name: Create conda environment
Expand Down
28 changes: 28 additions & 0 deletions ci/environment-py3.12.plotfuncs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: test_env_xagg_38
channels:
- conda-forge
dependencies:
- python=3.12
############## These will have to be adjusted to your specific project
- numpy
- scipy
- xarray
- pandas
- netcdf4
- geopandas >= 0.12.0
- shapely
- xesmf >= 0.7.0 # These versions and explicit loads are to fix an issue in xesmf's call to cf_xarray (possibly through esmpy)
- cf_xarray >= 0.5.1
- esmf >= 8.1.0
- esmpy >= 8.1.0
- pytables
- cartopy
- matplotlib
- cmocean
##############
- pytest
- pip:
- codecov
- pytest-cov
- coverage[toml]
# - tables >= 3.7.0 # For exporting hd5 files through wm.to_file() (3.6.0 may have issues)
47 changes: 13 additions & 34 deletions docs/source/notebooks/full_run.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "broken-labor",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -237,36 +237,7 @@
"id": "a0ceaf52-bc04-4630-a017-8157b24ce9d0",
"metadata": {},
"source": [
"Let's verify if the aggregation was successful. The `weightmap` class can produce diagnostic figures that show a given polygon + the grid cells of the original raster dataset that overlap it. (This feature is still a bit experimental and finicky, and as of v0.3.2.0 needs a little bit of manual processing) "
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d136f646-836b-4c0a-8d5c-e84e5e25a38b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"adjusting grid... (this may happen because only a subset of pixels were used for aggregation for efficiency - i.e. [subset_bbox=True] in xa.pixel_overlaps())\n",
"grid adjustment successful\n"
]
}
],
"source": [
"# Load `subset_find()`, which allows you to find one grid within another\n",
"from xagg.auxfuncs import subset_find\n",
"\n",
"# weightmap.diag_fig() takes two required arguments: some information about\n",
"# a grid, and either the polygons of the raster grid, or the raster grid\n",
"# itself to calculate the polygons. \n",
"\n",
"# Let's get the raster grid.\n",
"# To match the internal indexing of `weightmap`, we need to subset the `ds`\n",
"# TODO: move this step internally to `weightmap.diag_fig()`\n",
"grid_polygon_info = subset_find(ds,weightmap.source_grid)"
"Let's verify if the aggregation was successful. The `weightmap` class can produce diagnostic figures that show a given polygon + the grid cells of the original raster dataset that overlap it."
]
},
{
Expand All @@ -290,7 +261,15 @@
"# Create diagnostic figure of the polygon with index 50 in `gdf` (in this \n",
"# case, a county in Montana). You can verify this is the 50th row in `gdf`\n",
"# by printing `gdf.loc[50]`. \n",
"weightmap.diag_fig(50,grid_polygon_info)"
"weightmap.diag_fig(50,ds)"
]
},
{
"cell_type": "markdown",
"id": "075f03b0-f099-44fe-a157-8b1825849bb6",
"metadata": {},
"source": [
"This diagnostic figure shows the weight used for each grid, relative to the grid cell with the largest overlap. For this county in Montana, the middle grid cell has the largest area overlap, followed by the grid cell to the left. The bottom right grid cell only barely touches the county and therefore has a much lower relative weight in the aggregation calculation. "
]
},
{
Expand Down Expand Up @@ -334,7 +313,7 @@
"source": [
"# Let's use the \"NAME\" (aka, county name) column to plot a \n",
"# diagnostic plot of Los Angeles county\n",
"weightmap.diag_fig({'NAME':'Los Angeles'},grid_polygon_info)"
"weightmap.diag_fig({'NAME':'Los Angeles'},ds)"
]
},
{
Expand Down Expand Up @@ -1138,7 +1117,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.12.1"
}
},
"nbformat": 4,
Expand Down
61 changes: 61 additions & 0 deletions tests/test_diag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
import pandas as pd
import numpy as np
import xarray as xr
import geopandas as gpd
from matplotlib import pyplot as plt

from xagg.core import aggregate
from xagg.wrappers import pixel_overlaps

try:
from cartopy import crs as ccrs
from matplotlib import pyplot as pyplot
import cmocean
_has_plotpckgs=True
except ImportError:
# To be able to test the rest with environments without xesmf
_has_plotpckgs=False

##### diag_fig() tests #####

# Load sample data
ds = xr.open_dataset('data/climate_data/tas_Amon_CCSM4_rcp85_monthavg_20700101-20991231.nc')
gdf = gpd.read_file('data/geo_data/UScounties.shp')
# Subset manually to not go through slow subset_find in every test
ds = ds.sel(lat=slice(30,55),lon=slice(360-130,360-65))

# Calculate overlaps
wm = pixel_overlaps(ds,gdf)

if _has_plotpckgs:
def test_diag_fig_noerror():
# Test whether error occurs when calling diag_fig()
# With some random location
wm.diag_fig(50,ds)

def test_diag_fig_isfig():
# Test to make sure a figure, axis is returned
fig,ax = wm.diag_fig(50,ds)

assert isinstance(fig,plt.Figure)
assert isinstance(ax,plt.Axes)

def test_diag_fig_subsetbypolyid():
# Test to make sure the right location is returned
# when using an integer poly index
fig,ax = wm.diag_fig(50,ds)

assert ax.get_title() == 'Poly #50: Sanders; Montana; 30; 089; 30089'

def test_diag_fig_subsetbyname():
# Test to make sure the right location is returned
# when using a column dictionary
fig,ax=wm.diag_fig({'NAME':'Los Angeles'},ds)

assert ax.get_title() == 'Poly #2384: Los Angeles; California; 06; 037; 06037'
else:
def test_diag_fig_noimport():
# Should raise ImportError in the no-plot environment
with pytest.raises(ImportError):
wm.diag_fig(50,ds)
10 changes: 8 additions & 2 deletions xagg/auxfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
import os
import re
from .options import get_options

def normalize(a,drop_na = False):
""" Normalizes the vector `a`
Expand Down Expand Up @@ -187,7 +188,7 @@ def fix_ds(ds,var_cipher = {'latitude':{'latitude':'lat','longitude':'lon'},
def get_bnds(ds,wrap_around_thresh='dynamic',
break_window_width=3,
break_thresh_x=2,
silent=False):
silent=None):

""" Builds vectors of lat/lon bounds if not present in `ds`

Expand Down Expand Up @@ -239,6 +240,9 @@ def get_bnds(ds,wrap_around_thresh='dynamic',
already existed, or with new variables "lat_bnds" and "lon_bnds"
if not.
"""
if silent is None:
silent = get_options()['silent']

#----------- Setup -----------
if (type(wrap_around_thresh) == str) and (wrap_around_thresh != 'dynamic'):
raise ValueError('`wrap_around_thresh` must either be numeric or the string "dynamic"; instead, it is '+str(wrap_around_thresh)+'.')
Expand Down Expand Up @@ -378,7 +382,7 @@ def get_bnds(ds,wrap_around_thresh='dynamic',
return ds


def subset_find(ds0,ds1,silent=False):
def subset_find(ds0,ds1,silent=None):
""" Finds the grid of `ds1` in `ds0`, and subsets `ds0` to the grid in `ds1`

Parameters
Expand All @@ -402,6 +406,8 @@ def subset_find(ds0,ds1,silent=False):
The input `ds0`, subset to the locations in `ds1`.

"""
if silent is None:
silent = get_options()['silent']

if 'loc' not in ds0.sizes:
ds0 = ds0.stack(loc = ('lat','lon'))
Expand Down
12 changes: 9 additions & 3 deletions xagg/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import warnings
import os
import re
from .options import get_options
from .options import get_options,set_options
from .auxfuncs import subset_find

try:
import cartopy
Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(self,agg,source_grid,geometry,overlap_da=None,weights='nowghts'):
self.weights = weights
self.overlap_da = overlap_da

def diag_fig(self,poly_id,ds):
def diag_fig(self,poly_id,ds,fig=None,ax=None):
""" Create a diagnostic figure showing overlap between pixels and a given polygon

See `xagg.diag.diag_fig()` for more info.
Expand All @@ -50,9 +51,14 @@ def diag_fig(self,poly_id,ds):
except ImportError:
raise ImportError('`wm.diag_fig()` separately requires `cartopy`, `matplotlib`, and `cmocean` to function; make sure these are installed first.')

# Adjust grids between the input ds and the weightmap grid (in case subset to
# bbox was used)
with set_options(silent=True):
ds = subset_find(ds,self.source_grid)

# Plot diagnostic figure
diag_fig(self,poly_id,ds)
fig,ax=diag_fig(self,poly_id,ds,fig=fig,ax=ax)
return fig,ax

def to_file(self,fn,overwrite=False):
""" Save a copy of the weightmap, to avoid recalculating it
Expand Down
2 changes: 2 additions & 0 deletions xagg/diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,5 @@ def diag_fig(wm,poly_id,pix_overlap_info,
# Gridlines
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=['x','y','bottom','left'],
linewidth=1, color='gray', alpha=0.5, linestyle=':')

return fig,ax
Loading