diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index be09eb2..e3d5252 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -32,10 +32,10 @@ jobs: run: | python -m pip install --upgrade pip pip install tensorflow-cpu==2.12.0 - pip install jax==0.4.14 - pip install jaxlib==0.4.14 + pip install jax==0.4.20 + pip install jaxlib==0.4.20 pip install -r docker/requirements.txt pip install -e . - name: Test with pytest run: | - pytest --splits 4 --group ${{ matrix.group }} --randomly-seed=0 -k "not slow" + pytest --splits 4 --group ${{ matrix.group }} --randomly-seed=0 -k "not slow and not integration" diff --git a/.gitignore b/.gitignore index 724b2a9..816a88f 100644 --- a/.gitignore +++ b/.gitignore @@ -41,7 +41,7 @@ pip-delete-this-directory.txt htmlcov/ .tox/ .nox/ -.coverage +.coverage* .coverage.* .cache nosetests.xml @@ -137,6 +137,7 @@ dmypy.json # notebook *.ipynb notebooks/ +!examples/segmentation/inference.ipynb # hydra outputs outputs/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 68d072e..8ce5426 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,6 +5,7 @@ repos: rev: v4.5.0 hooks: - id: check-added-large-files + args: ["--maxkb=15000"] - id: check-ast - id: check-byte-order-marker - id: check-builtin-literals @@ -27,36 +28,16 @@ repos: hooks: - id: isort - repo: https://github.com/psf/black - rev: 23.10.0 + rev: 23.11.0 hooks: - id: black args: - --line-length=100 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.6.1 + rev: v1.7.1 hooks: # https://github.com/python/mypy/issues/4008#issuecomment-582458665 - id: mypy - name: mypy-imgx - files: ^imgx/ - entry: mypy imgx/ - pass_filenames: false - args: - [ - --strict-equality, - --disallow-untyped-calls, - --disallow-untyped-defs, - --disallow-incomplete-defs, - --check-untyped-defs, - --disallow-untyped-decorators, - --warn-redundant-casts, - --warn-unused-ignores, - --no-warn-no-return, - --warn-unreachable, - ] - - id: mypy - name: mypy-imgx_datasets - files: ^imgx_datasets/ - entry: mypy imgx_datasets/ + name: mypy pass_filenames: false args: [ @@ -64,6 +45,7 @@ repos: --disallow-untyped-calls, --disallow-untyped-defs, --disallow-incomplete-defs, + --disallow-any-generics, --check-untyped-defs, --disallow-untyped-decorators, --warn-redundant-casts, @@ -72,7 +54,7 @@ repos: --warn-unreachable, ] - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.3 + rev: v3.1.0 hooks: - id: prettier args: @@ -80,7 +62,7 @@ repos: - --prose-wrap=always - --tab-width=2 - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.1.1" + rev: "v0.1.6" hooks: - id: ruff - repo: https://github.com/pre-commit/mirrors-pylint diff --git a/README.md b/README.md index 8eb25e9..818340d 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,8 @@ # ImgX-DiffSeg -ImgX-DiffSeg is a Jax-based deep learning toolkit (now using Flax) for biomedical image -segmentation. +ImgX-DiffSeg is a Jax-based deep learning toolkit using Flax for biomedical image segmentations. -This repository currently includes the implementation of the following work +This repository includes the implementation of the following work - [A Recycling Training Strategy for Medical Image Segmentation with Diffusion Denoising Models](https://arxiv.org/abs/2308.16355) - [Importance of Aligning Training Strategy with Evaluation for Diffusion Models in 3D Multiclass Segmentation](https://arxiv.org/abs/2303.06040) @@ -11,6 +10,19 @@ This repository currently includes the implementation of the following work :construction: **The codebase is still under active development for more enhancements and applications.** :construction: +- November 2023: + - :warning: Upgrade to JAX to 0.4.20. + - :warning: Removed Haiku specific modification to convolutional layers. This may impact model + performance. + - :smiley: Added example notebooks for inference on single image without TFDS. + - Added integration tests for training, validation and testing. + - Refactored config. + - Added `patch_size` and `scale_factor` to data config. + - Moved loss config from main config to task config. + - Refactored code, including defining `imgx/task` submodule. +- October 2023: :sunglasses: Migrated from [Haiku](https://github.com/google-deepmind/dm-haiku) to + [Flax](https://github.com/google/flax) following Google DeepMind's recommendation. + :mailbox: Please feel free to [create an issue](https://github.com/mathpluscode/ImgX-DiffSeg/issues/new/choose) to request features or [reach out](https://orcid.org/0000-0002-1184-7421) for collaborations. :mailbox: @@ -61,11 +73,6 @@ See the [readme](imgx_datasets/README.md) for further details. - Gradient clipping and accumulation. - [Early stopping](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html). -**Changelog** - -- October 2023: Migrated from [Haiku](https://github.com/google-deepmind/dm-haiku) to - [Flax](https://github.com/google/flax) following Google DeepMind's recommendation. - ## Installation ### TPU with Docker @@ -112,8 +119,7 @@ The following instructions have been tested only for TPU-v3-8. The docker contai ### GPU with Docker -The following instructions have been tested only for CUDA == 11.4.1 and CUDNN == 8.2.0. The docker -container uses non-root user. +CUDA >= 11.8 is required. The docker container uses non-root user. [Docker image used may be removed.](https://gitlab.com/nvidia/container-images/cuda/blob/master/doc/support-policy.md) 1. Build the docker image inside the repository. @@ -141,7 +147,7 @@ container uses non-root user. where - `--rm` removes the container once exit it. - - `-v` maps the `ImgX` folder into container. + - `-v` maps the current folder into container. 3. Install the package inside container. @@ -214,12 +220,10 @@ export DATASET_NAME="brats2021_mr" # Vanilla segmentation imgx_train data=${DATASET_NAME} task=seg -imgx_valid --log_dir wandb/latest-run/ imgx_test --log_dir wandb/latest-run/ # Diffusion-based segmentation imgx_train data=${DATASET_NAME} task=gaussian_diff_seg -imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDPM imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDPM imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM @@ -259,10 +263,26 @@ Run the command below to test and get coverage report. As JAX tests requires two threads, therefore requires 8 CPUs in total. ```bash -pytest --cov=imgx -n 4 imgx +pytest --cov=imgx -n 4 imgx -k "not integration" pytest --cov=imgx_datasets -n 4 imgx_datasets ``` +`-k "not integration"` excludes integration tests, which requires downloading muscle ultrasound and +amos CT data sets. + +For integration tests, run the command below. `-s` enables the print of stdout. This test may take +40-60 minutes. + +```bash +pytest imgx/integration_test.py -s +``` + +To test the jupyter notebooks, run the command below. + +```bash +pytest --nbmake examples/**/*.ipynb +``` + ## References - [Segment Anything (PyTorch)](https://github.com/facebookresearch/segment-anything) diff --git a/docker/Dockerfile b/docker/Dockerfile index 8d81867..732912c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -75,10 +75,10 @@ COPY docker/requirements.txt /${USER}/requirements.txt RUN /${USER}/conda/bin/pip3 install --upgrade pip \ && /${USER}/conda/bin/pip3 install \ - jax==0.4.14 \ - jaxlib==0.4.14+cuda11.cudnn86 \ + jax==0.4.20 \ + jaxlib==0.4.20+cuda11.cudnn86 \ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ - && /${USER}/conda/bin/pip3 install tensorflow-cpu==2.12.0 \ + && /${USER}/conda/bin/pip3 install tensorflow-cpu==2.14.0 \ && /${USER}/conda/bin/pip3 install -r /${USER}/requirements.txt RUN git config --global --add safe.directory /${USER}/ImgX diff --git a/docker/Dockerfile.tpu b/docker/Dockerfile.tpu index 9885672..e49ed14 100644 --- a/docker/Dockerfile.tpu +++ b/docker/Dockerfile.tpu @@ -1,4 +1,4 @@ -FROM mambaorg/micromamba:0.27.0 as conda +FROM mambaorg/micromamba:1.5.1 as conda # Speed up the build, and avoid unnecessary writes to disk ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 diff --git a/docker/environment.yml b/docker/environment.yml index f3341e4..155075a 100644 --- a/docker/environment.yml +++ b/docker/environment.yml @@ -3,9 +3,9 @@ channels: - defaults dependencies: - python=3.9 - - pip=23.0.1 + - pip=23.3.1 - pip: - - tensorflow-cpu==2.13.0 - - jax==0.4.14 - - jaxlib==0.4.14 + - tensorflow-cpu==2.14.0 + - jax==0.4.20 + - jaxlib==0.4.20 - -r requirements.txt diff --git a/docker/environment_mac_m1.yml b/docker/environment_mac_m1.yml index 4565a48..d05505f 100644 --- a/docker/environment_mac_m1.yml +++ b/docker/environment_mac_m1.yml @@ -3,10 +3,10 @@ channels: - defaults dependencies: - python=3.9 - - pip=23.0.1 + - pip=23.3.1 - pip: - - tensorflow-macos==2.13.0 - - tensorflow-metal==1.0.1 - - jax==0.4.14 - - jaxlib==0.4.14 + - tensorflow-macos==2.14.0 + - tensorflow-metal==1.1.0 + - jax==0.4.20 + - jaxlib==0.4.20 - -r requirements.txt diff --git a/docker/environment_tpu.yml b/docker/environment_tpu.yml index 446f968..489c050 100644 --- a/docker/environment_tpu.yml +++ b/docker/environment_tpu.yml @@ -4,10 +4,10 @@ channels: - conda-forge dependencies: - python=3.9 - - pip=23.0.1 + - pip=23.3.1 - pip: - --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html - - tensorflow-cpu==2.13.0 - - jax[tpu]==0.4.14 - - jaxlib==0.4.14 + - tensorflow-cpu==2.14.0 + - jax[tpu]==0.4.20 + - jaxlib==0.4.20 - -r requirements.txt diff --git a/docker/requirements.txt b/docker/requirements.txt index e6a80de..ace5b52 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -1,24 +1,26 @@ -SimpleITK==2.3.0 +SimpleITK==2.3.1 chex==0.1.8 -coverage==7.3.1 -flax==0.7.4 +coverage==7.3.2 +flax==0.7.5 hydra-core==1.3.2 kaggle==1.5.16 -numpy==1.24.3 # limited by tensorflow-macos 2.13.0 -opencv-python==4.8.0.76 +nbmake==1.4.6 +numpy==1.26.2 +opencv-python==4.8.1.78 optax==0.1.7 -pandas==2.1.1 -pre-commit==3.4.0 +pandas==2.1.3 +pre-commit==3.5.0 protobuf==3.20.3 # https://github.com/tensorflow/datasets/issues/4858 pytest-cov==4.1.0 +pytest-mock==3.12.0 pytest-randomly==3.15.0 pytest-split==0.8.1 -pytest-xdist==3.3.1 -pytest==7.4.2 +pytest-xdist==3.5.0 +pytest==7.4.3 rdkit-pypi==2022.9.5 -rich==13.5.3 -ruff==0.0.291 +rich==13.7.0 +ruff==0.1.6 tensorflow-datasets==4.9.3 -torch==2.0.1 # for testing only -wandb==0.15.11 -wily==1.24.2 +torch==2.1.1 # for testing only +wandb==0.16.0 +wily==1.25.0 diff --git a/examples/segmentation/BB_anon_348_1.png b/examples/segmentation/BB_anon_348_1.png new file mode 100644 index 0000000..98b17f3 Binary files /dev/null and b/examples/segmentation/BB_anon_348_1.png differ diff --git a/examples/segmentation/BB_anon_348_1_mask.png b/examples/segmentation/BB_anon_348_1_mask.png new file mode 100644 index 0000000..de04d59 Binary files /dev/null and b/examples/segmentation/BB_anon_348_1_mask.png differ diff --git a/examples/segmentation/config.yaml b/examples/segmentation/config.yaml new file mode 100644 index 0000000..9e21b35 --- /dev/null +++ b/examples/segmentation/config.yaml @@ -0,0 +1,83 @@ +data: + name: muscle_us + loader: + max_num_samples_per_split: -1 + patch_shape: + - 480 + - 512 + patch_overlap: + - 0 + - 0 + data_augmentation: + max_rotation: + - 0.088 + max_translation: + - 10 + - 10 + max_scaling: + - 0.15 + - 0.15 + trainer: + max_num_samples: 512000 + batch_size: 64 + batch_size_per_replica: 8 + num_devices_per_replica: 1 + patch_size: + - 2 + - 2 + scale_factor: + - 2 + - 2 +task: + name: segmentation + model: + _target_: imgx.model.Unet + remat: true + num_spatial_dims: 2 + patch_size: + - 2 + - 2 + scale_factor: + - 2 + - 2 + num_channels: + - 8 + - 16 + - 32 + - 64 + out_channels: 2 + num_heads: 8 + widening_factor: 4 + num_transform_layers: 1 + loss: + dice: 1.0 + cross_entropy: 0.0 + focal: 20.0 + early_stopping: + metric: mean_binary_dice_score_without_background + mode: max + min_delta: 0.0001 + patience: 10 +debug: false +seed: 0 +half_precision: true +optimizer: + name: adamw + kwargs: + b1: 0.9 + b2: 0.999 + weight_decay: 1.0e-08 + grad_norm: 1.0 + lr_schedule: + warmup_steps: 100 + decay_steps: 10000 + init_value: 1.0e-05 + peak_value: 0.0008 + end_value: 5.0e-05 +logging: + root_dir: null + log_freq: 10 + save_freq: 100 + wandb: + project: imgx + entity: entity diff --git a/examples/segmentation/files/ckpt/checkpoint_1300/checkpoint b/examples/segmentation/files/ckpt/checkpoint_1300/checkpoint new file mode 100644 index 0000000..0c335e6 Binary files /dev/null and b/examples/segmentation/files/ckpt/checkpoint_1300/checkpoint differ diff --git a/examples/segmentation/inference.ipynb b/examples/segmentation/inference.ipynb new file mode 100644 index 0000000..dfa60a0 --- /dev/null +++ b/examples/segmentation/inference.ipynb @@ -0,0 +1,208 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fb9844d4-1d21-45a6-8383-61d18bdf96e4", + "metadata": {}, + "source": [ + "# Segmentation Inference Example\n", + "\n", + "This notebooks aims to demo the inference on a custom data sample, instead of using pre-defined tfds dataset. To execute this notebook, please follow the [README](https://github.com/mathpluscode/ImgX-DiffSeg) to install the `imgx` package locally." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "eac79551-2ff5-4bd0-b403-bfe45b794db2", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/yunguanfu/miniforge3/envs/imgx/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from pathlib import Path\n", + "import numpy as np\n", + "import jax\n", + "from flax.training import common_utils\n", + "from omegaconf import OmegaConf\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from imgx.task.segmentation.experiment import SegmentationExperiment\n", + "from imgx_datasets.constant import IMAGE, LABEL\n", + "from imgx_datasets.save import load_2d_grayscale_image" + ] + }, + { + "cell_type": "markdown", + "id": "39287037-c4c7-46cb-b678-2858a2af814e", + "metadata": {}, + "source": [ + "## Load a model\n", + "\n", + "The model is a supervisedly trained four-layer U-net with channels [8, 16, 32, 64]." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b6115f20-beb6-43db-8ead-8e4d90bc69fc", + "metadata": {}, + "outputs": [], + "source": [ + "config_path = \"config.yaml\" # backup config stored\n", + "ckpt_dir = \"files/ckpt\"\n", + "step = 1300" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "be7db5f1-fc70-4775-adc4-2abbb8e9b4a1", + "metadata": {}, + "outputs": [], + "source": [ + "config = OmegaConf.load(config_path)\n", + "run = SegmentationExperiment(config=config)\n", + "train_state, _ = run.train_init(ckpt_dir=ckpt_dir, step=step) # still loads data from tfds" + ] + }, + { + "cell_type": "markdown", + "id": "62df3c6b-bef7-45e2-93e1-a3aa6c94bea6", + "metadata": {}, + "source": [ + "## Evaluate on a custom image\n", + "\n", + "`BB_anon_348_1` is a training data from `muscle_us` data set." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "22f79c47-1e54-48d2-b187-8dce98190066", + "metadata": {}, + "outputs": [], + "source": [ + "seed = 0\n", + "out_dir = Path(\"outputs\")\n", + "\n", + "image = load_2d_grayscale_image(\"BB_anon_348_1.png\", dtype=np.float32)\n", + "label = load_2d_grayscale_image(\"BB_anon_348_1_mask.png\")\n", + "# image.shape = (1, 1, 480, 512, 1) label.shape = (1, 1, 480, 512)\n", + "# the first axis is shard axis for pmap, the second axis is batch axis\n", + "batch = {IMAGE:image[None, None,..., None], LABEL:label[None, None, ...]}\n", + "uids = [\"BB_anon_348_1\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a23030e3-77a9-420f-9599-ca69ea1d61f7", + "metadata": {}, + "outputs": [], + "source": [ + "device_cpu = jax.devices(\"cpu\")[0]\n", + "key = jax.random.PRNGKey(seed)\n", + "key = common_utils.shard_prng_key(key)\n", + "\n", + "metrics, label_pred, key = run.eval_batch(\n", + " train_state=train_state,\n", + " key=key,\n", + " batch=batch,\n", + " uids=uids,\n", + " device_cpu=device_cpu,\n", + " out_dir=out_dir,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "c443fddd-b876-4b7f-be23-88d0db08de40", + "metadata": {}, + "source": [ + "## Visualize output" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7c148884-844d-4635-bf6c-4e978f8c2a70", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'binary_dice_score_class_0': 0.9822485446929932, 'binary_dice_score_class_1': 0.934889554977417, 'centroid_dist_class_0': 1.7067033052444458, 'centroid_dist_class_1': 3.3515679836273193, 'class_0_proportion_label': 0.7809703946113586, 'class_0_proportion_pred': 0.7905650734901428, 'class_1_proportion_label': 0.21902620792388916, 'class_1_proportion_pred': 0.209431454539299, 'hausdorff_dist_class_0': 11.313708498984761, 'hausdorff_dist_class_1': 25.09479504037782, 'iou_class_0': 0.9651163220405579, 'iou_class_1': 0.8777395486831665, 'mean_binary_dice_score': 0.9585690498352051, 'mean_binary_dice_score_without_background': 0.934889554977417, 'mean_centroid_dist': 2.5291357040405273, 'mean_centroid_dist_without_background': 3.3515679836273193, 'mean_hausdorff_dist': 18.204251769681292, 'mean_hausdorff_dist_without_background': 25.09479504037782, 'mean_iou': 0.9214279651641846, 'mean_iou_without_background': 0.8777395486831665, 'mean_mean_surface_dist': 4.810380502239952, 'mean_mean_surface_dist_without_background': 7.301036206416621, 'mean_normalised_surface_dice': 0.4292188610791067, 'mean_normalised_surface_dice_without_background': 0.13163716814159293, 'mean_stability': 0.5477132797241211, 'mean_stability_without_background': 0.32302427291870117, 'mean_surface_dist_class_0': 2.3197247980632825, 'mean_surface_dist_class_1': 7.301036206416621, 'normalised_surface_dice_class_0': 0.7268005540166205, 'normalised_surface_dice_class_1': 0.13163716814159293, 'stability_class_0': 0.772402286529541, 'stability_class_1': 0.32302427291870117}\n" + ] + } + ], + "source": [ + "# get scalar values\n", + "print(jax.tree_map(lambda x: x.item(), metrics))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "fe6aa939-1488-4ab4-b53f-da22b42d1838", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAggAAAChCAYAAACicEdsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAACQ3klEQVR4nO39WawkaXoWjj+5RkZERmbknnn22rqqZrpnut3tmfEyYDODZZCFkI0suOACS0gIgcQFMgJk0HBjCV8gjRAIEJYQAglbAsQmIVkg7LE99my9d1V11dlzXyIzMjIjl8jI38X5P+/kmeU/7TlVPae6v0dqdVWdU3kyI6Pyfb/3fZbIer1eQ0FBQUFBQUFhA9Ef9RNQUFBQUFBQuH5QDYKCgoKCgoLCd0E1CAoKCgoKCgrfBdUgKCgoKCgoKHwXVIOgoKCgoKCg8F1QDYKCgoKCgoLCd0E1CAoKCgoKCgrfhfgP+xfDMESj0YBlWYhEIk/zOSl8jLBerzEej7G1tYVo9MPpV9W9q/A0oO5dhecVH/Te/cANwnw+x3w+l9/X63V84hOfuNqzVFD4/+Hs7Aw7OzvP5LHVvavwLKHuXYXnFT/o3v3ADcKv//qv40tf+tJ3/flf+St/BalUCovFAgCQTCahaRri8TjCMITnechms9B1HdPpFMvlEolEAmEYgiaO6/Ua6/Uamqah3+8jkUgglUoBAGazGcbjMcIwRDKZxHK5RBAESCQSAIDlcgnTNLFcLjEcDrFcLhGGIeLxOHzfRzQaRSKRQCKRQBAE0DQNkUgEyWQS0+kU0WgUYRhC0zTMZjOs12uk02l5fnyMVCqFyWSCWCyGeDyOxWKB2WwGwzCgaZr8QzZNU74ehiHCMIRhGEgkEvB9H+PxGJZlyfVZr9dYLpeIRqMIggDL5RLZbBaTyQSmaWI6nWK1WiESiSAej2O9XmO1WiEej8OyLHkfkskk1us1YrEYDMOA53lYr9dIJBJyfUejETKZDBKJBFarFQB81ymEH0bj8Rjr9RrJZFKuh6Zp8H0fALBYLDCZTBCPx7FarbBcLjGfz6HrOoIgwGq1wmKxkOsxn88RBMHFTRePIxaLIZFI4Hd+53cA4NJredr4fveugsLTgLp3FZ5X/KB7N/JBrZa/s5N1XRe7u7v4W3/rbyEajWK9XkvhXa1WUhiXyyUymQx6vR7CMEQul5PiPZ/PEYYhyuUydF3HZDK5VPhd14VhGJhMJohEItJQsFBqmoblconpdIrFYoGtrS1Mp1N5ftFoFPF4HLPZDGEYIpVKSdMSj8eRTqcRi8Xk567XawRBgDAMpSAnEolLjQFw0dDE43FpdiaTCaLRKCKRyKViz+8n+NhsDqLRKJLJpHw9mUzK14bDofzMzabENE10u11Eo1EYhoFMJoPJZCJNg+u60rAAQCwWk/eH143XMhKJSCMWi8WwWCzk+fO5h2GIIAgwm82k0YrFYohEItIYsDEhfN+X7+F14eONRiMEQYBYLIZUKoV/+2//rfx5JpP5ILfinxjf795VUHgaUPeuwvOKH3TvfuAJgqZp0DTtu/58d3cXlmUhlUqh3+/DdV2Uy2WkUil4nofZbCan/EQigeFwCNu2EYYh8vk8ptMphsMhGo2GFBw2DclkEq7rIpfLIZ1Oy5Si1+vB9305xSYSCSwWCzx8+BDpdBq5XA6JRALlchmmaSIIAuzs7MD3fWQyGSl6vu+jWq3CdV24rivTisVigVKphF6vB8dxsLu7i+VyiUgkgtFohPl8jkQiIaf7Wq2GyWSC5XKJZDIJ0zTlsQeDARzHgW3bUhi3trYwGAxgGAY6nQ6WyyV0XcdqtYJhGEin03I9NU1DEAQYjUZIp9PwfV/2j4vFArFYDPl8HqPRCJqmwbIs9Ho9Oe2zKZrNZohGo0in0xiNRiiVSojFYjg5OcF0OkUqlZLJjmmaOD09RRAESKfTmEwmCMNQpg78c9M0UalUAEAanSAIUK1WUS6X0ev1AACmaaJYLCISicDzPBiGgbfffluauWeN73fvKihcd6h7V+FHiQ88QfhOuK6LbDaLX/qlX0IsFpNCznWDaZqwLEsagOFwKMWUpIggCLBYLGTUX61W0e/3pRBNJhOkUilpBvL5PEqlEhKJBNrttoz4Y7EYVquVjLUzmQx0XUe9XpcRShiGMhLXdV1G4TzFd7tdZLNZrNdrZLNZrFYrJJNJWJYlqwA2BsPhUJqA5XKJ+/fvQ9d1nJ+fy/Xp9/vI5/MoFAqYzWZIJpNySmcjtFgspHDO53O4rgvP85BOp7FcLmUCE4vFEIvFMB6Pkc1mEYvFkEwmpegHQYBkMgnP82SywLVEEATy+3g8Ds/zZOIyn8+haRrG4zFSqRQ0TUM2m4XjODBNE4VCQaY31WoV0+lU3otcLocgCNButzGdTlGtVqWB4vOKRqMyoQGA6XSKeDwukyZd1/Ebv/EbAJ7tKez73bsKCk8D6t5VeF7x1CYI3w+VSgXr9fpSISyVSphMJshkMrAsS07zN27cwGg0wpMnT6Tw2raN6XSK9XqNaDSKWq0G27YxHo9lP55KpWAYBqLRKBqNBgDgJ3/yJxGNRrFYLOTvzedzaS5GoxE++clPwnVdVKtVAMCjR49g2zY0TcP5+bmc+IfDIW7duiXMzjAMkUgkhJOwXq9Rr9eRSqVgmiai0ai8Vt/30Ww2EYYhHMdBv9+X03Wv18Px8TFyuZysFCzLguM4iEQiCMMQ/X4f8XgcqVQKy+VSViLL5RKj0Ug4FGEYwnVdDAYD+L4vq5pGo4FsNotIJIJYLIZ+vy+NQhAEyOVyWK/X8DxPuAGpVEr+fDQaIZ/PIxqNYjKZ4I033kAymcRkMoGmaQjDUHgNXOeQC8Lrt1qt8K1vfQsAoOs6lsulTF6Gw+GldUsYhrAsC/1+X94XBQUFBYXrhys3CCyYJOitVivMZjPM53McHR1dIqz1+304jiO/5tiaxTiZTCIMQ9RqNWiaJusInraTyaQUpzfeeAO+7yMMQ9y4cQOPHz+W9cRwOBTyYSwWQ6vVQiwWw40bN7C9vY3j42PUajVUKhX5GZ1OB5qmwTRNOfFaloX5fI5oNIq7d+8iEong/PxcTvjkJkynU9y4cQOvvfYalssl3nrrLZRKJcxmM5mucLIxnU7lOvFxVqsVKpWKnOpHoxF0XZfXO5lMUC6XYdu2XCfHcdBqtZBOp4Wn4Ps+DMPA7u4uisXipTXMCy+8IN9nmibq9TpM08RsNsObb74JwzAQBAG2trbk1M8pBZub5XIJTdOE/MjGx3Vd3L17F57nydShUCgIITMajULXdZimiU6nI9fWdd2r3n4KCgoKCs8IV24QvvnNb8qYfrFYCHfAcRw5fcdiMXziE5/A7u4uXNdFr9dDv9/HbDaTE/TW1paM0weDgTDiHz16BADY399HrVbDcDgUVvxsNkMqlcL5+bmcVG/cuIFqtSorARbVbreLJ0+eYLlcStHL5/PwfV9UE5xYhGEIAMhmsxiNRpdIeOv1GqlUCvF4HLZtCxfBcRw4joPxeCyykdFohJ2dHSQSCbRaLbz77rtC9qPCIR6PYzKZ4PXXXwcA4RXEYjGk02nM53OZLHBiwmmD7/tIp9NIp9Ny/TOZDFqtFnzfh+/70DQNW1tbGI/H6Ha7QjScTqc4OTlBpVLBzs6OTHI2NbGr1QrRaBT5fB7pdFpWByQ2jsdjFAoFLBYLIUGm02kUCgXE43E8fvwYuVwOo9EIyWQS0WgUn/rUp5BOp9HtdnF8fHzV209BQUFB4RnhyhyEP/Nn/oycDgeDAeLxuEjbKpUKXn75ZSlwsVgMw+EQkUgE9Xodq9UKo9EIx8fHiEajQqxj0ZnNZgAgJ2rDMOC6LubzuZAgqURggVytVvA8D77vwzRN2XWHYSgqh1gsJtMFnpqpCPA8DwBgGAb29/dRrVYxmUxkKsLmZjKZoNFooFQqYTgcYjabCQ/CcRx5vSzaiUQChUIBmqbJlIXNh+M4sCxLVhWWZcl0gqf1IAhQKBTgeZ48d05cOMHZlI6yEQAuFCH8M5KedF2XacDmSZ7Ext3dXcznc2mEOMEJggDz+RyGYWA8HmMwGGA2m2E6ncIwDHS7XcTjcdRqNQAXDc9oNBIlRr/fl8bq7OxMGiO1x1V4XqHuXYXnFT/o3r1yg/CFL3xBxvAkxvGUyQkCCyf9EbrdLkql0iXt/nA4xHg8Fj8CMvpJJKQnQSQSQalUkpN/EASYTqdSVDcLME/apmmKp0A6nZYTfCQSgWEYAIBoNArTNEVeOBwOkUqlkE6ncXp6inQ6jfF4jNVqJWP6fr8PXdcveQHEYjGRJfFUTekii+gmQdIwDJimiU9/+tMYjUZIpVJot9sYDocinyRHg8U5m82KuoGciuVyiXa7DeCCB2BZFrLZLIIggOd5qFQq0HUd6/Ua7XZbJjiDwQD5fF6uXb/fFyXGaDSCZVlYLpfwPA/FYhG9Xg/r9VoaI04TdF0HcEEsnU6nolBgE2PbNlzXlWtCrsjbb7/9gW7Upwn1IavwNKHuXYXnFc+8Qbh//z4ikYj8t7OzIwQ8sv5JgOOpmgUjk8mIDBDApRNvEAQIgkD8DihxJBOfaokwDMXIh01BuVzGfD5Hs9mE7/solUqyNgAgOn/LsrBer0UWOBqNZAIyHA6haRomkwlyuZwUfur65/M5kskkEomEnLR935diu+lvYBjGJcUFTYM4kp/P5xiPx8jlctI08JrFYjF88pOfRD6fRyQSwXQ6RRiGyGQychIfDAZYLBYigxwOh3jy5AkODw+FGAlcnOZpkkQSYzabFT8H8hhIzmRzQp7IfD5HJBKRBoe+DlxN8D8qNgCgWCxiNBqh1WqJ90I0GoXrugjDEKenpx/oRn2aUB+yCk8T6t5VeF7xzFUML730EvL5vHADRqMRRqMRgIuCH4/HMRgMYFkWZrMZHMfBarWC7/vo9/sy5ud+m3K8RCIBwzDQbreF0Lder+WkSwLg5oqBRffRo0eIx+PIZDIwDEMMhDKZjEwVNosfpw+FQkHG/uQk2LYNAPA8D7quI5VKyal5OBwCANLptJgvkTgJQIplIpEQDwDbtqWZ2trawq1bt5DJZESZ0O12pfk4OztDIpFAo9HAgwcP5NpEIhE0m00p9JQTLpdLpFIpZDIZaJomXgfD4VC4DLVaTVYUdIek7NFxHJnapFIpTKdTjMdjIUyygRqNRshms5jP50in0+IMOZ1OhVxKZUWr1cJkMoFt27AsSyYKnufBcRxpEBQUFBQUrheu3CD89b/+1+G6Lh49eoR33nlHTvzxeBxBEMB1XViWJU3AbDZDsVgUkiH34KlUSk6dnDLM53NhwVOml8/n5XFIzuMKgH9eLpeF0MeTMU+1PB3zey3LEgOlyWQiK4ggCESKybUC7ZwBCOeB/gScZNCtkBMPAGi1WgAu5Jo0D6rVatB1XU78nKJMp1O0221ZPdC4iZwMrhb4/DzPE4kjizevVblcltdAMiTNlxKJBPr9PpbLpRhd1Wo18SqgkoOrA5op0f/AcRzoui5mWJsyRr6X+Xwev/RLv4Rbt27h3XffxYMHD/DOO+9gMplgsVgoFYOCgoLCNcaVVwyvvPKKnOipkTcMQ0benucJL8AwDIRhKAWPZkNhGCKbzcpUgD4DNPkJwxDj8RiJRAKmaYpbInMVdF3H8fGxFFYWUK4j+H1sKrhPp5qBxEDHcZDJZOC6LlKplJAkWfxSqZT4M2yuPgzDkPF7EAQiYyTBsFqtwvM8VKtVUW6QU7FcLlEqlcS7gORKFtrBYID1ei0ESvoMsLnga1sul6IO4AqHfApd16VZoq8CAFkB+b4v8tHNYs/nQZ4IfQ80TRPZIx9rsVhA0zQxraIMlIZTnBqsVivkcjmUy2X4vo//9t/+GwA1plV4fqHuXYXnFc98xbBpPcyT42g0QiKRuMQVIPmOCoR0Oi3+CCQuRiIR5HI5IcZxcjCdTiXwKQgCcfcjme7k5ASpVAp37twRHwEWzu3tbei6LnkAp6encsJnE5FIJCQjgE6ADE5ar9diKrRer2GaprgPmqaJTCaD2Wwm1sdcJ2xvb4sPgGEYsmLp9/vid0B7ZeYarNdrOZUDF41JPp/HYrGQa0nSpeM4omBgM8GGStd1OI6DXC4nKgMSPA3DEJ4BuRHpdFoCsGzbxnq9RhiGYoy0Wq1QLBYlr4IkTVpUb/JB2CCMx2NZhziOI7yDeDyOTqeDbrcrDZiCgoKCwvXDlRsEGvgMh0N4nicFkgWVZDQWsGq1KrJHnppJiAMgXAbf90XaaBgGarWa5CesVisJf9ra2oJt27h79y5eeuklDIdDPHz4EAAkEGo4HAqJkfbC1OXzZKtpmuQWcBLAxoCM+00yYhAEGI/HQias1WpIpVKi2Egmk8jn85jP58LB6Pf7QpCk26JlWeh2u6hWq6K+ME1THotTDl3XJbuCBMsgCGS9wOaCtspcT/A0Txkppyt8/dlsVn7WppcBX+dLL72ETCaD1Wolv2ajx4nEcrnE4eEh+v0+Dg8PRWoKQDgiVHbQnOk7PRcUFBQUFK4XrtwgcIzN4J5NVz2eJGezGXRdR6lUwvn5ubD8yT8g2Y4FTtM0Ye1z/O04DrLZrBQlbkbokthsNvGVr3wFtm2LXfPW1pY85pMnT/Dee+/J+mI6ncLzPCl08/lc9vyUE/JnUO1A2SElip7nIZ/Po1qtwrIsCanilGM8HosXAk/e1WoV2WwW5+fnGAwG6Pf7MjVptVpIpVLSTFCVwZUNpxabhkW8frFYTAKvVqsVMpmMvC/b29uIxWI4OzuDrutyuufK5jtzHwBga2sLL7zwgig5er0evvWtb6HRaIi9cq/Xk8ejGgSAGDhRTko+Cv8eJzybPA0FBQUFheuFK3MQfvqnf1rGy2Tt5/N5/Ok//aflhH9wcIByuSxWxhx1x2IxvPnmmxiPx0gmk+JI+MlPfhKr1Qqnp6d4//334boums2m7OPn8zna7TY8z5Miw/E3rYg5vuZJfT6fyx6da4FYLCZ5A9y98zTO6cZmOmMul4PruqjVasIt2Nvbw8nJiXAKwjBEu90WS2nKHSlnBC4S2orFInRdR7FYRCaTuZSGSGIiCZXNZhOu6wqxkCdy8jnYXJFYmMlk5HROaaFhGMKrYOPEYCxyDzYdG4fDoUx2+B+bHIZKMdI7DEMkk0mk02kMh8NL0lKubmjSxJUHfScODw8BqD2uwvMLde8qPK945j4Ir776qhjfsMjati0+/bPZDJPJRIx6yLwnYc80TazXa9y+fRu+78sIvdfrIZvNyiohn89LESXBcDwe47333oPjOFJo6FOwXq/FiCmRSAiJzzAMIf2NRiMp4jwRM9GQpkYk9fGkXKlUsFqtkM/nhd/gOA7m8znu3bsngUnj8Rij0QidTgfHx8dIpVLY39+X9QQdFyuVCqbTqez44/E4XnzxRSyXS3S7XViWhUwmg36/j0ePHkHTNImI5rSBDcCTJ09EXZHNZpFOpxGNRnH79m3k83nUajW88847mE6nmM/nODs7Q6PRkKhqXqvVaiWrCMpB2TyRbLoZlEXS6WaDRW4DmwE2M5sNxmq1EkMl9SGr8LxC3bsKzyueeYPwuc99Tk7mwIUsrlqt4uDgALZt47333oNt26JUKBQKsic/PT3FbDbDYrFAq9WCYRiwbRupVEpc91iouF7Y39/HbDZDp9NBPp/H5z//eQAXI/G9vT0Za89mMxwfHws5r9vtioxvOp2Kv0AQBDg4OJAYaFoEx2IxkfRxstFqtYRgp+s6XnjhBZEphmEok4rlcolbt27hhRdeELLmYrHA6empcBZ4Og+CAOfn5+IXwLCqzYTKs7Mz2LaNarWKeDyOGzduwPM8PHr0SKYhAGQ18ejRI+i6LtkLN27cwMOHD1Gv1yVzwfd9jEYjkRxynUEPBE6ESObkioA/i4TNMAzFZIpkVDZXbCAMwxBOAqdNdMJkRLb6kFV4XqHuXYXnFc+8Qfipn/opISZWq1Up8LPZDNlsFnfv3sXe3h6Ojo5wdnYmKgNK9WazGTzPw3q9xu7urkwj0uk0er0exuMxptOp7NRJoiOPwTRN7OzsIAgC3LlzR9Ic6Z+QTCbR6XTEEGg8HsNxHMxmM1SrVVEQ6Lp+afROEiBPucvlEq7rQtd1yYPodDrCYaCXQy6XAwBZUXDUv7OzI66CnuchmUwiEomg1+tJLgUbmxdeeAGNRkPIgslkErZtY2dnB9PpFD/xEz8hUsr33nsPtVoNlmXhwYMHaLVa2Nvbk8L89ttvw3VdTKdTyYygBBEAms3mJWfDfD4vDRkbFhoxUSpK2SoVHsxb4PpjM0wrmUzKeod8CTYjnCR8kBv1aUJ9yCo8Tah7V+F5xTNvED796U/DsizEYjEYhoFkMolsNisufiwSR0dHMgpfLpdieDSZTMQ5kDLGXC6HTCYjUch7e3uIx+MYDoc4PT2FbduisU8kEnBdV1wP+Vwsy8Le3p40E/V6HUdHR8K6H41GeOGFF1Aul/Ho0SPMZjM0m005UcfjcdH2czefSqWEHMgcgiAIsLe3h3v37sk6IRaLodvtyonacRwZz9OeOBqNioUzVQZcdzDemvJMygwZlWzbNvL5vHgyMHOBiZhhGCKRSGA4HIox0quvvopMJoPj42Ok02lks1k8ePAA3W4XjuPIBIhJmoZhyGqBYVDfSSz0fV+yJjgRmc1mcF1XVBL0dQCAfD6Pfr8PAPL+dbvdD3SjPk2oD1mFpwl17yo8r3jmDcLnP/95sUhOJBIyht6UEtJMJ5lM4uWXX4ZhGOh0Ojg8PEShUEC9XsdgMBBWPTMI4vE49vb2RNEwm81kLcDfU4EwHA7lObAQFwoFKe4sYLQ5ZoRxPB7H4eGhjMxpIcznAHzbICmVSgkHIAxDyZ3gNKDT6WC9XuP09FQsjSeTCUqlEsbjsRRvEv7IX+Bzo4qAHhGcTFBuyVM8o5zH47E8P9M00Wq1xHeBqw76L0QiERQKBbz66qtIpVJ49OiREC+Bi4lHq9WC4zjo9XrCJ+AJ3/d9kUoSm9LI7+WqyEnDeDwWq2wSG/v9PubzubhMftw+ZH/If3ZPFYw3V7gaPm73rsJHB8/cKIm2ykwcNAxDDHcmkwl838d6vcbe3h5c18VXv/pV8f4PwxD1eh2VSgWlUgm3bt1Cr9fDZDJBKpVCsVjEcDiUnXk6nUa73RY3wU3yIU/KlPlthhtxvM0VwGw2w3A4RL1eRy6XE+4DVQgczzNTgs8llUrh3XffFRkhcyHo9zAYDJBMJnH//n1pXOjMSIKl7/vCwaCk0/d9LBYLVCoVtNttaSQqlQqKxSIA4Pz8XP4OJxNcd5CsyHUITafYAPGajMdj/K//9b+k0dE0DY7jYDQa4e7du7h165Y0F6PRCO12G+PxGL7vo16vi/8CiaOcFnmeJ0qITcMm13WlqaBjI+8Jekp8HHEdmgPgh3seqqlQ+Kjhe/07INH9O//s44Yrf0I7joNSqSQnZeBCqgdcyO3IaOfonWNrJhECkFH0+++/L80BZYJhGEquQL1eR61WQ6FQkKaEkwTyHzbd/ZjESO+E8XgsUslIJIJKpYJqtSquho7j4Pz8HNFoFL1eD/l8HgAk2XG1Wl0Kcep2u7hx4wYKhQLu37+PwWCA4XCI5XIp3x+JRMRHYbFY4Pbt20LqI7kRALa3t0WFQH8DAKjX65K3wMnGrVu3cHR0JPwHehWwYSHvwTRNsV6mq2GlUhFJKVcW1WpVpiIvvvgidF1HLBbDvXv3YBgGCoUCfvd3fxcPHz7E6emprH7G47FMG+Lx+KWgJ8oY2ex4ngfTNPHaa6/BdV20220sFgshKSo8H/h+TcXH8cNT4fnEB2mMv9f3fL9G4qOMKzcI9DM4ODi4ZF/MU248HkexWBTVQqFQkDH6cDiUnTuNiICLhuHg4EAKInAxCrFtG7ZtY3t7W4iOsVgMjuOIgZCu68hkMjLWj8fjODo6kgmCYRgyUqGz4uc+9zmxIz44OECxWESz2UQkEkEikRDPhDt37kjw02bhJ39iPp+jUqnAcRwkk0kcHx8jCAIhHB4eHsqahMmSVCFEo1FUKhXxDUin02g2m5ITwZ/T6/XE/phFWdd18RfgWmQymSAIAkmKXC6XwjXI5XISvw0Ah4eHSCQSyGazaLVacm17vR62trawWCzw1/7aX0M+nxfL7PPzcxwfH+P3f//3MZvNUK/XAQDj8VhiuimTjEajKJfL8j6u12tsb29jtVrh9ddfv+ot+FzhukwPnjbUaevjgQ96/16n9/9Z/pv7QY99na7DD4MrNwg07On1epjNZsjn8xKAxHE99+m2bYtMsdVqQdd1tNtt0d9HIhEUi0V4nofRaIRkMomf/dmfhWma+L3f+z20221ks1nU63Wx6rVtGz/90z8te+5Op4MHDx7g6OhIGhSS9eLxuLg0fvrTn5aCx0L66quv4u7du2KMFI1GsbW1BcuyRIrJYntycoKTkxMUi0Xk83lJrXRdF0+ePBFvgfl8LhkP6XQamqahWq3KiZ+5EuPxWKYdJAgygTKVSiGZTF5KbiwUCqLKmEwmEnbluq5EXNOF0vM8xGIx2LaN+/fvi2U1PRMsy5JUy6997WvyWO+88w7S6TQA4Ld/+7eRz+fx2muv4eDgAOl0GltbW/iJn/gJzOdzHB0dSfOy6ahI4iIbFk5hut2u8BsUPnpQDcPzj6sU1s2/+6N873/UDfnz/u/gyiTFl19+WciItPadz+e4c+eOpBTyZLtcLlGr1VAsFvHyyy/DcRz8l//yXyRCmYx9KhxGoxEikQhSqRQODg5gmiZGoxFM0xRvgk6nIxkDTDF85ZVXEIah8AWAixE+1xDMamCBvXfvHvL5vKwetre3MZ1OZXfe6/WE3NjpdOB5npAEqThg8UsmkxKf3Ov1hBtB1UEmk0Emk0E6ncb5+Tl6vd4ln4F0Oo3BYIBsNgvHcSQ0CoAoCDYnD1yn0EmREwSSDLmuASDyR8dxUCgU5D3jtacUMpPJSMM3n88xHo8xn89h27ZMYDzPQ6VSkT9jumatVsPJyQkajQZarRbS6TQKhQJSqRQ8z8N7770nNtSO46DRaAD4eBC9ftQfVtcdz9uHJ/FRvHef9r36Yb+3z8O/tetwvz9zFcNnP/tZpFIpJBIJiTrWdR2FQgHr9Rqj0Qj5fB6ZTAZhGIobYKPREFIjR+W0/U2lUkilUpjP5+j3+9je3ka/30cYhqhWq5JCSHMh0zRFAWFZFra3t7FYLPCFL3wB2WwW7XYbzWYTwMXEw/d9ABeyu+FwiMFggN3dXTx8+BDL5RLL5VIIi0xe5MqEzUA0GkU+nxeFRCwWE9OgF154AXfu3EE6ncYf/MEfoNvtYjgciqeDYRgolUpIpVLi2kijIeDCbIqKAs/zhIBI4yGuZTht4euigmG5XMr6hnyQIAigaRrm87nYPmezWVF8UI6o6zo0TcPOzg6Ojo7Q7XYxmUwkiZHuiPSl2JRjVioVFAoFWJYl049ms4lmsykpmaZpIggCnJ2dAQAePXr0gW7UpwnVIDw/uA4foj8IH7V791ndpx/We/m8/jv7Udzrz7xB+Nmf/VkxLSK3oN1uI51Oi+8B2faMGga+nbQYiUREishCw5UDzXooCaSaYLlc4s6dO9jb28NsNkO5XEapVILjOBgOh1gsFtjZ2ZHnxLH/2dkZ2u22GCMFQQDbtiXEqNvtCkmQfge6rouOf7lc4sGDB2LkNB6PUavVMB6Pxfkxk8lIEJSu6zLGNwwDAKSgUoKo6zpyuRwKhQKePHmC8/NzkXoOh0ORZvJxmNS4WCykmSFRcTQaIZvNYjweI5FIiA/E2dmZTC84MaCNM9+PUqkkyZK2baNQKGAwGGC5XKLRaIiygbkPJB+yAcvn81itVshms8hmsygWi7hz5w6KxSJ830er1YLneWg0Gjg/P0cikcDh4eHHZoLwvH5oXTdcx4bho3bvfhj36p9EJfBx+rfzYd/fz7xBePnll5HL5eA4DlKpFABcIs9R4kd/AsoSk8kk+v0+dF2XcCDKEVnESDykUoGcAKYmzmYzHBwcoNlswrZt8TfQNE3kfo1GA2dnZxJtXCwWZU9OPwWaHm0aO/Hnl0olIewVCgVEIhEcHR0hk8ng3r17l5wAu92urA/m87kYJGWzWYxGI5RKJYxGIzEYopySaYzlchnHx8ciA2XxNk1TFCAAJPaZKhCaMK1WK+E8cN1hGIYYKum6LoZPnuchCALcvHkTnU5H3CoppWQiZaVSwXA4lNRMqlLY3AVBgFgsJmROvr/kdpTLZVSrVazXa5nk1Ot1ZLNZxONxPHjw4APdqE8TqkH4aOC6NAsftXtX3as/WnyY9/Uz90EgaY9791wuh8FgAOBCAskEwmw2Kyfx+XyOcrmMz3/+8+j3+3j77bdlHE5CXTwelxPx7du3cfPmTclN0DQN3W4Xg8EArVYLpmmi3W5jNpvhj//4j9HtdkVtQLkhzYXOz8+FwEdDJxr4VKtV5HI5tFotkQMyGnq9XqPb7SKfz0sGwWKxwMHBgfAQuBrZzB9YrVZot9uIxWI4Pj4WvoamaSIL7fV6WK1WePLkiZA7yecgsXK1WqFUKgmfg80BnQ/X6zV2dnZQrVYlYptyx0KhIM8DuFg31Go16LqOJ0+eiH8BcNFscB2yWq1g27Y0G3zvaH6VTqcRhiFc15UpUDwex3K5lL9HrgHf28VigVqtBtd1r3rrPTdQH7jPBs87Aey64nud7hU+PJALdx1w5QbhpZdekgLc6/VkWpDJZGDbtvgk0HiIWvjt7W28+eab8DwPN27cQBAE6Pf7sj8fj8cyzl4sFvjGN76BRCKBb37zmzJhqFQqkp/wxS9+EePxGOv1GoPBAJlMBo1GA77v4+HDh+LUSCUESYmWZeETn/gEdnZ24Ps+3njjDTEcGgwGKJfLMnWg4mB7exsHBwcyoeD+/+joCLPZTNYJqVRKjJF0XZcJCqOSmYlAEmIqlUI+nxdiJImBjuNgOp2K/PETn/gEPM/Dt771LSE2MsthOByiVCqh1+thvV6j0WggFotB13WRUfLnDYdDRCIRRKNR4TmQjJlOpyV9kURDz/PgOA7CMBReBe2h+Z6QbEmXRQBC5oxEIpI3kc/nYRgGHj9+fNVbUEEBwPX6YH2eoZoDBeLKDQIdA8fjsZyuWZio36cl8P3799Hr9fBHf/RHQmgk+TAajaJarYoREiWA7XYbiUQCn/jEJ0ST7ziOBAKdnp6KYuHk5EQIgEyNjMfj+PEf/3EhP8bjcTSbTXS7XVSrVTiOgz/6oz/CV77yFXS7XWkaisUibNvGcrkU3sJsNoPjODg4OMDx8TGAC/voyWSChw8fyvh+a2tLpgzT6RRPnjxBuVxGLBbD+fm5cCmAixTKt956S/IjeD1ZxJvNJoIgkAlFNpvF+fm5FGaqRnZ2dtDv98WjgWFIlmWJpwPVEky8jMVikrPA5oqSVCY6np6eCm+Cr5WkRv6cdDotboq0gk6n07J6IhGS7oqz2Qyz2UwIlB9lqA/bDxeqSVBQeHq4MgfhL/2lvyTWw4wypvERSXAc29Mx0XVdIa+RCR8EASzLQrVaFeLbYrFAOp2WUXkqlcJyuYRhGMhms1iv1+j3+7LTZnJjoVDA8fExcrkcwjDEarWSZkXTNOzv70PTNBlz12o1mKaJs7MztFotlMtlPH78WPgPi8UCmUxGnB3Pzs6EOElXwXg8jmq1itFoJBbLjKi+efOmNAqDwUACoOiL8MUvfhGlUglf//rX8eabb2K1WsF1XbFCzuVysuagguD27dsYDAbiksjrG4Yhbt++jVgsJu6IJFUCF2sfKiHIN/jEJz6BWCyGwWCAwWAgr5NWyeR80ESJcc0AZHVgmqZYPsfjcTF4Wi6X6HQ66PV6YsfMaQMjt4GP3h6XUA3Ch4/rRvR6mlAchI8HPqx7+JmTFH/hF35ByH2UNXqeh+3tbdy5c0ecA1lImT3Ak+qdO3ews7MDXdfx+PFj1Ot1FItFyWFot9vY2dlBuVyG53nCLwAg424GBpHVT9OkRCIBx3GQTqeF1KdpGnRdR6lUkhMxuQrcmzMJsdvtYjweCx9A0zTcvHkTf+Ev/AXhG6xWKwyHQxwfH+O9995Dq9XCo0ePEASBpFSyeLKQWpYFAGK1zOdBUma32xXeAfkYtVoNP/mTP4l+vy8rmPPzc/kZwEWjk8/nxdWy3+/j0aNHohrwfR+6rsOyLMzncyE1MpuBKxRmQpCrwckGXTNJKqVXxXQ6leuhaZo0VlyncKoQiUQkLZMNBomXH7UPWUB90P4ocZ2IXk8TqkH4+ODDuIefeYPwxS9+UeSJ3DUnEgnk83kp2FQLmKYpWQXcVTNCOJFISOFktoKmaTBNE41GAwcHB+j3+zg+PpaAqM3kwM3fZzIZGctvEg3J0mdgE0/dhmFI2mGlUpERfiwWg6Zp+MIXvoDPfvazODs7w1tvvYXFYoFGo3Fpz04ZJxsg5j7QT4EFNJ1OS6Fut9twXRe7u7vIZDLI5/OYz+eXvAdoyczxP70M6MVw584dmKYJz/NEWuk4joRJdbtdFItFmKaJ9957T6ywmc1QLBYRiUTQbDYltZI20slkEsPhEIZhYGtrC7lcTmSWZ2dnWCwWsgKhj0K1WkUkEkGr1cJ8PsdsNgMAaUBmsxnS6TRisRjOzs4kD+Kj9iELqA/aHzWuyynsaeJZ37vqnr0+uA4NwpU5CKVSCavVSsbZ0+kUxWJR3PPeffdd+L6PYrEIy7IwmUyQSCSQy+UkGIlseD6epmmyI280GhJ7/NJLL+Hu3btotVrigthqtSTVMZvNyokbAHZ2dtDtdkWtQHJgq9WCbdvCGXjhhRcwn89Rr9fhuq6caulG+LWvfQ2/93u/J5bHrusKobLX66FUKomEEQBOT09lF2rbtlwXEv7oKWAYBkzThG3bmM/nGAwGIq2k7HKxWEiipa7rmE6n2NraAnAxgViv18jn8wiCAIeHhyJP1DRNJJyu6woHwfM8ZLNZ9Pt9pNNpmewwPIuZFOSUcB3S6XQwHA5RKBRgmiby+fylVQsloFQ57O7uiorlhRdewO7uLqrVKqrVKt555x288cYbqNVq+B//439c9Ra8llAftAoKCs87rjxB+MIXvoBoNCose8MwcOvWLRSLRfT7fbE71nVdZIs0MppOp0LISyaTODg4kD9vNpvwfR+j0UhWCgBQLBbF258+/8lkEru7u7h//z7effddLJdL2c9HIhHkcjl0Oh2MRiMkEgmUy2Ukk0kMBgPZ7XNFQWtm0zSRzWZlb26aJsbjMRaLBcbjMQzDkCI+mUzEfpgZColEAoPBQEbyxWIRN27ckAYFgIQskShYrVbFH2LT3pjPi49Vq9XkMX3fx3A4FOMhNgD0iVgsFqIeYHASba8ZqsQVCrMT+HPo3ZDNZmUCwPVQOp3GaDTCYDAQwijjoBOJBHRdx/7+Pj73uc9hvV7j/PwcrVYLx8fHGAwG6Pf7iEajePPNNwF8tE5hgGoQrguuwynsaUJNED4+uA737pUbhJ//+Z9HPB5Hr9eDbdsy5p7P53KaXy6XsvcnCW42myGXy8G2bfHm51ifQURMaqQJElcD/X5fihx32wBErrdcLnHz5k0xWNq0CmYwEYl/vV5PGhHGNHO6YFkWstksBoMBer2eFM1KpQLLsnB6eiokRpL/MpnMJYfETqeDxWIhREySChOJBHZ2dhCJRHB+fo7ZbIZsNiuFNggCOfnbtg3TNOWEzokLrZ1rtRrS6TQePnyIer0u3ADbtmWlwAlNJBIRIys+51wuh+FwKJJSfi9XNORt0N+CPgixWEyaDDYwL730EkzTxMOHDxGLxcR7AYDkcsxmM4zHY6RSKUmBVB+yCs8C1+FD9mlC3bsfH1yHe/fKKwbGCR8cHEiRCYIAuVwO8/kcAMQ6uNFowDCMS5bEmqZha2tLmPXMJEin08hkMrh//z5SqRQqlQqWy6U0B6PRSKKgOcrmCP/mzZvIZDKS4vjWW29hvV5LrkGr1UK9XofneUin0+j3+8hkMmL2w+CiXq8nxEE6JpqmiUwmA9d1cf/+fcznc7z++uvI5XIAIBOExWKBTqeDIAiQz+eloSEZMR6P4/T0FNPpFAcHB2I4xe9j6BNJjuv1GrVaDQBwfHwsxdv3fRwfHyOTycjEgrbJVDAEQSCullSFBEGA2WwmK6JMJiNul1xhjMdj7O7uIgxDtNtt3L59W0iR1WpVGir6XjA7o16vy1RhsVhgMpnI8yChku+hgoKCgsJlXBep7pUbhMFgIDv0arWK2WyGXq8nZMRN9jzd+kzThOM4khvAvTUJeWT8M5NhMBgI0dGyLFQqlUucBVoLP3jwAPV6Ha+//roEI4VheMmauNVqSfojo6gZfkRFxPn5uVgsz2azSxLJSCSC09NTJBIJvPHGGwiCANlsVmKYKclk8FI0GkWr1ZII5UKhgK2tLVkzsEFJpVJot9sStpTL5aQhYBIkLaF3dnak+NLbwLZt5HI5JJNJidA2DEOKs+/7EiDFa0MvBp7sgyBAqVQSl8TPfe5ziEQiGAwG4oPA1wlAiKBc9axWK5ycnMDzPHlN8/kcuq7D932s12vhcdDBUkU+KzxLKF8EhecR1+W+vXKDwGCgarUKABIIxALISOD3338fxWIRQRCg2WwKb4H7d8/z0Ol0RBKYSCQwGo2wt7eHbDaLRqOByWSCVCqFBw8eyAj/7OxM3AtpbVwqlXDv3j0kk0k0m024rivyumQyiVqtJnK8Xq8HTdMkzIjrEM/zUK1Woes6Tk9PJTSJ05HZbAZd12Vknkgk5NQ8m82wu7uLra0t2LaNer0uU5O3335bfk9XwUgkInJM7vfDMEQikcB0OkWj0UChUEC1WsXR0ZE0O3SprFar2N7eRqFQkAhluhxyHEnyItcDnU5HijVXBZygHB8fI5vN4lvf+pZMIki21DRNHCLpe0H3RV3XZYXCxE3gcgPC1E02Xt1u96q3oIKCgoLCM8BTCWs6ODgAcCFl63Q6ckKMRqO4f/8+crmchPRsGinRrZAuhVwbBEGAYrGIWCyGSCQingAs0DzNc3xvGAaq1ao4B5Jz0Ov1MJlM5LQNQHblACSDIJvNYnt7G7ZtI5/Po9Fo4MGDBxJKRKnfYrFApVLBer0W4mChUIDv+1KAAeDg4AD5fF7G7dFoFJ7nIR6P4+DgAPV6HYPBANlsFuVyWQiAnBDcvn1bYqg9z8POzo4U5V6vh1gsJnHNVCYwA2K5XOK1115DvV7H8fGxOCqyeSGvYr1eyxQFgEx3+Bh8fzdjo/P5PKLRKBaLhZg20XVxMpnA8zwYhoHBYIDZbCY8CTZR8/lcFCk0TGIjqfa4Cs8Kz/okpu5dhWeBjwQHwTRNseOlxv/GjRuYTqcSj0znvXK5fGnPb9s2bNtGOp3GYDCA67oYDocwTfOSmoCOiADExpdxxSxW7XYbmqbhrbfeQiqVwv7+PizLwpMnTwBA8hSYekgjHwCSk7BarXB+fo4bN27Atm288847aDabIjeMRqNwHEfSFTVNg2EYiEaj6HQ62N7elvTEd955B8ViEfv7+7h79y7CMMR7772H4XCIH//xH0cqlcLp6Sl6vZ7YEe/t7WF7e1tUElwVMOzI933xQYjH49J8UKlg2za2trZEkglcTA64urFtW0Kn2NAYhiE8kmKxKPJFTkTy+bzwJqLRqDQobA6oumDxb7Va0lAMh0N5Hzk5oBcCvSEUFBQUFK4nrjxB+NznPoft7W3E43Gs12uJMubqodvtIh6Po1wuy858uVwiGo3i4OAAN27cQKlUwo/92I9hPB7DsiwcHx/ja1/7mjglco9OCR1PqjyZb74Eqh+4dx+PxzBNU/bq8/lc1AHJZFL24ul0Gq1WC6lUSsKVdnd3oWkaHj58CM/zxMqYxkH1el0Ifuv1WtYVdBckgY+kw3Q6LXJJPlcGR3HlEQQB6vW6GEhtciVIxGRhPjg4wP7+vrgiWpaFw8NDtFotpNNpnJ6eotvtQtd1AJBEzXQ6DcdxZLpSKpUkxZEyx00OBf0iOJkpl8tIJBKSGkkOx2g0kskOmwMqVjiBAS6aBUovDw8PAahTmMKzg5ogfHCoe/f64CMxQWBBKxaLMAwDe3t7aLVaeO2112RXfnx8jEajIcZDVAJQHXB0dISvfvWrcF1XdP90RkwkElJEz8/Pkc1mJRmSvgqcVszncxSLRUSjUWxvb8tenKS5TqcjJkrcv9MPoNVqSdNhWRZs24brumKnHI1GJbuBCgmqFYbDoaw2giCApmmYTqfisUC1xGg0QjKZRKlUQiaTQb/fx2g0QqfTkV1/LBZDuVxGt9uVBEXaWLPZoO8BADmxDwYDaaTK5TJOT08l1Mp1XazXa1mhpNNpVKtVTCYTPHr0CGdnZyiVSqhUKmg0GhiNRpIcyaJerVZFJcJrtL29jdlshna7jX6/j1wuJ+sWyinJJ5nP54hEIpKUSY6FgoKCgsL1xFOROdKmOB6PYzQaIZ1O4/DwEOPxGPfu3UM+n5eToud5cvJkXgALF3MNGAe8ydBvt9vo9XpySucUAoCcqieTCU5PT8UiGLjIJ6BSIZ1OSyQ1Mxai0Sg0TYNlWUKQ5D6du3MWZ6YWUsPP5zgcDoVfQbtpJkeuVivhScxmM0SjUTx58gS+78tj1mo1LBYLDAYDOI6Dfr8PALISIf+AqxkmZNKXgP4TAMSXYWdnB2EYot/vC3mQ5lHkYbiuK7JIqieAC2UFCaWPHz8WnwtyB46Pj2EYhshYE4kEFosF4vE4otGo8BuCIJAmab1eS5PFtRHfPwUFBQWFy7gOSoYrNwhkxXPXvLW1hWw2i0wmg9PTUynwpmlid3cX9XpdRtnJZBJhGKLb7SIajcK2bSn2PP3T4rdUKkkxByByyFarJSuFxWKBYrGIO3fuAAD29vYwGAywXq9x7949pFIpHB4eim9CEAQ4Pz+HZVkSgLRYLNBut2HbtuQo8ISeTCZRKBRg2zYWiwW+/vWvw7ZtsR/m9xeLRSQSCRwdHUnWA10cAeCVV15BtVoVxcDJyQnm8zk0TRMpoW3byGaz6Ha76Pf7ME1TRvibJ/xsNotarSYR28vlUhQHtIMGIM0QizJXAFypdDod+fp0OhUVBAt9LBbDdDqVLIX1eo39/X1RWnB6cnZ2BgDi0MhVRCQSueR14bquPLePGtjEKSgoKDzPuDIH4dVXX0Umk4FpmgAulAxsAFjAyLSnbW+r1ZKi0e12YRgGKpWKmCuNRiPk8/lLRQ+48BagZwKTCOmcaBiGmPycn58jn89D13Vks1nE43H0+330ej2sVitUKhXs7+/j7OwMsVgMR0dHACCR1Z7niYUwpwqLxQIvvPACCoUCHjx4gMlkAsdxpDi3Wi30ej0Z33ueJ8XRsiwcHBzgrbfeklhqNgLkK9RqNSmcNJvaLOgkRHLnD1zsj3haJ4cAgDReNEQKggCFQgEAxPaaUxjP85DL5VAoFBCGIZrNpqw6NicBsVhMuBZBEEhDwekGVRqr1Uo8JagW8X1fJK2GYch7SVIjX8tHZY8LqF3udcOzPIl9lO5ddd9eL/yo+TNXniAUi0Xs7e1JFPNkMpFgIZ4oS6WSFJ9ut4tMJiP2veQl8Gt0VaRckmZFXAOEYSgjfF3XpZDxxH1+fg7btmEYBra3t4X9z4wIMvofP34sWQIk/bGxoQqiWq2KU+Hdu3elCeEqwTRNtNttTKdTlMtlfPKTn5QURSZP0j3w/fffR6VSQa1WQyKREHvl5XIp0xYWUZ7KmS7Z7/cvWUrTX8CyLIxGI4nCpjslACnEYRhiOByi1+vBsiwp6OPxWPgQlGeyWWDzNZlMpNBPJhPYto3pdIrFYiENDdcXAC5lWdAVErhoWGj7zCkNVRIKCgoKCtcTT8Uo6Rvf+AZSqZScMGOxmEjhKG0k2Y9JgCTysZGIx+NSIBmCZFmWGCDRvpfFZrNxKBaLWC6XaLVa4m3AtQUVB/v7+zg+Psb5+bnI/Li3p0Mi9+h87u+//z4sy8LP//zP4+TkREiXtE6mIqNWq8G2bQwGA7RaLWQyGdy5cwe3b99GEATY29sTUuSbb76Jo6MjCXbiNUwkEmJCBFwU1e3tbSwWC+zt7WE6nQpfYz6fCweiVCqh3W6L+VI0GhV3yc3nzMTFfD4vRkb0c+BagRkQnucJUZSNXCaTwWg0EtdEqkAYRU21AiWOpmlKtgYAaRyn06k0GuQwKCgoKChcPzwVFUOpVBLDI3IDxuMxer0enjx5gjAMsb29jVwuJ0x46vkLhQKGwyE0TUM+nxc9/tbWFh4+fCh2vqVSSU7OtDZmYWWQ0t7eHsIwRK/XExIk3RdJknRdVwx82NAsl0vZ3TPh0fd9vPTSS6I+YA7Ee++9J5kOzEsYjUYSIsVGaTwe48UXX5Rpwfn5Od58801ZGZCAwoaGkxDLsuQkz1M2TZW4vqGXQKlUwoMHD+D7vkg1aVlNQiE9DEzTxI0bN3D37l2k02k0Gg10u10hJDabTQyHQ6TTaezs7GCxWKDb7QqZMZ1OI5fLwfM85PN5uWZUSmQyGSSTSbRaLWm2GP7EiYGu61iv1+LjkMvlRBmioKCgoHC9cOUGod/vYzgcioyNJ1tN08R3AACazSYcx5FY5EwmA8Mw0Gq1xHmPY/HJZCJ5DgcHBzg4OBBzHZLkKEtkBgQVEZVKRVIkDcMQ85+joyPM53Os12skk0lkMhmcnZ1JrDPdHLmfp+NfEAR46623YNs2fuzHfgyZTAZ/+Id/iHa7jeFwiHg8Lk1IMplELpfDkydP0Gg0MJvNMJ1ORV756quvIplM4tGjR/I1yv1isRgqlYoQABeLBXzfR7vdFjMkejDQuZI/t9frIQgC9Ho97OzsXMqhyGaziEaj2N/fRy6XQ7/fR7fbFekpGzX6ErRaLTSbTcnRoCU0Jxe+7wO4UDpwUsCUS04vxuMxVqsVXNcVSWQ8Hkc8HkepVJJVjsphUFBQUPje+FErGICnQFL8zGc+811yNRYzJhOm02kJa2KBr1Qq0HUdhmFItDGthjVNk4AgThdu3Lghv2bM82QykSCg0WgEwzAQhiH29/cRiUTwhS98AScnJ+JMuL29jaOjIxweHsJ1XWHp8+Rr2zbK5TImk4k4JY5GI6xWKyEbLpdLVKtVrNdrFAoF1Go1RCIR3Lt3D2+99Rb+6I/+SMKl+HqpsKDag6d8TdNw7949VKtVHB4e4ujoSFIZF4sFNE1DJpNBtVpFGIYoFArivcDmgDyP4XCIcrmM8XgM13XFi2Bvbw/5fB61Wg3Hx8colUrodrui7uj3+5jNZkJUYWPA95AWznSf9H1fJg2e58kKIggCef+Hw6GQLbmOWK/Xl4yq1us1crkcHj16BOCjRfQCFNnrukGRFD8Y1H17fXAdjJKu3CDcuXMHyWRS0v5IWuOOmXyExWIhRjrlchnlchm7u7tYLBZoNBqyy2YKIU+l8Xgctm1D0zRsb2+jXC6j2WyKKVCr1RJtv2VZuHHjhkgMY7GYyOo8z4Pv++h0OgAgYU/FYhHlchkPHjyQEy+9CtLptMQSkzzIsCbLspDL5SQzwbZt7O7u4tGjR6JgiMfj8DwPxWIR4/EY3W4X+XxewpnCMJTdPVMWV6uVeBCQ4xCJRMTemT4DsVgMN2/ehKZpePnllzGdToVj4TiOTHPY9MxmM2nC6O1A90sGLKXTaQAXnIhNBUMkEpG1B0mdDLUifwGAfP+mPJKOkuR1TCYTWVfE43E8ePDgA92oTxOqQfh44UfNBH+aUA3CxwfXoUF4KkZJiUQClmUJH4EkN+YdkHi4XC7FLnkymeCdd96RiGIaFG2G+5C4uFqtJKQpl8tJHHK5XMYnPvEJdDoddDod9Pt9fOtb35Jit7e3h36/Lz+PVsb0KxgMBjg5OYFpmrJbZ7PB6UU0GhWPg1qtBsuycHR0JE0Nr4HjOOh2u/K4QRCII2O73Rb2/nA4hK7rwvjnNIFJiYVCQdwH2TAw5wD4Nv+C/7HoM+ypWq1if39fMhRSqZSETdElEYBMXorFIuLxOE5OTmStYZomqtUqkskker2eqCgMw0A2m5X8CzYB8XhcVhlUKvBrVEVQikr5aqVSwc2bN6VBUFBQUFC4XvjAE4T5fH6Jde66LnZ3d3FwcADbtpFMJsWamA9Ju+StrS0YhiGWuxzrM5qZkwMAUtC3t7clvpkWzI7jSDAR/QRoaFStVjEej2UqQZKe4zhIpVJoNpsyTk8mk8I3sG0bN27cAAB8/vOfx5MnT/Dw4UPZ52uahldeeQW7u7v4zd/8TbiuK02L7/tyGienIJ1Oi7oCuDB0yuVy+FN/6k+h2Wzia1/7GoIgkAYoHo9L0BWJfLQw5v6+Xq/D8zxsbW2JY+XBwQGazSYGgwF6vZ6sAkqlEgCIL8WjR49wfn4uJEgSGNkoUK1Al0Y2NvSBKBaL4vnAKG/gQmK5Wq1gGAam06l4MziOg4ODA1SrVWiaBtd15bXE43FsbW2hXC5D13WMx2P81b/6VwE821PY97t3nyXUSez64HmeIKh79+OL52qC8Ou//uv40pe+9F1/fnBwAM/z0Ov1kMvlhJnPnTQlhwxAymQyqNVqMAzj0v57Op1KIBJVENlsVvIaBoMB5vO5GBcxACmRSEjRq9Vq6HQ6SKfTGA6H0og0m01RNfB0zqRCTdNwfHyMaDSK//gf/6OsNO7evYt4PA5d1/H//t//QxiG4nA4GAzQbreRTCalwNOciAZF4/FYjJ8ODg5wdHSEd955B4vFAqZpSmz0pz/9aTSbTbz77rsIw1CCl1KpFEzTlPXD3t4ejo+PJT3x7bfflsnNzs4Out2uRC/7vo9Hjx6JIoMrAeZOkCSay+Xw4MEDmQAwcImhTFSYABd+F1wZJRIJHB4einqB6wwAKJfLCMNQVkU3btzAnTt3JDFyNBrBcRy8/fbbaDabf5J7+YfG97t3FRSuO9S9+/HEdSAoAk9hgvDaa68Jcc00TYRhKEz8WCwmZMXVaoV0Og3LsoSVb5qmcASi0Sh83xfnQxYj4Nt2zgCERc8TKxsEkvfYQEynU4xGI+FAkH/A4l2pVBCJRC4Vz2w2i0QiIdwARkgznIm5BnQGTCaTsmPnnp5piZyQAJCmgdI/27ZxcHCAs7MzBEGAYrEIz/Og6zq63S4ajQYikYg4VDKWmb4C8/kcyWRS/BjoacCcCHoscJ1BPwOmV04mE0SjUXn9NGxar9fI5/Mol8vCAaG9M3+2aZpwHAeVSkU4BptrIK4U+LqZuzCbzdDpdGTiwGTMP/iDPwCgTmEKzwbX4RR2Fah79+OJD6tBeGoTBE3TZGy+iZs3byKTyaBer0u08yc/+UkEQYCjoyOcnJxIwiItejVNE7JdtVrFpz71KazXa7iui6OjI+ECcBXQ7/dhGIYoHFiASGZkcWOuQzqdFpe+dDqNfr8vqgSeqGm7zMAkJg1GIhEZq69WK3S7XXF0pKSR/guTyQSJRAKJRAK5XA63bt3CycnJpbTFUqmE7e1tjEYjmKYpDUelUkEikcDJyQkePXokzP7NZoAcAs/zkM1mpdGyLAvj8VgUBLSi7na7GI/HcBxHVhWxWAzb29toNpuibuCUIZVKIQxDWStYlgVN0zAej8WNkmoGy7KE/KjrOtrttjQnbKyY0Om6rthgLxYLnJ2dSbojPSvYSH4Y+H737rOEymP40eO6nMKugh/FvaugQFxZxfArv/IrqFQqUlxJmhsMBtIUMH+BCYYAZFceBIHo+gGI4qHb7cpenafZSCSC6XQqngNkzFNREIlERJqoaRo8zxOjICZFzudzyTRgkuFsNpPiSMJftVrFkydPsF6vRa5HuSIA5HI5aJqGVqslxk97e3tYr9c4PDwUI6cgCPDKK69gb28PxWIR7733HhqNhpy6k8mk+BCMx2MAEEtkNi58vjyRs2jzuVEeSh4EuR+e58nXaDHteR7m87nwHsrlskhEqbDodDpCjKSxEacswMWHFps2TkqYC0EOymZmBHDBR2H+BgB5/959910AHy0mOKEahB8trssp7GlCKXA+Hrgu9+6VVQzcZ3MXncvlJBehWCxKMXvppZcQhqGcXDl6brfb6Pf7WK/X0i33ej1kMhkxGKK+PxaLIZ/PS1YBDYAY+jSbzWCaJvr9vtg0L5dLiTkmUZIuhLlcDuv1Gi+//DIODw8RBAFGo5F4LFB1QV7CcDhEoVBAOp3G1taWBD/du3dP1A/z+Ry/+Iu/KKqH999/HwcHB4hEIvjv//2/y+pjMpmIyyAlh5PJRF4f7aQLhQKy2axcE03T8P7774sbIQAZ9/d6PbTbbXlMrh1qtdql3AP6EdCHgk0ZG5Rbt24BAHq9npAZ+fPoncDmjx4Y5CCwqaG1NsHJBwmcq9VKZTEoPDN8FKYHCh9PXKd798oThM997nNIpVKIRCIoFApIJBL4zGc+g3v37iEIAjx+/BiNRkMIaYZhYHd3F9lsFvV6Ha1WCycnJ3Ii9zwPi8UChmHIjpw8Ap5mJ5MJNE0TgiBzG8bjsRgGLRYL5PN5UVgMh0PU63X4vi/kRtu2pUg2Gg04joPVaiUSQWYsHB4eykic5kC0jy6VSsjn88hms0K8fOutt2TN0O12YVkWYrEYer2eZCKs12v4vg/HcbC7u4tGoyHx2HzNdHgslUpIpVKo1+sAIDwH7vS73S7m87kUX3oOMDab/IOtrS14ngfXdWVdQbdFBkwxAwKA8B02m69isSgcjM2JEJsdWm3Tw4HrpCAIJEKaiZx8v4CP3ikMUCexHxU+7A9Yde8qPE18mPfvM58gkCDX7XYxGo2EuPeVr3wFAKTgkzy4Wq1wenoqTomr1QqapiEIAjnlc13BPT/1+bFYTE7W0WgUk8lEPATG47FEGgdBIPttWhV3u13s7OwAuDjtMmZ5uVzi8PBQTtuJRALHx8fQdR2np6dot9tYLpfiYRCLxTAej+XkfHp6KvJNkgNLpZLIM4ELZn+lUsF0OsXZ2ZlILBmPTC+AT37yk7KiAYBGoyGTC8/zcO/ePaTTaQwGAxQKBRQKBXQ6HSEt9ft9ISkGQQDbtoVjkEql0Gq15NTOdQCJhzdv3hSL5tlsJjbSQRBI0zGfz9Hr9ZBMJuF5HlKplJhL0dqaExq+n+RucAJBpQe5Dh9lKB7Ch4/rdPpSUPiT4rrdv1duEOh9QHMiTgGCIMDOzo7kJtCoaNN/n3vzWCyGcrksxYXyPe7lk8kkisUigAsb30gkgn6/L9wGyhyPjo5kR08/Ap7IDw4O5DlwT04jJLopZjIZ8QwYj8d48uQJdF2XqOODgwPJfLh16xai0SiazSZarZYkVpL4R8lmtVqVVcfOzg7W6zUajYbIQNPpNP7sn/2z6PV6+MY3viF+BCRwUkIaBAFef/11pNNpTKdTmKaJ+Xwu3IJkMolarYZsNotut4t+vy/XgRHVVETQnInZEWxcaHlNT4bZbIZcLodoNCp5GTS0isfj4qZIAiKbD8o/qfSwLAuTyQTz+VxcG8k9ef311696CyooALh+H64KCn8SXMf798oNAkfbvu8LOx0AptMpGo2GfN9iscBoNJKQHxZTuhg6jgMAsqqYTCbwPA83bty4ND5PJBIwTVNyAcrlskwjUqmUhDXRpjmfz0uTwIlBNBpFuVzG1taWWDIPBgO89dZb4n64s7ODUqkE3/fFkpkcg1wuh1qtJqN9ygMjkYioJW7evCkmSRyvP3nyROSbjKHWdR3/+3//b5muhGGIyWQC4OKUzz1+oVDA7u6u+Cskk0kp8HRuNAwDmUwG2WwWd+/exYMHD8TVkteBqZur1UqmJuQUkD/g+744PXa7XSGGchLA6QCVKHxOdE+kLDWfz2M2m6HZbArvgc6LlmVd9dZ7LqCmCB8OruOH6/MMdd8+OzxP9+qVOQif/vSnEQSBeB5Qmsg9OvMXGNjE3fWm1TAA0f1zz0/DHeBiF07ZIW/cWCwmRkhMLqRToGVZYgWcSqXEKtg0TViWhfl8jnw+L0mI6/Ua0WhUGPl8jrquw/M8OYW7rit/zuaAp2FGLVNpAFycpDVNk0yE/f19hGGI0WiE4XCIRqOBVqslhki7u7tYr9eSN6HrOgCIjJMmVPP5XMyiVqsVCoWC5DSMx2PJdViv1zBNE7Zty2qCj8Hpiud5iEQi0rh5nodSqSQk0EQigcFgAN/3pSFjodc0Db7vw3Vd8TegkoJ/P5vNwvd9WS+QnMi1D1UhH8U9LqE+aJ8NrssH7Uf53gXU/fv9cF3uv6vgmXMQAEguANUBu7u7cF0X/X5fvt7pdGRikM1mRZXA0TRjkcfjMXRdR6vVQhAEyOfzQmyjLJCNBaVy+/v7CIIAT548ge/7WCwWWK1WolK4ffu2aP37/T6m0ylOTk4AXMRQs7khS3/TI0DTNBwdHcnP5XidLoKj0Qi5XA7pdBqO4yAIAqxWK2SzWTlp03DowYMHcBxH5IOJREImCbFYDAcHB/LndExkkuSmdwHJkCzkNFIpl8sS38ymptPpoF6vC++B3g9cB8TjcWmiKKl0HAeFQkGUDnt7e2i328hms9A0TRoPNhT5fB4ARJLKaQKbNF3Xsb+/L4oQ/sNKp9P4N//m3zyNW/BaQ53Gnh4+Ch/Kzxs+Dvcv76vN16nutafUIDCyeTMvgRr4TT18EAQiH+Som3tp4MLOl7wEmu5omgbbtiU/YDabYTQaYX9/H9lsVk7jyWQS1WpVTtTc+w+HQ7TbbckDcBwHhmFIM1Aul2XCsVqtMJvNUC6XxdRpPp9LgBJ9BLgSodNiEATChdjd3UW/38f5+TkAYHt7G9VqFUEQoNFoyHQjmUxKwfc8D4VCAaenp1KUN0l9n/vc56BpmqgYXNdFsViUUz/5DJ1OR9wN4/G4XBuaHnmeh1gsJmZRbLRYxHmNo9Eozs/PkUgk0Gg0xCWRLpd8zfF4HC+99BLS6TTS6bRwO2imtFgspGGo1+tIJBL47Gc/i0gkgjfeeEOcMj8O+F4fQArfG+qD+frh/9978rzf05uvTd17l3HlBmGxWCAej8tKgc0CExNZ7Dd33/1+X+yCSSakHp/Jg+v1Whj8m0V6uVyiVqtB0zScnZ2hVCqJTbLjONA0DcvlEr1eT2KiO52OrD9omEQsl0vJgeD+nMTI7e1tGcvXajXs7e2h1Wrh7OwML774IgaDAVzXRTKZlNdNK1R6NqTTaRwdHSGVSuGll14S1QS9FWKxmBAKS6US9vf3sV6vJVaZZlBscjiVYeARm4HFYoFKpQLgwn66VquJ94DjOCKhJAHx5OREwqboUMn3Yb1eiykSAFGgRKNRmbaUy2XE43GJoaYi4t69ezKRGAwGqFarKJfL+Jmf+Rl4nofXX38djUYD4/FYbLM/TvheH0DP+wcsoD5YP874fu/9d97XH/Yk4gf9PHXP/mBcuUGg3p1rBkYFp1IpnJ6eIhaLiWSRyY+9Xg+PHz/G2dkZbty4gXK5LI5/dBWcTCYSjkR+Ab0LcrkcRqMRXn31Vdy8eROO4+D9998X8h4fYzgcSo4ApxQ060kmkzISZwPDtcjZ2Zn8Wtd1ZLNZvPvuu3j06JGoB/r9/iXSZSQSQalUQqlUQrFYRCaTwXK5RL1exyuvvCKBU41GQxQAnGD8+T//57G1tYXxeIxGo4EwDFEsFsXCmXkNt2/fljUJ8yI6nQ6ePHkiz2GxWODRo0dIJpPQdV2mEXRTnM1maLVaKBaLmM1maLfbUuTpwBiPx2VdUqlUsFwuxfbZsiwsl0scHx+LzTLf+1gsJq+PUwaSQFutFh4+fIhMJgPXdeG6rkxZPu5QH1QKH0V8r/v6WU3Svt+/IfVv62q4Mknxp37qp4ShzuJNEqLneUJISyaTQjRcLBbY3d2V4u37vuzNKZtkOiL19dzxx+NxVCoVCYGi+c7Ozo7Y/47HY5kMxGIxIRXS4XE8HssJ1vd9/PRP/zTW6zW+/vWvI5VKYTAYyATj5s2bYpbEFQDjmNfrNSzLQqVSwd7eHnZ3d/HWW28J+Y6jf1onn5+fi88DUydJ6MzlciJLBCCmUJzO8PEYn8zpSy6XQz6fR6FQQLlcxvvvvy9plpRxkm9Ag6d2u43T01O5BuQkhGEong7lcllWEJVKBaZpygSBqwNySng/9Ho9MWlqtVqIRqMYDofiHklTKq4hgiDA22+/DeCjT/RS+OhC3bsKzyueOUmRPv1hGMJ1XTQaDWQyGSwWC3HwI6GNXgRBEMionHwFOg0ylTGVSl0y61kulxJCxBjmbrcriYYc42uahlwuBwCiUshms3L65Sjfsiy89NJLWK1WaLVaGI/HIr2rVCriwGiaJvb29hCLxfDyyy/j5s2bePToEfb29hAEAcrlMpbLJU5OTvA7v/M7AC5SDDudDqrVKhaLheQ43L17F2EY4uzsTBwQObJvt9uS+cDpxmAwwN7eHhKJBPb393Hr1i3s7e1hOp3i8PBQ8iMeP36MR48e4fd+7/dkMkElBhUNjHPeTLrUdV3MqZjEyQkCnzdNsBiTTZOk/f19JBIJHB0diRkWY6KpauGKxDRNlEolJJNJRCIR8YlQVssKCgoK1xdPRebIog1cxDEzNZFa/fV6LSZKPJ3XajUAEHIbCXXb29tIpVLodDqiWiAXYTqdwrZtbG1tyYiqXq9fGnMDF03LdDpFJpMRDwQWX576fd9HpVLBnTt3xFXx5ZdfhmVZaLfbePTo0SW1RSwWE5UF+QBUFFiWJTHJ9CSgKsCyLImdHg6HSKfTaDabYjNMRQDlfrqui7UyI5JHo5EoHsrlMgBIwFS/30e32xW76Nlshhs3bsj1sG1bSJXpdFr8KgBgPB6jWCxK48DUSlok0xOBdsrkb+i6LtfMtm0JpaIXBf0byD+hI2av18N0OkU+n5egKRolqVOYwvMKde8qPK945hMExi3TerdQKGC5XMI0Tei6DtM0hcAXBAFeeuklAECr1ZITPouSZVk4OjpCOp3Gzs4OJpOJFB2uHAzDQKvVAnDRXPC0Tf+CVCqFnZ0dsQVmFgHlj4lEAq1WC7FYDD/+4z+OW7duwfd9vPvuu/if//N/AriYADSbTeEwkFBZKpUkn+HmzZuwLAvf/OY3MZ/PJaVxb28PnudhNBqJyoHSSRIPDw4OsLW1ha2tLSyXSynwAEQd4Pu+ZDcsl0sYhoHRaIRGoyHXnERAxmOT4EgiYq1WE8kopwLpdFrWFTdv3pSsCMoo6fMAfNu0ajabSQNDSep3rgtIUiUBNBqNYrFYyMoolUohm82iUCggCALM53OxolZQUFBQuH648gThM5/5DLa2tlAoFBCGIXq9HgAIQ53Jfhxh0xSIO+zFYiEOh6VSSRQQmUxGApq4r98055lMJkgmkyiXy2IepOs6NE1Dt9tFJpNBpVJBPB6X8T0Jj3RZ5GnWMAyJMM5kMigWixiNRqjX63BdV8KR7ty5I2N43/dxfn4uVsyr1UqyCajSoCqAhkbAReGkbwFfOwmNxWIRhUIBpmmKkuDs7Ez+HotxJpMR0yie9LmqMAwDlUpFIrNpTBSJROA4jjQde3t7QthsNpviT7FYLJDJZETGSQkmcxOYvMmp0XK5xHg8huu6yGQykmfB50WJJM2ZyCuJRCJYLpfodDoA1ClM4fmFuncVnlc88wkCi8J6vZbI4EQiIdbBe3t70HVdioTv+2Ki47qusN0LhYLkCqxWKxwfHwshLpvNIpfLidUv2fpUN3CfzQRF5jucnJwIwe/8/BzZbFZMijZZ/77vyz6/0+mg2WyK9wJJjpPJBG+//bbwCSaTCabTKSqViqgTotGo+A0AEMIkmx1OFBiNXCgU4DiOWCEvFgucnZ0hm81KbDVVCHyNXJXw71A5woaCCZKz2UxIiFz1RKNRec28Xv1+X6Y+nU4HnufJFMA0TZimKXLI9XotCZPkJZimKVHfsVgMpVJJij9XCozsZo4EJyRcdSgoKCgoXD9cuUG4desWIpEIqtWqeO/3ej38wR/8AeLxOE5OTmQqYNu2aFPH47F4HxiGgdlshuPjY1QqFTFTokXv6ekpWq0WqtWqsOx1XYfrumi1WtIQtNttGZe7risn7mQyid3dXQyHQ8l8YOEksdFxHIlBJhgYNRwOpcFhI+O6rqQ6hmGISqUingUAUCqVEIahFPhoNCokyel0isVigcFgINHIjuPIiJ/uiMViEaZpCgeBq5ZoNCpyy1arJfkUbIYSiYSsbCzLQjQaFeOi3d1dVCoVuK6L4+NjudZ0faSnA2WUlJtS3tjr9ZDL5ZDL5aRZ4td938doNJKobsosyaWIx+OiZnFdV5EUFRQUFK4xnkqa43Q6xdHREe7evYs7d+7g1q1buH//PiKRiJjiUDZHi+B4PI5erydkxF6vh3Q6jV6vB9/3JR8gm83CMAxxY+REYblcSoFiUTcMQ3b5AIQYOZlM0G63Yds2TNMUAx8mNwLAycmJTDpu376NwWCA4XAo+QGUAMZiMTx58kTWCtPpFJqmicdAPp/HZz7zGUSjUbz77rvo9XqYTCZYr9di0UyOBk/8nFb0+33xP8hms+IXwOLOBEbaUEejUZRKJQyHQ7kW0+kUBwcHuH//PiaTiawGyC0YjUY4PDzEbDYDAMm04KqDpERyQBqNBorFIqbTKVarFarVqjxXTgJI2OT16XQ6KBaLsG1bVk3ValV8FtgkUJmioKCgoHD9cGUOwmuvvSZF2bZtkRDmcjmsVitxG6REjuN1mumQUBiNRkVuWCqVpLCRaU8PASYsnp6eirVzv98XW2XTNCUNkjv5RCIh5LxIJIKtrS3U63XUajXhOGya/GyO8cmVGI1GME0TBwcHl0yUJpOJ+AsAwMHBAUzTxP7+PgCIVLPZbCKZTGI0GuH8/PySbHA0GsHzPPEdIEeDP3u5XMrPBy5cFKlgWCwWYnrUaDSwtbUF13WF00EJaiQSkSYgDEMcHR0hn89LM8YGjHHOnCLM53OYpilOi1SWdLtdWXOQcMigLf56NBrJRIcNDScTnDAoFYPC8w517yo8r3jmHASGElG3T8IgCYHUvtMq+ezsTEbgnBiQMLe7uyurg263K4TE4+Njkc7V63X4vi/ufLPZDPv7+5hMJmLmQ1JgrVaT0y0LZTqdRqfTQSQSwenpKTKZjJgS0YNhOp1iZ2dHTuwMTYpEInj48KHYCp+cnKBQKKDZbAqHIQgCtNtt4VCQT6BpGtrttoQzMYCK14s20PF4XMb3dIVkJsVkMkEul5NmhP4SnU5HiIHdbhe6riMWi6HZbMr4f7VaiSRxPp9ja2sLg8EAi8VCLJ9jsZjYKpPMyGvTbreFj7FcLkURMZ/Pkc1mxWuCsd7RaBRbW1sYDociFWVMNx/Htm1pEBQUFBQUrheu3CD0+33EYjFkMhkZn3NUr+u6FOPZbCas++VyiVgshuVyifl8LqY/dEAkeW2zcPZ6PYlELhQKlwiJi8VCdu6UCW5vbwv50fM8TKdTKfT5fF5kmYlEAqPRSNIis9ksptMp2u025vM5xuMxut0uEokELMsSE6hoNIrd3V2EYYharSaF9+TkRDgDJAPSJTIajYqRk23bYi197949mKYp1tQkQHKF0Gq1hIRIjkG32wVwoW7gdMQ0TVFL0H+C15HETMY3p1IpbG9vY7VaCRlzkzhIvkA6nZavk+zYbDaRyWQkUpphWcPhEK+++ioKhYK8p3R1ZLNDt8v5fI7BYHDV209BQUFB4Rnhyg1COp1Go9GQwl4ul1EulyXVkU6AmxwC4GKUPR6PZZTfbDZRKBRkbL5pORyGoZghDQYDlEolKWjU5NOqWNM02YVrmiYZAzzJb0467ty5g/l8LmFDpmnC8zyxL+73+1IESVJcrVbIZDLodrs4ODgQw6KXXnpJeAPkItBGuVQqAbhoplKplLy20WiEfr+PdDqNw8NDmQxs+kaYpok7d+7IFGQymYjygg0S+Q31eh3JZBKapklQUzQaFVIkMxx4XT3PEwXHer0WEiSbi9FoJE0bbaFpMLXZYFmWhVgshu3tbQyHQ/R6PcTjcbz//vtYr9fQNE2ao2KxKA3W9vY2vvnNb171FlRQUFBQeAa4Mgfhz/25PyeRvgDkpL9er2EYhhSdfD6P8Xgs+36uJSqVCiKRiOzbZ7MZHMdBLpcTR0aqHWhhzOAh8hTo2kfZI70XKIXM5/NIJpOYTqeXcgB835cJA02XUqmUkAk31xOMmm40GrBtG7quY3t7G81mUxwGX3zxRSFC0mqYHgNccTAumsRAXqNMJoOtrS1MJhNxLqQzYqPRkLwJKiU8z0OpVEI6nUar1RK1BY2pSOIcDofCbWCo1mq1QiKRQLlcxmQyAQDhBtC7gSsV/p6rA7ot0kabzzOfz0PTNAAQ8iUfl792XVckoMx+ePPNNwGoPa7C8wt17yo8r3jmHAQSzo6Pj0V+p+u6kNPIfvc8D7PZDKZpolgs4u7du2L4Q7kiT56RSATZbFYkgjTbITM+Ho/DdV0xD6LhERuCfr8vBZFqAyoRWOzCMEQqlUKj0cCdO3dkAsIcAYZMTSYT7O7u4pVXXhEJ52KxwHA4hOd5GI/HkgPx1a9+FcDFeD+Xy+H09FRSKIGLGGZOIAqFgtgwHx8fS1GdzWbiJcDXF4/HxWcinU6jXC6LWZHjONjZ2bkke7RtW64B3wf+XMMw0O/34fu+GBzRpCmbzSKTyVxacazXa/k9mzJmOFBhwVyMdDqNTCaD+XyO8/Nz5PN55HI5TKdTydmIx+PodDoiFVVQUFBQuJ648gThi1/8ImzbRiKRQKFQkCI3m83EV4DrACYTHhwcALiQ3oVhCMdxxGnx7OwMpVIJmUxGiny328VwOIRlWSiXy/B9H6ZpIplMShKk67pyomUgFNUHtB8mi37T6Y8nZ1o40yGQO3Xu6zlRyGQymE6nACBFGoAoBsgFoFqDhZuWzePxWNYA7XZbHAiz2SyKxaIQ/zatlOkXQLUHPRYajYZwForFIoIgkGkKY6xJ7jQMA9PpVCY9JEeSUDqdTqWBc11XVhibtwe9Ejb/z6aQ6ZkECZ9Uq/DrtMOmYdLx8TEAdQpTeH6h7l2F5xUfSprjZDJBpVKRvTb31pPJRBj/HG2Px2M0Gg0cHBwIM59ywXQ6jWq1KnJH7r65r2YjMJvN0O12xfZY13UUi8VLun8aKvHFT6dTBEEgjPrVaoVcLofxeCyNCy2GN9UHy+US8Xgck8lEkhsZ4xyLxeB5njQZTIocDAYiCywUCohGo2JBzQkG3Q6DIMBgMJAVytHREVarlRA4S6US8vk8BoOBNAiO48g+n6mMo9FIPAc4eeHJn6mKjJk2TRPj8RiRSESaBE5qGKlNciKbATY/jI+mIyYVKVxL8P2wLEtSIzkFsW1b+BHJZFKRFBUUFBSuMa7cILzwwgtIJpMIggCTyURIfkEQiMyOqYc0FSoWi0gmk9B1HVtbW+JrcHZ2huVyiUqlIkTDRCIh8kAy7MlJ4KkcgCQS8uRL90DHcRCNRlEoFDCfz2FZlhAdqS5gAW02m9ja2pKVCAv13bt3MZ1OcXZ2Btu28fLLL0vRfPHFF7G9vQ3XdTGfzxGPx1EsFiUgqdfryerEsiwkEglZazDlkVME4ILISCJiKpWSKQNXDLFYTOSKpVJJfpZhGDg4OEAQBDg/P5eJTCKREDJnLpcTG2Rd1zGdThGJROC6LsIwvOT8yAaAKx8Wf153TdNkOhCGIYbDochLSToFIA2QruuIRCIol8vi55DJZNBoNK56CyooKCgoPANcuUEYjUZSmLvdLqbTKXzfR6lUwk/8xE/I2JkJfhzte56Hx48fiyFPPp8X5j0LDvMH+DX+R5Md13Vxfn6OMAxl785GgvkQbERSqZSoCe7cuQNd15FMJvHkyRPxCZjNZqjX61gul7Kvp4VzPp/Hyy+/DNM0kcvlcHBwICoBnqoty4Ku6xgMBjI1sCxLGiLaKVMqmUgkRMnheR4Wi4UQ/lKpFOr1OorFopgbcXVC8yXyK/j9yWQSs9kMtm2LZXIQBJfIhZyQkPOwWq2kGalWqwAgqwtKR0mw3CQw0neCjRKnCdlsFtFoVBqT4XCIarUq0dvdbhedTkfkmQoKCgoK1xNX5iD83b/7d7FcLtFut8WBj/G+0WgUZ2dnEoxE+R8JgLFYTBoCAJKxsFwuAXzbhAkABoMBBoMBbNsWgiKLHQBRIJBcl8/nL53EGUmcSCQwnU4vuSySkc8pB2WVmqbBdV186lOfwv3797G1tSVN0Hw+R7PZRL/fF/tjqgcoX9zd3UW9XheeRSwWw2w2k+JO/kIYhojFYqK26HQ6GAwG8vq5urEsS+SerVZLfA/i8TgKhQJ6vZ4kXvLnjcdjpNNpWUMw7ZHBSfSIID+DWRnkZ7AR43oDgKwVAFxySFyv1yIFpYqCEkq+Z1ydaJomclNA7XEVnl+oe1fhecUPunev3CD85b/8lxGPx8VOdzQaCUN/Z2cH5XIZg8EAk8lEDHtogxwEAXq9nkgJPc8TH4NeryfSP04TmITIUzfNleiF4HkeqtWqcAw8z4Ou6xgOh8hkMqhWqwiCQDwTPM+7lCjI0Tkbimq1ilwuh/l8jul0iq9+9auIRCJCuOx2u3KiZgElD8CyLDSbTSFp8kRt27aM8CmJZIOk6zoajcalEb5t22i329JUcTriuq7INUulEiqVCsbjsXgXMBmTpMp+vw/LspBMJmGapvA12BwAgOd5khPB5mM+n0uDt/ne8flsBi5RHkpTpcVigdVqhXQ6LcRNwzBE4jibzfAf/sN/+EA36tOE+pBVeJpQ967C84oPxWqZpkBsEJbLJW7evCkRv7Ztw/M8NBoNPH78GMPhUIx5eKpfr9cSi7xYLMRB8fHjx1KwbNuGYRjI5XLo9/vi8JdIJGAYBtLpNFKplIy96ey4Wq3g+z7eeecdmKaJZrMpY3BKASnZ297elrF+r9dDt9vFyckJZrOZeAYAgOM40vywwDM3gtkKPF3T84DrlWw2i3g8Lq6Htm1jNBpJQ0EyJU/9bL5IfGSGRbfbFfdITgeYh8B1BFc+tm1LIBaVETRiSqVSwh0BIE6XzNBgIFYkEgGAS4mRVEPk83mJiY5GoxiPx2JM5TiOvEYqXL6zuVBQUFBQuF54KlbLPOGbpomf+7mfw6c+9SmJdq7X63jnnXdweHiIfr8P4NvyQO7QOWZOJBJwXVdG1CTmcV/d7/dFzlgoFOA4jpD1qOenFDGdTiMIAjHmoe8BLY5938dgMEAikZC0RRa2k5MTpFIpFItFyW9IJBKoVquSrEgFQBiG6PV6iEQiME0TjuOISVOhUMBoNBJVADkLNDSinwKtiLnuACAhSCza4/EYlUpFrke/30ehUIBt2zg+Ppb4bDohTiYTcT4kb6FarcqpnmROwzBkasGGhHLEnZ0dpFIpkTrS3pphTYyv5uRkM61ye3tbwp50XUev1xMlRDabFe6DgoKCgsL1xJUbBKY3Oo4jhMDZbIaHDx9iMpmg2+3i9PQUrVZLiibJdL7vy6+n0ylM05SdNdUJPLnS2nc4HAohEbiIER6Px9IMkFzYbrelANFXgE3AJgeCskj6HvR6PeTzecRiMTx+/BiVSkXWIZ1OR1YEtm3j9u3b+Jmf+Rkkk0mcnp6i0WhgOBxiMpng+PgY29vbcjqfzWZ49OiRGDKRX8AkRXoGUF7I60gFBABxf1ytVmJWdHp6Ck3TUC6XMRwOJWWRr5NKhVKpBMdxRGnBFUKz2RSvBNpJh2GIarWKXq8nXI1oNIrVaiVrAq43stnsJfkjLat930c6ncZkMhE1CzMyAEgYlYKCgoLC9cSVGwRK38IwxMnJCY6Pj3FwcCDj99PTUxn712o1ZLNZ+L6P9XqNarUKx3HEyIcnY0YYk/Uei8VQKpXEja/b7SISiQi7fzabYTQayXidpkqbBEYAYtazWCwQBAF2d3ehaZqcjAuFAgzDgOu6MmU4PDyEaZoSlhSJRETO2el08PjxY4mT3tvbk5Pz5z//+UtNieM4CIIAy+USZ2dnODs7k8Zjcy2QSCSQz+dlTUHiJT0gFovFJRfCWq0mbow87TOyGsAluSjXBVxVMOSJxZqGUSzwLPhhGEq6JpsEykXj8filxoxNF1/bbDZDLpeTCO5N3wmlYlBQUFC4vrgySfEXfuEXsFgsYFkWDMPAer2WYkpvAhaRIAjEKIg5ByxMAC4Z+VBaF41G5XsKhQIKhYLIFwGIuRCJioZhALgYhxuGIVK7XC6HVCqFyWQiJMbFYoFyuSwpjzRuMk3zUggRAPEcIOFyvV7D9324rosbN27A8zzhDwAXhZmnaU4KwjBEuVyW1zGbzdBqtfDgwQMMBgNphjanKAx7mk6nmEwmiMfjaDab0gQ4joNKpXLpugIXTdHW1hbK5bJISgHIFIakStd1hSOSTqdFAcEpBomJJHSyyaBvAomTXLtQJZJMJiVfA4AoHTYVJPF4XOKeFdFL4XmFuncVnlc8c5IiCz2jmUulkpgUWZYFAOh2u1IsOKJerVZSbHlKdV1XbI75uOQqTCYTIcmRtDebzXB+fi6PGwQB2u02wjAUy+D5fI5KpSLsfhZ/8hgGgwFu374t6xC+Ft/3ZQWyXC5xfn6O9XotWRC0S85ms+h0OuK7EIlERFpJEmAkEoFlWbJKYXbD3t4ewjAUEyfP8zCfz2WiQRVCvV6/ZN9cKpXEE8I0TcTjcVm95PN5OI4j/gUPHjwAcEEsZAjWZsMVj8dF9UAVBBUhXFWQA8JcCXokcBLDFQglnJyUMImTTaPv+4hEIpIMqVYMCgoKCtcXV24QeKoELk7NPFnylE5uAAsN8wGGwyH6/T663a5MBDjWZjPAv0enxtFoJMz/g4MD7O/v4/bt2+h2u3j99deRyWRkN8+fFY1G0W63kUwmJTOCdsuMY2632yKJ3IxDHo1GwhWgnTBf32KxwGQywWAwwM2bN1GtVhGJRNDr9aSAj8djWVVwmsIG5Pj4GKenp0Ko3N7elsIMQIiEs9kM1WpVzIny+TwAiGKA0dS1Wg29Xk/WMbxOlUpFVgMkdS6XS+RyOSEyUlaaTqclc4KeEHwPAUjaJrkQpVJJYqk5aeH0gGuKTQfMTCYjuQyu6wqPREFBQUHh+uHKDYLruhgMBrh//z6q1SoSiQQGgwHq9TpKpZKcNuv1OjRNw+npqawf6NRn27akMjJpcbFYiFERx//T6VQMejqdDg4PDxGPx+XEzkLEYjqfzxGLxUROyUKfTqdFKbFYLNBsNmVaQalkIpHArVu3hAA5Ho8vRUmzuG5tbcF1XUwmE7z44ovigMjsh9VqJXyK0WgkzUqxWISmaWi1WuKRQBIfw6h4Am80GqKGOD8/F5JkoVAAcDEdINmy0WjI82czw6K9XC7l+yiz5ESFzo+RSESmJLZty0mfckn6ITDgyXVdFAoFCceigiORSEhmAxUe8Xgc7XZbGhE2lgoKCgoK1w9XbhA+/elPY71eYzwey6ma/IFWqyWeBHt7e3L6BCDTgcVigcFgIE6B3I2bpgnP84S0x1MtTYxI+uM6w/d9RKNR8RpgiuJ8PpfROgmVjELmPj0SiQjLnnyKWCyGbreLxWKBfr8vwU7pdBrdbheJREKeM0/Uv//7vy+x1JRHRqNRaXxOT09Rq9XgOA7Ozs4Qi8Wg6zra7TZ0XZffu64LTdNQKBRECkmpZSwWQ6FQQCwWw+npqaRIclKRz+dlPWAYBhzHwXQ6RSwWu5TMyFUGyYy0WmZDx/eKEyASE7lmYEgVJxqJREJ8D77zMch7oOuj4zjQdR17e3s4Pz+/6i2ooKCgoPAMcOUGwTRNaJomTH3u5yeTCer1OhaLBXRdFyLgarUSDoBpmqKAoH6fu3vu4aPRKOr1Ora3t7FarUT2x2nBJvFxvV7LKTgIAmQyGSyXS9i2jVgsJrtyBjWR0MeTbr/fh+d52NrawtbWlnzfJuGQp2RGLDNVkjt/ZhFEo1EJg+JJvVKpIBKJ4ObNm1LgI5EIYrEYRqORNDfdbhfpdFqIjXQ+5Hh+MBiImqLf70sWA90VvzPm2rZtRKNRITkCEILmpisj1Q5s0OgguQlOBuhRQYIjTZm4EgKAdDqNMAylcdE0TYyvAEjUs4KCgoLC9cOVG4TT01MAkNE8HQcZPsTY30wmg3w+j06nI03CptVvNBoVkhvdB4FvF6R2uw3LslAqlaQwBkEgREOy++kHYNu2mCNxRUA5Yb1eh2EYkopIMx+S8QaDgWQGcCIwn89l3cCmgzHXbE64shiPx2KK5Ps+wjDEeDyWVMXvlAGWSiXkcjnxKSC5k5bG5AJsplkCEHUIo63J3wAuJgKbfAU2QZqm4ezsDNVqVVIds9msrEyYYhkEAfL5/KVmgWZKnCK4rovT01NEo1Hs7OxA0zTJY2DuAqc2lHPev39fbK5N01RpjgoKCgrXFFduECi3YwgT9/+apiGfz4slMi12Y7EYhsOhOPxFIhEZ91MBQL4Av5+TAEYVM8a4XC5LM8KJweYJlhMCjrmp/WdDwkLPk/J4PJbpRr/fF9dFEgKHw6EUXer5aY9cLBbFcIj8A64+mD5pmqac9tfrtbgh8vnVajUJZ+I0hpyMMAwxGAzEipoJlblcTor2bDYTXgKfB9UVm06RW1tbYv0MAPV6XSYs2WwWsVhMDJa42mACJRshJka2221RjNBDolgsSggU35MwDCV5k5wRqjIUFBQUFK4frtwgOI4j7PhYLCYRzRx5z2YzOUkyJZENAF0EyZzniZnxwZtyRhZnMuZ5qk0kEigWi6LpH41GYgHMcCgWdRZsmvYwp4GOgFxV8P9sKthoGIYhP5OPRZlfv98XfkA+n7/UbLBJYKKl4zjiGeA4DjKZDIbDIbLZrKQwktfBfX42m5W0xVQqJROF8XgsXIjhcCgqDE5POAXgc+CEoFqtyurC9315PSzgXNPQgKrZbErDBkDyGShrZVZENBpFo9GQtQYdMtnQTKdT8XhQVssKCgoK1xdXbhBILCS/gCdxFldK9KhGoM6eBEWuFVjYeGJm08ATPCWGjA1mweEEAQAKhQISiYRY/+q6jiAIcHx8LMx5TdPQ7XZFGUHOhO/76HQ6wq6nTJNf5+MBF8Uxn8/DNE1ZO9BQiJHX3O9TRhiPx8WTwfd9KdQkCoZhiLOzMwAXaxWSItlMjUYjABeyUk4NuJ7J5XIiZbRtG51OB4VCAdPpFMPhUCYOdFDkdaZ9Mh8LAJrNJizLkskBVxZcx1DFAEAaOXoe0MeiWCzCdV1YliVKBTYf5EAoFYOCgoLC9caVG4RWqwUAkiWQzWZhWZbI3ajHNwwDmqah3W7LqZS7aRZgsvE5yuZplcx/z/PEy4A+AcwvGAwGoqRgEWKTQiMiEgfpk8ATOQAZnTM+efMETyIk44/L5TIymYzID03TFDkgpyiz2UzspH3fx3w+RyaTkdTFvb09RCIRibMej8fY3d3FarUS4iOASw6Hm2ZSnIBomiaGRYyxZsGvVqsYDodiS01Vga7r4kWQyWSQzWZlTWBZlvAeFouFNBKcYrBpGAwGSCaTcF0X+Xxepg7AxUqH38drruu6eDDw2pNLoaCgoKBw/fBUrJZZDJgPwJN2GIbwPE9OnfTxByAnchr4ULfPkb9lWbJ7X6/X0DQNvV5PJIRMRGQDAnw7F4KNBaWGm1wESiPp3kgr6E1FA9n8sVhMYqd5SqbMklMEqgri8TgMwxA1AwtiLBaTEfxkMkGpVEKr1RLnyXw+D13X0e12JZFxuVxKVgF9C/j6bNuWos3nwudL4yESC3nd+TrYtNBjgWRJTnIYasWJCA2vEomEhGvx/eVKiNbJuq5LCBSjoNn4ua4rJEX6VXBC87u/+7sAlF2twvMLde8qPK945lbLNESifp/WvbyJaVbEqQHDlZLJJObzueziPc8Tx0Q2BRz984XwhEyGPU+4m4WQMcybYUCDwUBIlPyPZEYy9ykxZGHn2oNyPp7gKVlkQJJpmkJqdBxH5Im0YmaTs7OzI03I1tYW5vM5DMOQbIlMJoN2uy1ZFLx22WxWfB94/bjOWC6Xki1BuelgMJDmh9LDarUKz/NwfHyMTCaDTqcjigQA36VS4DqIcc3AReOUy+XEgZKKCjYsmz4Is9lMXsNyuRTSIpsv4ILkycdQUFBQULh+uHKDQGkfNfck7dGPQNM0GcszddGyLElMJAOfxZx7/8lkIuRBNheTyUTG7Ww8aHC0qd93HEeKJAl8XEckEgmZPlBuxyLM0342mxWTJa4syJ3I5/Ni+sOTMGV+nE4wcZEEx52dHZFN8u8xk6Lf78t1icfjKJVK4qQ4m80QiUQwGo0wGAwkaTGZTCISiYiqwTAM8U7g14MgkGv05MkT2LYt8kkA8riapmE0GsmkhM0dGwFySzhR6Pf7sqIhf4RTGBoy0Sdhc41jWRbG4zFyuRwASCbEH//xH1/1FlRQUFBQeAa4coMwnU4vmRexoE4mEzEQ4pg/Go3Kjps+BTz9M6+Ae3yGG7GIUdXA0yrdAjkCJ0ufRXU4HMpUgLp9+iBsBi4Vi0XZt29KNC3LguM4MvIvFouIxWLo9XpSeLmGoOSv1WqJgRLjmYvFokw9OFaPRCIolUrCRTg5OcGtW7cwm83guq5MHVKpFJbLpeQ0nJ2dIZlMyrqmUqnIqmJra0vsq9lMxONxnJ2dSQEn78N1XQyHQ7FQ5mSENsuUR3JiAeBSngJ/zYaLBk5UhdDvIZPJyEqJ05SjoyMJdFIcBAUFBYXri6eiYqAqIQxDZDIZTKdTpNNpmSj0ej0YhoHhcCgeBLlcThqC9XotnAMWHxZ/mvjQsIh8ARblwWAgJ3aeYvnzeQKm/bDruuLYyOdxfn4O0zSFtc9GhLbOhmGIPwNfLyONSdrbDJWaTCY4OzvDfD7HjRs3MJlMhMjJk3WtVpPnOBgMkEgk0Gq1hG/ASUA6nUa/38fZ2RmWyyUqlQoAyM6fq4zpdCqKAEZep1Ip9Pt93L59W/IrRqOR8DBI5CQfgQTGTU+J2WwmzQ2nJZxC9Ho9IUTSmnrTs4ErD/IamGWRTqcvSUgVFBQUFK4nrtwgcLfuOI4k9MXjceTzeSk8mzt/7rQjkYg4K3qed8kdkHHELGTRaBT7+/si9WNxYUYAuQi+78OyLFlnUIHAdcbW1hYSiQROT0/F3Ihph4ZhSCFst9uYz+eSTQBAzJd4WiaXIRKJYG9vD7du3QJw0dj0+32Uy2UpirVaTU7mVD5EIhGUy2VEo1F0Oh0hKPq+LyuO999/H7qu48UXX8R6vcbZ2ZkQOtlIMS6a/gJsGBqNBmq1GqLRKIbD4aUwq1wuJ+6SJBsahiGukmxQaCK1XC7lMdgsMQNC13VpcjgZYKPU7/eFr0GCJx0ydV1Xcc8KCgoK1xgfuEHgTp6gYqFYLIrfwHw+x2AwEP+C1WolJ1MqEIAL6R518YZhyCibTPdYLIb5fC5xx6lUCr1eT9QEHIVbliX5AiyE5+fnsiLgPjyVSkma4WAwEM5Bv98XYh4bgfV6jRs3bsgkQdO0S4mQfCxyJ8IwxKc+9Slsb2/DNE0pynR95P/r9Tref/99mKYp14HeC+PxGN1uF77vw/d9OI4j6ozZbIZmsykR2JxCcHVC9UE0GpXIav4Mqgym06lMFSaTCRzHgWVZUqDJiVgsFkLmHA6Hskrhz2SuAo2vKBnl5AX4tr8CPSc2JaKpVAqO48i0iVOMZ43vd+8qKFx3qHtX4UeJD/wJ/eu//uv40pe+9F1/PplMZM9M/TsA2fmTDJjP56UA2rYtJ03uuAGIYoBfAyCcBTLuyRVotVpS2BKJhJxQmfxItz4GQvEUzEnFfD6XE3EYhnLK5e6fToGe5+FTn/qUSPvo5rizs4NGo4E7d+7gzp07mM/neOONN/DOO+8AAL75zW/itddew97enpzOt7a2UKvV8I1vfAMvvPACKpUKhsMh3njjDXGLJLmPUddcJZBwyYaH15vXL5VKoVgswvd9jEYjIWMGQYD9/X1pgCKRCHzfRzqdRj6fF6Mp2lBTbkp1Ag2UqEQgz4BqFDZz5Hfw+W1KLenYOJvNcPPmTQAQb4gPA9/v3lVQuO5Q967CjxIf2Afhe3Wyu7u7+Lmf+zkh0lHRwKag1WpB0zSMx2Ox4c3n82IGtHmCpxPgpjKAOn0WHeY28IQ+nU7h+75wAtigsNDT5Y8afZL1AFwKSjJNU7wHSLzbJEQeHBxgb28Pq9VKdvOHh4e4f/8+stmsNDRPnjzB6ekpTNMUy2VaJPd6PRwfH8MwDNy5c0dcFGmd3O/30W63oWkaNE2D7/sYj8dwXVeeB0Og6FY5GAykqHMVoGkaSqUSOp2OyEnH47GsbRigtZk0SVlnoVCQNQ6nOv1+X3IYOAVg1DNXIgyM4rUh4RGA2Din02lMp1ORwlLy+MYbbwB4tlry73fvKig8Dah7V+F5xQ+6d69slPSLv/iLEtNMFQHDjQAI8TCVSkkx5ziae/BoNCpqBWYzMBNh0zqZxZ+n1lgsJva/5CWwmFLZkMlkUCqVMJ1OJReCXgZ8fDL+qXBghgMAWWXQJpm+DdlsFvfv38d0OsXXvvY1MYuiYRQnFbZtI5vNCrOf0wl6QDA3YjKZCOlv02fA8zxpBLhmYOpkEATo9/uS9mgYhrzO9XotjoosyORpmKYpCZNcNTC5cjPIKpPJiBKEqw5eX9u2L9k/c/LCJsTzPBSLRWn8uBYhJ4Vkx//7f//vB7pRnyaU2YzC04S6dxWeVzxzoyTP85DJZGQP7rouXNeVPTdXDYZhyKifuQqb7no8iTK7oFQqictfJpNBJBIRq15+P3MQKpWKkO0AyJi+UCggDEO0Wq1LroibGn76BvCx6HKYzWZxcHCASCSCZrMpa41IJIL79+8jHo/jP//n/wzXdXH79m0xbiJ7PwxD1Ot1+L6PR48eXUqFNAwD29vbiEQiOD09RblcRj6fF4fIs7Mz8URgMebrY6PEkwXlljSMovkTnR85MaHtMZsFGiPxGm+mTjqOIzyJ0WgkHIxarYbRaIRGowHf98X1Utd1ZDIZWWMAkMkRrzOtnfnhRtdHBQUFBYXriSs3CCQSapomp/tMJiMyRsuysF6v8eDBA+zu7ooUkR4BlAlyzM0I4sFgIMWcfACSC5ng6Ps+Go0GAIh5EH9Nd0F6G1DZwFM29foceTPoiQTKaDQqEkrf98VGmWqEdrsN4GL3z3yJhw8fwvM8bG9vC9eCz4kqDvo/tNttlEol2LaNer0uZlPxeBz7+/sy7qc6xDAM8XOgBJRpiRz/D4dDyUwgCXO1WsnqhHJUcgmYk8AJBdMeAYgfA5Uk/Bnr9VoaFho3rVYrdDodSaxk40fZJJuSTetslcWgoKCgcL1x5Qah1+vJCJ8JiWTYL5dLdLtdkR7GYjEZRbNQkhXPUTRP+kxazGazMpbfzFJgpsCdO3cwmUzQ6/UAQDgN4/FYCIok4wGQUz5zBm7evCn8BeDbKxEaCQVBgEKhIA3IarWSJMZOpyP8CTYE1WpVxvc8pZumiSAIsFgsJI2SZlJUaNBumdOAk5MT+TozK5LJ5KUo5cFgIMWYhk7kFGy6QJIHQD8Evk4GPRmGIbkSbEB6vZ5cZ8ZWUzZKEyXGXDN1khMETkP4Xq/Xa/T7fVkD0WHzw1IxKCgoKCj8yXHlT+h8Pi+kQZ7WWcgBSDJit9sVY6EgCNBsNsVtj2uG+XyO4XAop11q/3VdFxIkAMRiMei6LiRHjtk3A4xyuZzo8DdNe+bzucj2YrEYms2mNBsk300mE/i+L1OI+/fv4+bNm8hkMnAcB1/72tdwfHwsJ3n6IdTrdeEW1Go1zOdzjEYjUWawQPLUf+vWLeE6aJoGz/PE24HqDZL5FouFkD1JoqSdNPkTlECRODmfzzEej4WMuBmbzekBpxC0oiZKpZIETG16P7DAU+rI9Es2IfF4XJwcF4uFNA18DzilodOkgoKCgsL1xJUbBBL+2CScn5+LxXI2mxXzHKoGRqOR2CCnUimx+iULn+NsFkqy9Ov1uhQ2jvx5OiXBjgU7k8nAdV1Uq1UAkAaAXggsnpZlScCR53mo1WqYTCbCyKca4bOf/SxM08Tx8THeffddnJ2diVRQ0zTMZjPkcjnJT+B4nQ6INDOibwEtp5lzUCwWZR2TyWQwGAwQhiF2d3dlwmAYBiqVCsIwFD8Hci6Oj49lKsEgp83rDkCmBZspmbR/ZgMRBAF83xdeAadAVHZQEcJJD+WOnU5HXB83CaJcP9C/gSRFTpnUikFBQUHh+uLKDUIYhuJUyJMyT5nkG9AwiDI57rCDIJBxPwuGZVkIggCO42C5XAr7n9bGm4V+MpnAMAwYhoFms3lJKmlZlpyyh8MhBoMBAIi7IYsVix69GTjpSKVSeOWVV1Aul1Gv1wFAmhuqBIbDIXzfx97eHkajEbrdrvgMJJNJmY54nifPlf4BlCtSrUAOAB0RE4kEHj58iEQigf39fUQiEfR6PbRaLclD2MxUIGgmRd4Co7M5BaCvARUj5BWwmWDzQHIhJwAkOZLgSb5EGIYioeRjbqpJeI9w3cH3ZDOWW0FBQUHh+uGpcBAY9sOiTH4BT/abboTcyfu+j8ViIUFPdFDkn9OWl7yFSqWCbreL2WwmqgYmQLquK3bJuq7LqH46naJYLKJUKol3AU+4bGyq1SqGw6GoGOhAePv2bbz44ouIxWJ4/fXX8c1vfhPr9RqGYUiDQ4Jks9nEzs4OcrmcmBOx2G7KKF3XxXg8RrFYlGvlOA7m8zmazabkKJCHYJom4vE4zs/P5ZS+WfQ33RI5paA6gMV5d3cXo9FIVhBcP7iui1gs9j0LfTQahW3bEgrF1Q1XEdlsVsKZWPDZGKbTaVQqFZFy8u9PJhNpXhaLhfAoFBQUFBSuJ67cIGxvb8t4nUQ+z/PQ7XZlrJ1MJmFZFjzPk9M4Nf+Ux5FEyNMx9++U9E0mE2xvb8MwDJFSMg66UChIMWI0NFUJnBTQ26Db7WJrawtBECASiaDT6WBvb09OwGEYolwuI5VK4fj4GLPZDG+99RZGo5FkGDC4iZyK4XCIZrOJWCwmhlA81Y/HY0wmEzEmov0wXQhLpRIcx5HAJxITPc9DqVSSnIVsNotcLicETE5E2Iis12ucnp6KYoHNhOu60rBRGcEmY9OXgX4I5FmQ8Mg1DpujTX8GGkTxseinwBAmNh5sLKhQ+U7FhIKCgoLC9cOVj3A01VmtVnBdV3wEGOLDcTfT/2azGf7dv/t3yOfz+M3f/E0Mh0P8y3/5LzGbzTCbzfDP//k/l1H6l7/8ZfzGb/yGNAxf+tKXcH5+jsVigUKhAE3T8Pbbb+NXf/VX8Wu/9mt49913cXR0hPV6jX/4D/8hZrMZ/vbf/tv4P//n/+D3f//3kUwmsbu7i8FggMePH2O5XKJWq+Hv/J2/g69//evodrvSIFQqFQRBgCdPnsA0TWxvb2M4HErq4nQ6Rb/fR7PZlIkBJZE0KyLbP5fLoVwuI5lMwjRNIeoNh0Ocnp5iMpkgn88jkUgIqXJ3dxexWAydTkckiZQHjkYjkWmyGaE982a40snJiRAKKemk3TWJoQCkqdiUfTIhkr4LXNOwuWHSJS2UqTxhuudisUC325WEShpgsfni31FQUFBQuJ54KioGyvoACJnN930ZPfOEyZUATY8ikQgODg4AQAiFAPBbv/Vb+JVf+RUAwN/8m39TJIwAZA0xn8+Ry+UuPZff+q3fwi//8i/jH/2jf4T1eo233noLv/zLv4xXX31Vnofv+6hWq2i1WmKDDFwQGQ8ODmQ1cXZ2JnyAer0uRZzMfsoOGVlNsiOjmOlSyHwFAGJDzBE7eQGz2Qy+74viYzAYXEo/ZMGmEsK2bXieJw6MVFwwr4LGT+Q8jEYjaSSWy6UYUrGpoZSSPBBmZ2wSGXd2djAYDGSFwtAqqlLYxPAeSKVSEv1Ncunm9GW9XqPT6Vz19lNQUFBQeEa4coNwdHQkhZ9kRCoPGL4UjUblRP5bv/VbAIB//+//PQDgn/7TfwoA+PKXv/w9H3+9XqNYLOJf/It/AQDCd2Bh5OMBwN//+39ffv/Lv/zL+K3f+i0hJH75y1+G4ziSaPif/tN/QiQSwT/5J/8EkUgEv/3bv43f+I3fuLQbdxxHzIYKhYJYO0+nU9i2LeP/fr8vPgbz+Vyspx3HubTL57VIJpNwXVdcJWmCRAIhTYR4mue15PWllLJYLEoQFABxgNy0qx4Oh7BtWxI3OVEAIPLQ9XoN3/flPeRagI9BcypyKUhgBHDJ9pky0UwmI2uSZDJ56e8xrZJrDwUFBQWF64krZzG89tprMjonG54n1VgsJidO4MJ1cHO3TVnefD6XcTxNh2jgE41GkclkoOu6qCWAC0UBPQDi8bgECvHvr9drbG9vi4cArYEHg4FkASSTSQyHQ3Fp5N6eBMterycTC04CyMQn34HafioCuD7QdV1G/+RLkERJIiLVDdzNU9WwGYvNUb+maWJTzHRHSjappiiXy0gkEjg/PxfPh2KxiPV6LfwQEkcpg8xkMiL7BCDESkowCfI4OD3YXBNwOsNAKZIcqVDhz6L3BJuG1WqF//pf/6u8n8rPXuF5hLp3FZ5XPPMsBpoaOY4DAEJGHAwGyOfz4iIYj8cxHA5hWZZo8SeTiYzf6dfP7+WpmysL+gQwuZFWzlxnMIdg87TPPIIgCETnz4LIhoGnWzow8u+s12tks1m0222Mx2OxKSYLnwFRXFOQpGeaJgBI/gMbAgYasUnhY+TzeaRSKbiuKwFGLLSLxQK5XE4aCqZOtlotIQ9Go1Hs7u5itVqh3W5jtVohm82K5JNTiU25YjweFx+JwWBwaSJATwQ6RJJHwhP/pqlSLpeTyG7aLOfzecznc5ycnKBQKMh7zSlDs9nEbDaT/AkFBQUFheuJKzcIjuPg9PQUtm2LDXE6ncbOzo54ITDlcHd3V0bwi8VCpgwkvPE0Sgc+FnWemAGIR8B4PL5UoACg3W6jUCig1+vBsiyk02ksl0uYpolUKiWugUEQIJ/Pi6fAarXC0dGREB/ZrMznc8lL4Gmee396AgCQXARaKYdhiFwuJ9JJav5pIOX7vjQL9Fbg67UsC51OR2KbeWIolUoYj8dotVrI5/PSkJA0SA8KTjyAb08klsulFOhUKiVx3DRBYoPEx0gkEmLoRIko/5weD7TIdl0Xuq7DNE1pmuLxOO7fv4/1ei2ESsuyRCLq+z4Gg8ElbomCgoKCwvXCUzHDz+Vy8DxPMhhYKHn6pSNhOp3GeDyWmGAAUlTIyl+tVmIdnMlksLu7i2g0ivF4LMRGThk4OeCpOJfLyV4+m80KAz8IAjm1j0YjmSjw5F0qlXD37l2RI3LHH4/HMR6P5dTLk7emaSJn5Niev2eKJUl7lmXJVGEymaDdbosvBPkRnITQX0DXdViWBcdxsFgsMBgMRP5IKSYVBI7jyFqHJ/lcLid2za7rirxzU2LJpiAej8sagg6WzLugSoLNRBiG4ppJFQTXKHw/GEm9uY6ZTqfy/lFKube3h0gkgq985StP4xZUUFBQUHjKeCpxz7quywmVTohUNlDrTiMknuAZ5MOimclkZERPDoLnefjDP/xD2LaN3d1dCULiaoLFKZ/Pi+0vi7Wu65LtQBe/SCQiz4+eAOVyGev1WrwFNnMEOCHQdR2z2UxOyXyt3MGTMzAcDpHP5zEYDJDNZsVKmBMROhPSjXHTgAiAOEsmk0l0u10Jm0okEiIfpAIkkUgID4NNzcHBgdgwc2pD0KCKJlG0U6ZPA9M0eX10XZf1TjweF5klVzaNRkM4CoyiTiQSItdcLBYy5WCQFHkqbHbYOCkoKCgoXD88NZkjfQx4KuZkoFQqAbgY/3ONwKKXy+WwWq3En5/pgCy0i8UC+/v7cvJtNBriBRCNRlEoFGT0TdMhThfi8Tjm87kUq1KphEgkgsViIQVyvV6j1WqJmZJt28jn88J/oIKhXC7Dtm10u10plqvVSkiVw+FQ5I3L5RLFYlGUHCRM0sKY8kNKNFlYOVXg6/Q8Dy+88AKSySSm0ykmkwmm06kUVcorqZgIwxD9fl+yFVKpFHZ3d7FcLjEajTAajXB+fi7eA5RCciJh27ZILqle4Nc3My84IWATtnm9gQtjKK40yJlot9uSrZHL5eSxaPakoKCgoHD9cOUGgft1kg3pFUCtfbvdRiQSkV04cLGSiMfjMh4ngQ+AcBQ4Zp/P5+j3+9ja2pL1AwmIAMTMR9M0WV0kk0nxE8jn8ygWi4jFYnBdF57nCXchl8vJCoI/G4CoJkiIbDQaCIIAlUpFHAD5MxlWRZVAJBIRVQFfo2VZkp5IfgNJl2T5s4GIxWKwbRuWZWEwGEDTNLiuC9M0RYLJ0CtyH8jjCIJArKs3cxAMw0AsFkOhUBDVRCqVkuu1qTSgEiSRSMgUhuZPnBQxpZKTAzZ98Xgco9FI+Cj9fh+TyQS6riOXy2E6naLZbCKXy0ngloKCgoLC9cRTaRAo02ORo8kPJYlUJTAs6OzsTOyASTokYbFcLiMIAiH4cZdOW2A2E5lMBpZlyekcgHgMlEqlSzHJk8lE5I3lclm+l/bBLJRMj2RM8nA4BPDt6GOqKvi62ViwsNKQqFAoIAxDWb/QZ4ATBjoRUvmx+edsjjzPk+ZA13WUSiU57VORMB6PEYYhvvzlL+PXfu3X4DiOTHO4vqD3QBAEyOVy0uCwGWKRHg6HcF0XtVpN+ATkFHBiw/cAgExc/vW//tf4G3/jb4h5FBMt+Rqz2Sym06k4aUajUbTbbZF5KigoKChcT/zQDQInBpZlCXGQiY2UHC6XS8lDYKHmaB7ApV03x+z/4B/8A/y9v/f35BS9WCykkDOkiVJE7sA3Y6LJc+BunhMOjtCbzaYQ76rVKnRdx9bWFnzfx1/8i38R/+pf/StxKaQXA5uQ4XAojclisYDjOMhkMkgkEmJqxLUFCYzJZBKTyUSmIWyg6DY4n8/hui7+2T/7Z/jVX/1VmWiQhGgYhnAKSqWSqAg2R/ue58m1ZsM2HA4luprvQaPRuDRZIJfB932MRiMhdtISO5FIiLwxEonIe8Zr3ul05FrRGImNF5tBNmpUceRyOXlvcrmcuCn+kHYcV7p3FRSeBtS9q/C84gfdTz+0UdLh4SFu3br1Qz0pBYXvxNnZGXZ2dj6Un6XuXYWnCXXvKjyv+EH37g89QaDJzenp6UfW2ct1Xezu7uLs7OxDc0r7sPGjfo1UkGxtbX1oP1Pdux8N/Khfo7p3nw1+1O/rh4Ef9Wv8oPfuD90gUAKYzWY/sm8ikclk1Gt8hviwP+jUvfvRgrp3P5pQ9+6zxQe5d68c96ygoKCgoKDw0YNqEBQUFBQUFBS+Cz90g6BpGv7xP/7HInH8KEK9xo8mPg6vWb3GjyY+Dq9Zvcbrgx9axaCgoKCgoKDw0YVaMSgoKCgoKCh8F1SDoKCgoKCgoPBdUA2CgoKCgoKCwndBNQgKCgoKCgoK3wXVICgoKCgoKCh8F1SDoKCgoKCgoPBdUA2CgoKCgoKCwndBNQgKCgoKCgoK34X/DwTNTb7bJl3ZAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axs = plt.subplots(1, 3)\n", + "axs[0].imshow(image, cmap='gray', interpolation='nearest')\n", + "axs[1].imshow(label, cmap='gray', interpolation='nearest')\n", + "axs[2].imshow(label_pred[0], cmap='gray', interpolation='nearest')\n", + "for i in range(3):\n", + " axs[i].xaxis.set_tick_params(labelbottom=False)\n", + " axs[i].yaxis.set_tick_params(labelleft=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c45aa552-042c-43a4-8219-f68d6a8f6259", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/imgx/conf/config.yaml b/imgx/conf/config.yaml index 41dd749..d9f9f39 100644 --- a/imgx/conf/config.yaml +++ b/imgx/conf/config.yaml @@ -1,18 +1,14 @@ defaults: - data: muscle_us - task: gaussian_diff_seg + # config below overwrites the values above + # https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/default_composition_order/ + - _self_ debug: False seed: 0 half_precision: True -loss: - dice: 1.0 - cross_entropy: 0.0 - focal: 20.0 - mse: 0.0 - vlb: 0.0 - optimizer: name: "adamw" kwargs: @@ -28,6 +24,7 @@ optimizer: end_value: 5e-05 logging: + root_dir: log_freq: 10 save_freq: 500 wandb: diff --git a/imgx/conf/data/amos_ct.yaml b/imgx/conf/data/amos_ct.yaml index 73e7e2a..0044b58 100644 --- a/imgx/conf/data/amos_ct.yaml +++ b/imgx/conf/data/amos_ct.yaml @@ -14,3 +14,6 @@ trainer: batch_size: 8 # all model replicas are updated every `batch_size` samples batch_size_per_replica: 1 # each model replicate takes `batch_size_per_replica` samples per step num_devices_per_replica: 1 # model is split into num_devices_per_replica shards/slices + +patch_size: [2, 2, 2] +scale_factor: [2, 2, 2] diff --git a/imgx/conf/data/brats2021_mr.yaml b/imgx/conf/data/brats2021_mr.yaml index 35c8eb8..04b29b9 100644 --- a/imgx/conf/data/brats2021_mr.yaml +++ b/imgx/conf/data/brats2021_mr.yaml @@ -14,3 +14,6 @@ trainer: batch_size: 8 # all model replicas are updated every `batch_size` samples batch_size_per_replica: 1 # each model replicate takes `batch_size_per_replica` samples per step num_devices_per_replica: 1 # model is split into num_devices_per_replica shards/slices + +patch_size: [2, 2, 2] +scale_factor: [2, 2, 2] diff --git a/imgx/conf/data/male_pelvic_mr.yaml b/imgx/conf/data/male_pelvic_mr.yaml index 1de4dca..e9d709d 100644 --- a/imgx/conf/data/male_pelvic_mr.yaml +++ b/imgx/conf/data/male_pelvic_mr.yaml @@ -14,3 +14,6 @@ trainer: batch_size: 8 # all model replicas are updated every `batch_size` samples batch_size_per_replica: 1 # each model replicate takes `batch_size_per_replica` samples per step num_devices_per_replica: 1 # model is split into num_devices_per_replica shards/slices + +patch_size: [2, 2, 2] +scale_factor: [2, 2, 2] diff --git a/imgx/conf/data/muscle_us.yaml b/imgx/conf/data/muscle_us.yaml index d8351ed..6706b26 100644 --- a/imgx/conf/data/muscle_us.yaml +++ b/imgx/conf/data/muscle_us.yaml @@ -14,3 +14,6 @@ trainer: batch_size: 64 # all model replicas are updated every `batch_size` samples batch_size_per_replica: 8 # each model replicate takes `batch_size_per_replica` samples per step num_devices_per_replica: 1 # model is split into num_devices_per_replica shards/slices + +patch_size: [2, 2] +scale_factor: [2, 2] diff --git a/imgx/conf/task/gaussian_diff_seg.yaml b/imgx/conf/task/gaussian_diff_seg.yaml index 803c210..eab7179 100644 --- a/imgx/conf/task/gaussian_diff_seg.yaml +++ b/imgx/conf/task/gaussian_diff_seg.yaml @@ -31,12 +31,20 @@ model: _target_: imgx.model.Unet remat: True num_spatial_dims: 3 - patch_size: 2 + patch_size: MISSING # data dependent, will be set after loading config + scale_factor: MISSING # data dependent, will be set after loading config num_channels: [32, 64, 128, 256] - out_channels: MISSING # will be set after loading config + out_channels: MISSING # data dependent, will be set after loading config num_heads: 8 widening_factor: 4 +loss: + dice: 1.0 + cross_entropy: 0.0 + focal: 20.0 + mse: 0.0 + vlb: 0.0 + early_stopping: # used on validation set metric: "mean_binary_dice_score_without_background" mode: "max" diff --git a/imgx/conf/task/seg.yaml b/imgx/conf/task/seg.yaml index a866da1..b414c06 100644 --- a/imgx/conf/task/seg.yaml +++ b/imgx/conf/task/seg.yaml @@ -4,11 +4,18 @@ model: _target_: imgx.model.Unet remat: True num_spatial_dims: 3 - patch_size: 2 + patch_size: MISSING # data dependent, will be set after loading config + scale_factor: MISSING # data dependent, will be set after loading config num_channels: [32, 64, 128, 256] - out_channels: MISSING # will be set after loading config + out_channels: MISSING # data dependent, will be set after loading config num_heads: 8 widening_factor: 4 + num_transform_layers: 1 + +loss: + dice: 1.0 + cross_entropy: 0.0 + focal: 20.0 early_stopping: # used on validation set metric: "mean_binary_dice_score_without_background" diff --git a/imgx/config.py b/imgx/config.py index 80db555..549dc89 100644 --- a/imgx/config.py +++ b/imgx/config.py @@ -1,7 +1,7 @@ """Module for configuration related functions.""" -def flatten_dict(d: dict, parent_key: str = "", sep: str = "_") -> dict: +def flatten_dict(d: dict, parent_key: str = "", sep: str = "_") -> dict: # type:ignore[type-arg] """Flat a nested dict. Args: diff --git a/imgx/data/__init__.py b/imgx/data/__init__.py index a402718..c6d0719 100644 --- a/imgx/data/__init__.py +++ b/imgx/data/__init__.py @@ -1 +1,9 @@ """Module to handle data.""" +from __future__ import annotations + +from typing import Callable + +import jax +import jax.numpy as jnp + +AugmentationFn = Callable[[jax.Array, dict[str, jnp.ndarray]], dict[str, jnp.ndarray]] diff --git a/imgx/data/affine.py b/imgx/data/affine.py new file mode 100644 index 0000000..e2acbdd --- /dev/null +++ b/imgx/data/affine.py @@ -0,0 +1,389 @@ +"""Affine transformation for image and label.""" +from __future__ import annotations + +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +from omegaconf import DictConfig + +from imgx.data import AugmentationFn +from imgx.data.util import get_batch_size +from imgx.data.warp import batch_grid_sample, get_coordinate_grid +from imgx_datasets import INFO_MAP +from imgx_datasets.constant import FOREGROUND_RANGE, IMAGE, LABEL + + +def get_2d_rotation_matrix( + radians: jnp.ndarray, +) -> jnp.ndarray: + """Return 2d rotation matrix given radians. + + The affine transformation applies as following: + [x, = [[* * 0] * [x, + y, [* * 0] y, + 1] [0 0 1]] 1] + + Args: + radians: tuple of one values, correspond to xy planes. + + Returns: + Rotation matrix of shape (3, 3). + """ + sin, cos = jnp.sin(radians[0]), jnp.cos(radians[0]) + return jnp.array( + [ + [cos, -sin, 0.0], + [sin, cos, 0.0], + [0.0, 0.0, 1.0], + ] + ) + + +def get_3d_rotation_matrix( + radians: jnp.ndarray, +) -> jnp.ndarray: + """Return 3d rotation matrix given radians. + + The affine transformation applies as following: + [x, = [[* * * 0] * [x, + y, [* * * 0] y, + z, [* * * 0] z, + 1] [0 0 0 1]] 1] + + Args: + radians: tuple of three values, correspond to yz, xz, xy planes. + + Returns: + Rotation matrix of shape (4, 4). + """ + affine = jnp.eye(4) + + # rotation of yz around x-axis + sin, cos = jnp.sin(radians[0]), jnp.cos(radians[0]) + affine_ax = jnp.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, cos, -sin, 0.0], + [0.0, sin, cos, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + affine = jnp.matmul(affine_ax, affine) + + # rotation of zx around y-axis + sin, cos = jnp.sin(radians[1]), jnp.cos(radians[1]) + affine_ax = jnp.array( + [ + [cos, 0.0, sin, 0.0], + [0.0, 1.0, 0.0, 0.0], + [-sin, 0.0, cos, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + affine = jnp.matmul(affine_ax, affine) + + # rotation of xy around z-axis + sin, cos = jnp.sin(radians[2]), jnp.cos(radians[2]) + affine_ax = jnp.array( + [ + [cos, -sin, 0.0, 0.0], + [sin, cos, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + affine = jnp.matmul(affine_ax, affine) + + return affine + + +def get_rotation_matrix( + radians: jnp.ndarray, +) -> jnp.ndarray: + """Return rotation matrix given radians. + + Args: + radians: correspond to rotate around each axis. + + Returns: + Rotation matrix of shape (n+1, n+1). + + Raises: + ValueError: if not 2D or 3D. + """ + if radians.size == 1: + return get_2d_rotation_matrix(radians) + if radians.size == 3: + return get_3d_rotation_matrix(radians) + raise ValueError("Only support 2D/3D rotations.") + + +def get_translation_matrix( + shifts: jnp.ndarray, +) -> jnp.ndarray: + """Return 3d translation matrix given shifts. + + For example, the 3D affine transformation applies as following: + [x, = [[1 0 0 *] * [x, + y, [0 1 0 *] y, + z, [0 0 1 *] z, + 1] [0 0 0 1]] 1] + + Args: + shifts: correspond to each axis shift. + + Returns: + Translation matrix of shape (n+1, n+1). + """ + ndims = shifts.size + shifts = jnp.concatenate([shifts, jnp.array([1.0])]) + return jnp.concatenate( + [ + jnp.eye(ndims + 1, ndims), + shifts[:, None], + ], + axis=1, + ) + + +def get_scaling_matrix( + scales: jnp.ndarray, +) -> jnp.ndarray: + """Return scaling matrix given scales. + + For example, the 3D affine transformation applies as following: + [x, = [[* 0 0 0] * [x, + y, [0 * 0 0] y, + z, [0 0 * 0] z, + 1] [0 0 0 1]] 1] + + Args: + scales: correspond to each axis scaling. + + Returns: + Affine matrix of shape (n+1, n+1). + """ + scales = jnp.concatenate([scales, jnp.array([1.0])]) + return jnp.diag(scales) + + +def get_affine_matrix( + radians: jnp.ndarray, + shifts: jnp.ndarray, + scales: jnp.ndarray, +) -> jnp.ndarray: + """Return an affine matrix from parameters. + + The matrix is not squared, as the last row is not needed. For rotation, + translation, and scaling matrix, they are kept for composition purpose. + For example, the 3D affine transformation applies as following: + [x, = [[* * * *] * [x, + y, [* * * *] y, + z, [* * * *] z, + 1] [0 0 0 1]] 1] + + Args: + radians: correspond to rotate around each axis. + shifts: correspond to each axis shift. + scales: correspond to each axis scaling. + + Returns: + Affine matrix of shape (n+1, n+1). + """ + affine_rot = get_rotation_matrix(radians) + affine_shift = get_translation_matrix(shifts) + affine_scale = get_scaling_matrix(scales) + return jnp.matmul(affine_shift, jnp.matmul(affine_scale, affine_rot)) + + +def batch_get_random_affine_matrix( + key: jax.Array, + max_rotation: jnp.ndarray, + min_translation: jnp.ndarray, + max_translation: jnp.ndarray, + max_scaling: jnp.ndarray, +) -> jnp.ndarray: + """Get a batch of random affine matrices. + + Args: + key: jax random key. + max_rotation: maximum rotation in radians, (1,) for 2d and (2,) for 3d. + min_translation: minimum translation in pixel/voxels, (num_spatial_dims,). + max_translation: maximum translation in pixel/voxels, (num_spatial_dims,). + max_scaling: maximum scaling difference in pixel/voxels, (num_spatial_dims,). + + Returns: + Affine matrix of shape (batch, n+1, n+1), n is num_spatial_dims. + """ + key_radian, key_shift, key_scale = jax.random.split(key, num=3) + radians = jax.random.uniform( + key=key_radian, + shape=max_rotation.shape, + minval=-max_rotation, + maxval=max_rotation, + ) + shifts = jax.random.uniform( + key=key_shift, + shape=max_translation.shape, + minval=min_translation, + maxval=max_translation, + ) + scales = jax.random.uniform( + key=key_scale, + shape=max_scaling.shape, + minval=1.0 - max_scaling, + maxval=1.0 + max_scaling, + ) + # vmap on first axis, which is a batch + return jax.vmap(get_affine_matrix)(radians, shifts, scales) + + +def apply_affine_to_grid(grid: jnp.ndarray, affine_matrix: jnp.ndarray) -> jnp.ndarray: + """Apply affine matrix to grid. + + The grid has non-negative coordinates, means the origin is at a corner. + Need to shift the grid such that the origin is at center, + then apply affine, then shift the origin back. + + Args: + grid: grid coordinates, of shape (n, d1, ..., dn). + grid[:, i1, ..., in] = [i1, ..., in] + affine_matrix: shape (n+1, n+1) + + Returns: + Grid with updated coordinates. + """ + # (n+1, d1, ..., dn) + extended_grid = jnp.concatenate([grid, jnp.ones((1,) + grid.shape[1:])], axis=0) + + # shift to center + shift = (jnp.array(grid.shape[1:]) - 1) / 2 + shift_matrix = get_translation_matrix(-shift) # (n+1, n+1) + # (n+1, n+1) * (n+1, d1, ..., dn) = (n+1, d1, ..., dn) + extended_grid = jnp.einsum("ji,i...->j...", shift_matrix, extended_grid) + + # affine + # (n+1, n+1) * (n+1, d1, ..., dn) = (n+1, d1, ..., dn) + extended_grid = jnp.einsum("ji,i...->j...", affine_matrix, extended_grid) + + # shift to corner + shift_matrix = get_translation_matrix(shift)[:-1, :] # (n, n+1) + # (n, n+1) * (n+1, d1, ..., dn) = (n, d1, ..., dn) + extended_grid = jnp.einsum("ji,i...->j...", shift_matrix, extended_grid) + + return extended_grid + + +def batch_apply_affine_to_grid(grid: jnp.ndarray, affine_matrix: jnp.ndarray) -> jnp.ndarray: + """Apply batch of affine matrix to grid. + + Args: + grid: grid coordinates, of shape (n, d1, ..., dn). + grid[:, i1, ..., in] = [i1, ..., in] + affine_matrix: shape (batch, n+1, n+1). + + Returns: + Grid with updated coordinates, shape (batch, n, d1, ..., dn). + """ + return jax.vmap(apply_affine_to_grid, in_axes=(None, 0))(grid, affine_matrix) + + +def batch_random_affine_transform( + key: jax.Array, + batch: dict[str, jnp.ndarray], + image_shape: tuple[int, ...], + grid: jnp.ndarray, + max_rotation: jnp.ndarray, + max_translation: jnp.ndarray, + max_scaling: jnp.ndarray, +) -> dict[str, jnp.ndarray]: + """Keep image and label only. + + Args: + key: jax random key. + batch: dict having images or labels, or foreground_range. + images have shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c) + labels have shape (batch, d1, ..., dn) + batch should not have other keys such as UID. + if foreground_range exists, it's pre-calculated based on label, it's + pre-calculated because nonzero function is not jittable. + image_shape: image spatial shape, (d1, ..., dn). + grid: grid coordinates, of shape (n, d1, ..., dn). + grid[:, i1, ..., in] = [i1, ..., in] + max_rotation: maximum rotation in radians, shape = (batch, ...). + max_translation: maximum translation in pixel/voxels, + shape = (batch, d1, ..., dn). + max_scaling: maximum scaling difference in pixel/voxels, + shape = (batch, d1, ..., dn). + + Returns: + Augmented dict having image and label, shapes are not changed. + """ + batch_size = get_batch_size(batch) + + # (batch, ...) + max_rotation = jnp.tile(max_rotation[None, ...], (batch_size, 1)) + max_translation = jnp.tile(max_translation[None, ...], (batch_size, 1)) + min_translation = -max_translation + max_scaling = jnp.tile(max_scaling[None, ...], (batch_size, 1)) + + # refine translation to avoid removing classes + if FOREGROUND_RANGE in batch: + shape = jnp.array(image_shape) + shape = jnp.tile(shape[None, ...], (batch_size, 1)) + max_translation = jnp.minimum(max_translation, shape - 1 - batch[FOREGROUND_RANGE][..., -1]) + min_translation = jnp.maximum(min_translation, -batch[FOREGROUND_RANGE][..., 0]) + + # (batch, n+1, n+1) + affine_matrix = batch_get_random_affine_matrix( + key=key, + max_rotation=max_rotation, + min_translation=min_translation, + max_translation=max_translation, + max_scaling=max_scaling, + ) + + # (batch, n, d1, ..., dn) + grid = batch_apply_affine_to_grid(grid=grid, affine_matrix=affine_matrix) + + resampled_batch = {} + for k, v in batch.items(): + if LABEL in k: + # assume label related keys have label in name + resampled_batch[k] = batch_grid_sample(x=v, grid=grid, order=0) + elif IMAGE in k: + # assume image related keys have image in name + resampled_batch[k] = batch_grid_sample(x=v, grid=grid, order=1) + elif k == FOREGROUND_RANGE: + # not needed anymore + continue + else: + raise ValueError(f"Unknown key {k} in batch.") + return resampled_batch + + +def get_random_affine_augmentation_fn(config: DictConfig) -> AugmentationFn: + """Return a data augmentation function for random affine transformation. + + Args: + config: entire config. + + Returns: + A data augmentation function. + """ + dataset_info = INFO_MAP[config.data.name] + image_shape = dataset_info.image_spatial_shape + grid = get_coordinate_grid(shape=image_shape) + max_rotation = np.array(config.data.loader.data_augmentation.max_rotation) + max_translation = np.array(config.data.loader.data_augmentation.max_translation) + max_scaling = np.array(config.data.loader.data_augmentation.max_scaling) + return partial( + batch_random_affine_transform, + image_shape=image_shape, + grid=grid, + max_rotation=max_rotation, + max_translation=max_translation, + max_scaling=max_scaling, + ) diff --git a/imgx/data/augmentation_test.py b/imgx/data/affine_test.py similarity index 82% rename from imgx/data/augmentation_test.py rename to imgx/data/affine_test.py index a2ab532..6321153 100644 --- a/imgx/data/augmentation_test.py +++ b/imgx/data/affine_test.py @@ -1,6 +1,5 @@ """Test function for data augmentation.""" - import chex import jax import jax.numpy as jnp @@ -8,17 +7,16 @@ from absl.testing import parameterized from chex._src import fake -from imgx.data.augmentation import ( +from imgx.data.affine import ( batch_apply_affine_to_grid, batch_get_random_affine_matrix, batch_random_affine_transform, - batch_resample_image_label, get_affine_matrix, get_rotation_matrix, get_scaling_matrix, get_translation_matrix, ) -from imgx.metric.centroid import get_coordinate_grid +from imgx.data.warp import get_coordinate_grid from imgx_datasets.constant import FOREGROUND_RANGE, IMAGE, LABEL @@ -464,7 +462,7 @@ def test_values( max_rotation: np.ndarray, max_translation: np.ndarray, max_scaling: np.ndarray, - expected_shape: tuple, + expected_shape: tuple[int, ...], ) -> None: """Test affine matrix values. @@ -582,125 +580,6 @@ def test_values( chex.assert_trees_all_equal(got, batch_expected) -class TestResample(chex.TestCase): - """Test apply_affine_to_grid.""" - - @chex.all_variants() - @parameterized.product( - ( - { - "image": np.asarray( - [ - [ - [2.0, 1.0, 0.0], - [0.0, 3.0, 4.0], - ], - [ - [2.0, 1.0, 0.0], - [0.0, 3.0, 4.0], - ], - ], - ), - "label": np.asarray( - [ - [ - [2.0, 1.0, 0.0], - [0.0, 3.0, 4.0], - ], - [ - [2.0, 1.0, 0.0], - [0.0, 3.0, 4.0], - ], - ], - ), - "grid": np.asarray( - [ - # first image, un changed - [ - # x axis - [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], - # y axis - [[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]], - ], - # second image, changed - # (0.4, 0) x-axis linear interpolation - # (0, 0.6) y-axis linear interpolation - # (0.4, 1.6) x/y-axis linear interpolation - # (1.0, 3.0) out of boundary - [ - # x axis - [[0.4, 0.0, 0.4], [1.0, 1.0, 1.0]], - # y axis - [[0.0, 0.6, 1.6], [0.0, 3.0, 2.0]], - ], - ] - ), # (batch=2, n=2, d1=2, d2=3) - "expected_image": np.asarray( - [ - [ - [2.0, 1.0, 0.0], - [0.0, 3.0, 4.0], - ], - [ - [1.2, 1.4, 1.68], - [0.0, 0.0, 4.0], - ], - ], - ), - "expected_label": np.asarray( - [ - [ - [2.0, 1.0, 0.0], - [0.0, 3.0, 4.0], - ], - [ - [2.0, 1.0, 0.0], - [0.0, 0.0, 4.0], - ], - ], - ), - }, - ), - num_channels=[0, 1, 2], - ) - def test_shapes( - self, - image: np.ndarray, - label: np.ndarray, - grid: np.ndarray, - expected_image: np.ndarray, - expected_label: np.ndarray, - num_channels: int, - ) -> None: - """Test affine matrix values. - - Test affine matrix shapes, and test random seed impact. - - Args: - image: input image batch. - label: input label batch. - grid: batch of grid with affine applied. - expected_image: expected image. - expected_label: expected label. - num_channels: number of channels to add to image. - """ - if num_channels == 1: - image = image[..., None] - expected_image = expected_image[..., None] - elif num_channels > 1: - reps = (1,) * (len(image.shape) - 1) + (num_channels,) - image = np.tile(image[..., None], reps) - expected_image = np.tile(expected_image[..., None], reps) - - batch = {IMAGE: image, LABEL: label} - got = self.variant(batch_resample_image_label)( - batch=batch, - grid=grid, - ) - expected = {IMAGE: expected_image, LABEL: expected_label} - chex.assert_trees_all_close(got, expected) - - class TestRandomAffineTransformation(chex.TestCase): """Test batch_random_affine_transform.""" @@ -728,7 +607,7 @@ def test_shapes( max_rotation: np.ndarray, max_translation: np.ndarray, max_scaling: np.ndarray, - image_shape: tuple, + image_shape: tuple[int, ...], ) -> None: """Test affine matrix values. @@ -754,6 +633,7 @@ def test_shapes( got = self.variant(batch_random_affine_transform)( key=key, batch=batch, + image_shape=image_shape, grid=grid, max_rotation=max_rotation, max_translation=max_translation, diff --git a/imgx/data/augmentation.py b/imgx/data/augmentation.py index 1b69762..c57947f 100644 --- a/imgx/data/augmentation.py +++ b/imgx/data/augmentation.py @@ -2,414 +2,16 @@ from __future__ import annotations from collections.abc import Sequence -from functools import partial -from typing import Callable import jax import jax.numpy as jnp -import numpy as np -from jax.scipy.ndimage import map_coordinates -from omegaconf import DictConfig -from imgx.data.patch import batch_patch_random_sample -from imgx.metric.centroid import get_coordinate_grid -from imgx_datasets import INFO_MAP -from imgx_datasets.constant import FOREGROUND_RANGE, IMAGE, LABEL +from imgx.data import AugmentationFn -def get_2d_rotation_matrix( - radians: jnp.ndarray, -) -> jnp.ndarray: - """Return 2d rotation matrix given radians. - - The affine transformation applies as following: - [x, = [[* * 0] * [x, - y, [* * 0] y, - 1] [0 0 1]] 1] - - Args: - radians: tuple of one values, correspond to xy planes. - - Returns: - Rotation matrix of shape (3, 3). - """ - sin, cos = jnp.sin(radians[0]), jnp.cos(radians[0]) - return jnp.array( - [ - [cos, -sin, 0.0], - [sin, cos, 0.0], - [0.0, 0.0, 1.0], - ] - ) - - -def get_3d_rotation_matrix( - radians: jnp.ndarray, -) -> jnp.ndarray: - """Return 3d rotation matrix given radians. - - The affine transformation applies as following: - [x, = [[* * * 0] * [x, - y, [* * * 0] y, - z, [* * * 0] z, - 1] [0 0 0 1]] 1] - - Args: - radians: tuple of three values, correspond to yz, xz, xy planes. - - Returns: - Rotation matrix of shape (4, 4). - """ - affine = jnp.eye(4) - - # rotation of yz around x-axis - sin, cos = jnp.sin(radians[0]), jnp.cos(radians[0]) - affine_ax = jnp.array( - [ - [1.0, 0.0, 0.0, 0.0], - [0.0, cos, -sin, 0.0], - [0.0, sin, cos, 0.0], - [0.0, 0.0, 0.0, 1.0], - ] - ) - affine = jnp.matmul(affine_ax, affine) - - # rotation of zx around y-axis - sin, cos = jnp.sin(radians[1]), jnp.cos(radians[1]) - affine_ax = jnp.array( - [ - [cos, 0.0, sin, 0.0], - [0.0, 1.0, 0.0, 0.0], - [-sin, 0.0, cos, 0.0], - [0.0, 0.0, 0.0, 1.0], - ] - ) - affine = jnp.matmul(affine_ax, affine) - - # rotation of xy around z-axis - sin, cos = jnp.sin(radians[2]), jnp.cos(radians[2]) - affine_ax = jnp.array( - [ - [cos, -sin, 0.0, 0.0], - [sin, cos, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - ] - ) - affine = jnp.matmul(affine_ax, affine) - - return affine - - -def get_rotation_matrix( - radians: jnp.ndarray, -) -> jnp.ndarray: - """Return rotation matrix given radians. - - Args: - radians: correspond to rotate around each axis. - - Returns: - Rotation matrix of shape (n+1, n+1). - - Raises: - ValueError: if not 2D or 3D. - """ - if radians.size == 1: - return get_2d_rotation_matrix(radians) - if radians.size == 3: - return get_3d_rotation_matrix(radians) - raise ValueError("Only support 2D/3D rotations.") - - -def get_translation_matrix( - shifts: jnp.ndarray, -) -> jnp.ndarray: - """Return 3d translation matrix given shifts. - - For example, the 3D affine transformation applies as following: - [x, = [[1 0 0 *] * [x, - y, [0 1 0 *] y, - z, [0 0 1 *] z, - 1] [0 0 0 1]] 1] - - Args: - shifts: correspond to each axis shift. - - Returns: - Translation matrix of shape (n+1, n+1). - """ - ndims = shifts.size - shifts = jnp.concatenate([shifts, jnp.array([1.0])]) - return jnp.concatenate( - [ - jnp.eye(ndims + 1, ndims), - shifts[:, None], - ], - axis=1, - ) - - -def get_scaling_matrix( - scales: jnp.ndarray, -) -> jnp.ndarray: - """Return scaling matrix given scales. - - For example, the 3D affine transformation applies as following: - [x, = [[* 0 0 0] * [x, - y, [0 * 0 0] y, - z, [0 0 * 0] z, - 1] [0 0 0 1]] 1] - - Args: - scales: correspond to each axis scaling. - - Returns: - Affine matrix of shape (n+1, n+1). - """ - scales = jnp.concatenate([scales, jnp.array([1.0])]) - return jnp.diag(scales) - - -def get_affine_matrix( - radians: jnp.ndarray, - shifts: jnp.ndarray, - scales: jnp.ndarray, -) -> jnp.ndarray: - """Return an affine matrix from parameters. - - The matrix is not squared, as the last row is not needed. For rotation, - translation, and scaling matrix, they are kept for composition purpose. - For example, the 3D affine transformation applies as following: - [x, = [[* * * *] * [x, - y, [* * * *] y, - z, [* * * *] z, - 1] [0 0 0 1]] 1] - - Args: - radians: correspond to rotate around each axis. - shifts: correspond to each axis shift. - scales: correspond to each axis scaling. - - Returns: - Affine matrix of shape (n+1, n+1). - """ - affine_rot = get_rotation_matrix(radians) - affine_shift = get_translation_matrix(shifts) - affine_scale = get_scaling_matrix(scales) - return jnp.matmul(affine_shift, jnp.matmul(affine_scale, affine_rot)) - - -def batch_get_random_affine_matrix( - key: jax.random.PRNGKeyArray, - max_rotation: jnp.ndarray, - min_translation: jnp.ndarray, - max_translation: jnp.ndarray, - max_scaling: jnp.ndarray, -) -> jnp.ndarray: - """Get a batch of random affine matrices. - - Args: - key: jax random key. - max_rotation: maximum rotation in radians. - min_translation: minimum translation in pixel/voxels. - max_translation: maximum translation in pixel/voxels. - max_scaling: maximum scaling difference in pixel/voxels. - - Returns: - Affine matrix of shape (batch, n+1, n+1). - """ - key_radian, key_shift, key_scale = jax.random.split(key, num=3) - radians = jax.random.uniform( - key=key_radian, - shape=max_rotation.shape, - minval=-max_rotation, - maxval=max_rotation, - ) - shifts = jax.random.uniform( - key=key_shift, - shape=max_translation.shape, - minval=min_translation, - maxval=max_translation, - ) - scales = jax.random.uniform( - key=key_scale, - shape=max_scaling.shape, - minval=1.0 - max_scaling, - maxval=1.0 + max_scaling, - ) - # vmap on first axis, which is a batch - return jax.vmap(get_affine_matrix)(radians, shifts, scales) - - -def apply_affine_to_grid(grid: jnp.ndarray, affine_matrix: jnp.ndarray) -> jnp.ndarray: - """Apply affine matrix to grid. - - The grid has non-negative coordinates, means the origin is at a corner. - Need to shift the grid such that the origin is at center, - then apply affine, then shift the origin back. - - Args: - grid: grid coordinates, of shape (n, d1, ..., dn). - grid[:, i1, ..., in] = [i1, ..., in] - affine_matrix: shape (n+1, n+1) - - Returns: - Grid with updated coordinates. - """ - # (n+1, d1, ..., dn) - extended_grid = jnp.concatenate([grid, jnp.ones((1,) + grid.shape[1:])], axis=0) - - # shift to center - shift = (jnp.array(grid.shape[1:]) - 1) / 2 - shift_matrix = get_translation_matrix(-shift) # (n+1, n+1) - # (n+1, n+1) * (n+1, d1, ..., dn) = (n+1, d1, ..., dn) - extended_grid = jnp.einsum("ji,i...->j...", shift_matrix, extended_grid) - - # affine - # (n+1, n+1) * (n+1, d1, ..., dn) = (n+1, d1, ..., dn) - extended_grid = jnp.einsum("ji,i...->j...", affine_matrix, extended_grid) - - # shift to corner - shift_matrix = get_translation_matrix(shift)[:-1, :] # (n, n+1) - # (n, n+1) * (n+1, d1, ..., dn) = (n, d1, ..., dn) - extended_grid = jnp.einsum("ji,i...->j...", shift_matrix, extended_grid) - - return extended_grid - - -def batch_apply_affine_to_grid(grid: jnp.ndarray, affine_matrix: jnp.ndarray) -> jnp.ndarray: - """Apply batch of affine matrix to grid. - - Args: - grid: grid coordinates, of shape (n, d1, ..., dn). - grid[:, i1, ..., in] = [i1, ..., in] - affine_matrix: shape (batch, n+1, n+1). - - Returns: - Grid with updated coordinates, shape (batch, n, d1, ..., dn). - """ - return jax.vmap(apply_affine_to_grid, in_axes=(None, 0))(grid, affine_matrix) - - -def batch_resample_image_label( - batch: dict[str, jnp.ndarray], - grid: jnp.ndarray, -) -> dict[str, jnp.ndarray]: - """Apply batch of affine matrix to image and label. - - Args: - batch: dict having image and label. - image may have shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c) - label has shape (batch, d1, ..., dn) - grid: grid coordinates, of shape (batch, n, d1, ..., dn). - - Returns: - Updated image and label, of same shape. - """ - image = batch[IMAGE] - label = batch[LABEL] - - # check shapes - if image.ndim not in [label.ndim, label.ndim + 1]: - raise ValueError( - f"image and label must have same ndim or ndim+1, " - f"got {image.ndim} and {label.ndim} " - f"for image and label, correspondingly." - ) - - # vmap on batch axis - resample_image_vmap = jax.vmap( - partial( - map_coordinates, - order=1, - mode="constant", - cval=0.0, - ), - in_axes=(0, 0), - ) - if image.ndim == label.ndim + 1: - # vmap on channel axis - ch_axis = image.ndim - 1 - resample_image_vmap = jax.vmap( - resample_image_vmap, - in_axes=(ch_axis, None), - out_axes=ch_axis, - ) - # vmap on batch axis - resample_label_vmap = jax.vmap( - partial( - map_coordinates, - order=0, - mode="constant", - cval=0.0, - ), - in_axes=(0, 0), - ) - image = resample_image_vmap(image, grid) - label = resample_label_vmap(label, grid) - return {IMAGE: image, LABEL: label} - - -def batch_random_affine_transform( - key: jax.random.PRNGKeyArray, - batch: dict[str, jnp.ndarray], - grid: jnp.ndarray, - max_rotation: jnp.ndarray, - max_translation: jnp.ndarray, - max_scaling: jnp.ndarray, -) -> dict[str, jnp.ndarray]: - """Keep image and label only. - - Args: - key: jax random key. - batch: dict having image and label. - image may have shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c) - label has shape (batch, d1, ..., dn) - grid: grid coordinates, of shape (n, d1, ..., dn). - grid[:, i1, ..., in] = [i1, ..., in] - max_rotation: maximum rotation in radians, shape = (batch, ...). - max_translation: maximum translation in pixel/voxels, - shape = (batch, d1, ..., dn). - max_scaling: maximum scaling difference in pixel/voxels, - shape = (batch, d1, ..., dn). - - Returns: - Augmented dict having image and label, shapes are not changed. - """ - # (batch, ...) - batch_size = batch[IMAGE].shape[0] - max_rotation = jnp.tile(max_rotation[None, ...], (batch_size, 1)) - max_translation = jnp.tile(max_translation[None, ...], (batch_size, 1)) - min_translation = -max_translation - max_scaling = jnp.tile(max_scaling[None, ...], (batch_size, 1)) - - # refine translation to avoid remove classes - shape = jnp.array(batch[LABEL].shape[1:]) - shape = jnp.tile(shape[None, ...], (batch_size, 1)) - max_translation = jnp.minimum(max_translation, shape - 1 - batch[FOREGROUND_RANGE][..., -1]) - min_translation = jnp.maximum(min_translation, -batch[FOREGROUND_RANGE][..., 0]) - # (batch, n+1, n+1) - affine_matrix = batch_get_random_affine_matrix( - key=key, - max_rotation=max_rotation, - min_translation=min_translation, - max_translation=max_translation, - max_scaling=max_scaling, - ) - - # (batch, n, d1, ..., dn) - grid = batch_apply_affine_to_grid(grid=grid, affine_matrix=affine_matrix) - - return batch_resample_image_label( - batch=batch, - grid=grid, - ) - - -def build_aug_fn_from_fns( - fns: Sequence[Callable], -) -> Callable: +def chain_aug_fns( + fns: Sequence[AugmentationFn], +) -> AugmentationFn: """Combine a list of data augmentation functions. Args: @@ -419,55 +21,10 @@ def build_aug_fn_from_fns( A data augmentation function. """ - def aug_fn( - key: jax.random.PRNGKeyArray, batch: dict[str, jnp.ndarray] - ) -> dict[str, jnp.ndarray]: + def aug_fn(key: jax.Array, batch: dict[str, jnp.ndarray]) -> dict[str, jnp.ndarray]: keys = jax.random.split(key, num=len(fns)) for k, fn in zip(keys, fns): batch = fn(k, batch) return batch return aug_fn - - -def build_aug_fn_from_config( - config: DictConfig, -) -> Callable: - """Return a data augmentation function. - - Args: - config: entire config. - - Returns: - A data augmentation function. - """ - data_config = config.data - dataset_name = data_config["name"] - dataset_info = INFO_MAP[dataset_name] - image_shape = dataset_info.image_spatial_shape - da_config = data_config.loader.data_augmentation - patch_shape = data_config.loader.patch_shape - - grid = get_coordinate_grid(shape=image_shape) - max_rotation = np.array(da_config["max_rotation"]) - max_translation = np.array(da_config["max_translation"]) - max_scaling = np.array(da_config["max_scaling"]) - - aug_fns = [ - partial( - batch_random_affine_transform, - grid=grid, - max_rotation=max_rotation, - max_translation=max_translation, - max_scaling=max_scaling, - ), - partial( - batch_patch_random_sample, - image_shape=image_shape, - patch_shape=patch_shape, - ), - ] - - if len(aug_fns) == 1: - return aug_fns[0] - return build_aug_fn_from_fns(aug_fns) diff --git a/imgx/data/iterator.py b/imgx/data/iterator.py index 39eed39..ad6ff67 100644 --- a/imgx/data/iterator.py +++ b/imgx/data/iterator.py @@ -8,20 +8,23 @@ import jax import jax.numpy as jnp import jax.scipy +import numpy as np import tensorflow as tf +import tensorflow.experimental.numpy as tnp import tensorflow_datasets as tfds from absl import logging from omegaconf import DictConfig from imgx.data.util import get_foreground_range, maybe_pad_batch, tf_to_numpy from imgx.device import shard -from imgx.optim import get_half_precision_dtype +from imgx.train_state import get_half_precision_dtype from imgx_datasets.constant import ( FOREGROUND_RANGE, IMAGE, LABEL, TEST_SPLIT, TRAIN_SPLIT, + UID, VALID_SPLIT, ) @@ -39,37 +42,49 @@ ) -def create_image_label_dict_from_dict( - x: dict[str, tf.Tensor], +def remove_uid_from_dict( + batch: dict[str, tf.Tensor], ) -> dict[str, tf.Tensor]: """Create a dict from inputs. Args: - x: dict having image, label, and other tensors. + batch: dict potentially having uid. Returns: - Dict having image and label. + Dict not having uid. """ - return { - IMAGE: x[IMAGE], - LABEL: x[LABEL], - } + return {k: v for k, v in batch.items() if k != UID} -def add_foreground_range_in_dict( - x: dict[str, tf.Tensor], +def add_foreground_range( + batch: dict[str, tf.Tensor], ) -> dict[str, tf.Tensor]: - """Add FOREGROUND_RANGE in input dict. + """Add FOREGROUND_RANGE in input dict if there are labels. Args: - x: dict having some attributes. + batch: dict maybe having label. Returns: - Dict having FOREGROUND_RANGE. + Dict having FOREGROUND_RANGE of shape (ndim, 2). """ + foreground_ranges = [] + for k, v in batch.items(): + if LABEL in k: + # (ndim, 2) + foreground_ranges.append(get_foreground_range(v)) + if len(foreground_ranges) == 0: + # no labels + return batch + # (num_labels, ndim, 2) + foreground_range = tnp.stack(foreground_ranges, axis=0) + # (ndim, 2) + foreground_range = tnp.stack( + [tnp.min(foreground_range[:, :, 0], axis=0), tnp.max(foreground_range[:, :, 1], axis=0)], + axis=-1, + ) return { - FOREGROUND_RANGE: get_foreground_range(x[LABEL]), - **x, + FOREGROUND_RANGE: foreground_range, + **batch, } @@ -101,47 +116,45 @@ def load_split_from_image_tfds_builder( returns -1 for training. """ is_train = split == TRAIN_SPLIT - # Prepare arguments. shuffle_buffer_size = shuffle_buffer_size or (8 * batch_size) - # Download data. + # download data builder.download_and_prepare() - # Each host is responsible for a fixed subset of data. + # each host is responsible for a fixed subset of data if is_train: split = tfds.even_splits(split, jax.process_count())[jax.process_index()] dataset = builder.as_dataset( split=split, ) - # Shrink data set if required + # shrink data set if required if max_num_samples > 0: logging.info(f"Taking first {max_num_samples} data samples for split {split}.") dataset = dataset.take(max_num_samples) - # Caching. + # caching dataset = dataset.cache() num_steps = -1 # not set for training if is_train: - # First repeat then batch. + # first repeat then batch dataset = dataset.repeat() - # Augmentation should be done after repeat for true randomness. + # augmentation should be done after repeat for true randomness # remove uid and calculate foreground range (deterministic) dataset = dataset.map( - create_image_label_dict_from_dict, + remove_uid_from_dict, num_parallel_calls=tf.data.experimental.AUTOTUNE, ) dataset = dataset.map( - add_foreground_range_in_dict, + add_foreground_range, num_parallel_calls=tf.data.experimental.AUTOTUNE, ) - # Shuffle after augmentation to avoid loading non-augmented images into - # buffer. + # shuffle after augmentation to avoid loading non-augmented images into buffer dataset = dataset.shuffle(shuffle_buffer_size, seed=shuffle_seed) dataset = dataset.batch(batch_size, drop_remainder=True) else: - # First batch then repeat. + # first batch then repeat dataset = dataset.batch(batch_size, drop_remainder=False) num_steps = tf.data.experimental.cardinality(dataset).numpy() if split == VALID_SPLIT: @@ -154,7 +167,9 @@ def load_split_from_image_tfds_builder( if dtype != jnp.float32: def cast_fn(batch: dict[str, jnp.ndarray]) -> dict[str, jnp.ndarray]: - batch[IMAGE] = tf.cast(batch[IMAGE], tf.dtypes.as_dtype(dtype)) + for k in batch: + if IMAGE in k: + batch[k] = tf.cast(batch[k], tf.dtypes.as_dtype(dtype)) return batch dataset = dataset.map(cast_fn) @@ -173,7 +188,7 @@ def get_image_iterator( shuffle_seed: int, max_num_samples: int, dtype: jnp.dtype = jnp.float32, -) -> tuple[Iterator, int]: +) -> tuple[Iterator[dict[str, np.ndarray]], int]: """Returns iterator from builder. Args: @@ -217,8 +232,6 @@ def get_image_tfds_dataset( ) -> DatasetIterator: """Returns generators for the dataset train, valid, and test sets. - TODO: not necessary to init all iterators at once. - Args: dataset_name: Data set name. config: entire config. diff --git a/imgx/data/iterator_test.py b/imgx/data/iterator_test.py index 6228164..51ed02c 100644 --- a/imgx/data/iterator_test.py +++ b/imgx/data/iterator_test.py @@ -1,5 +1,5 @@ """Test image data iterators, requires building datasets first.""" - +from __future__ import annotations import chex import jax @@ -10,9 +10,9 @@ from chex._src import fake from omegaconf import DictConfig -from imgx.data.iterator import DatasetIterator, get_image_tfds_dataset +from imgx.data.iterator import DatasetIterator, add_foreground_range, get_image_tfds_dataset from imgx_datasets import AMOS_CT, BRATS2021_MR, INFO_MAP, MALE_PELVIC_MR, MUSCLE_US -from imgx_datasets.constant import IMAGE, LABEL, UID +from imgx_datasets.constant import FOREGROUND_RANGE, IMAGE, LABEL, UID # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. @@ -228,3 +228,73 @@ def test_labels( raise ValueError( f"{err_paths} have less than {num_classes} classes including background." ) + + +@pytest.mark.parametrize( + ("batch", "expected"), + [ + ( + {LABEL: np.array([0, 1, 2, 3])}, + np.array([[1, 3]]), + ), + ( + {LABEL: np.array([1, 2, 3, 0])}, + np.array([[0, 2]]), + ), + ( + {LABEL: np.array([1, 2, 3, 4])}, + np.array([[0, 3]]), + ), + ( + {LABEL: np.array([0, 1, 2, 3, 4, 0, 0])}, + np.array([[1, 4]]), + ), + ( + {LABEL: np.array([[0, 1, 2, 3], [0, 1, 2, 3], [0, 0, 0, 0]])}, + np.array([[0, 1], [1, 3]]), + ), + ( + { + LABEL + "_1": np.array([0, 1, 2, 3]), + LABEL + "_2": np.array([1, 2, 3, 0]), + }, + np.array([[0, 3]]), + ), + ( + { + LABEL + "_1": np.array([0, 1, 2, 3]), + LABEL + "_2": np.array([0, 2, 3, 0]), + }, + np.array([[1, 3]]), + ), + ( + {}, + None, + ), + ], + ids=[ + "1d-left", + "1d-right", + "1d-none", + "1d-both", + "2d", + "1d two labels all foreground", + "1d two labels some foreground", + "no foreground", + ], +) +def test_get_foreground_range( + batch: dict[str, np.ndarray], + expected: np.ndarray | None, +) -> None: + """Test get_translation_range return values. + + Args: + batch: batch may have labels. + expected: expected range, if None means there is no labels in batch. + """ + got = add_foreground_range(batch) + if expected is None: + assert FOREGROUND_RANGE not in got + else: + chex.assert_trees_all_equal(got[FOREGROUND_RANGE], expected) diff --git a/imgx/data/patch.py b/imgx/data/patch.py index 796ea02..6e3118d 100644 --- a/imgx/data/patch.py +++ b/imgx/data/patch.py @@ -1,50 +1,49 @@ """Script for image patching.""" +from __future__ import annotations + +from functools import partial + import jax import jax.numpy as jnp import numpy as np from omegaconf import DictConfig +from imgx.data import AugmentationFn +from imgx.data.util import get_batch_size from imgx_datasets import INFO_MAP from imgx_datasets.constant import IMAGE, LABEL def batch_patch_random_sample( - key: jax.random.PRNGKeyArray, + key: jax.Array, batch: dict[str, jnp.ndarray], - image_shape: jnp.ndarray, - patch_shape: jnp.ndarray, + image_shape: tuple[int, ...], + patch_shape: tuple[int, ...], ) -> dict[str, jnp.ndarray]: """Randomly crop patch from image and label. The crop per sample in the batch is different. Image and label have no channel dimension. + The crop in each sample is the same. + Args: key: jax random key. - batch: dict having image and label. - image may have shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c) - label has shape (batch, d1, ..., dn) + batch: dict having images or labels, and foreground_range. + images have shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c) + labels have shape (batch, d1, ..., dn) + batch should not have other keys such as UID. image_shape: image spatial shape, (d1, ..., dn). - patch_shape: patch size, shape = (p1, ..., pn), + patch_shape: patch size shape, (p1, ..., pn), patch_shape should <= image_shape for all dimensions. Returns: Augmented dict having image and label. image and label all have shapes (batch, p1, ..., pn). """ - image = batch[IMAGE] - label = batch[LABEL] - - # check shapes - if image.ndim not in [label.ndim, label.ndim + 1]: - raise ValueError( - f"image and label must have same ndim or ndim+1, " - f"got {image.ndim} and {label.ndim} " - f"for image and label, correspondingly." - ) + batch_size = get_batch_size(batch) # define sample range - batch_size = image.shape[0] indice_range = jnp.array(image_shape) - jnp.array(patch_shape) # sample a corner for each sample in the batch @@ -70,34 +69,44 @@ def slice_per_sample(x: jnp.ndarray, start_indices_i: jnp.ndarray) -> jnp.ndarra return jax.lax.dynamic_slice(x, start_indices_i, patch_shape) # crop patch - # vmap on batch axis - slice_image_vmap = jax.vmap( - slice_per_sample, - in_axes=(0, 0), - ) - if image.ndim == label.ndim + 1: - # vmap on channel axis - ch_axis = image.ndim - 1 - slice_image_vmap = jax.vmap( - slice_image_vmap, - in_axes=(ch_axis, None), - out_axes=ch_axis, - ) - # (batch, p1, ..., pn) or (batch, p1, ..., pn, c) - image = slice_image_vmap(image, start_indices) - # vmap on batch axis - # (batch, p1, ..., pn) - label = jax.vmap( - slice_per_sample, - in_axes=(0, 0), - )(label, start_indices) - return {IMAGE: image, LABEL: label} + cropped_batch = {} + for k, v in batch.items(): + if LABEL in k: + # assume label related keys have label in name + # vmap on batch axis + # (batch, p1, ..., pn) + if v.ndim != len(image_shape) + 1: + raise ValueError(f"Label {k} has wrong ndim {v.ndim}.") + cropped_batch[k] = jax.vmap( + slice_per_sample, + in_axes=(0, 0), + )(v, start_indices) + elif IMAGE in k: + # assume image related keys have image in name + if v.ndim not in [len(image_shape) + 1, len(image_shape) + 2]: + raise ValueError(f"Image {k} has wrong ndim {v.ndim}.") + # vmap on batch axis + slice_image_vmap = jax.vmap( + slice_per_sample, + in_axes=(0, 0), + ) + if v.ndim == len(image_shape) + 2: + # vmap on channel axis + ch_axis = len(image_shape) + 1 + slice_image_vmap = jax.vmap( + slice_image_vmap, + in_axes=(ch_axis, None), + out_axes=ch_axis, + ) + # (batch, p1, ..., pn) or (batch, p1, ..., pn, c) + cropped_batch[k] = slice_image_vmap(v, start_indices) + return cropped_batch def get_patch_grid( - image_shape: tuple, - patch_shape: tuple, - patch_overlap: tuple, + image_shape: tuple[int, ...], + patch_shape: tuple[int, ...], + patch_overlap: tuple[int, ...], ) -> np.ndarray: """Get start_indices per patch following a grid. @@ -135,7 +144,7 @@ def get_patch_grid( def batch_patch_grid_sample( x: jnp.ndarray, start_indices: np.ndarray, - patch_shape: tuple, + patch_shape: tuple[int, ...], ) -> jnp.ndarray: """Extract patch following a grid. @@ -149,6 +158,8 @@ def batch_patch_grid_sample( Patched, has shapes (batch, num_patches, p1, ..., pn) or (batch, num_patches, p1, ..., pn, c). """ + if x.ndim not in [len(patch_shape) + 1, len(patch_shape) + 2]: + raise ValueError(f"Image has wrong ndim {x.ndim}.") def slice_per_sample( x: jnp.ndarray, @@ -235,7 +246,7 @@ def add_patch_with_channel( def batch_patch_grid_mean_aggregate( x_patch: jnp.ndarray, start_indices: np.ndarray, - image_shape: tuple, + image_shape: tuple[int, ...], ) -> jnp.ndarray: """Aggregate patches by average on overlapping area following a grid. @@ -314,3 +325,27 @@ def get_patch_shape_grid_from_config( patch_overlap=patch_overlap, ) return patch_shape, patch_start_indices + + +def get_random_patch_fn( + config: DictConfig, +) -> AugmentationFn: + """Return a data augmentation function for patching. + + Args: + config: entire config. + + Returns: + A data augmentation function. + """ + dataset_info = INFO_MAP[config.data.name] + image_shape = dataset_info.image_spatial_shape + patch_shape = config.data.loader.patch_shape + if all(i <= j for i, j in zip(image_shape, patch_shape)): + # no need to patch + return lambda _, batch: batch + return partial( + batch_patch_random_sample, + image_shape=image_shape, + patch_shape=patch_shape, + ) diff --git a/imgx/data/patch_test.py b/imgx/data/patch_test.py index 68719ff..b39096a 100644 --- a/imgx/data/patch_test.py +++ b/imgx/data/patch_test.py @@ -48,8 +48,8 @@ class TestBatchPatchRandomSample(chex.TestCase): ) def test_shapes( self, - patch_shape: tuple, - image_shape: tuple, + patch_shape: tuple[int, ...], + image_shape: tuple[int, ...], num_channels: int, ) -> None: """Test random cropped patch shapes. @@ -194,9 +194,9 @@ class TestGetPatchGrid(chex.TestCase): ) def test_values( self, - patch_shape: tuple, - image_shape: tuple, - patch_overlap: tuple, + patch_shape: tuple[int, ...], + image_shape: tuple[int, ...], + patch_overlap: tuple[int, ...], expected: jnp.ndarray, ) -> None: """Test get_patch_grid return values. @@ -242,9 +242,9 @@ class TestBatchPatchGridSample(chex.TestCase): ) def test_shapes( self, - patch_shape: tuple, - image_shape: tuple, - patch_overlap: tuple, + patch_shape: tuple[int, ...], + image_shape: tuple[int, ...], + patch_overlap: tuple[int, ...], num_patches: int, num_channels: int, ) -> None: @@ -335,9 +335,9 @@ class TestAddPatchWithChannel(chex.TestCase): ) def test_values( self, - patch_shape: tuple, - image_shape: tuple, - start_indices: tuple, + patch_shape: tuple[int, ...], + image_shape: tuple[int, ...], + start_indices: tuple[int, ...], num_channels: int, ) -> None: """Test add_patch_with_channel shapes. @@ -422,9 +422,9 @@ class TestBatchPatchGridMeanAggregate(chex.TestCase): ) def test_shapes( self, - patch_shape: tuple, - image_shape: tuple, - patch_overlap: tuple, + patch_shape: tuple[int, ...], + image_shape: tuple[int, ...], + patch_overlap: tuple[int, ...], num_channels: int, ) -> None: """Test batch_patch_grid_mean_aggregate shapes. diff --git a/imgx/data/util.py b/imgx/data/util.py index b692be3..00b53d4 100644 --- a/imgx/data/util.py +++ b/imgx/data/util.py @@ -8,7 +8,27 @@ import tensorflow as tf import tensorflow.experimental.numpy as tnp -from imgx_datasets.constant import IMAGE +from imgx_datasets.constant import IMAGE, LABEL + + +def get_batch_size( + batch: dict[str, jnp.ndarray], +) -> int: + """Get batch size from a batch. + + Args: + batch: dict having images or labels. + + Returns: + Batch size. + """ + for k in batch: + if (LABEL in k) or (IMAGE in k): + # assume label related keys have label in name + return batch[k].shape[0] + raise ValueError( + f"No label or image in batch to get batch size, batch contains {batch.keys()}." + ) def maybe_pad_batch( @@ -56,8 +76,8 @@ def maybe_pad_batch( Raises: ValueError: if configs are conflicting. """ - sample_tensor = batch[IMAGE] - batch_pad = batch_size - sample_tensor.shape[batch_dim] + curr_batch_size = get_batch_size(batch) + batch_pad = batch_size - curr_batch_size if is_train and batch_pad != 0: raise ValueError( @@ -72,7 +92,7 @@ def maybe_pad_batch( def zero_pad(array: np.ndarray) -> np.ndarray: pad_with = [(0, 0)] * batch_dim + [(0, batch_pad)] + [(0, 0)] * (array.ndim - batch_dim - 1) - return np.pad(array, pad_with, mode="constant") + return np.pad(array, pad_with) padded_batch = jax.tree_map(zero_pad, batch) return padded_batch @@ -100,7 +120,7 @@ def _unpad_array(x: jnp.ndarray) -> jnp.ndarray: return jax.tree_map(_unpad_array, pytree) -def tf_to_numpy(batch: dict) -> np.ndarray: +def tf_to_numpy(batch: dict[str, tf.Tensor]) -> np.ndarray: """Convert an input batch from tf Tensors to numpy arrays. Args: diff --git a/imgx/data/warp.py b/imgx/data/warp.py new file mode 100644 index 0000000..a8a1310 --- /dev/null +++ b/imgx/data/warp.py @@ -0,0 +1,96 @@ +"""Module for image/lavel warping.""" +from __future__ import annotations + +from functools import partial + +import jax +import jax.numpy as jnp +from jax._src.scipy.ndimage import map_coordinates + + +def get_coordinate_grid(shape: tuple[int, ...]) -> jnp.ndarray: + """Generate a grid with given shape. + + This function is not jittable as the output depends on the value of shapes. + + Args: + shape: shape of the grid, (d1, ..., dn). + + Returns: + grid: grid coordinates, of shape (n, d1, ..., dn). + grid[:, i1, ..., in] = [i1, ..., in] + """ + return jnp.stack( + jnp.meshgrid( + *(jnp.arange(d) for d in shape), + indexing="ij", + ), + axis=0, + dtype=jnp.float32, + ) + + +def batch_grid_sample( + x: jnp.ndarray, + grid: jnp.ndarray, + order: int, + constant_values: float = 0.0, +) -> jnp.ndarray: + """Apply sampling to input. + + https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + + Args: + x: shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c). + grid: grid coordinates, of shape (batch, n, d1, ..., dn). + order: interpolation order, 0 for nearest, 1 for linear. + constant_values: constant value for out of bound coordinates. + + Returns: + Same shape as x. + """ + if x.ndim not in [grid.ndim - 1, grid.ndim]: + raise ValueError(f"Input x has shape {x.shape}, grid has shape {grid.shape}.") + + # vmap on batch axis + sample_vmap = jax.vmap( + partial( + map_coordinates, + order=order, + mode="constant", + cval=constant_values, + ), + in_axes=(0, 0), + ) + if x.ndim == grid.ndim: + # vmap on channel axis + ch_axis = x.ndim - 1 + sample_vmap = jax.vmap( + sample_vmap, + in_axes=(ch_axis, None), + out_axes=ch_axis, + ) + return sample_vmap(x, grid) + + +def warp_image( + x: jnp.ndarray, + ddf: jnp.ndarray, + order: int, +) -> jnp.ndarray: + """Warp the image with the deformation field. + + TODO: grid is a constant, can be precomputed. + + Args: + x: shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c). + ddf: deformation field, of shape (batch, d1, ..., dn, n). + order: interpolation order, 0 for nearest, 1 for linear. + + Returns: + warped image, of shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c). + """ + # (batch, n, d1, ..., dn) + grid = get_coordinate_grid(shape=ddf.shape[1:-1]) + grid += jnp.moveaxis(ddf, -1, 1) + return batch_grid_sample(x, grid, order=order) diff --git a/imgx/data/warp_test.py b/imgx/data/warp_test.py new file mode 100644 index 0000000..d92b4c0 --- /dev/null +++ b/imgx/data/warp_test.py @@ -0,0 +1,182 @@ +"""Test warp and grid functions.""" + +from __future__ import annotations + +from functools import partial + +import chex +import numpy as np +from absl.testing import parameterized +from chex._src import fake + +from imgx.data.warp import batch_grid_sample, get_coordinate_grid + + +# Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. +def setUpModule() -> None: # pylint: disable=invalid-name + """Fake multi-devices.""" + fake.set_n_cpu_devices(2) + + +class TestGrid(chex.TestCase): + """Test get_coordinate_grid.""" + + @chex.variants(without_jit=True) + @parameterized.named_parameters( + ( + "1d", + (2,), + np.asarray([[0.0, 1.0]]), + ), + ( + "2d", + (3, 2), + np.asarray( + [ + [ + [0.0, 0.0], + [1.0, 1.0], + [2.0, 2.0], + ], + [ + [0.0, 1.0], + [0.0, 1.0], + [0.0, 1.0], + ], + ], + ), + ), + ) + def test_values(self, shape: tuple[int, ...], expected: np.ndarray) -> None: + """Test exact values. + + Args: + shape: shape of the grid, (d1, ..., dn). + expected: expected coordinates. + """ + got = self.variant(get_coordinate_grid)( + shape=shape, + ) + chex.assert_trees_all_equal(got, expected) + + +class TestResample(chex.TestCase): + """Test apply_affine_to_grid.""" + + @chex.all_variants() + @parameterized.product( + ( + { + "image": np.asarray( + [ + [ + [2.0, 1.0, 0.0], + [0.0, 3.0, 4.0], + ], + [ + [2.0, 1.0, 0.0], + [0.0, 3.0, 4.0], + ], + ], + ), + "label": np.asarray( + [ + [ + [2.0, 1.0, 0.0], + [0.0, 3.0, 4.0], + ], + [ + [2.0, 1.0, 0.0], + [0.0, 3.0, 4.0], + ], + ], + ), + "grid": np.asarray( + [ + # first image, un changed + [ + # x axis + [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], + # y axis + [[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]], + ], + # second image, changed + # (0.4, 0) x-axis linear interpolation + # (0, 0.6) y-axis linear interpolation + # (0.4, 1.6) x/y-axis linear interpolation + # (1.0, 3.0) out of boundary + [ + # x axis + [[0.4, 0.0, 0.4], [1.0, 1.0, 1.0]], + # y axis + [[0.0, 0.6, 1.6], [0.0, 3.0, 2.0]], + ], + ] + ), # (batch=2, n=2, d1=2, d2=3) + "expected_image": np.asarray( + [ + [ + [2.0, 1.0, 0.0], + [0.0, 3.0, 4.0], + ], + [ + [1.2, 1.4, 1.68], + [0.0, 0.0, 4.0], + ], + ], + ), + "expected_label": np.asarray( + [ + [ + [2.0, 1.0, 0.0], + [0.0, 3.0, 4.0], + ], + [ + [2.0, 1.0, 0.0], + [0.0, 0.0, 4.0], + ], + ], + ), + }, + ), + num_channels=[0, 1, 2], + ) + def test_shapes( + self, + image: np.ndarray, + label: np.ndarray, + grid: np.ndarray, + expected_image: np.ndarray, + expected_label: np.ndarray, + num_channels: int, + ) -> None: + """Test affine matrix values. + + Test affine matrix shapes, and test random seed impact. + + Args: + image: input image batch. + label: input label batch. + grid: batch of grid with affine applied. + expected_image: expected image. + expected_label: expected label. + num_channels: number of channels to add to image. + """ + if num_channels == 1: + image = image[..., None] + expected_image = expected_image[..., None] + elif num_channels > 1: + reps = (1,) * (len(image.shape) - 1) + (num_channels,) + image = np.tile(image[..., None], reps) + expected_image = np.tile(expected_image[..., None], reps) + + got_image = self.variant(partial(batch_grid_sample, order=1))( + x=image, + grid=grid, + ) + chex.assert_trees_all_close(got_image, expected_image) + got_label = self.variant(partial(batch_grid_sample, order=0))( + x=label, + grid=grid, + ) + chex.assert_trees_all_close(got_label, expected_label) diff --git a/imgx/device.py b/imgx/device.py index 598a00b..fc203e3 100644 --- a/imgx/device.py +++ b/imgx/device.py @@ -85,11 +85,12 @@ def _shard_array(array: jnp.ndarray) -> jnp.ndarray: return jax.tree_map(_shard_array, pytree) -def unshard(pytree: chex.ArrayTree) -> chex.ArrayTree: +def unshard(pytree: chex.ArrayTree, device: jax.Device) -> chex.ArrayTree: """Reshapes arrays from [ndev, bs, ...] to [host_bs, ...]. Args: pytree: A pytree of arrays to be sharded. + device: device to put. Returns: Sharded data. @@ -99,4 +100,5 @@ def _unshard_array(array: jnp.ndarray) -> jnp.ndarray: ndev, bs = array.shape[:2] return array.reshape((ndev * bs,) + array.shape[2:]) + pytree = jax.device_put(pytree, device) return jax.tree_map(_unshard_array, pytree) diff --git a/imgx/diffusion/diffusion.py b/imgx/diffusion/diffusion.py index 56a6c88..b3494cc 100644 --- a/imgx/diffusion/diffusion.py +++ b/imgx/diffusion/diffusion.py @@ -16,9 +16,7 @@ class Diffusion: num_timesteps: int noise_fn: Callable[..., jnp.ndarray] - def sample_noise( - self, key: jax.random.KeyArray, shape: Sequence[int], dtype: jnp.dtype - ) -> jnp.ndarray: + def sample_noise(self, key: jax.Array, shape: Sequence[int], dtype: jnp.dtype) -> jnp.ndarray: """Return a noise of the same shape as input. Define this function to avoid defining randon key. @@ -121,7 +119,7 @@ def variational_lower_bound( def sample( self, - key: jax.random.KeyArray, + key: jax.Array, model_out: jnp.ndarray, x_t: jnp.ndarray, t_index: jnp.ndarray, diff --git a/imgx/diffusion/gaussian/sampler.py b/imgx/diffusion/gaussian/sampler.py index a1b0a17..5253a5f 100644 --- a/imgx/diffusion/gaussian/sampler.py +++ b/imgx/diffusion/gaussian/sampler.py @@ -13,7 +13,7 @@ class DDPMSampler(GaussianDiffusion): def sample( self, - key: jax.random.KeyArray, + key: jax.Array, model_out: jnp.ndarray, x_t: jnp.ndarray, t_index: jnp.ndarray, @@ -58,7 +58,7 @@ class DDIMSampler(GaussianDiffusion): def sample( self, - key: jax.random.KeyArray, + key: jax.Array, model_out: jnp.ndarray, x_t: jnp.ndarray, t_index: jnp.ndarray, diff --git a/imgx/diffusion/time_sampler.py b/imgx/diffusion/time_sampler.py index a5d476b..932a9c1 100644 --- a/imgx/diffusion/time_sampler.py +++ b/imgx/diffusion/time_sampler.py @@ -89,7 +89,7 @@ def t_index_to_t(self, t_index: jnp.ndarray) -> jnp.ndarray: def sample_uniformly( self, - key: jax.random.KeyArray, + key: jax.Array, t_index_minval: jnp.ndarray, t_index_maxval: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: @@ -153,7 +153,7 @@ def t_probs_from_loss_count( def sample_with_importance( self, - key: jax.random.KeyArray, + key: jax.Array, t_index_minval: jnp.ndarray, t_index_maxval: jnp.ndarray, probs: jnp.ndarray, @@ -207,7 +207,7 @@ def sample_with_importance( def sample( self, - key: jax.random.KeyArray, + key: jax.Array, batch_size: int, t_index_min: int, t_index_max: int, diff --git a/imgx/experiment.py b/imgx/experiment.py index 58a6707..1acff15 100644 --- a/imgx/experiment.py +++ b/imgx/experiment.py @@ -5,12 +5,15 @@ import chex import jax +import jax.numpy as jnp import tensorflow as tf +from absl import logging from flax import jax_utils from omegaconf import DictConfig from imgx.data.iterator import get_image_tfds_dataset -from imgx.segmentation.train_state import TrainState +from imgx.metric.util import aggregate_pmap_metrics +from imgx.train_state import TrainState from imgx_datasets import INFO_MAP @@ -24,8 +27,15 @@ def __init__(self, config: DictConfig) -> None: config: experiment config. """ # Do not use accelerators in data pipeline. - tf.config.experimental.set_visible_devices([], device_type="GPU") - tf.config.experimental.set_visible_devices([], device_type="TPU") + try: + tf.config.set_visible_devices([], device_type="GPU") + tf.config.set_visible_devices([], device_type="TPU") + except RuntimeError: + logging.error( + f"Failed to set visible devices, data set may be using GPU/TPUs. " + f"Visible GPU devices: {tf.config.get_visible_devices('GPU')}. " + f"Visible TPU devices: {tf.config.get_visible_devices('TPU')}." + ) # save config self.config = config @@ -45,6 +55,9 @@ def __init__(self, config: DictConfig) -> None: self.valid_iter = jax_utils.prefetch_to_device(self.valid_iter, 2) self.test_iter = jax_utils.prefetch_to_device(self.test_iter, 2) + self.p_train_step = None # To be defined in train_init + self.p_eval_step = None # To be defined in train_init + def train_init( self, ckpt_dir: Path | None = None, step: int | None = None ) -> tuple[TrainState, int]: @@ -60,8 +73,8 @@ def train_init( raise NotImplementedError def train_step( - self, train_state: TrainState, key: jax.random.PRNGKeyArray - ) -> tuple[TrainState, jax.random.PRNGKeyArray, chex.ArrayTree]: + self, train_state: TrainState, key: jax.Array + ) -> tuple[TrainState, jax.Array, chex.ArrayTree]: """Perform a training step. Args: @@ -73,16 +86,24 @@ def train_step( - new random key. - metric dict. """ - raise NotImplementedError + batch = next(self.train_iter) + train_state, key, metrics = self.p_train_step( # pylint: disable=not-callable + train_state, + batch, + key, + ) + metrics = aggregate_pmap_metrics(metrics) + metrics = jax.tree_map(lambda x: x.item(), metrics) # tensor to values + return train_state, key, metrics def eval_step( self, train_state: TrainState, - key: jax.random.PRNGKeyArray, + key: jax.Array, split: str, out_dir: Path | None = None, - ) -> tuple[jax.random.PRNGKeyArray, chex.ArrayTree]: - """Evaluation on entire validation data set. + ) -> tuple[jax.Array, chex.ArrayTree]: + """Evaluation on entire validation/test data set. Args: train_state: training state. @@ -95,3 +116,33 @@ def eval_step( metric dict. """ raise NotImplementedError + + def eval_batch( + self, + train_state: TrainState, + key: jax.Array, + batch: dict[str, jnp.ndarray], + uids: list[str], + device_cpu: jax.Device, + out_dir: Path | None, + reference_suffix: str = "mask_preprocessed", + output_suffix: str = "mask_pred", + ) -> tuple[dict[str, jnp.ndarray], jnp.ndarray, jax.Array]: + """Evaluate a batch. + + Args: + train_state: training state. + key: random key. + batch: batch data without uid. + uids: uids in the batch. + device_cpu: cpu device. + out_dir: output directory, if not None, predictions will be saved. + reference_suffix: suffix of reference image. + output_suffix: suffix of output image. + + Returns: + metrics, each item has shape (num_shards*batch,). + label_pred: predicted label, of shape (num_shards*batch, *spatial_shape). + key: random key. + """ + raise NotImplementedError diff --git a/imgx/integration_test.py b/imgx/integration_test.py new file mode 100644 index 0000000..0f39105 --- /dev/null +++ b/imgx/integration_test.py @@ -0,0 +1,82 @@ +"""Test experiments train, valid, and test. + +mocker.patch, https://pytest-mock.readthedocs.io/en/latest/ +""" +import shutil +from tempfile import TemporaryDirectory + +import pytest +from pytest_mock import MockFixture + +from imgx.run_test import main as run_test +from imgx.run_train import main as run_train +from imgx.run_valid import main as run_valid + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "dataset", + ["muscle_us", "amos_ct"], +) +def test_segmentation_train_valid_test(mocker: MockFixture, dataset: str) -> None: + """Test train, valid, and test. + + Args: + mocker: mocker, a wrapper of unittest.mock. + dataset: dataset name. + """ + with TemporaryDirectory() as temp_dir: + mocker.resetall() + mocker.patch.dict("os.environ", {"WANDB_MODE": "offline"}) + mocker.patch( + "sys.argv", + ["pytest", "debug=true", "task=seg", f"data={dataset}", f"logging.root_dir={temp_dir}"], + ) + run_train() # pylint: disable=no-value-for-parameter + mocker.patch( + "sys.argv", + ["pytest", f"--log_dir={temp_dir}/wandb/latest-run"], + ) + run_valid() + run_test() + shutil.rmtree(temp_dir) + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "dataset", + ["muscle_us", "amos_ct"], +) +def test_diffusion_segmentation_train_valid_test(mocker: MockFixture, dataset: str) -> None: + """Test train, valid, and test. + + Args: + mocker: mocker, a wrapper of unittest.mock. + dataset: dataset name. + """ + with TemporaryDirectory() as temp_dir: + mocker.resetall() + mocker.patch.dict("os.environ", {"WANDB_MODE": "offline"}) + mocker.patch( + "sys.argv", + [ + "pytest", + "debug=true", + "task=gaussian_diff_seg", + f"data={dataset}", + f"logging.root_dir={temp_dir}", + ], + ) + run_train() # pylint: disable=no-value-for-parameter + mocker.patch( + "sys.argv", + [ + "pytest", + "--num_timesteps=2", + "--sampler=DDPM", + f"--log_dir={temp_dir}/wandb/latest-run", + ], + ) + run_valid() + run_test() + shutil.rmtree(temp_dir) diff --git a/imgx/loss/__init__.py b/imgx/loss/__init__.py index 4fd53b3..5952d84 100644 --- a/imgx/loss/__init__.py +++ b/imgx/loss/__init__.py @@ -1,9 +1,10 @@ """Package for loss functions.""" from imgx.loss.cross_entropy import cross_entropy, focal_loss -from imgx.loss.dice import dice_loss +from imgx.loss.dice import dice_loss, dice_loss_from_masks __all__ = [ "cross_entropy", "focal_loss", "dice_loss", + "dice_loss_from_masks", ] diff --git a/imgx/loss/cross_entropy_test.py b/imgx/loss/cross_entropy_test.py index 3feac93..6e531e6 100644 --- a/imgx/loss/cross_entropy_test.py +++ b/imgx/loss/cross_entropy_test.py @@ -1,4 +1,4 @@ -"""Test dice loss functions.""" +"""Test cross entropy loss functions.""" import chex import jax diff --git a/imgx/loss/dice.py b/imgx/loss/dice.py index d7a077b..e036627 100644 --- a/imgx/loss/dice.py +++ b/imgx/loss/dice.py @@ -3,10 +3,9 @@ import jax.numpy as jnp -def dice_loss( - logits: jnp.ndarray, +def dice_loss_from_masks( + mask_pred: jnp.ndarray, mask_true: jnp.ndarray, - classes_are_exclusive: bool, ) -> jnp.ndarray: """Mean dice loss, smaller is better. @@ -14,19 +13,12 @@ def dice_loss( This is to avoid the need of smoothing and potentially nan gradients. Args: - logits: unscaled prediction, (batch, ..., num_classes). + mask_pred: binary masks, (batch, ..., num_classes). mask_true: binary masks, (batch, ..., num_classes). - classes_are_exclusive: classes are exclusive, i.e. no overlap. Returns: Dice loss value of shape (batch, num_classes). """ - mask_pred = jax.lax.cond( - classes_are_exclusive, - jax.nn.softmax, - jax.nn.sigmoid, - logits, - ) reduce_axis = tuple(range(mask_pred.ndim))[1:-1] # (batch, num_classes) numerator = 2.0 * jnp.sum(mask_pred * mask_true, axis=reduce_axis) @@ -38,3 +30,30 @@ def dice_loss( x=1.0 - numerator / denominator, y=jnp.nan, ) + + +def dice_loss( + logits: jnp.ndarray, + mask_true: jnp.ndarray, + classes_are_exclusive: bool, +) -> jnp.ndarray: + """Mean dice loss, smaller is better. + + Losses are not calculated on instance-classes, where there is no label. + This is to avoid the need of smoothing and potentially nan gradients. + + Args: + logits: unscaled prediction, (batch, ..., num_classes). + mask_true: binary masks, (batch, ..., num_classes). + classes_are_exclusive: classes are exclusive, i.e. no overlap. + + Returns: + Dice loss value of shape (batch, num_classes). + """ + mask_pred = jax.lax.cond( + classes_are_exclusive, + jax.nn.softmax, + jax.nn.sigmoid, + logits, + ) + return dice_loss_from_masks(mask_pred, mask_true) diff --git a/imgx/loss/segmentation.py b/imgx/loss/segmentation.py new file mode 100644 index 0000000..9a3ef00 --- /dev/null +++ b/imgx/loss/segmentation.py @@ -0,0 +1,78 @@ +"""Vanilla segmentation loss.""" +from __future__ import annotations + +import jax.numpy as jnp +from omegaconf import DictConfig + +from imgx.loss import cross_entropy, dice_loss, focal_loss +from imgx.metric import class_proportion +from imgx_datasets.dataset_info import DatasetInfo + + +def segmentation_loss( + logits: jnp.ndarray, + label: jnp.ndarray, + dataset_info: DatasetInfo, + loss_config: DictConfig, +) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: + """Calculate segmentation loss with auxiliary losses and return metrics. + + Args: + logits: unnormalised logits of shape (batch, ..., num_classes). + label: label of shape (batch, ...). + dataset_info: dataset info with helper functions. + loss_config: have weights of diff losses. + + Returns: + - calculated loss, of shape (batch,). + - metrics, values of shape (batch,). + """ + mask_true = dataset_info.label_to_mask(label, axis=-1) + metrics = {} + + # (batch, num_classes) + class_prop_batch_cls = class_proportion(mask_true) + for i in range(dataset_info.num_classes): + metrics[f"class_{i}_proportion_true"] = class_prop_batch_cls[:, i] + + # total loss + loss_batch = jnp.zeros((logits.shape[0],), dtype=logits.dtype) + if loss_config.get("dice", 0.0) > 0: + # (batch, num_classes) + dice_loss_batch_cls = dice_loss( + logits=logits, + mask_true=mask_true, + classes_are_exclusive=dataset_info.classes_are_exclusive, + ) + # (batch, ) + # without background + # mask out non-existing classes + dice_loss_batch = jnp.mean( + dice_loss_batch_cls[:, 1:], axis=-1, where=class_prop_batch_cls[:, 1:] > 0 + ) + metrics["dice_loss"] = dice_loss_batch + for i in range(dice_loss_batch_cls.shape[-1]): + metrics[f"dice_loss_class_{i}"] = dice_loss_batch_cls[:, i] + loss_batch += dice_loss_batch * loss_config["dice"] + + if loss_config.get("cross_entropy", 0.0) > 0: + # (batch, ) + ce_loss_batch = cross_entropy( + logits=logits, + mask_true=mask_true, + classes_are_exclusive=dataset_info.classes_are_exclusive, + ) + metrics["cross_entropy_loss"] = ce_loss_batch + loss_batch += ce_loss_batch * loss_config["cross_entropy"] + + if loss_config.get("focal", 0.0) > 0: + # (batch, ) + focal_loss_batch = focal_loss( + logits=logits, + mask_true=mask_true, + classes_are_exclusive=dataset_info.classes_are_exclusive, + ) + metrics["focal_loss"] = focal_loss_batch + loss_batch += focal_loss_batch * loss_config["focal"] + metrics["total_loss"] = loss_batch + return loss_batch, metrics diff --git a/imgx/loss/segmentation_test.py b/imgx/loss/segmentation_test.py new file mode 100644 index 0000000..2667292 --- /dev/null +++ b/imgx/loss/segmentation_test.py @@ -0,0 +1,67 @@ +"""Test segmentation loss.""" +from functools import partial + +import chex +import jax.numpy as jnp +from absl.testing import parameterized +from chex._src import fake + +from imgx.loss.segmentation import segmentation_loss +from imgx_datasets import INFO_MAP + + +# Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. +def setUpModule() -> None: # pylint: disable=invalid-name + """Fake multi-devices.""" + fake.set_n_cpu_devices(2) + + +class TestSegmentationLoss(chex.TestCase): + """Test segmentation_loss.""" + + batch_size = 2 + + @chex.all_variants() + @parameterized.product( + dataset_name=sorted(INFO_MAP.keys()), + loss_config=[ + { + "cross_entropy": 1.0, + "dice": 1.0, + "focal": 1.0, + }, + { + "dice": 1.0, + }, + ], + ) + def test_shape( + self, + dataset_name: str, + loss_config: dict[str, float], + ) -> None: + """Test return shapes. + + Args: + dataset_name: dataset name. + loss_config: loss config. + """ + dataset_info = INFO_MAP[dataset_name] + shape = dataset_info.image_spatial_shape + shape = tuple(max(x // 16, 2) for x in shape) # reduce shape to speed up test + logits = jnp.ones( + (self.batch_size, *shape, dataset_info.num_classes), + dtype=jnp.float32, + ) + label = jnp.ones((self.batch_size, *shape), dtype=jnp.int32) + + got_loss_batch, got_metrics = self.variant( + partial( + segmentation_loss, + dataset_info=dataset_info, + loss_config=loss_config, + ) + )(logits, label) + chex.assert_shape(got_loss_batch, (self.batch_size,)) + for v in got_metrics.values(): + chex.assert_shape(v, (self.batch_size,)) diff --git a/imgx/metric/area_test.py b/imgx/metric/area_test.py index d563710..cd7b23e 100644 --- a/imgx/metric/area_test.py +++ b/imgx/metric/area_test.py @@ -5,7 +5,7 @@ from absl.testing import parameterized from chex._src import fake -from imgx.metric.area import class_proportion +from imgx.metric import class_proportion # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. diff --git a/imgx/metric/centroid.py b/imgx/metric/centroid.py index 9671d02..86a9634 100644 --- a/imgx/metric/centroid.py +++ b/imgx/metric/centroid.py @@ -4,28 +4,6 @@ import jax.numpy as jnp -def get_coordinate_grid(shape: tuple[int, ...]) -> jnp.ndarray: - """Generate a grid with given shape. - - This function is not jittable as the output depends on the value of shapes. - - Args: - shape: shape of the grid, (d1, ..., dn). - - Returns: - grid: grid coordinates, of shape (n, d1, ..., dn). - grid[:, i1, ..., in] = [i1, ..., in] - """ - return jnp.stack( - jnp.meshgrid( - *(jnp.arange(d) for d in shape), - indexing="ij", - ), - axis=0, - dtype=jnp.float32, - ) - - def get_centroid( mask: jnp.ndarray, grid: jnp.ndarray, @@ -52,7 +30,7 @@ def get_centroid( denominator = summed_mask[:, None, :] # if mask is not empty return real centroid, else nan centroid = jnp.where(condition=denominator > 0, x=numerator / denominator, y=jnp.nan) - return centroid, summed_mask == 0 + return centroid, jnp.array(summed_mask == 0, dtype=jnp.bool_) def centroid_distance( diff --git a/imgx/metric/centroid_test.py b/imgx/metric/centroid_test.py index cf09811..a44e06c 100644 --- a/imgx/metric/centroid_test.py +++ b/imgx/metric/centroid_test.py @@ -7,8 +7,9 @@ from absl.testing import parameterized from chex._src import fake +from imgx.data.warp import get_coordinate_grid from imgx.metric import centroid_distance -from imgx.metric.centroid import get_centroid, get_coordinate_grid +from imgx.metric.centroid import get_centroid # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. @@ -17,48 +18,6 @@ def setUpModule() -> None: # pylint: disable=invalid-name fake.set_n_cpu_devices(2) -class TestGrid(chex.TestCase): - """Test get_coordinate_grid.""" - - @chex.variants(without_jit=True) - @parameterized.named_parameters( - ( - "1d", - (2,), - np.asarray([[0.0, 1.0]]), - ), - ( - "2d", - (3, 2), - np.asarray( - [ - [ - [0.0, 0.0], - [1.0, 1.0], - [2.0, 2.0], - ], - [ - [0.0, 1.0], - [0.0, 1.0], - [0.0, 1.0], - ], - ], - ), - ), - ) - def test_values(self, shape: tuple[int, ...], expected: np.ndarray) -> None: - """Test exact values. - - Args: - shape: shape of the grid, (d1, ..., dn). - expected: expected coordinates. - """ - got = self.variant(get_coordinate_grid)( - shape=shape, - ) - chex.assert_trees_all_equal(got, expected) - - class TestCentroid(chex.TestCase): """Test get_coordinate_grid.""" diff --git a/imgx/metric/dice_test.py b/imgx/metric/dice_test.py index 546bb59..5210aeb 100644 --- a/imgx/metric/dice_test.py +++ b/imgx/metric/dice_test.py @@ -178,7 +178,7 @@ class TestStability(chex.TestCase): ) def test_shapes( self, - spatial_shape: tuple, + spatial_shape: tuple[int, ...], num_classes: int, ) -> None: """Test dice loss values. diff --git a/imgx/segmentation/metric.py b/imgx/metric/segmentation.py similarity index 92% rename from imgx/segmentation/metric.py rename to imgx/metric/segmentation.py index ecdeeda..f84f122 100644 --- a/imgx/segmentation/metric.py +++ b/imgx/metric/segmentation.py @@ -8,6 +8,7 @@ import jax.numpy as jnp import numpy as np +from imgx.data.warp import get_coordinate_grid from imgx.metric import ( aggregated_surface_distance, centroid_distance, @@ -17,7 +18,6 @@ normalized_surface_dice_from_distances, stability, ) -from imgx.metric.centroid import get_coordinate_grid from imgx.metric.util import flatten_diffusion_metrics from imgx_datasets import DatasetInfo @@ -196,8 +196,9 @@ def get_non_jit_segmentation_metrics_per_step( def get_segmentation_metrics( - logits: jnp.ndarray, - label: jnp.ndarray, + logits: jnp.ndarray | None, + label_pred: jnp.ndarray | None, + label_true: jnp.ndarray, dataset_info: DatasetInfo, ) -> tuple[dict[str, jnp.ndarray], jnp.ndarray]: """Calculate segmentation metrics. @@ -206,21 +207,31 @@ def get_segmentation_metrics( Args: logits: (batch, ..., num_classes). - label: (batch, ...). + label_pred: (batch, ...). + label_true: (batch, ...). dataset_info: dataset information. Returns: Metrics, each metric value has shape (batch, ). Predicted label, with potential post-processing. """ + if (logits is None) == (label_pred is None): + raise ValueError("Either logits or label_pred must be None.") + # (batch, ..., num_classes) - mask_true = dataset_info.label_to_mask(label, axis=-1) + mask_true = dataset_info.label_to_mask(label_true, axis=-1) # post process is non-jiitable and maybe time-consuming - label_pred = dataset_info.logits_to_label_with_post_process(logits, axis=-1) + if label_pred is None: + label_pred = dataset_info.logits_to_label_with_post_process(logits, axis=-1) + else: + label_pred = dataset_info.post_process_label(label_pred) mask_pred = dataset_info.label_to_mask(label_pred, axis=-1) spacing = jnp.array(dataset_info.image_spacing) - metrics_confidence_jit = jax.jit(get_jit_segmentation_confidence)(logits) + if logits is None: + metrics_confidence_jit = {} + else: + metrics_confidence_jit = jax.jit(get_jit_segmentation_confidence)(logits) metrics_jit = jax.jit(get_jit_segmentation_metrics)( mask_pred=mask_pred, mask_true=mask_true, spacing=spacing ) diff --git a/imgx/segmentation/metric_test.py b/imgx/metric/segmentation_test.py similarity index 98% rename from imgx/segmentation/metric_test.py rename to imgx/metric/segmentation_test.py index fc24e86..8d5db8b 100644 --- a/imgx/segmentation/metric_test.py +++ b/imgx/metric/segmentation_test.py @@ -6,7 +6,7 @@ import numpy as np from chex._src import fake -from imgx.segmentation.metric import ( +from imgx.metric.segmentation import ( get_jit_segmentation_metrics, get_non_jit_segmentation_metrics, get_non_jit_segmentation_metrics_per_step, diff --git a/imgx/metric/surface_distance.py b/imgx/metric/surface_distance.py index 84029ad..c44cf82 100644 --- a/imgx/metric/surface_distance.py +++ b/imgx/metric/surface_distance.py @@ -16,6 +16,9 @@ from imgx_datasets.preprocess import get_binary_mask_bounding_box +OneArgScalarFunc = Callable[[np.ndarray], float] +TwoArgsScalarFunc = Callable[[np.ndarray, np.ndarray], float] + def get_mask_edges(mask_pred: np.ndarray, mask_true: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Do binary erosion and use XOR for input to get the edges. @@ -70,7 +73,7 @@ def get_surface_distance( def _aggregated_symmetric_surface_distance( dist_pred_true: np.ndarray, dist_true_pred: np.ndarray, - f: Callable, + f: OneArgScalarFunc | TwoArgsScalarFunc, num_args: int, ) -> float: """Aggregate surface distance in a symmetric way. @@ -98,7 +101,7 @@ def _aggregated_symmetric_surface_distance( def _aggregated_surface_distance( mask_pred: np.ndarray, mask_true: np.ndarray, - agg_fn_list: list[Callable], + agg_fn_list: list[OneArgScalarFunc | TwoArgsScalarFunc], num_args_list: list[int], spacing: tuple[float, ...] | None, symmetric: bool = True, @@ -154,7 +157,7 @@ def _aggregated_surface_distance( def aggregated_surface_distance( mask_pred: np.ndarray, mask_true: np.ndarray, - agg_fns: Callable | list[Callable], + agg_fns: OneArgScalarFunc | TwoArgsScalarFunc | list[OneArgScalarFunc | TwoArgsScalarFunc], num_args: int | list[int], spacing: tuple[float, ...] | None, symmetric: bool = True, diff --git a/imgx/metric/surface_distance_test.py b/imgx/metric/surface_distance_test.py index a969cfc..e2fa778 100644 --- a/imgx/metric/surface_distance_test.py +++ b/imgx/metric/surface_distance_test.py @@ -1,4 +1,4 @@ -"""Test loss functions.""" +"""Test surface distance functions.""" from __future__ import annotations @@ -12,6 +12,8 @@ from chex._src import fake from imgx.metric.surface_distance import ( + OneArgScalarFunc, + TwoArgsScalarFunc, aggregated_surface_distance, average_surface_distance, get_mask_edges, @@ -195,7 +197,7 @@ class TestSurfaceDistance(chex.TestCase): def test_nan_distance( self, ndims: int, - func: Callable, + func: Callable[[np.ndarray, np.ndarray], np.ndarray], ) -> None: """Test average_surface_distance returns nan given empty inputs. @@ -246,7 +248,7 @@ def test_nan_distance( def test_zero_distance( self, ndims: int, - func: Callable, + func: Callable[[np.ndarray, np.ndarray], np.ndarray], ) -> None: """Test average_surface_distance returns zero given same inputs. @@ -575,7 +577,9 @@ def test_agg_surface_distance( mask_true: np.ndarray, spacing: tuple[float, ...], symmetric: bool, - agg_funcs: Callable | list[Callable], + agg_funcs: OneArgScalarFunc + | TwoArgsScalarFunc + | list[OneArgScalarFunc | TwoArgsScalarFunc], num_args: int | list[int], expected: np.ndarray, ) -> None: diff --git a/imgx/model/basic.py b/imgx/model/basic.py index 87b3e17..ae97337 100644 --- a/imgx/model/basic.py +++ b/imgx/model/basic.py @@ -1,65 +1,19 @@ """Basic functions and modules.""" from __future__ import annotations -from collections.abc import Sequence -from functools import partial from typing import Callable import flax.linen as nn import jax import jax.numpy as jnp -# flax LayerNorm differs from haiku -LayerNorm = partial( - nn.LayerNorm, - epsilon=1e-5, - use_fast_variance=False, -) +class InstanceNorm(nn.Module): + """Instance norm. -def truncated_normal( - stddev: float | jnp.ndarray = 1.0, - mean: float | jnp.ndarray = 0.0, - dtype: jnp.dtype = jnp.float_, -) -> Callable[[jax.random.PRNGKeyArray, jnp.shape, jnp.dtype], jnp.ndarray]: - """Truncated normal initializer as in haiku. - - Args: - stddev: standard deviation of the truncated normal distribution. - mean: mean of the truncated normal distribution. - dtype: dtype of the array. - - Returns: - Initializer function. + The norm is calculated on axes excluding batch and features. """ - def init( - key: jax.random.KeyArray, shape: Sequence[int], dtype: jnp.dtype = dtype - ) -> jnp.ndarray: - """Init function. - - Args: - key: random key. - shape: shape of the array. - dtype: dtype of the array. - """ - real_dtype = jnp.finfo(dtype).dtype - m = jax.lax.convert_element_type(mean, dtype) - s = jax.lax.convert_element_type(stddev, real_dtype) - is_complex = jnp.issubdtype(dtype, jnp.complexfloating) - if is_complex: - shape = [2, *shape] - unscaled = jax.random.truncated_normal(key, -2.0, 2.0, shape, real_dtype) - if is_complex: - unscaled = unscaled[0] + 1j * unscaled[1] - return s * unscaled + m - - return init - - -class InstanceNorm(nn.Module): - """Instance norm.""" - dtype: jnp.dtype = jnp.float32 @nn.compact @@ -67,13 +21,13 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: """Forward pass. Args: - x: input with batch axis, (batch, ...). + x: input with batch axis, (batch, ..., channel). Returns: Normalised input. """ reduction_axes = tuple(range(x.ndim)[slice(1, -1)]) - return LayerNorm( + return nn.LayerNorm( reduction_axes=reduction_axes, )(x) @@ -121,7 +75,7 @@ class MLP(nn.Module): output_size: int activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.gelu kernel_init: Callable[ - [jax.random.PRNGKeyArray, jnp.shape, jnp.dtype], jnp.ndarray + [jax.Array, jnp.shape, jnp.dtype], jnp.ndarray ] = nn.initializers.lecun_normal() dtype: jnp.dtype = jnp.float32 diff --git a/imgx/model/conv.py b/imgx/model/conv.py index fce77c2..91b3e96 100644 --- a/imgx/model/conv.py +++ b/imgx/model/conv.py @@ -1,10 +1,6 @@ -"""Module for convolution layers. - -The kernel initialisation follows haiku's default. -""" +"""Module for convolution layers.""" from __future__ import annotations -from functools import partial from typing import Callable import flax.linen as nn @@ -13,26 +9,6 @@ from imgx.model.basic import InstanceNorm -# flax variance_scaling normalizes the std with the constant -# stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype) -# this is different from haiku -Conv = partial( - nn.Conv, - kernel_init=nn.initializers.variance_scaling( - scale=0.87962566103423978**2, - mode="fan_in", - distribution="truncated_normal", - ), -) -ConvTranspose = partial( - nn.ConvTranspose, - kernel_init=nn.initializers.variance_scaling( - scale=0.87962566103423978**2, - mode="fan_in", - distribution="truncated_normal", - ), -) - class ConvNormAct(nn.Module): """Block with conv-norm-act.""" @@ -58,7 +34,7 @@ def __call__( """ return nn.Sequential( [ - Conv( + nn.Conv( features=self.out_channels, kernel_size=(self.kernel_size,) * self.num_spatial_dims, use_bias=False, @@ -93,7 +69,7 @@ def __call__( Array. """ res = x - x = Conv( + x = nn.Conv( features=self.out_channels, kernel_size=(self.kernel_size,) * self.num_spatial_dims, use_bias=False, @@ -101,7 +77,7 @@ def __call__( )(x) x = InstanceNorm(dtype=self.dtype)(x) x = self.activation(x) - x = Conv( + x = nn.Conv( features=self.out_channels, kernel_size=(self.kernel_size,) * self.num_spatial_dims, use_bias=False, @@ -144,7 +120,7 @@ def __call__( t_emb = nn.Dense(self.out_channels, dtype=self.dtype)(t_emb) res = x - x = Conv( + x = nn.Conv( features=self.out_channels, kernel_size=(self.kernel_size,) * self.num_spatial_dims, use_bias=False, @@ -152,7 +128,7 @@ def __call__( )(x) x = InstanceNorm(dtype=self.dtype)(x) x = self.activation(x) - x = Conv( + x = nn.Conv( features=self.out_channels, kernel_size=(self.kernel_size,) * self.num_spatial_dims, use_bias=False, @@ -213,9 +189,8 @@ def __call__( class ConvDownSample(nn.Module): """Down-sample with Conv.""" - num_spatial_dims: int out_channels: int - scale_factor: int + scale_factor: tuple[int, ...] dtype: jnp.dtype = jnp.float32 @nn.compact @@ -234,10 +209,10 @@ def __call__( """ return nn.Sequential( [ - Conv( + nn.Conv( features=self.out_channels, - kernel_size=(self.scale_factor,) * self.num_spatial_dims, - strides=(self.scale_factor,) * self.num_spatial_dims, + kernel_size=self.scale_factor, + strides=self.scale_factor, use_bias=False, dtype=self.dtype, ), @@ -249,9 +224,8 @@ def __call__( class ConvUpSample(nn.Module): """Up-sample with ConvTranspose.""" - num_spatial_dims: int out_channels: int - scale_factor: int + scale_factor: tuple[int, ...] dtype: jnp.dtype = jnp.float32 @nn.compact @@ -270,10 +244,10 @@ def __call__( """ return nn.Sequential( [ - ConvTranspose( + nn.ConvTranspose( features=self.out_channels, - kernel_size=(self.scale_factor,) * self.num_spatial_dims, - strides=(self.scale_factor,) * self.num_spatial_dims, + kernel_size=self.scale_factor, + strides=self.scale_factor, use_bias=False, dtype=self.dtype, ), diff --git a/imgx/model/conv_test.py b/imgx/model/conv_test.py new file mode 100644 index 0000000..c7e0f1e --- /dev/null +++ b/imgx/model/conv_test.py @@ -0,0 +1,140 @@ +"""Test conv layers.""" + +import chex +import jax +from absl.testing import parameterized +from chex._src import fake + +from imgx.model.conv import ConvDownSample, ConvUpSample + + +# Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. +def setUpModule() -> None: # pylint: disable=invalid-name + """Fake multi-devices.""" + fake.set_n_cpu_devices(2) + + +class TestConvDownSample(chex.TestCase): + """Test ConvDownSample.""" + + batch = 2 + + @chex.all_variants() + @parameterized.named_parameters( + ( + "1d", + (12,), + (2,), + (6,), + ), + ( + "2d", + (12, 13), + (2, 2), + (6, 7), + ), + ( + "2d - different scale factors", + (12, 13), + (4, 2), + (3, 7), + ), + ( + "3d - large scale factor", + (2, 4, 8), + (4, 4, 4), + (1, 1, 2), + ), + ( + "3d", + (12, 13, 14), + (2, 2, 2), + (6, 7, 7), + ), + ) + def test_shapes( + self, + in_shape: tuple[int, ...], + scale_factor: tuple[int, ...], + out_shape: tuple[int, ...], + ) -> None: + """Test output shapes under different device condition. + + Args: + in_shape: input shape, without batch, channel. + scale_factor: downsample factor. + out_shape: output shape, without batch, channel. + """ + in_channels = 1 + out_channels = 1 + rng = {"params": jax.random.PRNGKey(0)} + conv = ConvDownSample( + out_channels=out_channels, + scale_factor=scale_factor, + ) + x = jax.random.uniform( + jax.random.PRNGKey(0), + shape=(self.batch, *in_shape, in_channels), + ) + out, _ = self.variant(conv.init_with_output)(rng, x) + chex.assert_shape(out, (self.batch, *out_shape, out_channels)) + + +class TestConvUpSample(chex.TestCase): + """Test ConvUpSample.""" + + batch = 2 + + @chex.all_variants() + @parameterized.named_parameters( + ( + "1d", + (3,), + (2,), + (6,), + ), + ( + "2d", + (3, 4), + (2, 2), + (6, 8), + ), + ( + "2d - different scale factors", + (3, 4), + (4, 2), + (12, 8), + ), + ( + "3d", + (2, 3, 4), + (2, 2, 2), + (4, 6, 8), + ), + ) + def test_shapes( + self, + in_shape: tuple[int, ...], + scale_factor: tuple[int, ...], + out_shape: tuple[int, ...], + ) -> None: + """Test output shapes under different device condition. + + Args: + in_shape: input shape, without batch, channel. + scale_factor: up-sampler factor. + out_shape: output shape, without batch, channel. + """ + in_channels = 1 + out_channels = 1 + rng = {"params": jax.random.PRNGKey(0)} + conv = ConvUpSample( + out_channels=out_channels, + scale_factor=scale_factor, + ) + x = jax.random.uniform( + jax.random.PRNGKey(0), + shape=(self.batch, *in_shape, in_channels), + ) + out, _ = self.variant(conv.init_with_output)(rng, x) + chex.assert_shape(out, (self.batch, *out_shape, out_channels)) diff --git a/imgx/model/efficient_attention.py b/imgx/model/efficient_attention.py index fa5c763..1fa857d 100644 --- a/imgx/model/efficient_attention.py +++ b/imgx/model/efficient_attention.py @@ -192,7 +192,7 @@ def dot_product_attention_with_qkv_chunks( bias: jnp.ndarrary | None = None, mask: jnp.ndarrary | None = None, broadcast_dropout: bool = True, - dropout_rng: jax.random.PRNGKeyArray | None = None, + dropout_rng: jax.Array | None = None, dropout_rate: float = 0.0, deterministic: bool = False, dtype: jnp.dtype | None = None, diff --git a/imgx/model/transformer.py b/imgx/model/transformer.py index 6e4faab..0c100ea 100644 --- a/imgx/model/transformer.py +++ b/imgx/model/transformer.py @@ -6,7 +6,7 @@ import jax.numpy as jnp import numpy as np -from imgx.model.basic import MLP, LayerNorm, truncated_normal +from imgx.model.basic import MLP from imgx.model.efficient_attention import dot_product_attention_with_qkv_chunks @@ -96,7 +96,7 @@ def __call__( if self.add_position_embedding: positional_embeddings = self.param( "transformer_positional_embeddings", - truncated_normal(stddev=0.02), + nn.initializers.truncated_normal(stddev=0.02), (1, seq_len, model_size), ) x += positional_embeddings @@ -112,8 +112,8 @@ def __call__( attention_fn=attention_fn, kernel_init=kernel_init, dtype=self.dtype, - )(inputs_q=h, inputs_kv=h, mask=mask) - h_attn = LayerNorm(dtype=self.dtype)(h_attn) + )(inputs_q=h, inputs_k=h, inputs_v=h, mask=mask) + h_attn = nn.LayerNorm(dtype=self.dtype)(h_attn) h = h + h_attn h_dense = mlp_cls( @@ -122,10 +122,10 @@ def __call__( dtype=self.dtype, kernel_init=kernel_init, )(h) - h_dense = LayerNorm(dtype=self.dtype)(h_dense) + h_dense = nn.LayerNorm(dtype=self.dtype)(h_dense) h = h + h_dense # save intermediate hidden embeddings hidden_embeddings.append(h) - return LayerNorm(dtype=self.dtype)(h), hidden_embeddings + return nn.LayerNorm(dtype=self.dtype)(h), hidden_embeddings diff --git a/imgx/model/unet/bottom_encoder.py b/imgx/model/unet/bottom_encoder.py index c597617..cb78050 100644 --- a/imgx/model/unet/bottom_encoder.py +++ b/imgx/model/unet/bottom_encoder.py @@ -13,6 +13,7 @@ class BottomImageEncoderUnet(nn.Module): kernel_size: int = 3 # convolution layer kernel size num_heads: int = 8 # for multi head attention + num_layers: int = 1 # for transformer encoder widening_factor: int = 4 # for key size in MHA remat: bool = True # reduces memory cost at cost of compute speed dtype: jnp.dtype = jnp.float32 @@ -48,7 +49,7 @@ def __call__( image_emb = image_emb.reshape((batch_size, -1, model_size)) transformer = TransformerEncoder( num_heads=self.num_heads, - num_layers=1, + num_layers=self.num_layers, autoregressive=False, widening_factor=self.widening_factor, remat=self.remat, diff --git a/imgx/model/unet/image_encoder.py b/imgx/model/unet/downsample_encoder.py similarity index 81% rename from imgx/model/unet/image_encoder.py rename to imgx/model/unet/downsample_encoder.py index 5bf6210..6a775a6 100644 --- a/imgx/model/unet/image_encoder.py +++ b/imgx/model/unet/downsample_encoder.py @@ -1,4 +1,4 @@ -"""Image encoder for unet.""" +"""Downsample encoder for unet.""" from __future__ import annotations import flax.linen as nn @@ -7,13 +7,13 @@ from imgx.model.conv import ConvDownSample, ConvNormAct, ConvResBlock -class ImageEncoderUnet(nn.Module): - """Image encoder module with convolutions for unet.""" +class DownsampleEncoder(nn.Module): + """Down-sample encoder module with convolutions for unet.""" num_spatial_dims: int # 2 or 3 num_channels: tuple[int, ...] # channel at each depth, including the bottom - patch_size: int = 2 # first down sampling layer - scale_factor: int = 2 # spatial down-sampling/up-sampling + patch_size: tuple[int, ...] | int = 2 # first down sampling layer + scale_factor: tuple[int, ...] | int = 2 # spatial down-sampling/up-sampling num_res_blocks: int = 2 # number of residual blocks kernel_size: int = 3 # convolution layer kernel size num_heads: int = 8 # for multi head attention/MHA @@ -49,9 +49,17 @@ def __call__( Returns: List of embeddings from each layer. """ + patch_size = self.patch_size + scale_factor = self.scale_factor + if isinstance(patch_size, int): + patch_size = (patch_size,) * self.num_spatial_dims + if isinstance(scale_factor, int): + scale_factor = (scale_factor,) * self.num_spatial_dims + conv_norm_act_cls = nn.remat(ConvNormAct) if self.remat else ConvNormAct conv_res_block_cls = nn.remat(ConvResBlock) if self.remat else ConvResBlock conv_down_sample_cls = nn.remat(ConvDownSample) if self.remat else ConvDownSample + # encoder raw input x = conv_norm_act_cls( num_spatial_dims=self.num_spatial_dims, @@ -75,11 +83,9 @@ def __call__( # down-sampling for non-bottom layers # spatial shape get halved by 2**(i+1) if i < len(self.num_channels) - 1: - scale_factor = self.patch_size if i == 0 else self.scale_factor x = conv_down_sample_cls( - num_spatial_dims=self.num_spatial_dims, out_channels=self.num_channels[i + 1], - scale_factor=scale_factor, + scale_factor=patch_size if i == 0 else scale_factor, )(x) embeddings.append(x) diff --git a/imgx/model/unet/unet.py b/imgx/model/unet/unet.py index 7f35ed3..d1e3b4b 100644 --- a/imgx/model/unet/unet.py +++ b/imgx/model/unet/unet.py @@ -2,13 +2,14 @@ from __future__ import annotations import flax.linen as nn +import jax import jax.numpy as jnp from imgx.model.basic import sinusoidal_positional_embedding from imgx.model.slice import merge_spatial_dim_into_batch, split_spatial_dim_from_batch from imgx.model.unet.bottom_encoder import BottomImageEncoderUnet -from imgx.model.unet.image_encoder import ImageEncoderUnet -from imgx.model.unet.mask_decoder import MaskDecoderUnet +from imgx.model.unet.downsample_encoder import DownsampleEncoder +from imgx.model.unet.upsample_decoder import UpsampleDecoder class Unet(nn.Module): @@ -26,6 +27,8 @@ class Unet(nn.Module): kernel_size: int = 3 # convolution layer kernel size num_heads: int = 8 # for multi head attention/MHA widening_factor: int = 4 # for key size in MHA + num_transform_layers: int = 1 # for transformer encoder + out_kernel_init: jax.nn.initializers.Initializer = nn.linear.default_kernel_init remat: bool = True # remat reduces memory cost at cost of compute speed dtype: jnp.dtype = jnp.float32 @@ -81,7 +84,7 @@ def __call__( ) # image encoder - embeddings = ImageEncoderUnet( + embeddings = DownsampleEncoder( num_spatial_dims=self.num_spatial_dims, num_channels=self.num_channels, patch_size=self.patch_size, @@ -100,13 +103,14 @@ def __call__( kernel_size=self.kernel_size, num_heads=self.num_heads, widening_factor=self.widening_factor, + num_layers=self.num_transform_layers, remat=self.remat, dtype=self.dtype, )(image_emb=image_emb, t_emb=t_emb) embeddings.append(image_emb) # mask decoder - out = MaskDecoderUnet( + out = UpsampleDecoder( num_spatial_dims=self.num_spatial_dims, out_channels=self.out_channels, num_channels=self.num_channels, @@ -115,6 +119,7 @@ def __call__( num_res_blocks=self.num_res_blocks, kernel_size=self.kernel_size, widening_factor=self.widening_factor, + out_kernel_init=self.out_kernel_init, remat=self.remat, dtype=self.dtype, )(embeddings=embeddings, t_emb=t_emb) diff --git a/imgx/model/unet/unet_test.py b/imgx/model/unet/unet_test.py index 30d8e28..5b39acc 100644 --- a/imgx/model/unet/unet_test.py +++ b/imgx/model/unet/unet_test.py @@ -197,8 +197,8 @@ def test_output_real_shape( chex.assert_shape(out, (self.batch_size, *in_shape, self.out_channels)) @parameterized.named_parameters( - ("Unet without time", False, 34194, 26.643322), - ("Unet with time", True, 36106, 28.515594), + ("Unet without time", False, 34194, 27.610577), + ("Unet with time", True, 36106, 29.420588), ) def test_params_count( self, @@ -208,6 +208,8 @@ def test_params_count( ) -> None: """Count network parameters. + Changing layer/model names may change the initial parameters norm. + Args: with_time: with time or not. expected_params_count: expected number of parameters. diff --git a/imgx/model/unet/mask_decoder.py b/imgx/model/unet/upsample_decoder.py similarity index 76% rename from imgx/model/unet/mask_decoder.py rename to imgx/model/unet/upsample_decoder.py index d5941cb..faddd5e 100644 --- a/imgx/model/unet/mask_decoder.py +++ b/imgx/model/unet/upsample_decoder.py @@ -1,24 +1,25 @@ -"""Mask encoder for unet.""" +"""Upsample encoder for unet.""" from __future__ import annotations import flax.linen as nn import jax.lax import jax.numpy as jnp -from imgx.model.conv import Conv, ConvResBlock, ConvUpSample +from imgx.model.conv import ConvResBlock, ConvUpSample -class MaskDecoderUnet(nn.Module): - """Mask decoder module with convolutions for unet.""" +class UpsampleDecoder(nn.Module): + """Upsample decoder module with convolutions for unet.""" num_spatial_dims: int # 2 or 3 out_channels: int num_channels: tuple[int, ...] # channel at each depth, including the bottom - patch_size: int = 2 # first down sampling layer - scale_factor: int = 2 # spatial down-sampling/up-sampling + patch_size: tuple[int, ...] | int = 2 # first down sampling layer + scale_factor: tuple[int, ...] | int = 2 # spatial down-sampling/up-sampling num_res_blocks: int = 2 # number of residual blocks kernel_size: int = 3 # convolution layer kernel size widening_factor: int = 4 # for key size in MHA + out_kernel_init: jax.nn.initializers.Initializer = nn.linear.default_kernel_init remat: bool = True # remat reduces memory cost at cost of compute speed dtype: jnp.dtype = jnp.float32 @@ -40,10 +41,16 @@ def __call__( """ if len(embeddings) != len(self.num_channels) * (self.num_res_blocks + 1) + 1: raise ValueError("MaskDecoderConvUnet input length does not match") + patch_size = self.patch_size + scale_factor = self.scale_factor + if isinstance(patch_size, int): + patch_size = (patch_size,) * self.num_spatial_dims + if isinstance(scale_factor, int): + scale_factor = (scale_factor,) * self.num_spatial_dims conv_res_block_cls = nn.remat(ConvResBlock) if self.remat else ConvResBlock conv_up_sample_cls = nn.remat(ConvUpSample) if self.remat else ConvUpSample - conv_cls = nn.remat(Conv) if self.remat else Conv + conv_cls = nn.remat(nn.Conv) if self.remat else nn.Conv # spatial shape get halved by 2**(len(self.num_channels)-1) # channel = self.num_channels[-1] @@ -71,13 +78,9 @@ def __call__( # as padding may be added when down-sampling skipped_shape = embeddings[-1].shape[1:-1] # deconv and pad to make emb of same shape as skipped - scale_factor = ( - self.patch_size if i == len(self.num_channels) - 2 else self.scale_factor - ) x = conv_up_sample_cls( - num_spatial_dims=self.num_spatial_dims, out_channels=self.num_channels[-i - 2], - scale_factor=scale_factor, + scale_factor=patch_size if i == len(self.num_channels) - 2 else scale_factor, )(x) x = jax.lax.dynamic_slice( x, @@ -87,5 +90,6 @@ def __call__( out = conv_cls( features=self.out_channels, kernel_size=(1,) * self.num_spatial_dims, + kernel_init=self.out_kernel_init, )(x) return out diff --git a/imgx/optim.py b/imgx/optim.py deleted file mode 100644 index 08db988..0000000 --- a/imgx/optim.py +++ /dev/null @@ -1,202 +0,0 @@ -"""Module for optimization.""" -from __future__ import annotations - -from functools import partial -from typing import Callable - -import chex -import jax -import optax -from absl import logging -from flax.training import dynamic_scale as dynamic_scale_lib -from jax import lax -from jax import numpy as jnp -from omegaconf import DictConfig - -from imgx import REPLICA_AXIS -from imgx.train_state import TrainState - - -def get_lr_schedule(config: DictConfig) -> optax.Schedule: - """Get learning rate scheduler. - - Args: - config: entire configuration. - - Returns: - Scheduler - """ - return optax.warmup_cosine_decay_schedule(**config.optimizer.lr_schedule) - - -def get_every_k_schedule(config: DictConfig) -> int: - """Get k for gradient accumulations. - - Args: - config: entire configuration. - - Returns: - k, where gradients are accumulated every k step. - """ - num_devices_per_replica = config.data.trainer.num_devices_per_replica - batch_size_per_replica = config.data.trainer.batch_size_per_replica - num_replicas = jax.local_device_count() // num_devices_per_replica - batch_size_per_step = batch_size_per_replica * num_replicas - if config.data.trainer.batch_size < batch_size_per_step: - raise ValueError( - f"Batch size {config.data.trainer.batch_size} is too small. " - f"batch_size_per_replica * num_replicas = " - f"{batch_size_per_replica} * {num_replicas} = " - f"{batch_size_per_step}." - ) - if config.data.trainer.batch_size % batch_size_per_step != 0: - raise ValueError("Batch size cannot be evenly divided by batch size per step.") - every_k_schedule = config.data.trainer.batch_size // batch_size_per_step - if every_k_schedule > 1: - logging.info( - f"Using gradient accumulation. " - f"Each model duplicate is stored across {num_devices_per_replica} " - f"shard{'s' if num_devices_per_replica > 1 else ''}. " - f"Each step has {batch_size_per_step} samples. " - f"Gradients are averaged every {every_k_schedule} steps. " - f"Effective batch size is {config.data.trainer.batch_size}." - ) - return every_k_schedule - - -def init_optimizer( - config: DictConfig, -) -> optax.GradientTransformation: - """Initialize optimizer. - - Args: - config: entire configuration. - - Returns: - optimizer. - """ - lr_schedule = get_lr_schedule(config) - optimizer = optax.chain( - optax.clip_by_global_norm(config.optimizer.grad_norm), - getattr(optax, config.optimizer.name)(learning_rate=lr_schedule, **config.optimizer.kwargs), - ) - # accumulate gradient when needed - every_k_schedule = get_every_k_schedule(config) - if every_k_schedule == 1: - # no need to accumulate gradient - return optimizer - return optax.MultiSteps(optimizer, every_k_schedule=every_k_schedule) - - -def get_gradients( - train_state: TrainState, - loss_step: Callable[[chex.ArrayTree, chex.ArrayTree], tuple[jnp.ndarray, chex.ArrayTree]] - | Callable[ - [chex.ArrayTree, chex.ArrayTree, jax.random.KeyArray], tuple[jnp.ndarray, chex.ArrayTree] - ], - input_dict: dict[str, chex.ArrayTree], -) -> tuple[dynamic_scale_lib.DynamicScale, jnp.ndarray, chex.ArrayTree, chex.ArrayTree]: - """Get gradients. - - Args: - train_state: training state. - loss_step: loss step function. - input_dict: input to loss_step in additional to params. - - Returns: - dynamic_scale: dynamic scale. - is_fin: whether the gradients are finite. - aux: auxiliary outputs from loss_step. - grads: gradients. - """ - is_fin = None - dynamic_scale = train_state.dynamic_scale - if dynamic_scale: - grad_fn = dynamic_scale.value_and_grad(loss_step, has_aux=True, axis_name=REPLICA_AXIS) - dynamic_scale, is_fin, aux, grads = grad_fn(train_state.params, **input_dict) - # dynamic loss takes care of averaging gradients across replicas - else: - grad_fn = jax.value_and_grad(loss_step, has_aux=True) - aux, grads = grad_fn(train_state.params, **input_dict) - # Re-use same axis_name as in the call to `pmap(...train_step...)` below. - grads = lax.pmean(grads, axis_name=REPLICA_AXIS) - return dynamic_scale, is_fin, aux, grads - - -def update_train_state( - train_state: TrainState, - dynamic_scale: dynamic_scale_lib.DynamicScale, - is_fin: jnp.ndarray, - grads: chex.ArrayTree, -) -> TrainState: - """Update training state. - - Args: - train_state: training state. - dynamic_scale: dynamic scale. - is_fin: whether the gradients are finite. - grads: gradients. - - Returns: - new training state. - """ - new_state = train_state.apply_gradients(grads=grads) - if dynamic_scale: - # if is_fin == False the gradients contain Inf/NaNs and optimizer state and - # params should be restored (= skip this step). - new_state = new_state.replace( - opt_state=jax.tree_util.tree_map( - partial(jnp.where, is_fin), - new_state.opt_state, - train_state.opt_state, - ), - params=jax.tree_util.tree_map( - partial(jnp.where, is_fin), new_state.params, train_state.params - ), - dynamic_scale=dynamic_scale, - ) - return new_state - - -def get_optimization_metrics( - grads: chex.ArrayTree, - train_state: TrainState, - config: DictConfig, -) -> dict[str, float]: - """Get optimization metrics. - - Args: - grads: gradients. - train_state: training state. - config: entire configuration. - - Returns: - metrics. - """ - metrics = { - "grad_norm": optax.global_norm(grads), - "params_norm": optax.global_norm(train_state.params), - } - if train_state.dynamic_scale: - metrics["scale"] = train_state.dynamic_scale.scale - - lr_schedule = get_lr_schedule(config) - every_k_schedule = get_every_k_schedule(config) - metrics["lr"] = lr_schedule(train_state.step // every_k_schedule) - - return metrics - - -def get_half_precision_dtype(half_precision: bool) -> jnp.dtype: - """Get half precision dtype. - - Args: - half_precision: whether to use half precision. - - Returns: - dtype. - """ - if not half_precision: - return jnp.float32 - platform = jax.local_devices()[0].platform - return jnp.bfloat16 if platform == "tpu" else jnp.float16 diff --git a/imgx/run_train.py b/imgx/run_train.py index 0a504ed..079b01d 100644 --- a/imgx/run_train.py +++ b/imgx/run_train.py @@ -11,9 +11,9 @@ from omegaconf import DictConfig, OmegaConf from imgx.config import flatten_dict -from imgx.diffusion_segmentation.experiment import DiffusionSegmentationExperiment from imgx.experiment import Experiment -from imgx.segmentation.experiment import SegmentationExperiment +from imgx.task.diffusion_segmentation.experiment import DiffusionSegmentationExperiment +from imgx.task.segmentation.experiment import SegmentationExperiment from imgx.train_state import save_checkpoint from imgx_datasets import INFO_MAP from imgx_datasets.constant import VALID_SPLIT @@ -32,26 +32,26 @@ def set_debug_config(config: DictConfig) -> DictConfig: """ # reduce all model size # due to the attention, deeper model reduces the memory usage - config.task.model.num_channels = (1, 2, 4, 4) + config.task.model.num_channels = (1, 1, 1, 4) # make training shorter n_devices = jax.local_device_count() - config.data.loader.max_num_samples_per_split = 11 + config.data.loader.max_num_samples_per_split = 5 config.data.trainer.batch_size_per_replica = 2 config.data.trainer.batch_size = n_devices * config.data.trainer.batch_size_per_replica - config.data.trainer.max_num_samples = 256 + config.data.trainer.max_num_samples = 25 # make logging more frequent config.logging.log_freq = 1 - config.logging.save_freq = 4 + config.logging.save_freq = 2 # stop early - config.task.early_stopping.patience = 5 + config.task.early_stopping.patience = 1 config.task.early_stopping.min_delta = 0.1 return config -def process_config(config: DictConfig) -> DictConfig: +def process_config(config: DictConfig) -> tuple[DictConfig, list[str]]: """Modify attributes based on config. Args: @@ -59,6 +59,7 @@ def process_config(config: DictConfig) -> DictConfig: Returns: modified config. + tags for logging. """ if config.data.trainer.num_devices_per_replica != 1: raise ValueError("Distributed training not supported.") @@ -73,6 +74,10 @@ def process_config(config: DictConfig) -> DictConfig: # as by default model is 3D config.task.model.num_spatial_dims = dataset_info.ndim + # overwrite patch size and scale factor + config.task.model.patch_size = config.data.patch_size + config.task.model.scale_factor = config.data.scale_factor + # set model output channels out_channels = dataset_info.num_classes if config.task.name == "diffusion_segmentation": @@ -83,7 +88,16 @@ def process_config(config: DictConfig) -> DictConfig: out_channels *= 2 config.task.model.out_channels = out_channels - return config + # get tags for logging + tags = [config.data.name, config.task.name] + if config.debug: + tags.append("debug") + if config.task.name == "diffusion_segmentation": + if config.task.recycling.use: + tags.append("recycling") + if config.task.self_conditioning.use: + tags.append("self_conditioning") + return config, tags def get_batch_size_per_step(config: DictConfig) -> int: @@ -132,95 +146,109 @@ def main( # pylint:disable=too-many-statements # update config if config.debug: config = set_debug_config(config) - config = process_config(config) + config, tags = process_config(config) logging.info(OmegaConf.to_yaml(config)) # init wandb - wandb_run = wandb.init( + settings = None + if config.logging.root_dir: + root_dir = Path(config.logging.root_dir).resolve() + root_dir.mkdir(parents=True, exist_ok=True) + settings = wandb.Settings(root_dir=root_dir) + with wandb.init( project=config.logging.wandb.project, entity=config.logging.wandb.entity, config=flatten_dict(dict(config)), - ) - files_dir = Path(wandb_run.settings.files_dir) - # backup config - OmegaConf.save(config=config, f=files_dir / "config_backup.yaml") - ckpt_dir = files_dir / "ckpt" - - # init model - run = build_experiment(config=config) - train_state, step_offset = run.train_init() - key = jax.random.PRNGKey(config.seed) - key = common_utils.shard_prng_key(key) # each replica has a different key - - logging.info( - f"Start training with early stopping on {config.task.early_stopping.metric} " - f"and patience = {config.task.early_stopping.patience}." - ) - batch_size_per_step = get_batch_size_per_step(config) - max_num_steps = config.data.trainer.max_num_samples // batch_size_per_step - early_stop = EarlyStopping( - min_delta=config.task.early_stopping.min_delta, patience=config.task.early_stopping.patience - ) - for i in range(1 + step_offset, 1 + max_num_steps): - train_state, key, train_metrics = run.train_step(train_state, key) - train_metrics = {"train_" + k: v for k, v in train_metrics.items()} - metrics = { - "num_samples": i * batch_size_per_step, - **train_metrics, - } - - # save checkpoint if needed - to_save_ckpt = (i > 0) and ((i % config.logging.save_freq == 0) or (i == max_num_steps)) - if to_save_ckpt: - ckpt_path = save_checkpoint( - train_state=train_state, - ckpt_dir=ckpt_dir, - # when early stop, it's patience+1 ckpt - keep=config.task.early_stopping.patience + 1, - ) - key, val_metrics = run.eval_step(train_state=train_state, key=key, split=VALID_SPLIT) - out_dir = Path(ckpt_path) - if config.task.name == "diffusion_segmentation": - out_dir = out_dir / config.task.sampler.name - out_dir.mkdir(parents=True, exist_ok=True) - with open(out_dir / "mean_metrics.json", "w", encoding="utf-8") as f: - json.dump(val_metrics, f, sort_keys=True, indent=4) - - # update early stopping - early_stop_metric = val_metrics[config.task.early_stopping.metric] - if config.task.early_stopping.mode == "max": - early_stop_metric = -early_stop_metric - _, early_stop = early_stop.update(early_stop_metric) - logging.info( - f"Early stop updated {i}: " - f"should_stop={early_stop.should_stop}, " - f"best_metric({config.task.early_stopping.metric})={early_stop.best_metric:.4f}, " - f"patience_count={early_stop.patience_count}, " - f"min_delta={early_stop.min_delta}, " - f"patience={early_stop.patience}." - ) - - # update metrics - # only add prefix after saving to json - val_metrics = {"valid_" + k: v for k, v in val_metrics.items()} + tags=tags, + settings=settings, + ) as wandb_run: + files_dir = Path(wandb_run.settings.files_dir) + logging.info(f"Logging to {files_dir}.") + # backup config + OmegaConf.save(config=config, f=files_dir / "config_backup.yaml") + ckpt_dir = files_dir / "ckpt" + + # init model + run = build_experiment(config=config) + train_state, step_offset = run.train_init() + key = jax.random.PRNGKey(config.seed) + key = common_utils.shard_prng_key(key) # each replica has a different key + + logging.info( + f"Start training with early stopping on {config.task.early_stopping.metric} " + f"and patience = {config.task.early_stopping.patience}." + ) + batch_size_per_step = get_batch_size_per_step(config) + max_num_steps = config.data.trainer.max_num_samples // batch_size_per_step + early_stop = EarlyStopping( + min_delta=config.task.early_stopping.min_delta, + patience=config.task.early_stopping.patience, + ) + for i in range(1 + step_offset, 1 + max_num_steps): + train_state, key, train_metrics = run.train_step(train_state, key) + train_metrics = {"train_" + k: v for k, v in train_metrics.items()} metrics = { - **metrics, - **val_metrics, + "num_samples": i * batch_size_per_step, + **train_metrics, } - metrics_str = {k: v if isinstance(v, int) else f"{v:.2e}" for k, v in metrics.items()} - logging.info(f"Batch {i}: {metrics_str}") - - # log metrics - if config.logging.wandb.project and (i % config.logging.log_freq == 0): - wandb.log(metrics) - - # early stopping - if early_stop.should_stop: - logging.info( - f"Met early stopping criteria with {config.task.early_stopping.metric} = " - f"{early_stop.best_metric} and patience {early_stop.patience_count}, breaking..." - ) - break + + # save checkpoint if needed + to_save_ckpt = (i > 0) and ((i % config.logging.save_freq == 0) or (i == max_num_steps)) + if to_save_ckpt: + ckpt_path = save_checkpoint( + train_state=train_state, + ckpt_dir=ckpt_dir, + # when early stop, it's patience+1 ckpt + keep=config.task.early_stopping.patience + 1, + ) + key, val_metrics = run.eval_step( + train_state=train_state, key=key, split=VALID_SPLIT + ) + out_dir = Path(ckpt_path) + if config.task.name == "diffusion_segmentation": + out_dir = out_dir / config.task.sampler.name + out_dir.mkdir(parents=True, exist_ok=True) + with open(out_dir / "mean_metrics.json", "w", encoding="utf-8") as f: + json.dump(val_metrics, f, sort_keys=True, indent=4) + + # update early stopping + early_stop_metric = val_metrics[config.task.early_stopping.metric] + if config.task.early_stopping.mode == "max": + early_stop_metric = -early_stop_metric + early_stop = early_stop.update(early_stop_metric) + logging.info( + f"Early stop updated {i}: " + f"should_stop={early_stop.should_stop}, " + f"best_metric({config.task.early_stopping.metric})" + f"={early_stop.best_metric:.4f}, " + f"patience_count={early_stop.patience_count}, " + f"min_delta={early_stop.min_delta}, " + f"patience={early_stop.patience}." + ) + + # update metrics + # only add prefix after saving to json + val_metrics = {"valid_" + k: v for k, v in val_metrics.items()} + metrics = { + **metrics, + **val_metrics, + } + metrics_str = { + k: v if isinstance(v, int) else f"{v:.2e}" for k, v in metrics.items() + } + logging.info(f"Batch {i}: {metrics_str}") + + # log metrics + if config.logging.wandb.project and (i % config.logging.log_freq == 0): + wandb.log(metrics) + + # early stopping + if early_stop.should_stop: + logging.info( + f"Met early stopping criteria with {config.task.early_stopping.metric} = " + f"{early_stop.best_metric} and patience {early_stop.patience_count}, breaking." + ) + break if __name__ == "__main__": diff --git a/imgx/segmentation/loss.py b/imgx/segmentation/loss.py deleted file mode 100644 index f2a133b..0000000 --- a/imgx/segmentation/loss.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Vanilla segmentation loss.""" -from __future__ import annotations - -import jax.numpy as jnp -from omegaconf import DictConfig - -from imgx.loss import cross_entropy, focal_loss -from imgx.loss.dice import dice_loss -from imgx.metric.area import class_proportion -from imgx_datasets.dataset_info import DatasetInfo - - -def segmentation_loss( - logits: jnp.ndarray, - label: jnp.ndarray, - dataset_info: DatasetInfo, - loss_config: DictConfig, -) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: - """Calculate segmentation loss with auxiliary losses and return metrics. - - Args: - logits: unnormalised logits of shape (batch, ..., num_classes). - label: label of shape (batch, ...). - dataset_info: dataset info with helper functions. - loss_config: have weights of diff losses. - - Returns: - - calculated loss, of shape (batch,). - - metrics, values of shape (batch,). - """ - mask_true = dataset_info.label_to_mask(label, axis=-1) - metrics = {} - - # class proportion - # (batch, num_classes) - class_prop_batch_cls = class_proportion(mask_true) - for i in range(dataset_info.num_classes): - metrics[f"class_{i}_proportion_true"] = class_prop_batch_cls[:, i] - - # Dice - # (batch, num_classes) - dice_loss_batch_cls = dice_loss( - logits=logits, - mask_true=mask_true, - classes_are_exclusive=dataset_info.classes_are_exclusive, - ) - # (batch, ) - # without background - # mask out non-existing classes - dice_loss_batch = jnp.mean( - dice_loss_batch_cls[:, 1:], axis=-1, where=class_prop_batch_cls[:, 1:] > 0 - ) - metrics["dice_loss"] = dice_loss_batch - for i in range(dice_loss_batch_cls.shape[-1]): - metrics[f"dice_loss_class_{i}"] = dice_loss_batch_cls[:, i] - - # cross entropy - ce_loss_batch = cross_entropy( - logits=logits, - mask_true=mask_true, - classes_are_exclusive=dataset_info.classes_are_exclusive, - ) - metrics["cross_entropy_loss"] = ce_loss_batch - - # focal loss - focal_loss_batch = focal_loss( - logits=logits, - mask_true=mask_true, - classes_are_exclusive=dataset_info.classes_are_exclusive, - ) - metrics["focal_loss"] = focal_loss_batch - - # total loss - loss_batch = jnp.zeros_like(dice_loss_batch) - if loss_config["dice"] > 0: - loss_batch += dice_loss_batch * loss_config["dice"] - if loss_config["cross_entropy"] > 0: - loss_batch += ce_loss_batch * loss_config["cross_entropy"] - if loss_config["focal"] > 0: - loss_batch += focal_loss_batch * loss_config["focal"] - metrics["total_loss"] = loss_batch - return loss_batch, metrics diff --git a/imgx/segmentation/train_state.py b/imgx/segmentation/train_state.py deleted file mode 100644 index 74d6a00..0000000 --- a/imgx/segmentation/train_state.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Training state and checkpoints. - -https://github.com/google/flax/blob/main/examples/imagenet/train.py -""" -from __future__ import annotations - -from typing import Callable - -import chex -import flax.linen as nn -import jax -from absl import logging -from flax.training import dynamic_scale as dynamic_scale_lib -from omegaconf import DictConfig - -from imgx.optim import init_optimizer -from imgx.train_state import TrainState - - -def create_train_state( - key: jax.random.PRNGKeyArray, - batch: chex.ArrayTree, - model: nn.Module, - config: DictConfig, - initialized: Callable[[jax.random.PRNGKeyArray, chex.ArrayTree, nn.Module], chex.ArrayTree], -) -> TrainState: - """Create initial training state. - - Args: - key: random key. - batch: batch data for determining input shapes. - model: model. - config: entire configuration. - initialized: function to get initialized model parameters. - - Returns: - initial training state. - """ - dynamic_scale = None - platform = jax.local_devices()[0].platform - if config.half_precision and platform == "gpu": - dynamic_scale = dynamic_scale_lib.DynamicScale() - - params = initialized(key, batch, model) - - # count params - params_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) - logging.info(f"The model has {params_count:,} parameters.") - - tx = init_optimizer(config=config) - state = TrainState.create( - apply_fn=model.apply, - params=params, - tx=tx, - dynamic_scale=dynamic_scale, - ) - return state diff --git a/imgx/task/__init__.py b/imgx/task/__init__.py new file mode 100644 index 0000000..9863028 --- /dev/null +++ b/imgx/task/__init__.py @@ -0,0 +1 @@ +"""Module for different learning tasks.""" diff --git a/imgx/diffusion_segmentation/__init__.py b/imgx/task/diffusion_segmentation/__init__.py similarity index 100% rename from imgx/diffusion_segmentation/__init__.py rename to imgx/task/diffusion_segmentation/__init__.py diff --git a/imgx/diffusion_segmentation/diffusion.py b/imgx/task/diffusion_segmentation/diffusion.py similarity index 100% rename from imgx/diffusion_segmentation/diffusion.py rename to imgx/task/diffusion_segmentation/diffusion.py diff --git a/imgx/diffusion_segmentation/diffusion_step.py b/imgx/task/diffusion_segmentation/diffusion_step.py similarity index 95% rename from imgx/diffusion_segmentation/diffusion_step.py rename to imgx/task/diffusion_segmentation/diffusion_step.py index f8a1785..93c1094 100644 --- a/imgx/diffusion_segmentation/diffusion_step.py +++ b/imgx/task/diffusion_segmentation/diffusion_step.py @@ -7,10 +7,10 @@ from omegaconf import DictConfig from imgx.diffusion.time_sampler import TimeSampler -from imgx.diffusion_segmentation.diffusion import DiffusionSegmentation -from imgx.diffusion_segmentation.train_state import TrainState +from imgx.loss.segmentation import segmentation_loss from imgx.metric.util import aggregate_metrics, aggregate_metrics_for_diffusion -from imgx.segmentation.loss import segmentation_loss +from imgx.task.diffusion_segmentation.diffusion import DiffusionSegmentation +from imgx.task.diffusion_segmentation.train_state import TrainState from imgx_datasets.constant import IMAGE, LABEL from imgx_datasets.dataset_info import DatasetInfo @@ -102,7 +102,7 @@ def get_diffusion_loss_step( diffusion_model: DiffusionSegmentation, time_sampler: TimeSampler, ) -> Callable[ - [chex.ArrayTree, chex.ArrayTree, jax.random.KeyArray], + [chex.ArrayTree, chex.ArrayTree, jax.Array], tuple[jnp.ndarray, tuple[jnp.ndarray, chex.ArrayTree, jnp.ndarray, jnp.ndarray]], ]: """Return loss_step for vanilla diffusion. @@ -121,7 +121,7 @@ def get_diffusion_loss_step( def loss_step( params: chex.ArrayTree, batch: chex.ArrayTree, - key: jax.random.KeyArray, + key: jax.Array, ) -> tuple[jnp.ndarray, tuple[jnp.ndarray, chex.ArrayTree, jnp.ndarray, jnp.ndarray]]: """Apply forward and calculate loss.""" image, label = batch[IMAGE], batch[LABEL] diff --git a/imgx/diffusion_segmentation/experiment.py b/imgx/task/diffusion_segmentation/experiment.py similarity index 77% rename from imgx/diffusion_segmentation/experiment.py rename to imgx/task/diffusion_segmentation/experiment.py index e057dba..88fe211 100644 --- a/imgx/diffusion_segmentation/experiment.py +++ b/imgx/task/diffusion_segmentation/experiment.py @@ -16,44 +16,49 @@ from hydra.utils import instantiate from jax import lax from omegaconf import DictConfig +from tqdm import tqdm from imgx import REPLICA_AXIS -from imgx.data.augmentation import build_aug_fn_from_config +from imgx.data import AugmentationFn +from imgx.data.affine import get_random_affine_augmentation_fn +from imgx.data.augmentation import chain_aug_fns from imgx.data.patch import ( batch_patch_grid_mean_aggregate, batch_patch_grid_sample, get_patch_shape_grid_from_config, + get_random_patch_fn, ) from imgx.data.util import unpad -from imgx.device import bind_rng_to_host_or_device, get_first_replica_values +from imgx.device import bind_rng_to_host_or_device, get_first_replica_values, unshard from imgx.diffusion.time_sampler import TimeSampler -from imgx.diffusion_segmentation.diffusion import DiffusionSegmentation -from imgx.diffusion_segmentation.diffusion_step import get_diffusion_loss_step -from imgx.diffusion_segmentation.gaussian_diffusion import ( +from imgx.experiment import Experiment +from imgx.metric.segmentation import get_segmentation_metrics_per_step +from imgx.metric.util import aggregate_pmap_metrics +from imgx.task.diffusion_segmentation.diffusion import DiffusionSegmentation +from imgx.task.diffusion_segmentation.diffusion_step import get_diffusion_loss_step +from imgx.task.diffusion_segmentation.gaussian_diffusion import ( DDIMSegmentationSampler, DDPMSegmentationSampler, GaussianDiffusionSegmentation, ) -from imgx.diffusion_segmentation.recycling_step import get_recycling_loss_step -from imgx.diffusion_segmentation.self_conditioning_step import get_self_conditioning_loss_step -from imgx.diffusion_segmentation.train_state import TrainState, create_train_state -from imgx.experiment import Experiment -from imgx.metric.util import aggregate_pmap_metrics -from imgx.optim import ( +from imgx.task.diffusion_segmentation.recycling_step import get_recycling_loss_step +from imgx.task.diffusion_segmentation.self_conditioning_step import get_self_conditioning_loss_step +from imgx.task.diffusion_segmentation.train_state import TrainState, create_train_state +from imgx.task.segmentation.save import save_segmentation_prediction +from imgx.task.util import decode_uids +from imgx.train_state import ( get_gradients, get_half_precision_dtype, get_optimization_metrics, + restore_checkpoint, update_train_state, ) -from imgx.segmentation.metric import get_segmentation_metrics_per_step -from imgx.train_state import restore_checkpoint from imgx_datasets.constant import IMAGE, LABEL, TEST_SPLIT, UID, VALID_SPLIT from imgx_datasets.dataset_info import DatasetInfo -from imgx_datasets.image_io import save_segmentation_prediction def initialized( - key: jax.random.PRNGKeyArray, + key: jax.Array, batch: chex.ArrayTree, model: nn.Module, dataset_info: DatasetInfo, @@ -118,7 +123,7 @@ def get_importance_sampling_metrics( def sample_logits_progressive( train_state: TrainState, image: jnp.ndarray, - key: jax.random.PRNGKeyArray, + key: jax.Array, dataset_info: DatasetInfo, diffusion_model: DiffusionSegmentation, self_conditioning: bool, @@ -169,13 +174,13 @@ def sample_logits_progressive( def train_step( train_state: TrainState, batch: chex.ArrayTree, - key: jax.random.PRNGKeyArray, - aug_fn: Callable[[jax.random.PRNGKeyArray, chex.ArrayTree], chex.ArrayTree], + key: jax.Array, + aug_fn: Callable[[jax.Array, chex.ArrayTree], chex.ArrayTree], dataset_info: DatasetInfo, config: DictConfig, diffusion_model: DiffusionSegmentation, time_sampler: TimeSampler, -) -> tuple[chex.ArrayTree, jax.random.PRNGKeyArray, chex.ArrayTree]: +) -> tuple[TrainState, jax.Array, chex.ArrayTree]: """Perform a training step. Args: @@ -202,7 +207,7 @@ def train_step( dataset_info=dataset_info, diffusion_model=diffusion_model, time_sampler=time_sampler, - loss_config=config.loss, + loss_config=config.task.loss, prev_step=config.task.recycling.prev_step, reverse_step=config.task.recycling.reverse_step, ) @@ -212,7 +217,7 @@ def train_step( dataset_info=dataset_info, diffusion_model=diffusion_model, time_sampler=time_sampler, - loss_config=config.loss, + loss_config=config.task.loss, prev_step=config.task.self_conditioning.prev_step, probability=config.task.self_conditioning.probability, ) @@ -222,7 +227,7 @@ def train_step( dataset_info=dataset_info, diffusion_model=diffusion_model, time_sampler=time_sampler, - loss_config=config.loss, + loss_config=config.task.loss, ) # augment, calculate gradients, update train state @@ -273,13 +278,13 @@ def train_step( def eval_step( train_state: TrainState, batch: chex.ArrayTree, - key: jax.random.PRNGKeyArray, + key: jax.Array, patch_start_indices: np.ndarray, patch_shape: tuple[int, ...], dataset_info: DatasetInfo, config: DictConfig, diffusion_model: DiffusionSegmentation, -) -> tuple[jax.random.PRNGKeyArray, jnp.ndarray]: +) -> tuple[jax.Array, jnp.ndarray]: """Perform an evaluation step. Args: @@ -360,24 +365,30 @@ def train_init( image_shape = self.dataset_info.image_spatial_shape chex.assert_equal(batch[IMAGE].shape[1:-1], image_shape) - aug_fn = build_aug_fn_from_config(self.config) + # data augmentation + aug_fns: list[AugmentationFn] = [ + get_random_affine_augmentation_fn(self.config), + get_random_patch_fn(self.config), + ] + aug_fn = chain_aug_fns(aug_fns) aug_rng = jax.random.PRNGKey(self.config["seed"]) batch = aug_fn(aug_rng, batch) # init train state on cpu first dtype = get_half_precision_dtype(self.config.half_precision) model = instantiate(self.config.task.model, dtype=dtype) - train_state = create_train_state( - key=jax.random.PRNGKey(self.config.seed), - batch=batch, - model=model, - config=self.config, - initialized=partial( - initialized, - dataset_info=self.dataset_info, - self_conditioning=self.config.task.self_conditioning.use, - ), - ) + with jax.default_device(jax.devices("cpu")[0]): + train_state = create_train_state( + key=jax.random.PRNGKey(self.config.seed), + batch=batch, + model=model, + config=self.config, + initialized=partial( + initialized, + dataset_info=self.dataset_info, + self_conditioning=self.config.task.self_conditioning.use, + ), + ) # resume training if ckpt_dir is not None: train_state = restore_checkpoint(state=train_state, ckpt_dir=ckpt_dir, step=step) @@ -445,33 +456,13 @@ def train_init( return train_state, step_offset - def train_step( - self, train_state: TrainState, key: jax.random.PRNGKeyArray - ) -> tuple[TrainState, jax.random.PRNGKeyArray, chex.ArrayTree]: - """Perform a training step. - - Args: - train_state: training state. - key: random key. - - Returns: - - new training state. - - new random key. - - metric dict. - """ - batch = next(self.train_iter) - train_state, key, metrics = self.p_train_step(train_state, batch, key) - metrics = aggregate_pmap_metrics(metrics) - metrics = jax.tree_map(lambda x: x.item(), metrics) # tensor to values - return train_state, key, metrics - - def eval_step( + def eval_step( # pylint:disable=too-many-statements self, train_state: TrainState, - key: jax.random.PRNGKeyArray, + key: jax.Array, split: str, out_dir: Path | None = None, - ) -> tuple[jax.random.PRNGKeyArray, dict[str, jnp.ndarray]]: + ) -> tuple[jax.Array, dict[str, jnp.ndarray]]: """Evaluation on entire validation data set. Args: @@ -497,58 +488,27 @@ def eval_step( num_samples = 0 lst_metrics = [] lst_uids = [] - for _ in range(num_steps): + for _ in tqdm(range(num_steps), total=num_steps): batch = next(split_iter) # get uids uids = batch.pop(UID) uids = uids.reshape(-1) # remove shard axis - uids = [x.decode("utf-8") if isinstance(x, bytes) else x for x in uids.tolist()] - - # logits (num_shards, batch, *spatial_shape, num_classes) - # loss_metrics, values of shape (num_shards, batch) - key, logits = self.p_eval_step(train_state, batch, key) - - # move to CPU to save memory - logits = jax.device_put(logits, device_cpu) - - # logits (num_shards*batch, *spatial_shape, num_classes) - # metrics, values of shape (num_shards*batch,) - logits = logits.reshape(-1, *logits.shape[2:]) - num_samples += logits.shape[0] - - # label (num_shards*batch, *spatial_shape) - # eval_fn is not jittable, so pmap cannot be used - label = batch[LABEL] - label = label.reshape(-1, *label.shape[2:]) - - # remove padded examples - if 0 in uids: - # the batch was not complete, padded with zero - num_samples_in_batch = uids.index(0) - uids = uids[:num_samples_in_batch] - logits = unpad(logits, num_samples_in_batch) - label = unpad(label, num_samples_in_batch) - - # (batch,) per metric - metrics, label_pred = get_segmentation_metrics_per_step( - logits=logits, - label=label, - dataset_info=self.dataset_info, + uids = decode_uids(uids) + + # evaluate the batch + metrics, label_pred, key = self.eval_batch( + train_state=train_state, + key=key, + batch=batch, + uids=uids, + device_cpu=device_cpu, + out_dir=out_dir, ) + num_samples_in_batch = label_pred.shape[0] + lst_uids += uids[:num_samples_in_batch] + num_samples += num_samples_in_batch lst_metrics.append(metrics) - lst_uids += uids - - # save predictions - if out_dir is None: - continue - for i in range(logits.shape[-1]): - save_segmentation_prediction( - preds=np.array(label_pred[..., i], dtype=int), - uids=uids, - out_dir=out_dir / f"step_{i}", - tfds_dir=self.dataset_info.tfds_preprocessed_dir, - ) # concatenate metrics across all samples # metrics, values of shape (num_samples,) @@ -567,3 +527,65 @@ def eval_step( df_metric.to_csv(out_dir / "metrics_per_sample.csv", index=False) return key, agg_metrics + + def eval_batch( + self, + train_state: TrainState, + key: jax.Array, + batch: dict[str, jnp.ndarray], + uids: list[str], + device_cpu: jax.Device, + out_dir: Path | None, + reference_suffix: str = "mask_preprocessed", + output_suffix: str = "mask_pred", + ) -> tuple[dict[str, jnp.ndarray], jnp.ndarray, jax.Array]: + """Evaluate a batch. + + Args: + train_state: training state. + key: random key. + batch: batch data without uid. + uids: uids in the batch. + device_cpu: cpu device. + out_dir: output directory, if not None, predictions will be saved. + reference_suffix: suffix of reference image. + output_suffix: suffix of output image. + + Returns: + metrics, each item has shape (num_shards*batch,). + label_pred: predicted label, of shape (num_shards*batch, *spatial_shape). + key: random key. + """ + # logits (num_shards*batch, *spatial_shape, num_classes) + # label (num_shards*batch, *spatial_shape) + key, logits = self.p_eval_step(train_state, batch, key) + logits = unshard(logits, device=device_cpu) + label = unshard(batch[LABEL], device=device_cpu) + + # remove padded examples + if "" in uids: + num_samples_in_batch = uids.index("") + uids = uids[:num_samples_in_batch] + logits = unpad(logits, num_samples_in_batch) + label = unpad(label, num_samples_in_batch) + + # (batch,) per metric + metrics, label_pred = get_segmentation_metrics_per_step( + logits=logits, + label=label, + dataset_info=self.dataset_info, + ) + + if out_dir is None: + return metrics, label_pred, key + + for i in range(logits.shape[-1]): + save_segmentation_prediction( + preds=np.array(label_pred[..., i], dtype=int), + uids=uids, + out_dir=out_dir / f"step_{i}", + tfds_dir=self.dataset_info.tfds_preprocessed_dir, + reference_suffix=reference_suffix, + output_suffix=output_suffix, + ) + return metrics, label_pred, key diff --git a/imgx/diffusion_segmentation/gaussian_diffusion.py b/imgx/task/diffusion_segmentation/gaussian_diffusion.py similarity index 98% rename from imgx/diffusion_segmentation/gaussian_diffusion.py rename to imgx/task/diffusion_segmentation/gaussian_diffusion.py index a2d28b9..663f0ee 100644 --- a/imgx/diffusion_segmentation/gaussian_diffusion.py +++ b/imgx/task/diffusion_segmentation/gaussian_diffusion.py @@ -13,7 +13,7 @@ get_gaussian_diffusion_attributes, ) from imgx.diffusion.gaussian.sampler import DDIMSampler, DDPMSampler -from imgx.diffusion_segmentation.diffusion import DiffusionSegmentation +from imgx.task.diffusion_segmentation.diffusion import DiffusionSegmentation @dataclass diff --git a/imgx/diffusion_segmentation/gaussian_diffusion_test.py b/imgx/task/diffusion_segmentation/gaussian_diffusion_test.py similarity index 95% rename from imgx/diffusion_segmentation/gaussian_diffusion_test.py rename to imgx/task/diffusion_segmentation/gaussian_diffusion_test.py index e8850d4..20d7acc 100644 --- a/imgx/diffusion_segmentation/gaussian_diffusion_test.py +++ b/imgx/task/diffusion_segmentation/gaussian_diffusion_test.py @@ -4,7 +4,7 @@ import chex from chex._src import fake -from imgx.diffusion_segmentation.gaussian_diffusion import GaussianDiffusionSegmentation +from imgx.task.diffusion_segmentation.gaussian_diffusion import GaussianDiffusionSegmentation # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. diff --git a/imgx/diffusion_segmentation/recycling_step.py b/imgx/task/diffusion_segmentation/recycling_step.py similarity index 94% rename from imgx/diffusion_segmentation/recycling_step.py rename to imgx/task/diffusion_segmentation/recycling_step.py index 377d283..4143c46 100644 --- a/imgx/diffusion_segmentation/recycling_step.py +++ b/imgx/task/diffusion_segmentation/recycling_step.py @@ -8,9 +8,9 @@ from omegaconf import DictConfig from imgx.diffusion.time_sampler import TimeSampler -from imgx.diffusion_segmentation.diffusion import DiffusionSegmentation -from imgx.diffusion_segmentation.diffusion_step import get_loss_logits_metrics -from imgx.diffusion_segmentation.train_state import TrainState +from imgx.task.diffusion_segmentation.diffusion import DiffusionSegmentation +from imgx.task.diffusion_segmentation.diffusion_step import get_loss_logits_metrics +from imgx.task.diffusion_segmentation.train_state import TrainState from imgx_datasets.constant import IMAGE, LABEL from imgx_datasets.dataset_info import DatasetInfo @@ -24,7 +24,7 @@ def get_recycling_loss_step( prev_step: str, reverse_step: bool, ) -> Callable[ - [chex.ArrayTree, chex.ArrayTree, jax.random.KeyArray], + [chex.ArrayTree, chex.ArrayTree, jax.Array], tuple[jnp.ndarray, tuple[jnp.ndarray, chex.ArrayTree, jnp.ndarray, jnp.ndarray]], ]: """Return loss_step for recycling diffusion. @@ -47,7 +47,7 @@ def get_recycling_loss_step( def loss_step( params: chex.ArrayTree, batch: chex.ArrayTree, - key: jax.random.KeyArray, + key: jax.Array, ) -> tuple[jnp.ndarray, tuple[jnp.ndarray, chex.ArrayTree, jnp.ndarray, jnp.ndarray]]: """Apply forward and calculate loss.""" image, label = batch[IMAGE], batch[LABEL] diff --git a/imgx/diffusion_segmentation/self_conditioning_step.py b/imgx/task/diffusion_segmentation/self_conditioning_step.py similarity index 94% rename from imgx/diffusion_segmentation/self_conditioning_step.py rename to imgx/task/diffusion_segmentation/self_conditioning_step.py index af32f1b..b99c3c4 100644 --- a/imgx/diffusion_segmentation/self_conditioning_step.py +++ b/imgx/task/diffusion_segmentation/self_conditioning_step.py @@ -8,9 +8,9 @@ from omegaconf import DictConfig from imgx.diffusion.time_sampler import TimeSampler -from imgx.diffusion_segmentation.diffusion import DiffusionSegmentation -from imgx.diffusion_segmentation.diffusion_step import get_loss_logits_metrics -from imgx.diffusion_segmentation.train_state import TrainState +from imgx.task.diffusion_segmentation.diffusion import DiffusionSegmentation +from imgx.task.diffusion_segmentation.diffusion_step import get_loss_logits_metrics +from imgx.task.diffusion_segmentation.train_state import TrainState from imgx_datasets.constant import IMAGE, LABEL from imgx_datasets.dataset_info import DatasetInfo @@ -24,7 +24,7 @@ def get_self_conditioning_loss_step( prev_step: str, probability: float, ) -> Callable[ - [chex.ArrayTree, chex.ArrayTree, jax.random.KeyArray], + [chex.ArrayTree, chex.ArrayTree, jax.Array], tuple[jnp.ndarray, tuple[jnp.ndarray, chex.ArrayTree, jnp.ndarray, jnp.ndarray]], ]: """Return loss_step for self-conditioning diffusion. @@ -45,7 +45,7 @@ def get_self_conditioning_loss_step( def loss_step( params: chex.ArrayTree, batch: chex.ArrayTree, - key: jax.random.KeyArray, + key: jax.Array, ) -> tuple[jnp.ndarray, tuple[jnp.ndarray, chex.ArrayTree, jnp.ndarray, jnp.ndarray]]: """Apply forward and calculate loss.""" image, label = batch[IMAGE], batch[LABEL] diff --git a/imgx/diffusion_segmentation/train_state.py b/imgx/task/diffusion_segmentation/train_state.py similarity index 92% rename from imgx/diffusion_segmentation/train_state.py rename to imgx/task/diffusion_segmentation/train_state.py index 015c68a..206beef 100644 --- a/imgx/diffusion_segmentation/train_state.py +++ b/imgx/task/diffusion_segmentation/train_state.py @@ -14,8 +14,8 @@ from flax.training import dynamic_scale as dynamic_scale_lib from omegaconf import DictConfig -from imgx.optim import init_optimizer from imgx.train_state import TrainState as BaseTrainState +from imgx.train_state import init_optimizer class TrainState(BaseTrainState): @@ -31,11 +31,11 @@ class TrainState(BaseTrainState): def create_train_state( - key: jax.random.PRNGKeyArray, + key: jax.Array, batch: chex.ArrayTree, model: nn.Module, config: DictConfig, - initialized: Callable[[jax.random.PRNGKeyArray, chex.ArrayTree, nn.Module], chex.ArrayTree], + initialized: Callable[[jax.Array, chex.ArrayTree, nn.Module], chex.ArrayTree], ) -> TrainState: """Create initial training state. diff --git a/imgx/segmentation/__init__.py b/imgx/task/segmentation/__init__.py similarity index 100% rename from imgx/segmentation/__init__.py rename to imgx/task/segmentation/__init__.py diff --git a/imgx/segmentation/experiment.py b/imgx/task/segmentation/experiment.py similarity index 72% rename from imgx/segmentation/experiment.py rename to imgx/task/segmentation/experiment.py index cd72504..76cee99 100644 --- a/imgx/segmentation/experiment.py +++ b/imgx/task/segmentation/experiment.py @@ -14,36 +14,40 @@ from flax import jax_utils from hydra.utils import instantiate from omegaconf import DictConfig +from tqdm import tqdm from imgx import REPLICA_AXIS -from imgx.data.augmentation import build_aug_fn_from_config +from imgx.data import AugmentationFn +from imgx.data.affine import get_random_affine_augmentation_fn +from imgx.data.augmentation import chain_aug_fns from imgx.data.patch import ( batch_patch_grid_mean_aggregate, batch_patch_grid_sample, get_patch_shape_grid_from_config, + get_random_patch_fn, ) from imgx.data.util import unpad -from imgx.device import bind_rng_to_host_or_device, get_first_replica_values +from imgx.device import bind_rng_to_host_or_device, get_first_replica_values, unshard from imgx.experiment import Experiment +from imgx.loss.segmentation import segmentation_loss +from imgx.metric.segmentation import get_segmentation_metrics from imgx.metric.util import aggregate_metrics, aggregate_pmap_metrics -from imgx.optim import ( +from imgx.task.segmentation.save import save_segmentation_prediction +from imgx.task.util import decode_uids +from imgx.train_state import ( + TrainState, + create_train_state, get_gradients, get_half_precision_dtype, get_optimization_metrics, + restore_checkpoint, update_train_state, ) -from imgx.segmentation.loss import segmentation_loss -from imgx.segmentation.metric import get_segmentation_metrics -from imgx.segmentation.train_state import TrainState, create_train_state -from imgx.train_state import restore_checkpoint from imgx_datasets.constant import IMAGE, LABEL, TEST_SPLIT, UID, VALID_SPLIT from imgx_datasets.dataset_info import DatasetInfo -from imgx_datasets.image_io import save_segmentation_prediction -def initialized( - key: jax.random.PRNGKeyArray, batch: chex.ArrayTree, model: nn.Module -) -> chex.ArrayTree: +def initialized(key: jax.Array, batch: chex.ArrayTree, model: nn.Module) -> chex.ArrayTree: """Initialize model parameters and batch statistics. Args: @@ -58,8 +62,7 @@ def initialized( def init(*args) -> chex.ArrayTree: # type: ignore[no-untyped-def] return model.init(*args) - image = batch[IMAGE] - variables = jax.jit(init, backend="cpu")({"params": key}, image) + variables = jax.jit(init, backend="cpu")({"params": key}, batch[IMAGE]) return variables["params"] @@ -95,7 +98,7 @@ def loss_step( logits=logits, label=batch[LABEL], dataset_info=dataset_info, - loss_config=config.loss, + loss_config=config.task.loss, ) loss = jnp.mean(loss_batch) loss_metrics = aggregate_metrics(loss_metrics) @@ -107,11 +110,11 @@ def loss_step( def train_step( train_state: TrainState, batch: chex.ArrayTree, - key: jax.random.PRNGKeyArray, - aug_fn: Callable[[jax.random.PRNGKeyArray, chex.ArrayTree], chex.ArrayTree], + key: jax.Array, + aug_fn: Callable[[jax.Array, chex.ArrayTree], chex.ArrayTree], dataset_info: DatasetInfo, config: DictConfig, -) -> tuple[chex.ArrayTree, jax.random.PRNGKeyArray, chex.ArrayTree]: +) -> tuple[TrainState, jax.Array, chex.ArrayTree]: """Perform a training step. Args: @@ -228,7 +231,12 @@ def train_init( image_shape = self.dataset_info.image_spatial_shape chex.assert_equal(batch[IMAGE].shape[1:-1], image_shape) - aug_fn = build_aug_fn_from_config(self.config) + # data augmentation + aug_fns: list[AugmentationFn] = [ + get_random_affine_augmentation_fn(self.config), + get_random_patch_fn(self.config), + ] + aug_fn = chain_aug_fns(aug_fns) aug_rng = jax.random.PRNGKey(self.config["seed"]) batch = aug_fn(aug_rng, batch) @@ -271,33 +279,13 @@ def train_init( return train_state, step_offset - def train_step( - self, train_state: TrainState, key: jax.random.PRNGKeyArray - ) -> tuple[TrainState, jax.random.PRNGKeyArray, chex.ArrayTree]: - """Perform a training step. - - Args: - train_state: training state. - key: random key. - - Returns: - - new training state. - - new random key. - - metric dict. - """ - batch = next(self.train_iter) - train_state, key, metrics = self.p_train_step(train_state, batch, key) - metrics = aggregate_pmap_metrics(metrics) - metrics = jax.tree_map(lambda x: x.item(), metrics) # tensor to values - return train_state, key, metrics - - def eval_step( + def eval_step( # pylint:disable=too-many-statements self, train_state: TrainState, - key: jax.random.PRNGKeyArray, + key: jax.Array, split: str, out_dir: Path | None = None, - ) -> tuple[jax.random.PRNGKeyArray, chex.ArrayTree]: + ) -> tuple[jax.Array, chex.ArrayTree]: """Evaluation on entire validation data set. Args: @@ -326,58 +314,27 @@ def eval_step( num_samples = 0 lst_metrics = [] lst_uids = [] - for _ in range(num_steps): + for _ in tqdm(range(num_steps), total=num_steps): batch = next(split_iter) # get uids uids = batch.pop(UID) uids = uids.reshape(-1) # remove shard axis - uids = [x.decode("utf-8") if isinstance(x, bytes) else x for x in uids.tolist()] - - # logits (num_shards, batch, *spatial_shape, num_classes) - # loss_metrics, values of shape (num_shards, batch) - logits = self.p_eval_step(train_state, batch) - - # move to CPU to save memory - logits = jax.device_put(logits, device_cpu) - - # logits (num_shards*batch, *spatial_shape, num_classes) - # metrics, values of shape (num_shards*batch,) - logits = logits.reshape(-1, *logits.shape[2:]) - num_samples += logits.shape[0] - - # label (num_shards*batch, *spatial_shape) - # eval_fn is not jittable, so pmap cannot be used - label = batch[LABEL] - label = label.reshape(-1, *label.shape[2:]) - - # remove padded examples - if 0 in uids: - # the batch was not complete, padded with zero - num_samples_in_batch = uids.index(0) - uids = uids[:num_samples_in_batch] - logits = unpad(logits, num_samples_in_batch) - label = unpad(label, num_samples_in_batch) - - # evaluate - # (batch,) per metric - metrics, label_pred = get_segmentation_metrics( - logits=logits, - label=label, - dataset_info=self.dataset_info, - ) - lst_metrics.append(metrics) - lst_uids += uids + uids = decode_uids(uids) - # save predictions - if out_dir is None: - continue - save_segmentation_prediction( - preds=np.array(label_pred, dtype=int), + # evaluate the batch + metrics, label_pred, key = self.eval_batch( + train_state=train_state, + key=key, + batch=batch, uids=uids, + device_cpu=device_cpu, out_dir=out_dir, - tfds_dir=self.dataset_info.tfds_preprocessed_dir, ) + num_samples_in_batch = label_pred.shape[0] + lst_uids += uids[:num_samples_in_batch] + num_samples += num_samples_in_batch + lst_metrics.append(metrics) # concatenate metrics across all samples # metrics, values of shape (num_samples,) @@ -396,3 +353,65 @@ def eval_step( df_metric.to_csv(out_dir / "metrics_per_sample.csv", index=False) return key, agg_metrics + + def eval_batch( + self, + train_state: TrainState, + key: jax.Array, + batch: dict[str, jnp.ndarray], + uids: list[str], + device_cpu: jax.Device, + out_dir: Path | None, + reference_suffix: str = "mask_preprocessed", + output_suffix: str = "mask_pred", + ) -> tuple[dict[str, jnp.ndarray], jnp.ndarray, jax.Array]: + """Evaluate a batch. + + Args: + train_state: training state. + key: random key. + batch: batch data without uid. + uids: uids in the batch. + device_cpu: cpu device. + out_dir: output directory, if not None, predictions will be saved. + reference_suffix: suffix of reference image. + output_suffix: suffix of output image. + + Returns: + metrics, each item has shape (num_shards*batch,). + label_pred: predicted label, of shape (num_shards*batch, *spatial_shape). + key: random key. + """ + # logits (num_shards*batch, *spatial_shape, num_classes) + # label (num_shards*batch, *spatial_shape) + logits = self.p_eval_step(train_state, batch) + logits = unshard(logits, device=device_cpu) + label = unshard(batch[LABEL], device=device_cpu) + + # remove padded examples + if "" in uids: + num_samples_in_batch = uids.index("") + uids = uids[:num_samples_in_batch] + logits = unpad(logits, num_samples_in_batch) + label = unpad(label, num_samples_in_batch) + + # (batch,) per metric + metrics, label_pred = get_segmentation_metrics( + logits=logits, + label_pred=None, + label_true=label, + dataset_info=self.dataset_info, + ) + + if out_dir is None: + return metrics, label_pred, key + + save_segmentation_prediction( + preds=np.array(label_pred, dtype=int), + uids=uids, + out_dir=out_dir, + tfds_dir=self.dataset_info.tfds_preprocessed_dir, + reference_suffix=reference_suffix, + output_suffix=output_suffix, + ) + return metrics, label_pred, key diff --git a/imgx/task/segmentation/save.py b/imgx/task/segmentation/save.py new file mode 100644 index 0000000..4ab8ded --- /dev/null +++ b/imgx/task/segmentation/save.py @@ -0,0 +1,52 @@ +"""Segmentation related io (file cannot be named as io). + +https://stackoverflow.com/questions/26569828/pycharm-py-initialize-cant-initialize-sys-standard-streams +""" +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import SimpleITK as sitk # noqa: N813 + +from imgx_datasets.save import save_image + + +def save_segmentation_prediction( + preds: np.ndarray, + uids: list[str], + out_dir: Path, + tfds_dir: Path, + reference_suffix: str = "mask_preprocessed", + output_suffix: str = "mask_pred", +) -> None: + """Save segmentation predictions. + + Args: + preds: (num_samples, ...), the values are integers. + uids: (num_samples,). + out_dir: output directory. + tfds_dir: directory saving preprocessed images and labels. + reference_suffix: suffix of reference image. + output_suffix: suffix of output image. + """ + if preds.ndim == 3 and np.max(preds) > 1: + raise ValueError( + f"Prediction values should be 0 or 1, but " + f"max value is {np.max(preds)}. " + f"Multi-class segmentation for 2D images are not supported." + ) + if preds.ndim not in [3, 4]: + raise ValueError( + f"Prediction should be 3D or 4D with num_samples axis, but {preds.ndim}D is given." + ) + file_suffix = "nii.gz" if preds.ndim == 4 else "png" + out_dir.mkdir(parents=True, exist_ok=True) + for i, uid in enumerate(uids): + reference_image = sitk.ReadImage(tfds_dir / f"{uid}_{reference_suffix}.{file_suffix}") + save_image( + image=preds[i, ...], + reference_image=reference_image, + out_path=out_dir / f"{uid}_{output_suffix}.{file_suffix}", + dtype=np.uint8, + ) diff --git a/imgx/task/util.py b/imgx/task/util.py new file mode 100644 index 0000000..24a8753 --- /dev/null +++ b/imgx/task/util.py @@ -0,0 +1,25 @@ +"""Shared utility functions.""" +from __future__ import annotations + +from jax import numpy as jnp + + +def decode_uids(uids: jnp.ndarray) -> list[str]: + """Decode uids. + + Args: + uids: uids in bytes or int. + + Returns: + decoded uids. + """ + decoded = [] + for x in uids.tolist(): + if isinstance(x, bytes): + decoded.append(x.decode("utf-8")) + elif x == 0: + # the batch was not complete, padded with zero + decoded.append("") + else: + raise ValueError(f"uid {x} is not supported.") + return decoded diff --git a/imgx/train_state.py b/imgx/train_state.py index 17f84fe..3e393f8 100644 --- a/imgx/train_state.py +++ b/imgx/train_state.py @@ -4,13 +4,23 @@ """ from __future__ import annotations +from functools import partial from pathlib import Path +from typing import Callable import chex import jax +import jax.numpy as jnp +import optax +from absl import logging +from flax import linen as nn from flax.training import checkpoints from flax.training import dynamic_scale as dynamic_scale_lib from flax.training import train_state as ts +from jax import lax +from omegaconf import DictConfig + +from imgx import REPLICA_AXIS class TrainState(ts.TrainState): @@ -24,6 +34,46 @@ class TrainState(ts.TrainState): dynamic_scale: dynamic_scale_lib.DynamicScale +def create_train_state( + key: jax.Array, + batch: chex.ArrayTree, + model: nn.Module, + config: DictConfig, + initialized: Callable[[jax.Array, chex.ArrayTree, nn.Module], chex.ArrayTree], +) -> TrainState: + """Create initial training state. + + Args: + key: random key. + batch: batch data for determining input shapes. + model: model. + config: entire configuration. + initialized: function to get initialized model parameters. + + Returns: + initial training state. + """ + dynamic_scale = None + platform = jax.local_devices()[0].platform + if config.half_precision and platform == "gpu": + dynamic_scale = dynamic_scale_lib.DynamicScale() + + params = initialized(key, batch, model) + + # count params + params_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) + logging.info(f"The model has {params_count:,} parameters.") + + tx = init_optimizer(config=config) + state = TrainState.create( + apply_fn=model.apply, + params=params, + tx=tx, + dynamic_scale=dynamic_scale, + ) + return state + + def restore_checkpoint( state: TrainState, ckpt_dir: Path, step: int | None = None ) -> chex.ArrayTree: @@ -55,3 +105,186 @@ def save_checkpoint(train_state: TrainState, ckpt_dir: Path, keep: int) -> str: step = int(train_state.step) ckpt_path = checkpoints.save_checkpoint_multiprocess(ckpt_dir, train_state, step, keep=keep) return ckpt_path + + +def get_lr_schedule(config: DictConfig) -> optax.Schedule: + """Get learning rate scheduler. + + Args: + config: entire configuration. + + Returns: + Scheduler + """ + return optax.warmup_cosine_decay_schedule(**config.optimizer.lr_schedule) + + +def get_every_k_schedule(config: DictConfig) -> int: + """Get k for gradient accumulations. + + Args: + config: entire configuration. + + Returns: + k, where gradients are accumulated every k step. + """ + num_devices_per_replica = config.data.trainer.num_devices_per_replica + batch_size_per_replica = config.data.trainer.batch_size_per_replica + num_replicas = jax.local_device_count() // num_devices_per_replica + batch_size_per_step = batch_size_per_replica * num_replicas + if config.data.trainer.batch_size < batch_size_per_step: + raise ValueError( + f"Batch size {config.data.trainer.batch_size} is too small. " + f"batch_size_per_replica * num_replicas = " + f"{batch_size_per_replica} * {num_replicas} = " + f"{batch_size_per_step}." + ) + if config.data.trainer.batch_size % batch_size_per_step != 0: + raise ValueError("Batch size cannot be evenly divided by batch size per step.") + every_k_schedule = config.data.trainer.batch_size // batch_size_per_step + if every_k_schedule > 1: + logging.info( + f"Using gradient accumulation. " + f"Each model duplicate is stored across {num_devices_per_replica} " + f"shard{'s' if num_devices_per_replica > 1 else ''}. " + f"Each step has {batch_size_per_step} samples. " + f"Gradients are averaged every {every_k_schedule} steps. " + f"Effective batch size is {config.data.trainer.batch_size}." + ) + return every_k_schedule + + +def init_optimizer( + config: DictConfig, +) -> optax.GradientTransformation: + """Initialize optimizer. + + Args: + config: entire configuration. + + Returns: + optimizer. + """ + lr_schedule = get_lr_schedule(config) + optimizer = optax.chain( + optax.clip_by_global_norm(config.optimizer.grad_norm), + getattr(optax, config.optimizer.name)(learning_rate=lr_schedule, **config.optimizer.kwargs), + ) + # accumulate gradient when needed + every_k_schedule = get_every_k_schedule(config) + if every_k_schedule == 1: + # no need to accumulate gradient + return optimizer + return optax.MultiSteps(optimizer, every_k_schedule=every_k_schedule) + + +def get_gradients( + train_state: TrainState, + loss_step: Callable[[chex.ArrayTree, chex.ArrayTree], tuple[jnp.ndarray, chex.ArrayTree]] + | Callable[[chex.ArrayTree, chex.ArrayTree, jax.Array], tuple[jnp.ndarray, chex.ArrayTree]], + input_dict: dict[str, chex.ArrayTree], +) -> tuple[dynamic_scale_lib.DynamicScale, jnp.ndarray, chex.ArrayTree, chex.ArrayTree]: + """Get gradients. + + Args: + train_state: training state. + loss_step: loss step function. + input_dict: input to loss_step in additional to params. + + Returns: + dynamic_scale: dynamic scale. + is_fin: whether the gradients are finite. + aux: auxiliary outputs from loss_step. + grads: gradients. + """ + is_fin = None + dynamic_scale = train_state.dynamic_scale + if dynamic_scale: + grad_fn = dynamic_scale.value_and_grad(loss_step, has_aux=True, axis_name=REPLICA_AXIS) + dynamic_scale, is_fin, aux, grads = grad_fn(train_state.params, **input_dict) + # dynamic loss takes care of averaging gradients across replicas + else: + grad_fn = jax.value_and_grad(loss_step, has_aux=True) + aux, grads = grad_fn(train_state.params, **input_dict) + # Re-use same axis_name as in the call to `pmap(...train_step...)` below. + grads = lax.pmean(grads, axis_name=REPLICA_AXIS) + return dynamic_scale, is_fin, aux, grads + + +def update_train_state( + train_state: TrainState, + dynamic_scale: dynamic_scale_lib.DynamicScale, + is_fin: jnp.ndarray, + grads: chex.ArrayTree, +) -> TrainState: + """Update training state. + + Args: + train_state: training state. + dynamic_scale: dynamic scale. + is_fin: whether the gradients are finite. + grads: gradients. + + Returns: + new training state. + """ + new_state = train_state.apply_gradients(grads=grads) + if dynamic_scale: + # if is_fin == False the gradients contain Inf/NaNs and optimizer state and + # params should be restored (= skip this step). + new_state = new_state.replace( + opt_state=jax.tree_util.tree_map( + partial(jnp.where, is_fin), + new_state.opt_state, + train_state.opt_state, + ), + params=jax.tree_util.tree_map( + partial(jnp.where, is_fin), new_state.params, train_state.params + ), + dynamic_scale=dynamic_scale, + ) + return new_state + + +def get_optimization_metrics( + grads: chex.ArrayTree, + train_state: TrainState, + config: DictConfig, +) -> dict[str, float]: + """Get optimization metrics. + + Args: + grads: gradients. + train_state: training state. + config: entire configuration. + + Returns: + metrics. + """ + metrics = { + "grad_norm": optax.global_norm(grads), + "params_norm": optax.global_norm(train_state.params), + } + if train_state.dynamic_scale: + metrics["scale"] = train_state.dynamic_scale.scale + + lr_schedule = get_lr_schedule(config) + every_k_schedule = get_every_k_schedule(config) + metrics["lr"] = lr_schedule(train_state.step // every_k_schedule) + + return metrics + + +def get_half_precision_dtype(half_precision: bool) -> jnp.dtype: + """Get half precision dtype. + + Args: + half_precision: whether to use half precision. + + Returns: + dtype. + """ + if not half_precision: + return jnp.float32 + platform = jax.local_devices()[0].platform + return jnp.bfloat16 if platform == "tpu" else jnp.float16 diff --git a/imgx_datasets/README.md b/imgx_datasets/README.md index fd17aad..439c929 100644 --- a/imgx_datasets/README.md +++ b/imgx_datasets/README.md @@ -14,8 +14,7 @@ make build_dataset ### Description This data set from [Li et al. 2022](https://zenodo.org/record/7013610#.Y1U95-zMKrM) contains 589 -T2-weighted labeled images which are split into 411, 14, 164 images for training, validation and -testing respectively. +T2-weighted labeled images which are split for training, validation and testing respectively. ### Download and Build @@ -69,7 +68,7 @@ tfds build imgx_datasets/muscle_us ### Description This data set from [Baid et al. 2021](https://arxiv.org/abs/2107.02314) contains 1251 labeled images -which are split into 938, 13, 300 images for training, validation and testing respectively. +which are split for training, validation and testing respectively. ### Download and Build diff --git a/imgx_datasets/amos_ct/amos_ct_dataset_builder.py b/imgx_datasets/amos_ct/amos_ct_dataset_builder.py index 9f662d9..f998e40 100644 --- a/imgx_datasets/amos_ct/amos_ct_dataset_builder.py +++ b/imgx_datasets/amos_ct/amos_ct_dataset_builder.py @@ -19,7 +19,7 @@ ) from imgx_datasets.dataset_info import OneHotLabeledDatasetInfo from imgx_datasets.preprocess import load_and_preprocess_image_and_label -from imgx_datasets.util import save_uids +from imgx_datasets.save import save_uids _DESCRIPTION = """ The data set includes 500 CT and 100 MR images from Amos: @@ -45,6 +45,7 @@ AMOS_CT_TFDS_FOLD = "ZIP.zenodo.org_record_7262581_files_amos22ZnMFi429bmx93zDuUBTrjdo9oGlndCnbGAVAP0I3p_M.zip" # noqa: E501, pylint: disable=line-too-long AMOS_CT_INFO = OneHotLabeledDatasetInfo( + name="amos_ct", tfds_preprocessed_dir=TFDS_EXTRACTED_DIR / AMOS_CT_TFDS_FOLD / "preprocessed", image_spacing=(1.5, 1.5, 5.0), image_spatial_shape=(192, 128, 128), diff --git a/imgx_datasets/brats2021_mr/brats2021_mr_dataset_builder.py b/imgx_datasets/brats2021_mr/brats2021_mr_dataset_builder.py index 7134019..da6f640 100644 --- a/imgx_datasets/brats2021_mr/brats2021_mr_dataset_builder.py +++ b/imgx_datasets/brats2021_mr/brats2021_mr_dataset_builder.py @@ -22,7 +22,7 @@ ) from imgx_datasets.dataset_info import DatasetInfo from imgx_datasets.preprocess import clip_and_normalise_intensity -from imgx_datasets.util import save_uids +from imgx_datasets.save import save_uids _DESCRIPTION = """ All BraTS mpMRI scans are available as NIfTI files (.nii.gz) and describe @@ -122,6 +122,7 @@ def label_to_mask( BRATS2021_MR_INFO = BRATS2021MRNestedDatasetInfo( + name="brats2021_mr", tfds_preprocessed_dir=TFDS_MANUAL_DIR / BRATS2021_MR_TFDS_FOLD / "preprocessed", image_spacing=(1.0, 1.0, 1.0), image_spatial_shape=(179, 219, 155), diff --git a/imgx_datasets/constant.py b/imgx_datasets/constant.py index 4d3744b..2789f38 100644 --- a/imgx_datasets/constant.py +++ b/imgx_datasets/constant.py @@ -1,4 +1,8 @@ -"""Constants for imgx_datasets.""" +"""Constants for imgx_datasets. + +Cannot be defined in __init__.py because of circular import. +__init__.py imports from each data set, which imports constants from this file. +""" from pathlib import Path # splits @@ -8,12 +12,12 @@ # data dict keys UID = "uid" -IMAGE = "image" -LABEL = "label" +IMAGE = "image" # in a batch, keys having image are also considered as images +LABEL = "label" # in a batch, keys having label are also considered as labels TFDS_DIR: Path = Path.home() / "tensorflow_datasets" TFDS_EXTRACTED_DIR: Path = TFDS_DIR / "downloads" / "extracted" TFDS_MANUAL_DIR: Path = TFDS_DIR / "downloads" / "manual" # segmentation task -FOREGROUND_RANGE = "foreground_range" +FOREGROUND_RANGE = "foreground_range" # added during pre-processing in tf for augmentation diff --git a/imgx_datasets/dataset_info.py b/imgx_datasets/dataset_info.py index 6e65009..43ed5ca 100644 --- a/imgx_datasets/dataset_info.py +++ b/imgx_datasets/dataset_info.py @@ -10,12 +10,13 @@ class DatasetInfo: """Data set class for imgx_datasets.""" + name: str tfds_preprocessed_dir: Path image_spacing: tuple[float, ...] image_spatial_shape: tuple[int, ...] image_channels: int - class_names: tuple[str, ...] - classes_are_exclusive: bool + class_names: tuple[str, ...] # for segmentation label only + classes_are_exclusive: bool # for segmentation label only @property def input_image_shape(self) -> tuple[int, ...]: diff --git a/imgx_datasets/image_io.py b/imgx_datasets/image_io.py deleted file mode 100644 index d2df080..0000000 --- a/imgx_datasets/image_io.py +++ /dev/null @@ -1,173 +0,0 @@ -"""Module for image io.""" -from __future__ import annotations - -from pathlib import Path - -import numpy as np -import SimpleITK as sitk # noqa: N813 -from PIL import Image - - -def save_segmentation_prediction( - preds: np.ndarray, - uids: list, - out_dir: Path, - tfds_dir: Path, -) -> None: - """Save segmentation predictions. - - Args: - preds: (num_samples, ...), the values are integers. - uids: (num_samples,). - out_dir: output directory. - tfds_dir: directory saving preprocessed images and labels. - """ - if preds.ndim == 3: - save_2d_segmentation_prediction( - preds=preds, - uids=uids, - out_dir=out_dir, - ) - elif preds.ndim == 4: - save_3d_segmentation_prediction( - preds=preds, - uids=uids, - out_dir=out_dir, - tfds_dir=tfds_dir, - ) - else: - raise ValueError( - f"Prediction should be 3D or 4D with num_samples axis, but {preds.ndim}D is given." - ) - - -def save_2d_segmentation_prediction( - preds: np.ndarray, - uids: list, - out_dir: Path, -) -> None: - """Save segmentation predictions for 2d images. - - Args: - preds: (num_samples, ...), the values are integers. - uids: (num_samples,). - out_dir: output directory. - """ - if preds.ndim != 3: - raise ValueError(f"Prediction should be 3D, but {preds.ndim}D is given.") - out_dir.mkdir(parents=True, exist_ok=True) - for i, uid in enumerate(uids): - mask_pred = preds[i, ...] - if np.max(mask_pred) > 1: - raise ValueError( - f"Prediction values should be 0 or 1, but " - f"max value is {np.max(mask_pred)} for {uid}. " - f"Multi-class segmentation for 2D images are not supported." - ) - save_2d_grayscale_image( - image=mask_pred, - out_path=out_dir / f"{uid}_mask_pred.png", - ) - - -def save_2d_grayscale_image( - image: np.ndarray, - out_path: Path, -) -> None: - """Save grayscale 2d images. - - Args: - image: (height, width), the values between [0, 1]. - out_path: output path. - """ - out_path.parent.mkdir(parents=True, exist_ok=True) - image = np.asarray(image * 255, dtype="uint8") - Image.fromarray(image, "L").save(str(out_path)) - - -def save_3d_segmentation_prediction( - preds: np.ndarray, - uids: list, - out_dir: Path, - tfds_dir: Path, -) -> None: - """Save segmentation predictions for 3d volumes. - - Args: - preds: (num_samples, width, height, depth), the values are integers. - uids: (num_samples,). - out_dir: output directory. - tfds_dir: directory saving preprocessed images and labels. - """ - if preds.ndim != 4: - raise ValueError(f"Prediction should be 4D, but {preds.ndim}D is given.") - out_dir.mkdir(parents=True, exist_ok=True) - for i, uid in enumerate(uids): - # (width, height, depth) -> (depth, height, width) - mask_pred = np.transpose(preds[i, ...], axes=[2, 1, 0]) - mask_pred = mask_pred.astype(dtype="uint8") - save_3d_mask( - mask=mask_pred, - mask_true_path=tfds_dir / f"{uid}_mask_preprocessed.nii.gz", - out_path=out_dir / f"{uid}_mask_pred.nii.gz", - ) - - -def save_3d_mask( - mask: np.ndarray, - mask_true_path: Path, - out_path: Path, -) -> None: - """Save segmentation predictions for 3d volumes. - - Args: - mask: (depth, height, width), the values are integers. - mask_true_path: path to the true mask. - out_path: output path. - """ - out_path.parent.mkdir(parents=True, exist_ok=True) - volume_mask = sitk.GetImageFromArray(mask) - # copy meta data - volume_mask_true = sitk.ReadImage(mask_true_path) - volume_mask.CopyInformation(volume_mask_true) - # output - sitk.WriteImage( - image=volume_mask, - fileName=out_path, - useCompression=True, - ) - - -def load_2d_grayscale_image( - image_path: Path, - dtype: np.dtype = np.uint8, -) -> np.ndarray: - """Load 2d mask. - - Args: - image_path: path to the mask. - dtype: data type of the output. - - Returns: - mask: (height, width), the values are between [0, 1]. - """ - mask = Image.open(str(image_path)).convert("L") # value [0, 255] - mask = np.asarray(mask) / 255 # value [0, 1] - mask = np.asarray(mask, dtype=dtype) - return mask - - -def load_3d_image( - image_path: Path, - dtype: np.dtype = np.float32, -) -> np.ndarray: - """Load 3d images. - - Args: - image_path: path to the mask. - dtype: data type of the output. - - Returns: - mask: (depth, height, width), the values are integers. - """ - return np.asarray(sitk.GetArrayFromImage(sitk.ReadImage(image_path)), dtype=dtype) diff --git a/imgx_datasets/male_pelvic_mr/male_pelvic_mr_dataset_builder.py b/imgx_datasets/male_pelvic_mr/male_pelvic_mr_dataset_builder.py index 47ea1fc..44b6f9d 100644 --- a/imgx_datasets/male_pelvic_mr/male_pelvic_mr_dataset_builder.py +++ b/imgx_datasets/male_pelvic_mr/male_pelvic_mr_dataset_builder.py @@ -19,7 +19,7 @@ ) from imgx_datasets.dataset_info import OneHotLabeledDatasetInfo from imgx_datasets.preprocess import load_and_preprocess_image_and_label -from imgx_datasets.util import save_uids +from imgx_datasets.save import save_uids _DESCRIPTION = """ The data set includes 589 T2-weighted images acquired from the same number of @@ -44,6 +44,7 @@ MALE_PELVIC_MR_TFDS_FOLD = "ZIP.zenodo.org_record_7013610_files_dataW0mCI6aH_V-TdeDbM4TdKelNcJ5ZxbAi5isebqCnMr0.zip" # noqa: E501, pylint: disable=line-too-long MALE_PELVIR_MR_INFO = OneHotLabeledDatasetInfo( + name="male_pelvic_mr", tfds_preprocessed_dir=TFDS_EXTRACTED_DIR / MALE_PELVIC_MR_TFDS_FOLD / "preprocessed", image_spacing=(0.75, 0.75, 2.5), image_spatial_shape=(256, 256, 48), diff --git a/imgx_datasets/muscle_us/muscle_us_dataset_builder.py b/imgx_datasets/muscle_us/muscle_us_dataset_builder.py index 26916fe..4661579 100644 --- a/imgx_datasets/muscle_us/muscle_us_dataset_builder.py +++ b/imgx_datasets/muscle_us/muscle_us_dataset_builder.py @@ -1,4 +1,5 @@ """muscle_us dataset.""" +from __future__ import annotations from collections.abc import Generator from pathlib import Path @@ -20,8 +21,7 @@ VALID_SPLIT, ) from imgx_datasets.dataset_info import OneHotLabeledDatasetInfo -from imgx_datasets.image_io import load_2d_grayscale_image, save_2d_grayscale_image -from imgx_datasets.util import save_uids +from imgx_datasets.save import load_2d_grayscale_image, save_2d_grayscale_image, save_uids _DESCRIPTION = """ The dataset included 3917 images of biceps brachii, tibialis anterior and @@ -212,6 +212,7 @@ def post_process_label(self, label: jnp.ndarray) -> jnp.ndarray: MUSCLE_US_TFDS_FOLD = "ZIP.data.mend.com_publ-file_data_3jyk_file_b160-98XNE6wqHCOxLE8Ap4-__x82VYGr1POiW-quZggxPZSCk" # noqa: E501, pylint: disable=line-too-long MUSCLE_US_INFO = MuscleUSDatasetInfo( + name="muscle_us", tfds_preprocessed_dir=TFDS_EXTRACTED_DIR / MUSCLE_US_TFDS_FOLD / "preprocessed", image_spacing=(1.0, 1.0), image_spatial_shape=(480, 512), diff --git a/imgx_datasets/preprocess.py b/imgx_datasets/preprocess.py index 85ebe6a..5d02ca7 100644 --- a/imgx_datasets/preprocess.py +++ b/imgx_datasets/preprocess.py @@ -10,64 +10,66 @@ from imgx_datasets.util import get_center_crop_shape_from_bbox, get_center_pad_shape -def check_image_and_label( - image_volume: sitk.Image, - label_volume: sitk.Image, - image_path: Path, - label_path: Path, +def compare_volume_metadata( + volume1: sitk.Image, + volume2: sitk.Image, + path1: Path, + path2: Path, rtol: float = 1.0e-5, atol: float = 1.0e-3, ) -> None: - """Check if metadata matches between image and label. + """Check if metadata matches between images. + + Image can also be labels. Args: - image_volume: loaded image. - label_volume: loaded label. - image_path: image file path. - label_path: label file path. + volume1: image volume 1. + volume2: image volume 2. + path1: file path 1, for error message. + path2: file path 2, for error message. rtol: relative tolerance for sanity check, 1E-5 is too big. atol: absolute tolerance for sanity check, 1E-8 is too big. Raises: - ValueError: if image and label metadata does not match + ValueError: metadata does not match """ - if image_volume.GetSize() != label_volume.GetSize(): + if volume1.GetSize() != volume2.GetSize(): raise ValueError( - f"Image and label sizes are not the same for " - f"{image_path} and {label_path}: " - f"{image_volume.GetSize()} and {label_volume.GetSize()}." + f"Sizes are not the same for " + f"{path1} and {path2}: " + f"{volume1.GetSize()} and {volume2.GetSize()}." ) if not np.allclose( - image_volume.GetSpacing(), - label_volume.GetSpacing(), + volume1.GetSpacing(), + volume2.GetSpacing(), rtol=rtol, atol=atol, ): raise ValueError( - f"Image and label spacing are not the same for " - f"{image_path} and {label_path}: " - f"{image_volume.GetSpacing()} and {label_volume.GetSpacing()}." + f"Spacing are not the same for " + f"{path1} and {path2}: " + f"{volume1.GetSpacing()} and {volume2.GetSpacing()}." ) if not np.allclose( - image_volume.GetDirection(), - label_volume.GetDirection(), + volume1.GetDirection(), + volume2.GetDirection(), rtol=rtol, atol=atol, ): - arr_image = np.array(image_volume.GetDirection()) - arr_label = np.array(label_volume.GetDirection()) + arr_image = np.array(volume1.GetDirection()) + arr_label = np.array(volume2.GetDirection()) raise ValueError( - f"Image and label direction are not the same for " - f"{image_path} and {label_path}: " + f"Direction are not the same for " + f"{path1} and {path2}: " f"{arr_image} and {arr_label}, " f"difference is {arr_image - arr_label} for " f"rtol={rtol} and atol = {atol}." ) - if not np.allclose(image_volume.GetOrigin(), label_volume.GetOrigin(), rtol=rtol, atol=atol): + if not np.allclose(volume1.GetOrigin(), volume2.GetOrigin(), rtol=rtol, atol=atol): raise ValueError( - f"Image and label origin are not the same for " - f"{image_path} and {label_path}: " - f"{image_volume.GetOrigin()} and {label_volume.GetOrigin()}." + f"Origin are not close for " + f"{path1} and {path2}: " + f"{volume1.GetOrigin()} and {volume2.GetOrigin()}." ) @@ -291,11 +293,11 @@ def load_and_preprocess_image_and_label( label_volume = sitk.ReadImage(str(label_path)) # metadata should be the same - check_image_and_label( - image_volume=image_volume, - label_volume=label_volume, - image_path=image_path, - label_path=label_path, + compare_volume_metadata( + volume1=image_volume, + volume2=label_volume, + path1=image_path, + path2=label_path, ) # resample @@ -355,6 +357,7 @@ def load_and_preprocess_image_and_label( ) # save processed image/mask + out_dir.mkdir(parents=True, exist_ok=True) image_out_path = out_dir / (uid + "_img_preprocessed.nii.gz") label_out_path = out_dir / (uid + "_mask_preprocessed.nii.gz") sitk.WriteImage(image=image_volume, fileName=str(image_out_path), useCompression=True) diff --git a/imgx_datasets/save.py b/imgx_datasets/save.py new file mode 100644 index 0000000..ca638f6 --- /dev/null +++ b/imgx_datasets/save.py @@ -0,0 +1,155 @@ +"""IO related functions (file cannot be named as io). + +https://stackoverflow.com/questions/26569828/pycharm-py-initialize-cant-initialize-sys-standard-streams +""" +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pandas as pd +import SimpleITK as sitk # noqa: N813 +from absl import logging +from PIL import Image + + +def save_uids( + train_uids: list[str], + valid_uids: list[str], + test_uids: list[str], + out_dir: Path, +) -> None: + """Save uids to csv files. + + Args: + train_uids: list of training uids. + valid_uids: list of validation uids. + test_uids: list of test uids. + out_dir: directory to save the csv files. + """ + pd.DataFrame({"uid": train_uids}).to_csv(out_dir / "train_uids.csv", index=False) + pd.DataFrame({"uid": valid_uids}).to_csv(out_dir / "valid_uids.csv", index=False) + pd.DataFrame({"uid": test_uids}).to_csv(out_dir / "test_uids.csv", index=False) + logging.info(f"There are {len(train_uids)} training samples.") + logging.info(f"There are {len(valid_uids)} validation samples.") + logging.info(f"There are {len(test_uids)} test samples.") + + +def save_2d_grayscale_image( + image: np.ndarray, + out_path: Path, +) -> None: + """Save grayscale 2d images. + + Args: + image: (height, width), the values between [0, 1]. + out_path: output path. + """ + out_path.parent.mkdir(parents=True, exist_ok=True) + image = np.asarray(image * 255, dtype="uint8") + Image.fromarray(image, "L").save(str(out_path)) + + +def load_2d_grayscale_image( + image_path: Path, + dtype: np.dtype = np.uint8, +) -> np.ndarray: + """Load 2d images. + + Args: + image_path: path to the mask. + dtype: data type of the output. + + Returns: + mask: (height, width), the values are between [0, 1]. + """ + mask = Image.open(str(image_path)).convert("L") # value [0, 255] + mask = np.asarray(mask) / 255 # value [0, 1] + mask = np.asarray(mask, dtype=dtype) + return mask + + +def save_3d_image( + image: np.ndarray, + reference_image: sitk.Image, + out_path: Path, +) -> None: + """Save 3d image. + + Args: + image: (depth, height, width), the values are integers. + reference_image: reference image for copy meta data. + out_path: output path. + """ + out_path.parent.mkdir(parents=True, exist_ok=True) + image = sitk.GetImageFromArray(image) + image.CopyInformation(reference_image) + # output + sitk.WriteImage( + image=image, + fileName=out_path, + useCompression=True, + ) + + +def save_image( + image: np.ndarray, + reference_image: sitk.Image, + out_path: Path, + dtype: np.dtype, +) -> None: + """Save 2d or 3d image. + + Args: + image: (width, height, depth) or (height, width), 3D is not reversed but 2D is reversed. + reference_image: reference image for copy metadata. + out_path: output path. + dtype: data type of the output. + """ + out_path.parent.mkdir(parents=True, exist_ok=True) + if image.ndim not in [2, 3]: + raise ValueError( + f"Image should be 2D or 3D, but {image.ndim}D is given with shape {image.shape}." + ) + if image.ndim == 2: + save_2d_grayscale_image( + image=image.astype(dtype=dtype), + out_path=out_path, + ) + else: + # (width, height, depth) -> (depth, height, width) + image = np.transpose(image, axes=[2, 1, 0]).astype(dtype=dtype) + save_3d_image( + image=image, + reference_image=reference_image, + out_path=out_path, + ) + + +def save_ddf( + ddf: np.ndarray, + reference_image: sitk.Image, + out_path: Path, + dtype: np.dtype = np.float64, +) -> None: + """Save ddf for 3d volumes. + + Args: + ddf: (width, height, depth, 3), unit is 1 without spacing. + reference_image: reference image for copy metadata. + out_path: output path. + dtype: data type of the output. + """ + if ddf.ndim != 4: + raise ValueError(f"Mask should be 4D, but {ddf.ndim}D is given.") + out_path.parent.mkdir(parents=True, exist_ok=True) + + # ddf is scaled by spacing + ddf = np.transpose(ddf, axes=[2, 1, 0, 3]).astype(dtype=dtype) + ddf *= np.expand_dims(reference_image.GetSpacing(), axis=list(range(ddf.ndim - 1))) + + ddf_volume = sitk.GetImageFromArray(ddf, isVector=True) + ddf_volume.SetSpacing(reference_image.GetSpacing()) + ddf_volume.CopyInformation(reference_image) + tx = sitk.DisplacementFieldTransform(ddf_volume) + sitk.WriteTransform(tx, out_path) diff --git a/imgx_datasets/tests/test_muscle_us.py b/imgx_datasets/tests/test_muscle_us.py index 848727c..1c73c67 100644 --- a/imgx_datasets/tests/test_muscle_us.py +++ b/imgx_datasets/tests/test_muscle_us.py @@ -6,8 +6,8 @@ import numpy as np import pytest -from imgx_datasets.image_io import load_2d_grayscale_image from imgx_datasets.muscle_us.muscle_us_dataset_builder import select_connected_component +from imgx_datasets.save import load_2d_grayscale_image @pytest.mark.parametrize( diff --git a/imgx_datasets/util.py b/imgx_datasets/util.py index 87100d6..7fec66b 100644 --- a/imgx_datasets/util.py +++ b/imgx_datasets/util.py @@ -5,11 +5,7 @@ """ from __future__ import annotations -from pathlib import Path - import numpy as np -import pandas as pd -from absl import logging def get_center_pad_shape( @@ -135,25 +131,3 @@ def get_center_crop_shape_from_bbox( crop_lower.append(crop_lower_i) crop_upper.append(crop_upper_i) return tuple(crop_lower), tuple(crop_upper) - - -def save_uids( - train_uids: list[str], - valid_uids: list[str], - test_uids: list[str], - out_dir: Path, -) -> None: - """Save uids to csv files. - - Args: - train_uids: list of training uids. - valid_uids: list of validation uids. - test_uids: list of test uids. - out_dir: directory to save the csv files. - """ - pd.DataFrame({"uid": train_uids}).to_csv(out_dir / "train_uids.csv", index=False) - pd.DataFrame({"uid": valid_uids}).to_csv(out_dir / "valid_uids.csv", index=False) - pd.DataFrame({"uid": test_uids}).to_csv(out_dir / "test_uids.csv", index=False) - logging.info(f"There are {len(train_uids)} training samples.") - logging.info(f"There are {len(valid_uids)} validation samples.") - logging.info(f"There are {len(test_uids)} test samples.") diff --git a/pyproject.toml b/pyproject.toml index 8309afe..7c18285 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ description = "A Jax-based deep learning toolkit for biomedical applications." requires-python = ">=3.9" license = {text = "Apache-2.0"} -version = "0.3.0" +version = "0.3.1" [project.scripts] imgx_train="imgx.run_train:main" @@ -25,7 +25,8 @@ package-dir = {"imgx"="./imgx", "imgx_datasets"="./imgx_datasets"} # pytest [tool.pytest.ini_options] markers = [ - "slow", + "slow", # slow unit tests + "integration", # integration tests ] # pre-commit