Skip to content

Commit

Permalink
Merge pull request #59 from bopen/workflow-refactor
Browse files Browse the repository at this point in the history
Refactor high level workflow to accept intermediate inputs
  • Loading branch information
alexamici authored Jan 23, 2024
2 parents 9a6986f + e501efc commit 4677d96
Showing 1 changed file with 119 additions and 79 deletions.
198 changes: 119 additions & 79 deletions sarsen/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,113 @@ def simulate_acquisition(
return acquisition


def map_simulate_acquisition(
dem_ecef: xr.DataArray,
position_ecef: xr.DataArray,
template_raster: xr.DataArray,
correct_radiometry: Optional[str] = None,
) -> xr.Dataset:
acquisition_template = xr.Dataset(
data_vars={
"slant_range_time": template_raster,
"azimuth_time": template_raster.astype("datetime64[ns]"),
}
)
include_variables = {"slant_range_time", "azimuth_time"}
if correct_radiometry is not None:
acquisition_template["gamma_area"] = template_raster
include_variables.add("gamma_area")

acquisition = xr.map_blocks(
simulate_acquisition,
dem_ecef,
kwargs={
"position_ecef": position_ecef,
"include_variables": include_variables,
},
template=acquisition_template,
)
return acquisition


def do_terrain_correction(
product: datamodel.SarProduct,
dem_raster: xr.DataArray,
correct_radiometry: Optional[str] = None,
interp_method: xr.core.types.InterpOptions = "nearest",
grouping_area_factor: Tuple[float, float] = (3.0, 3.0),
radiometry_chunks: int = 2048,
radiometry_bound: int = 128,
) -> tuple[xr.DataArray, Optional[xr.DataArray]]:
logger.info("pre-process DEM")

dem_ecef = xr.map_blocks(
scene.convert_to_dem_ecef, dem_raster, kwargs={"source_crs": dem_raster.rio.crs}
)
dem_ecef = dem_ecef.drop_vars(dem_ecef.rio.grid_mapping)

logger.info("simulate acquisition")

template_raster = dem_raster.drop_vars(dem_raster.rio.grid_mapping) * 0.0

acquisition = map_simulate_acquisition(
dem_ecef, product.state_vectors(), template_raster, correct_radiometry
)

simulated_beta_nought = None
if correct_radiometry is not None:
logger.info("simulate radiometry")

grid_parameters = product.grid_parameters(grouping_area_factor)

if correct_radiometry == "gamma_bilinear":
gamma_weights = radiometry.gamma_weights_bilinear
elif correct_radiometry == "gamma_nearest":
gamma_weights = radiometry.gamma_weights_nearest

acquisition = acquisition.persist()

simulated_beta_nought = chunking.map_ovelap(
obj=acquisition,
function=gamma_weights,
chunks=radiometry_chunks,
bound=radiometry_bound,
kwargs=grid_parameters,
template=template_raster,
)
simulated_beta_nought.attrs["long_name"] = "terrain-simulated beta nought"

simulated_beta_nought.x.attrs.update(dem_raster.x.attrs)
simulated_beta_nought.y.attrs.update(dem_raster.y.attrs)
simulated_beta_nought.rio.set_crs(dem_raster.rio.crs)

logger.info("calibrate image")

beta_nought = product.beta_nought()

logger.info("terrain-correct image")

# HACK: we monkey-patch away an optimisation in xr.DataArray.interp that actually makes
# the interpolation much slower when indeces are dask arrays.
with mock.patch("xarray.core.missing._localize", lambda o, i: (o, i)):
geocoded = product.interp_sar(
beta_nought,
azimuth_time=acquisition.azimuth_time,
slant_range_time=acquisition.slant_range_time,
method=interp_method,
)

if correct_radiometry is not None:
geocoded = geocoded / simulated_beta_nought
geocoded.attrs["long_name"] = "terrain-corrected gamma nought"

geocoded.x.attrs.update(dem_raster.x.attrs)
geocoded.y.attrs.update(dem_raster.y.attrs)
geocoded.rio.set_crs(dem_raster.rio.crs)

return geocoded, simulated_beta_nought


def terrain_correction(
product: datamodel.SarProduct,
dem_urlpath: str,
Expand Down Expand Up @@ -95,6 +202,8 @@ def terrain_correction(
raise ValueError(
f"{correct_radiometry=}. Must be one of: {allowed_correct_radiometry}"
)
if simulated_urlpath is not None and correct_radiometry is None:
raise ValueError("Simulation cannot be saved")
if output_urlpath is None and simulated_urlpath is None:
raise ValueError("No output selected")

Expand All @@ -121,64 +230,18 @@ def terrain_correction(
f"{product.product_type=}. Must be one of: {allowed_product_types}"
)

logger.info("pre-process DEM")

dem_ecef = xr.map_blocks(
scene.convert_to_dem_ecef, dem_raster, kwargs={"source_crs": dem_raster.rio.crs}
)
dem_ecef = dem_ecef.drop_vars(dem_ecef.rio.grid_mapping)

logger.info("simulate acquisition")

template_raster = dem_raster.drop_vars(dem_raster.rio.grid_mapping) * 0.0
acquisition_template = xr.Dataset(
data_vars={
"slant_range_time": template_raster,
"azimuth_time": template_raster.astype("datetime64[ns]"),
}
)
include_variables = {"slant_range_time", "azimuth_time"}
if correct_radiometry is not None:
acquisition_template["gamma_area"] = template_raster
include_variables.add("gamma_area")

acquisition = xr.map_blocks(
simulate_acquisition,
dem_ecef,
kwargs={
"position_ecef": product.state_vectors(),
"include_variables": include_variables,
},
template=acquisition_template,
geocoded, simulated_beta_nought = do_terrain_correction(
product=product,
dem_raster=dem_raster,
correct_radiometry=correct_radiometry,
interp_method=interp_method,
grouping_area_factor=grouping_area_factor,
radiometry_chunks=radiometry_chunks,
radiometry_bound=radiometry_bound,
)

if correct_radiometry is not None:
logger.info("simulate radiometry")

grid_parameters = product.grid_parameters(grouping_area_factor)

if correct_radiometry == "gamma_bilinear":
gamma_weights = radiometry.gamma_weights_bilinear
elif correct_radiometry == "gamma_nearest":
gamma_weights = radiometry.gamma_weights_nearest

acquisition = acquisition.persist()

simulated_beta_nought = chunking.map_ovelap(
obj=acquisition,
function=gamma_weights,
chunks=radiometry_chunks,
bound=radiometry_bound,
kwargs=grid_parameters,
template=template_raster,
)
simulated_beta_nought.attrs["long_name"] = "terrain-simulated beta nought"

simulated_beta_nought.x.attrs.update(dem_raster.x.attrs)
simulated_beta_nought.y.attrs.update(dem_raster.y.attrs)
simulated_beta_nought.rio.set_crs(dem_raster.rio.crs)

if simulated_urlpath is not None:
assert simulated_beta_nought is not None
if output_urlpath is not None:
simulated_beta_nought.persist()

Expand All @@ -199,32 +262,9 @@ def terrain_correction(
maybe_delayed.compute()

if output_urlpath is None:
assert simulated_beta_nought is not None
return simulated_beta_nought

logger.info("calibrate image")

beta_nought = product.beta_nought()

logger.info("terrain-correct image")

# HACK: we monkey-patch away an optimisation in xr.DataArray.interp that actually makes
# the interpolation much slower when indeces are dask arrays.
with mock.patch("xarray.core.missing._localize", lambda o, i: (o, i)):
geocoded = product.interp_sar(
beta_nought,
azimuth_time=acquisition.azimuth_time,
slant_range_time=acquisition.slant_range_time,
method=interp_method,
)

if correct_radiometry is not None:
geocoded = geocoded / simulated_beta_nought
geocoded.attrs["long_name"] = "terrain-corrected gamma nought"

geocoded.x.attrs.update(dem_raster.x.attrs)
geocoded.y.attrs.update(dem_raster.y.attrs)
geocoded.rio.set_crs(dem_raster.rio.crs)

logger.info("save output")

maybe_delayed = geocoded.rio.to_raster(
Expand Down

0 comments on commit 4677d96

Please sign in to comment.