From fbcc951e4b77bfdcb13ad693570c723908ba535d Mon Sep 17 00:00:00 2001 From: Ismael Mendoza Date: Fri, 12 Mar 2021 09:12:25 -0500 Subject: [PATCH] took into account thomas suggestion to user friendly tag --- btk/draw_blends.py | 14 +++++++++----- poetry.lock | 10 +++++----- pyproject.toml | 1 + 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/btk/draw_blends.py b/btk/draw_blends.py index 7bcd08bc7..fcb720fdc 100644 --- a/btk/draw_blends.py +++ b/btk/draw_blends.py @@ -147,7 +147,7 @@ def __init__( add_noise=True, shifts=None, indexes=None, - dim_order=(0, 1, 2), + dim_order="NCHW", ): """Initializes the DrawBlendsGenerator class. @@ -169,9 +169,10 @@ def __init__( with indexes. indexes (list): Contains the ids of the galaxies to use in the stamp. Must be of length batch_size. Must be used with shifts. - dim_order (tuple): Transpose arrays so that image dimensions following - a specific order. Default order (0, 1, 2) corresponds to - [n_bands, nx, ny]""" + dim_order (str): Whether to return images as numpy arrays with the channel + (band) dimension before the pixel dimensions 'NCHW' (default) or + after 'NHWC'. + """ self.blend_generator = BlendGenerator( catalog, sampling_function, batch_size, shifts, indexes, verbose @@ -194,7 +195,10 @@ def __init__( self.add_noise = add_noise self.verbose = verbose - self.dim_order = dim_order + + if dim_order not in ("NCHW", "NHWC"): + raise ValueError("dim_order must be either 'NCHW' or 'NHWC'.") + self.dim_order = (0, 1, 2) if dim_order == "NCHW" else (1, 2, 0) def __iter__(self): return self diff --git a/poetry.lock b/poetry.lock index 28a4afbde..cad0481c3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -292,7 +292,7 @@ python-versions = "*" name = "flake8" version = "3.8.4" description = "the modular source code checker: pep8 pyflakes and co" -category = "dev" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" @@ -625,7 +625,7 @@ python-dateutil = ">=2.1" name = "mccabe" version = "0.6.1" description = "McCabe checker, plugin for flake8" -category = "dev" +category = "main" optional = false python-versions = "*" @@ -952,7 +952,7 @@ global = ["pybind11-global (==2.6.2)"] name = "pycodestyle" version = "2.6.0" description = "Python style guide checker" -category = "dev" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" @@ -983,7 +983,7 @@ test = ["pytest", "pytest-doctestplus (>=0.7)"] name = "pyflakes" version = "2.2.0" description = "passive checker of Python programs" -category = "dev" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" @@ -1478,7 +1478,7 @@ notebook = ">=4.4.1" [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "c0a0a3cbc310531d981f4a81b634049dc4de4d1933385fb29cac3d5c5006a2bb" +content-hash = "681ee5589d0dd7e8d8a1319920e84010fcb39d87f32adaadab40273b1a70b037" [metadata.files] alabaster = [ diff --git a/pyproject.toml b/pyproject.toml index 595600b4e..2b1868b8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ sep = "^1.1.1" [tool.poetry.dev-dependencies] Cython = "^0.29.21" black = "^20.8b1" +flake8 = "^3.8.4" flake8-absolute-import = "^1.0" jupyter-sphinx = "^0.3" mock = "^3.0.5"