Skip to content

Commit

Permalink
convert prefs dict to a class for futureproofing
Browse files Browse the repository at this point in the history
  • Loading branch information
ejolly committed Dec 15, 2022
1 parent a610840 commit ff353ca
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 46 deletions.
8 changes: 6 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ methods in the current release of Neurolearn.
:members:

:mod:`nltools.utils`: Utilities
==============================
===============================

.. automodule:: nltools.utils
:members:

:mod:`nltools.prefs`: Preferences
================================
=================================

This module can be used to adjust the default MNI template settings that are used
internally by all `Brain_Data` operations. By default all operations are performed in
Expand All @@ -93,8 +93,12 @@ Alternatively this module can be used to switch between 2mm or 3mm MNI spaces wi
from nltools.prefs import MNI_Template, resolve_mni_path
from nltools.data import Brain_Data
# Update the resolution globally
MNI_Template['resolution'] = '3mm'
# This works too:
MNI_Template.resolution = 3
# my_brain will be resampled to 3mm and future operation will be in 3mm space
brain = Brain_Data('my_brain.nii.gz')
Expand Down
132 changes: 89 additions & 43 deletions nltools/prefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,95 @@

__all__ = ["MNI_Template", "resolve_mni_path"]

MNI_Template = dict(
resolution="2mm",
mask_type="with_ventricles",
mask=os.path.join(get_resource_path(), "MNI152_T1_2mm_brain_mask.nii.gz"),
plot=os.path.join(get_resource_path(), "MNI152_T1_2mm.nii.gz"),
brain=os.path.join(get_resource_path(), "MNI152_T1_2mm_brain.nii.gz"),
)

class MNI_Template_Factory(dict):
"""Class to build the default MNI_Template dictionary. This should never be used
directly, instead just `from nltools.prefs import MNI_Template` and update that
object's attributes to change MNI templates."""

def __init__(
self,
resolution="2mm",
mask_type="with_ventricles",
mask=os.path.join(get_resource_path(), "MNI152_T1_2mm_brain_mask.nii.gz"),
plot=os.path.join(get_resource_path(), "MNI152_T1_2mm.nii.gz"),
brain=os.path.join(get_resource_path(), "MNI152_T1_2mm_brain.nii.gz"),
):
self._resolution = resolution
self._mask_type = mask_type
self._mask = mask
self._plot = plot
self._brain = brain

self.update(
{
"resolution": self.resolution,
"mask_type": self.mask_type,
"mask": self.mask,
"plot": self.plot,
"brain": self.brain,
}
)

@property
def resolution(self):
return self._resolution

@resolution.setter
def resolution(self, resolution):
if isinstance(resolution, (int, float)):
resolution = f"{int(resolution)}mm"
if resolution not in ["2mm", "3mm"]:
raise NotImplementedError(
"Only 2mm and 3mm resolutions are currently supported"
)
self._resolution = resolution
self.update({"resolution": self._resolution})

@property
def mask_type(self):
return self._mask_type

@mask_type.setter
def mask_type(self, mask_type):
if mask_type not in ["with_ventricles", "no_ventricles"]:
raise NotImplementedError(
"Only 'with_ventricles' and 'no_ventricles' mask_types are currently supported"
)
self._mask_type = mask_type
self.update({"mask_type": self._mask_type})

@property
def mask(self):
return self._mask

@mask.setter
def mask(self, mask):
self._mask = mask
self.update({"mask": self._mask})

@property
def plot(self):
return self._plot

@plot.setter
def plot(self, plot):
self._plot = plot
self.update({"plot": self._plot})

@property
def brain(self):
return self._brain

@brain.setter
def brain(self, brain):
self._brain = brain
self.update({"brain": self._brain})


# NOTE: We export this from the module and expect users to interact with it instead of
# the class constructor above
MNI_Template = MNI_Template_Factory()


def resolve_mni_path(MNI_Template):
Expand Down Expand Up @@ -64,39 +146,3 @@ def resolve_mni_path(MNI_Template):
else:
raise ValueError("Available templates are '2mm' or '3mm'")
return MNI_Template


# class Prefs(object):
#
# """
# Prefs is a class to represent module level preferences for nltools, e.g. masks.
# """
#
# def __init__(self):
# self.MNI_Template = {}
# self.MNI_Template['mask'] = os.path.join(get_resource_path(),'MNI152_T1_2mm_brain_mask.nii.gz')
# self.MNI_Template['plot']= os.path.join(get_resource_path(),'MNI152_T1_2mm.nii.gz')
# self.MNI_Template['brain'] = os.path.join(get_resource_path(),'MNI152_T1_2mm_brain.nii.gz')
#
# def __repr__(self):
# strOut = "nltools preferences:\n"
# for section_name in ['MNI_Template']:
# section = getattr(self,section_name)
# for key, val in list(section.items()):
# strOut += "%s['%s'] = %s\n" % (section_name, key, repr(val))
# return strOut
#
# def use_template(self,template_name):
# if isinstance(template_name, str):
# if template_name == '3mm':
# self.MNI_Template['mask'] = os.path.join(get_resource_path(),'MNI152_T1_3mm_brain_mask.nii.gz')
# self.MNI_Template['plot'] = os.path.join(get_resource_path(),'MNI152_T1_3mm.nii.gz')
# self.MNI_Template['brain'] = os.path.join(get_resource_path(),'MNI152_T1_3mm_brain.nii.gz')
# elif template_name == '2mm':
# self.MNI_Template['mask'] = os.path.join(get_resource_path(),'MNI152_T1_2mm_brain_mask.nii.gz')
# self.MNI_Template['plot'] = os.path.join(get_resource_path(),'MNI152_T1_2mm.nii.gz')
# self.MNI_Template['brain'] = os.path.join(get_resource_path(),'MNI152_T1_2mm_brain.nii.gz')
# else:
# raise ValueError("Available templates are '2mm' or '3mm'")
# else:
# raise TypeError("Available templates are '2mm' or '3mm'")
17 changes: 16 additions & 1 deletion nltools/tests/test_prefs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
from nltools.prefs import MNI_Template
from nltools.data import Brain_Data
import pytest


def test_change_mni_resolution():
assert MNI_Template["resolution"] == "2mm"

# Defaults
brain = Brain_Data()
assert brain.mask.affine[1, 1] == 2.0
assert MNI_Template["resolution"] == "2mm"

# -> 3mm
MNI_Template["resolution"] = "3mm"
new_brain = Brain_Data()
assert new_brain.mask.affine[1, 1] == 3.0

# switch back and test attribute setting
MNI_Template.resolution = 2.0 # floats are cool
assert MNI_Template["resolution"] == "2mm"

newer_brain = Brain_Data()
assert newer_brain.mask.affine[1, 1] == 2.0

with pytest.raises(NotImplementedError):
MNI_Template.resolution = 1

0 comments on commit ff353ca

Please sign in to comment.