Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch20 transfer #86

Merged
merged 9 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ tensorboard>=2.2.0
PyYAML>=5.1
geowombat@git+https://github.com/jgrss/geowombat.git
tsaug@git+https://github.com/jgrss/tsaug.git
setuptools==59.5.0
setuptools>=70
numpydoc
sphinx
sphinx-automodapi
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[build-system]
requires = [
'setuptools>=65.5.1',
'setuptools>=70',
'wheel',
'numpy<2,>=1.22',
]
Expand Down
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ package_dir=
packages=find:
include_package_data = True
setup_requires =
setuptools>=65.5.1
setuptools>=70
wheel
numpy<2,>=1.22
python_requires =
Expand Down Expand Up @@ -63,7 +63,6 @@ install_requires =
geowombat@git+https://github.com/jgrss/geowombat.git
tsaug@git+https://github.com/jgrss/tsaug.git
pygrts@git+https://github.com/jgrss/[email protected]
setuptools>=65.5.1

[options.extras_require]
docs = numpydoc
Expand Down
144 changes: 100 additions & 44 deletions src/cultionet/data/create.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import typing as T
from pathlib import Path

Expand All @@ -11,6 +12,7 @@
import torch
import xarray as xr
from affine import Affine
from dask.diagnostics import ProgressBar
from dask.distributed import Client, LocalCluster, progress
from rasterio.windows import Window, from_bounds
from scipy.ndimage import label as nd_label
Expand Down Expand Up @@ -71,19 +73,22 @@ def reshape_and_mask_array(
num_bands: int,
gain: float,
offset: int,
apply_gain: bool = True,
) -> xr.DataArray:
"""Reshapes an array and masks no-data values."""

src_ts_stack = xr.DataArray(
# Date are stored [(band x time) x height x width]
dtype = 'float32' if apply_gain else 'int16'

time_series = xr.DataArray(
# Data are stored [(band x time) x height x width]
(
data.data.reshape(
num_bands,
num_time,
data.gw.nrows,
data.gw.ncols,
).transpose(1, 0, 2, 3)
).astype('float32'),
).astype(dtype),
dims=('time', 'band', 'y', 'x'),
coords={
'time': range(num_time),
Expand All @@ -94,12 +99,18 @@ def reshape_and_mask_array(
attrs=data.attrs.copy(),
)

with xr.set_options(keep_attrs=True):
time_series = (src_ts_stack.gw.mask_nodata() * gain + offset).fillna(0)
if apply_gain:

with xr.set_options(keep_attrs=True):
# Mask and scale the data
time_series = (
time_series.gw.mask_nodata() * gain + offset
).fillna(0)

return time_series


@threadpool_limits.wrap(limits=1, user_api="blas")
def create_predict_dataset(
image_list: T.List[T.List[T.Union[str, Path]]],
region: str,
Expand All @@ -113,26 +124,36 @@ def create_predict_dataset(
padding: int = 101,
num_workers: int = 1,
compress_method: T.Union[int, str] = 'zlib',
use_cluster: bool = True,
):
"""Creates a prediction dataset for an image."""

# Read windows larger than the re-chunk window size
read_chunksize = 1024
while True:
if read_chunksize < window_size:
read_chunksize *= 2
else:
break

with gw.config.update(ref_res=ref_res):
with gw.open(
image_list,
stack_dim="band",
band_names=list(range(1, len(image_list) + 1)),
resampling=resampling,
chunks=512,
chunks=read_chunksize,
) as src_ts:

# Get the time and band count
num_time, num_bands = get_image_list_dims(image_list, src_ts)

time_series: xr.DataArray = reshape_and_mask_array(
time_series = reshape_and_mask_array(
data=src_ts,
num_time=num_time,
num_bands=num_bands,
gain=gain,
offset=offset,
apply_gain=False,
)

# Chunk the array into the windows
Expand Down Expand Up @@ -172,42 +193,77 @@ def create_predict_dataset(
trim=False,
)

with dask.config.set(
{
"distributed.worker.memory.terminate": False,
"distributed.comm.retry.count": 10,
"distributed.comm.timeouts.connect": 5,
"distributed.scheduler.allowed-failures": 20,
}
):
with LocalCluster(
processes=True,
n_workers=num_workers,
threads_per_worker=1,
memory_target_fraction=0.97,
memory_limit="4GB", # per worker limit
) as cluster:
with Client(cluster) as client:
with BatchStore(
data=time_series,
write_path=process_path,
res=ref_res,
resampling=resampling,
region=region,
start_date=pd.to_datetime(
Path(image_list[0]).stem, format=date_format
).strftime("%Y%m%d"),
end_date=pd.to_datetime(
Path(image_list[-1]).stem, format=date_format
).strftime("%Y%m%d"),
window_size=window_size,
padding=padding,
compress_method=compress_method,
gain=gain,
) as batch_store:
save_tasks = batch_store.save(time_series_array)
results = client.persist(save_tasks)
progress(results)
if use_cluster:
with dask.config.set(
{
"distributed.worker.memory.terminate": False,
"distributed.comm.retry.count": 10,
"distributed.comm.timeouts.connect": 5,
"distributed.scheduler.allowed-failures": 20,
"distributed.worker.memory.pause": 0.95,
"distributed.worker.memory.target": 0.97,
"distributed.worker.memory.spill": False,
"distributed.scheduler.worker-saturation": 1.0,
}
):
with LocalCluster(
processes=True,
n_workers=num_workers,
threads_per_worker=1,
memory_limit="6GB", # per worker limit
silence_logs=logging.ERROR,
) as cluster:
with Client(cluster) as client:
with BatchStore(
data=time_series,
write_path=process_path,
res=ref_res,
resampling=resampling,
region=region,
start_date=pd.to_datetime(
Path(image_list[0]).stem,
format=date_format,
).strftime("%Y%m%d"),
end_date=pd.to_datetime(
Path(image_list[-1]).stem,
format=date_format,
).strftime("%Y%m%d"),
window_size=window_size,
padding=padding,
compress_method=compress_method,
) as batch_store:
save_tasks = batch_store.save(
time_series_array
)
results = client.gather(
client.persist(save_tasks)
)
progress(results)

else:

with dask.config.set(
scheduler='processes', num_workers=num_workers
):
with BatchStore(
data=time_series,
write_path=process_path,
res=ref_res,
resampling=resampling,
region=region,
start_date=pd.to_datetime(
Path(image_list[0]).stem, format=date_format
).strftime("%Y%m%d"),
end_date=pd.to_datetime(
Path(image_list[-1]).stem, format=date_format
).strftime("%Y%m%d"),
window_size=window_size,
padding=padding,
compress_method=compress_method,
) as batch_store:
save_tasks = batch_store.save(time_series_array)
with ProgressBar():
save_tasks.compute()


class ReferenceArrays:
Expand Down
4 changes: 1 addition & 3 deletions src/cultionet/data/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(
window_size: int,
padding: int,
compress_method: Union[int, str],
gain: float,
):
self.data = data
self.res = res
Expand All @@ -43,7 +42,6 @@ def __init__(
self.window_size = window_size
self.padding = padding
self.compress_method = compress_method
self.gain = gain

def __setitem__(self, key: tuple, item: np.ndarray) -> None:
time_range, index_range, y, x = key
Expand Down Expand Up @@ -87,7 +85,7 @@ def write_batch(self, x: np.ndarray, w: Window, w_pad: Window):
)

x = einops.rearrange(
torch.from_numpy(x / self.gain).to(dtype=torch.int32),
torch.from_numpy(x.astype('int32')).to(dtype=torch.int32),
't c h w -> 1 c t h w',
)

Expand Down
1 change: 1 addition & 0 deletions src/cultionet/enums/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class LossTypes(StrEnum):
CLASS_BALANCED_MSE = "ClassBalancedMSELoss"
TANIMOTO_COMPLEMENT = "TanimotoComplementLoss"
TANIMOTO = "TanimotoDistLoss"
TANIMOTO_COMBINED = "TanimotoCombined"
TOPOLOGY = "TopologyLoss"


Expand Down
1 change: 1 addition & 0 deletions src/cultionet/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .losses import (
BoundaryLoss,
ClassBalancedMSELoss,
CombinedLoss,
LossPreprocessing,
TanimotoComplementLoss,
TanimotoDistLoss,
Expand Down
Loading