Skip to content

Commit

Permalink
Refactor code and add example notebooks (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
mathpluscode authored Nov 25, 2023
1 parent 39d76b3 commit 5349fdb
Show file tree
Hide file tree
Showing 95 changed files with 2,805 additions and 1,889 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install tensorflow-cpu==2.12.0
pip install jax==0.4.14
pip install jaxlib==0.4.14
pip install jax==0.4.20
pip install jaxlib==0.4.20
pip install -r docker/requirements.txt
pip install -e .
- name: Test with pytest
run: |
pytest --splits 4 --group ${{ matrix.group }} --randomly-seed=0 -k "not slow"
pytest --splits 4 --group ${{ matrix.group }} --randomly-seed=0 -k "not slow and not integration"
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pip-delete-this-directory.txt
htmlcov/
.tox/
.nox/
.coverage
.coverage*
.coverage.*
.cache
nosetests.xml
Expand Down Expand Up @@ -137,6 +137,7 @@ dmypy.json
# notebook
*.ipynb
notebooks/
!examples/segmentation/inference.ipynb

# hydra outputs
outputs/
Expand Down
32 changes: 7 additions & 25 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ repos:
rev: v4.5.0
hooks:
- id: check-added-large-files
args: ["--maxkb=15000"]
- id: check-ast
- id: check-byte-order-marker
- id: check-builtin-literals
Expand All @@ -27,43 +28,24 @@ repos:
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.10.0
rev: 23.11.0
hooks:
- id: black
args:
- --line-length=100
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.1
rev: v1.7.1
hooks: # https://github.com/python/mypy/issues/4008#issuecomment-582458665
- id: mypy
name: mypy-imgx
files: ^imgx/
entry: mypy imgx/
pass_filenames: false
args:
[
--strict-equality,
--disallow-untyped-calls,
--disallow-untyped-defs,
--disallow-incomplete-defs,
--check-untyped-defs,
--disallow-untyped-decorators,
--warn-redundant-casts,
--warn-unused-ignores,
--no-warn-no-return,
--warn-unreachable,
]
- id: mypy
name: mypy-imgx_datasets
files: ^imgx_datasets/
entry: mypy imgx_datasets/
name: mypy
pass_filenames: false
args:
[
--strict-equality,
--disallow-untyped-calls,
--disallow-untyped-defs,
--disallow-incomplete-defs,
--disallow-any-generics,
--check-untyped-defs,
--disallow-untyped-decorators,
--warn-redundant-casts,
Expand All @@ -72,15 +54,15 @@ repos:
--warn-unreachable,
]
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.3
rev: v3.1.0
hooks:
- id: prettier
args:
- --print-width=100
- --prose-wrap=always
- --tab-width=2
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.1.1"
rev: "v0.1.6"
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-pylint
Expand Down
48 changes: 34 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
# ImgX-DiffSeg

ImgX-DiffSeg is a Jax-based deep learning toolkit (now using Flax) for biomedical image
segmentation.
ImgX-DiffSeg is a Jax-based deep learning toolkit using Flax for biomedical image segmentations.

This repository currently includes the implementation of the following work
This repository includes the implementation of the following work

- [A Recycling Training Strategy for Medical Image Segmentation with Diffusion Denoising Models](https://arxiv.org/abs/2308.16355)
- [Importance of Aligning Training Strategy with Evaluation for Diffusion Models in 3D Multiclass Segmentation](https://arxiv.org/abs/2303.06040)

:construction: **The codebase is still under active development for more enhancements and
applications.** :construction:

- November 2023:
- :warning: Upgrade to JAX to 0.4.20.
- :warning: Removed Haiku specific modification to convolutional layers. This may impact model
performance.
- :smiley: Added example notebooks for inference on single image without TFDS.
- Added integration tests for training, validation and testing.
- Refactored config.
- Added `patch_size` and `scale_factor` to data config.
- Moved loss config from main config to task config.
- Refactored code, including defining `imgx/task` submodule.
- October 2023: :sunglasses: Migrated from [Haiku](https://github.com/google-deepmind/dm-haiku) to
[Flax](https://github.com/google/flax) following Google DeepMind's recommendation.

:mailbox: Please feel free to
[create an issue](https://github.com/mathpluscode/ImgX-DiffSeg/issues/new/choose) to request
features or [reach out](https://orcid.org/0000-0002-1184-7421) for collaborations. :mailbox:
Expand Down Expand Up @@ -61,11 +73,6 @@ See the [readme](imgx_datasets/README.md) for further details.
- Gradient clipping and accumulation.
- [Early stopping](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html).

**Changelog**

- October 2023: Migrated from [Haiku](https://github.com/google-deepmind/dm-haiku) to
[Flax](https://github.com/google/flax) following Google DeepMind's recommendation.

## Installation

### TPU with Docker
Expand Down Expand Up @@ -112,8 +119,7 @@ The following instructions have been tested only for TPU-v3-8. The docker contai

### GPU with Docker

The following instructions have been tested only for CUDA == 11.4.1 and CUDNN == 8.2.0. The docker
container uses non-root user.
CUDA >= 11.8 is required. The docker container uses non-root user.
[Docker image used may be removed.](https://gitlab.com/nvidia/container-images/cuda/blob/master/doc/support-policy.md)

1. Build the docker image inside the repository.
Expand Down Expand Up @@ -141,7 +147,7 @@ container uses non-root user.
where

- `--rm` removes the container once exit it.
- `-v` maps the `ImgX` folder into container.
- `-v` maps the current folder into container.

3. Install the package inside container.

Expand Down Expand Up @@ -214,12 +220,10 @@ export DATASET_NAME="brats2021_mr"

# Vanilla segmentation
imgx_train data=${DATASET_NAME} task=seg
imgx_valid --log_dir wandb/latest-run/
imgx_test --log_dir wandb/latest-run/

# Diffusion-based segmentation
imgx_train data=${DATASET_NAME} task=gaussian_diff_seg
imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDPM
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDPM
imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM
Expand Down Expand Up @@ -259,10 +263,26 @@ Run the command below to test and get coverage report. As JAX tests requires two
threads, therefore requires 8 CPUs in total.

```bash
pytest --cov=imgx -n 4 imgx
pytest --cov=imgx -n 4 imgx -k "not integration"
pytest --cov=imgx_datasets -n 4 imgx_datasets
```

`-k "not integration"` excludes integration tests, which requires downloading muscle ultrasound and
amos CT data sets.

For integration tests, run the command below. `-s` enables the print of stdout. This test may take
40-60 minutes.

```bash
pytest imgx/integration_test.py -s
```

To test the jupyter notebooks, run the command below.

```bash
pytest --nbmake examples/**/*.ipynb
```

## References

- [Segment Anything (PyTorch)](https://github.com/facebookresearch/segment-anything)
Expand Down
6 changes: 3 additions & 3 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ COPY docker/requirements.txt /${USER}/requirements.txt

RUN /${USER}/conda/bin/pip3 install --upgrade pip \
&& /${USER}/conda/bin/pip3 install \
jax==0.4.14 \
jaxlib==0.4.14+cuda11.cudnn86 \
jax==0.4.20 \
jaxlib==0.4.20+cuda11.cudnn86 \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
&& /${USER}/conda/bin/pip3 install tensorflow-cpu==2.12.0 \
&& /${USER}/conda/bin/pip3 install tensorflow-cpu==2.14.0 \
&& /${USER}/conda/bin/pip3 install -r /${USER}/requirements.txt

RUN git config --global --add safe.directory /${USER}/ImgX
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile.tpu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM mambaorg/micromamba:0.27.0 as conda
FROM mambaorg/micromamba:1.5.1 as conda

# Speed up the build, and avoid unnecessary writes to disk
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1
Expand Down
8 changes: 4 additions & 4 deletions docker/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ channels:
- defaults
dependencies:
- python=3.9
- pip=23.0.1
- pip=23.3.1
- pip:
- tensorflow-cpu==2.13.0
- jax==0.4.14
- jaxlib==0.4.14
- tensorflow-cpu==2.14.0
- jax==0.4.20
- jaxlib==0.4.20
- -r requirements.txt
10 changes: 5 additions & 5 deletions docker/environment_mac_m1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ channels:
- defaults
dependencies:
- python=3.9
- pip=23.0.1
- pip=23.3.1
- pip:
- tensorflow-macos==2.13.0
- tensorflow-metal==1.0.1
- jax==0.4.14
- jaxlib==0.4.14
- tensorflow-macos==2.14.0
- tensorflow-metal==1.1.0
- jax==0.4.20
- jaxlib==0.4.20
- -r requirements.txt
8 changes: 4 additions & 4 deletions docker/environment_tpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ channels:
- conda-forge
dependencies:
- python=3.9
- pip=23.0.1
- pip=23.3.1
- pip:
- --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
- tensorflow-cpu==2.13.0
- jax[tpu]==0.4.14
- jaxlib==0.4.14
- tensorflow-cpu==2.14.0
- jax[tpu]==0.4.20
- jaxlib==0.4.20
- -r requirements.txt
30 changes: 16 additions & 14 deletions docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
SimpleITK==2.3.0
SimpleITK==2.3.1
chex==0.1.8
coverage==7.3.1
flax==0.7.4
coverage==7.3.2
flax==0.7.5
hydra-core==1.3.2
kaggle==1.5.16
numpy==1.24.3 # limited by tensorflow-macos 2.13.0
opencv-python==4.8.0.76
nbmake==1.4.6
numpy==1.26.2
opencv-python==4.8.1.78
optax==0.1.7
pandas==2.1.1
pre-commit==3.4.0
pandas==2.1.3
pre-commit==3.5.0
protobuf==3.20.3 # https://github.com/tensorflow/datasets/issues/4858
pytest-cov==4.1.0
pytest-mock==3.12.0
pytest-randomly==3.15.0
pytest-split==0.8.1
pytest-xdist==3.3.1
pytest==7.4.2
pytest-xdist==3.5.0
pytest==7.4.3
rdkit-pypi==2022.9.5
rich==13.5.3
ruff==0.0.291
rich==13.7.0
ruff==0.1.6
tensorflow-datasets==4.9.3
torch==2.0.1 # for testing only
wandb==0.15.11
wily==1.24.2
torch==2.1.1 # for testing only
wandb==0.16.0
wily==1.25.0
Binary file added examples/segmentation/BB_anon_348_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/segmentation/BB_anon_348_1_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
83 changes: 83 additions & 0 deletions examples/segmentation/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
data:
name: muscle_us
loader:
max_num_samples_per_split: -1
patch_shape:
- 480
- 512
patch_overlap:
- 0
- 0
data_augmentation:
max_rotation:
- 0.088
max_translation:
- 10
- 10
max_scaling:
- 0.15
- 0.15
trainer:
max_num_samples: 512000
batch_size: 64
batch_size_per_replica: 8
num_devices_per_replica: 1
patch_size:
- 2
- 2
scale_factor:
- 2
- 2
task:
name: segmentation
model:
_target_: imgx.model.Unet
remat: true
num_spatial_dims: 2
patch_size:
- 2
- 2
scale_factor:
- 2
- 2
num_channels:
- 8
- 16
- 32
- 64
out_channels: 2
num_heads: 8
widening_factor: 4
num_transform_layers: 1
loss:
dice: 1.0
cross_entropy: 0.0
focal: 20.0
early_stopping:
metric: mean_binary_dice_score_without_background
mode: max
min_delta: 0.0001
patience: 10
debug: false
seed: 0
half_precision: true
optimizer:
name: adamw
kwargs:
b1: 0.9
b2: 0.999
weight_decay: 1.0e-08
grad_norm: 1.0
lr_schedule:
warmup_steps: 100
decay_steps: 10000
init_value: 1.0e-05
peak_value: 0.0008
end_value: 5.0e-05
logging:
root_dir: null
log_freq: 10
save_freq: 100
wandb:
project: imgx
entity: entity
Binary file not shown.
Loading

0 comments on commit 5349fdb

Please sign in to comment.