Skip to content

Commit

Permalink
Merge branch 'main' into crop-kspace-transforms-merged
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jun 5, 2022
2 parents d259b5d + bc3a38b commit 4463268
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 71 deletions.
2 changes: 1 addition & 1 deletion direct/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# Copyright (c) DIRECT Contributors

__author__ = """direct contributors"""
__version__ = "1.0.1"
__version__ = "1.0.2"
10 changes: 4 additions & 6 deletions direct/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@ def is_file(path):
def file_or_url(path: PathOrString) -> FileOrUrl:
if check_is_valid_url(path):
return FileOrUrl(path)
else:
path = pathlib.Path(path)
if path.is_file():
return FileOrUrl(path)
else:
raise argparse.ArgumentTypeError(f"{path} is not a valid file or url.")
path = pathlib.Path(path)
if path.is_file():
return FileOrUrl(path)
raise argparse.ArgumentTypeError(f"{path} is not a valid file or url.")


def check_train_val(key, name):
Expand Down
24 changes: 16 additions & 8 deletions direct/common/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,11 @@ def mask_func(self, shape, return_acs=False, seed=None):
The mask selects a subset of columns from the input k-space data. If the k-space data has N
columns, the mask picks out:
#. :math:`N_{\text{low\_freqs}} = (N \times \text{center_fraction})` columns in the center corresponding to low-frequencies. # pylint: disable=line-too-long
#. The other columns are selected uniformly at random with a probability equal to: :math:`\text{prob} = (N / \text{acceleration} - N_{\text{low\_freqs}}) / (N - N_{\text{low\_freqs}})`. This ensures that the expected number of columns selected is equal to (N / acceleration) # pylint: disable=line-too-long
#. :math:`N_{\text{low freqs}} = (N \times \text{center_fraction})` columns in the center corresponding
to low-frequencies.
#. The other columns are selected uniformly at random with a probability equal to:
:math:`\text{prob} = (N / \text{acceleration} - N_{\text{low freqs}}) / (N - N_{\text{low freqs}})`.
This ensures that the expected number of columns selected is equal to (N / acceleration).
It is possible to use multiple center_fractions and accelerations, in which case one possible
(center_fraction, acceleration) is chosen uniformly at random each time the MaskFunc object is
Expand Down Expand Up @@ -204,15 +207,18 @@ def mask_func(self, shape, return_acs=False, seed=None):
FastMRIEquispacedMaskFunc creates a sub-sampling mask of a given shape. The mask selects a subset of columns
from the input k-space data. If the k-space data has N columns, the mask picks out:
#. :math:`N_{\text{low\_freqs}} = (N \times \text{center_fraction})` columns in the center corresponding to low-frequencies. # pylint: disable=line-too-long
#. The other columns are selected with equal spacing at a proportion that reaches the desired acceleration rate taking into consideration the number of low frequencies. This ensures that the expected number of columns selected is equal to :math:`\frac{N}{\text{acceleration}}`. # pylint: disable=line-too-long
#. :math:`N_{\text{low freqs}} = (N \times \text{center_fraction})` columns in the center corresponding
to low-frequencies.
#. The other columns are selected with equal spacing at a proportion that reaches the desired acceleration
rate taking into consideration the number of low frequencies. This ensures that the expected number of
columns selected is equal to :math:`\frac{N}{\text{acceleration}}`.
It is possible to use multiple center_fractions and accelerations, in which case one possible
(center_fraction, acceleration) is chosen uniformly at random each time the EquispacedMaskFunc object is called.
Note that this function may not give equispaced samples (documented in https://github.com/facebookresearch/fastMRI/issues/54),
which will require modifications to standard GRAPPA approaches. Nonetheless, this aspect of the function has
been preserved to match the public multicoil data.
Note that this function may not give equispaced samples (documented in
https://github.com/facebookresearch/fastMRI/issues/54), which will require modifications to standard GRAPPA
approaches. Nonetheless, this aspect of the function has been preserved to match the public multicoil data.
Parameters
----------
Expand Down Expand Up @@ -374,7 +380,9 @@ class CIRCUSMaskFunc(BaseMaskFunc):
References
----------
.. [1] Liu J, Saloner D. Accelerated MRI with CIRcular Cartesian UnderSampling (CIRCUS): a variable density Cartesian sampling strategy for compressed sensing and parallel imaging. Quant Imaging Med Surg. 2014 Feb;4(1):57-67. doi: 10.3978/j.issn.2223-4292.2014.02.01. PMID: 24649436; PMCID: PMC3947985.
.. [1] Liu J, Saloner D. Accelerated MRI with CIRcular Cartesian UnderSampling (CIRCUS): a variable density
Cartesian sampling strategy for compressed sensing and parallel imaging. Quant Imaging Med Surg.
2014 Feb;4(1):57-67. doi: 10.3978/j.issn.2223-4292.2014.02.01. PMID: 24649436; PMCID: PMC3947985.
"""

def __init__(
Expand Down
38 changes: 23 additions & 15 deletions direct/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(
self.pass_attrs = pass_attrs if pass_attrs is not None else True
self.text_description = text_description
if self.text_description:
self.logger.info(f"Dataset description: {self.text_description}.")
self.logger.info("Dataset description: %s.", self.text_description)

self.fake_data: Callable = FakeMRIData(
ndim=len(self.spatial_shape),
Expand Down Expand Up @@ -178,7 +178,8 @@ def parse_filenames_data(self, filenames):
current_slice_number = 0
for idx, filename in enumerate(filenames):
if len(filenames) < 5 or idx % (len(filenames) // 5) == 0 or len(filenames) == (idx + 1):
self.logger.info("Parsing: {(idx + 1) / len(filenames) * 100:.2f}%.")
# pylint: disable=logging-fstring-interpolation
self.logger.info(f"Parsing: {(idx + 1) / len(filenames) * 100:.2f}%.")

num_slices = self.spatial_shape[0] if len(self.spatial_shape) == 3 else 1
self.volume_indices[pathlib.PosixPath(filename)] = range(
Expand Down Expand Up @@ -514,14 +515,19 @@ class ImageIntensityMode(str, Enum):
class SheppLoganDataset(Dataset):
"""Shepp Logan Dataset for MRI as implemented in [1]_. Code was adapted from [2]_.
References
----------
.. [1] Gach, H. Michael, Costin Tanase, and Fernando Boada. "2D & 3D Shepp-Logan phantom standards for MRI." 2008 19th International Conference on Systems Engineering. IEEE, 2008.
.. [2] https://github.com/mckib2/phantominator/blob/master/phantominator/mr_shepp_logan.py
Notes
-----
This dataset reconstructs into a single volume.
References
----------
<<<<<<< HEAD
.. [1] Gach, H. Michael, Costin Tanase, and Fernando Boada. "2D & 3D Shepp-Logan phantom standards for MRI." 2008 19th International Conference on Systems Engineering. IEEE, 2008.
=======
.. [1] Gach, H. Michael, Costin Tanase, and Fernando Boada. "2D & 3D Shepp-Logan phantom standards for MRI."
2008 19th International Conference on Systems Engineering. IEEE, 2008.
>>>>>>> main
.. [2] https://github.com/mckib2/phantominator/blob/master/phantominator/mr_shepp_logan.py
Notes
-----
This dataset reconstructs into a single volume.
"""

GYROMAGNETIC_RATIO: float = 267.52219
Expand Down Expand Up @@ -593,7 +599,7 @@ def __init__(
self.seed = list(self.rng.choice(a=range(int(1e5)), size=self.nz, replace=False))
self.text_description = text_description
if self.text_description:
self.logger.info(f"Dataset description: {self.text_description}.")
self.logger.info("Dataset description: %s.", self.text_description)

self.name = "shepp_loggan" + "_" + self.intensity
self.ndim = 2
Expand Down Expand Up @@ -623,7 +629,7 @@ def _set_params(self, ellipsoids=None) -> None:

self.ellipsoids = ellipsoids

def sample_image(self, idx: int) -> np.ndarray:
def sample_image(self, idx: int) -> np.ndarray: # pylint: disable=too-many-locals
# meshgrid does X, Y backwards
X, Y, Z = np.meshgrid(
np.linspace(-1, 1, self.ny),
Expand Down Expand Up @@ -708,7 +714,8 @@ def default_mr_ellipsoid_parameters() -> np.ndarray:
References
----------
.. [1] Gach, H. Michael, Costin Tanase, and Fernando Boada. "2D & 3D Shepp-Logan phantom standards for MRI." 2008 19th International Conference on Systems Engineering. IEEE, 2008.
.. [1] Gach, H. Michael, Costin Tanase, and Fernando Boada. "2D & 3D Shepp-Logan phantom standards for MRI."
2008 19th International Conference on Systems Engineering. IEEE, 2008.
"""
params = _mr_relaxation_parameters()

Expand Down Expand Up @@ -780,11 +787,12 @@ def _mr_relaxation_parameters():
References
----------
.. [1] Gach, H. Michael, Costin Tanase, and Fernando Boada. "2D & 3D Shepp-Logan phantom standards for MRI." 2008 19th International Conference on Systems Engineering. IEEE, 2008.
.. [1] Gach, H. Michael, Costin Tanase, and Fernando Boada. "2D & 3D Shepp-Logan phantom standards for MRI."
2008 19th International Conference on Systems Engineering. IEEE, 2008.
"""

# params['tissue-name'] = [A, C, (t1 value if explicit), t2, chi]
params = dict()
params = {}
params["scalp"] = [0.324, 0.137, np.nan, 0.07, -7.5e-6]
params["marrow"] = [0.533, 0.088, np.nan, 0.05, -8.85e-6]
params["csf"] = [np.nan, np.nan, 4.2, 1.99, -9e-6]
Expand Down
13 changes: 7 additions & 6 deletions direct/data/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def make_blobs(

return samples

def _get_image_from_samples(self, samples, spatial_shape):
@staticmethod
def _get_image_from_samples(samples, spatial_shape):
image = np.zeros(list(spatial_shape))
image[tuple(np.split(samples, len(spatial_shape), axis=-1))] = 1

Expand Down Expand Up @@ -140,11 +141,11 @@ def __call__(
Returns:
--------
sample: dict or list of dicts
Contains:
"kspace": np.array of shape (slice, num_coils, height, width)
"reconstruction_rss": np. array of shape (slice, height, width)
If spatial_shape is of shape 2 (height, width), slice=1.
sample: dict or list of dicts
Contains:
"kspace": np.array of shape (slice, num_coils, height, width)
"reconstruction_rss": np. array of shape (slice, height, width)
If spatial_shape is of shape 2 (height, width), slice=1.
"""

if len(spatial_shape) != self.ndim:
Expand Down
11 changes: 5 additions & 6 deletions direct/data/h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,13 @@ def __init__(
if filenames_filter is None:
if filenames_lists is not None:
if filenames_lists_root is None:
e = f"`filenames_lists` is passed but `filenames_lists_root` is None."
e = "`filenames_lists` is passed but `filenames_lists_root` is None."
self.logger.error(e)
raise ValueError(e)
else:
filenames = get_filenames_for_datasets(
lists=filenames_lists, files_root=filenames_lists_root, data_root=root
)
self.logger.info("Attempting to load %s filenames from list(s).", len(filenames))
filenames = get_filenames_for_datasets(
lists=filenames_lists, files_root=filenames_lists_root, data_root=root
)
self.logger.info("Attempting to load %s filenames from list(s).", len(filenames))
else:
self.logger.info("Parsing directory %s for h5 files.", self.root)
filenames = list(self.root.glob("*.h5"))
Expand Down
20 changes: 12 additions & 8 deletions direct/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,14 @@ def build_training_datasets_from_environment(
dataset_args.update({"pass_dictionaries": pass_dictionaries})
dataset = build_dataset_from_input(**dataset_args)

logger.debug(f"Transforms {idx + 1} / {len(datasets_config)} :\n{transforms}")
logger.debug("Transforms %s / %s :\n%s", idx + 1, len(datasets_config), transforms)
datasets.append(dataset)
logger.info(
f"Data size for {dataset_config.text_description}"
f" ({idx + 1}/{len(datasets_config)}): {len(dataset)}." # type: ignore
"Data size for %s (%s/%s): %s.",
dataset_config.text_description, # type: ignore
idx + 1,
len(datasets_config),
len(dataset),
)

return datasets
Expand Down Expand Up @@ -188,7 +191,7 @@ def setup_train(
# Build training datasets
training_datasets = build_training_datasets_from_environment(**training_dataset_args)
training_data_sizes = [len(_) for _ in training_datasets]
logger.info(f"Training data sizes: {training_data_sizes} (sum={sum(training_data_sizes)}).")
logger.info("Training data sizes: %s (sum=%s).", training_data_sizes, sum(training_data_sizes))

# Create validation data
if "validation" in env.cfg:
Expand Down Expand Up @@ -219,7 +222,7 @@ def setup_train(
for curr_model_name in env.engine.models:
# TODO(jt): Can get learning rate from the config per additional model too.
curr_learning_rate = env.cfg.training.lr
logger.info(f"Adding model parameters of {curr_model_name} with learning rate {curr_learning_rate}.")
logger.info("Adding model parameters of %s with learning rate %s.", curr_model_name, curr_learning_rate)
optimizer_params.append(
{
"params": env.engine.models[curr_model_name].parameters(),
Expand Down Expand Up @@ -257,9 +260,10 @@ def setup_train(
if env.cfg.training.model_checkpoint:
if initialization_checkpoint:
logger.warning(
f"`--initialization-checkpoint is set, and config has a set `training.model_checkpoint`: "
f"{env.cfg.training.model_checkpoint}. Will overwrite config variable with the command line: "
f"{initialization_checkpoint}."
"`--initialization-checkpoint is set, and config has a set `training.model_checkpoint`: %s. "
"Will overwrite config variable with the command line: %s.",
env.cfg.training.model_checkpoint,
initialization_checkpoint,
)
# Now overwrite this in the configuration, so the correct value is dumped.
env.cfg.training.model_checkpoint = str(initialization_checkpoint)
Expand Down
5 changes: 2 additions & 3 deletions direct/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ def get_filenames_for_datasets_from_config(cfg, files_root: PathOrString, data_r
"""
if "filenames_lists" not in cfg:
return None
else:
lists = cfg.filenames_lists
return get_filenames_for_datasets(lists, files_root, data_root)
lists = cfg.filenames_lists
return get_filenames_for_datasets(lists, files_root, data_root)


def get_filenames_for_datasets(lists: List[PathOrString], files_root: PathOrString, data_root: pathlib.Path):
Expand Down
23 changes: 12 additions & 11 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
ARG CUDA="10.2"
ARG CUDNN="8"
ARG PYTORCH="1.7"
ARG CUDA="11.3.0"
ARG PYTORCH="1.10"
ARG PYTHON="3.8"

# TODO: conda installs its own version of cuda
FROM nvidia/cuda:${CUDA}-cudnn${CUDNN}-devel-ubuntu18.04
FROM nvidia/cuda:${CUDA}-devel-ubuntu18.04

ENV CUDA_PATH /usr/local/cuda
ENV CUDA_ROOT /usr/local/cuda/bin
Expand All @@ -17,10 +16,10 @@ RUN apt-get -qq update && apt-get install -y --no-install-recommends libxext6 li
&& rm -rf /var/lib/apt/lists/* \
&& ldconfig

RUN git clone https://github.com/mrirecon/bart.git /tmp/bart
&& cd /tmp/bart
&& make -j4
&& make install
RUN git clone https://github.com/mrirecon/bart.git /tmp/bart \
&& cd /tmp/bart \
&& make -j4 \
&& make install \
&& rm -rf /tmp/bart

# Make a user
Expand All @@ -41,27 +40,29 @@ ENV PATH "/users/direct/miniconda3/bin:/tmp/bart/:$PATH:$CUDA_ROOT"
# Setup python packages
RUN conda update -n base conda -yq \
&& conda install python=${PYTHON} \
&& conda install jupyter \
&& conda install cudatoolkit=${CUDA} torchvision -c pytorch

RUN if [ "nightly$PYTORCH" = "nightly" ] ; then echo "Installing pytorch nightly" && \
conda install pytorch -c pytorch-nightly; else conda install pytorch=${PYTORCH} -c pytorch ; fi

USER root
RUN mkdir /direct && chmod 777 /direct
RUN mkdir direct:direct /direct && chown direct:direct /direct && chmod 777 /direct

USER direct

RUN jupyter notebook --generate-config
ENV CONFIG_PATH "/users/direct/.jupyter/jupyter_notebook_config.py"
COPY "jupyter_notebook_config.py" ${CONFIG_PATH}
COPY "docker/jupyter_notebook_config.py" ${CONFIG_PATH}

# Copy files into the docker
COPY [".", "/direct"]
WORKDIR /direct
USER root
RUN python -m pip install -e ".[dev]"
USER direct

ENV PYTHONPATH /tmp/bart/python:/direct
WORKDIR /direct

# Provide an open entrypoint for the docker
ENTRYPOINT $0 $@
17 changes: 11 additions & 6 deletions docker/README.rst
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
Docker Installation
-------------------

Use the container (docker ≥ 19.03 required)
-------------------------------------------
Use the container
~~~~~~~~~~~~~~~~~

To build:
To build the image:

.. code-block:: bash
cd docker/
docker build -t direct:latest .
cd direct/
docker build -t direct:latest -f docker/Dockerfile .
To run using all GPUs:
To run `DIRECT` using all GPUs:

.. code-block:: bash
docker run --gpus all -it \
--shm-size=24gb --volume=<source_to_data>:/data --volume=<source_to_results>:/output \
--name=direct direct:latest /bin/bash
Requirements
~~~~~~~~~~~~

* docker ≥ 19.03
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 1.0.1
current_version = 1.0.2
commit = True
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
Expand Down

0 comments on commit 4463268

Please sign in to comment.