From 71ed68d47f0870a988421c643e4b961b401d6216 Mon Sep 17 00:00:00 2001 From: Joost van Griethuysen Date: Tue, 31 Oct 2017 16:58:24 +0100 Subject: [PATCH] ENH: Simplify Wavelet filter function PyWavelets package supports applying wavelets sequentially along multiple axes, therefore there is no need for reslicing the (partially filtered) data using python functions. This reduces a lot of complexity in the code and also includes a higher performance of the filter. Additionally, allow for 2D application of wavelets when `force2D` is set to `True`, where no wavelet will be applied along the `force2Ddimension`. --- radiomics/imageoperations.py | 81 ++++++------------------------------ 1 file changed, 13 insertions(+), 68 deletions(-) diff --git a/radiomics/imageoperations.py b/radiomics/imageoperations.py index 03efa6e3..b1b474a6 100644 --- a/radiomics/imageoperations.py +++ b/radiomics/imageoperations.py @@ -1,6 +1,5 @@ from __future__ import print_function -from itertools import chain import logging import numpy @@ -714,7 +713,11 @@ def getWaveletImage(inputImage, **kwargs): logger.debug('Generating Wavelet images') - approx, ret = _swt3(inputImage, kwargs.get('wavelet', 'coif1'), kwargs.get('level', 1), kwargs.get('start_level', 0)) + axes = {2, 1, 0} # set + if kwargs.get('force2D', False): + axes -= {kwargs.get('force2Ddimension', 0)} # set + + approx, ret = _swt3(inputImage, kwargs.get('wavelet', 'coif1'), kwargs.get('level', 1), kwargs.get('start_level', 0), axes=tuple(axes)) for idx, wl in enumerate(ret, start=1): for decompositionName, decompositionImage in wl.items(): @@ -735,7 +738,7 @@ def getWaveletImage(inputImage, **kwargs): yield approx, inputImageName, kwargs -def _swt3(inputImage, wavelet='coif1', level=1, start_level=0): +def _swt3(inputImage, wavelet='coif1', level=1, start_level=0, axes=(2, 1, 0)): matrix = sitk.GetArrayFromImage(inputImage) matrix = numpy.asarray(matrix) if matrix.ndim != 3: @@ -750,41 +753,23 @@ def _swt3(inputImage, wavelet='coif1', level=1, start_level=0): wavelet = pywt.Wavelet(wavelet) for i in range(0, start_level): - H, L = _decompose_i(data, wavelet) - LH, LL = _decompose_j(L, wavelet) - LLH, LLL = _decompose_k(LL, wavelet) - - data = LLL.copy() + dec = pywt.swtn(data, wavelet, level=1, start_level=0, axes=axes)[0] + data = dec['a' * len(axes)].copy() ret = [] for i in range(start_level, start_level + level): - H, L = _decompose_i(data, wavelet) - - HH, HL = _decompose_j(H, wavelet) - LH, LL = _decompose_j(L, wavelet) - - HHH, HHL = _decompose_k(HH, wavelet) - HLH, HLL = _decompose_k(HL, wavelet) - LHH, LHL = _decompose_k(LH, wavelet) - LLH, LLL = _decompose_k(LL, wavelet) + dec = pywt.swtn(data, wavelet, level=1, start_level=0, axes=axes)[0] + data = dec['a' * len(axes)].copy() - data = LLL.copy() - - dec = {'HHH': HHH, - 'HHL': HHL, - 'HLH': HLH, - 'HLL': HLL, - 'LHH': LHH, - 'LHL': LHL, - 'LLH': LLH} + dec_im = {} for decName, decImage in six.iteritems(dec): decTemp = decImage.copy() decTemp = numpy.resize(decTemp, original_shape) sitkImage = sitk.GetImageFromArray(decTemp) sitkImage.CopyInformation(inputImage) - dec[decName] = sitkImage + dec_im[str(decName).replace('a', 'L').replace('d', 'H')] = sitkImage - ret.append(dec) + ret.append(dec_im) data = numpy.resize(data, original_shape) approximation = sitk.GetImageFromArray(data) @@ -793,46 +778,6 @@ def _swt3(inputImage, wavelet='coif1', level=1, start_level=0): return approximation, ret -def _decompose_i(data, wavelet): - # process in i: - H, L = [], [] - i_arrays = chain.from_iterable(data) - for i_array in i_arrays: - cA, cD = pywt.swt(i_array, wavelet, level=1, start_level=0)[0] - H.append(cD) - L.append(cA) - H = numpy.hstack(H).reshape(data.shape) - L = numpy.hstack(L).reshape(data.shape) - return H, L - - -def _decompose_j(data, wavelet): - # process in j: - s = data.shape - H, L = [], [] - j_arrays = chain.from_iterable(numpy.transpose(data, (0, 2, 1))) - for j_array in j_arrays: - cA, cD = pywt.swt(j_array, wavelet, level=1, start_level=0)[0] - H.append(cD) - L.append(cA) - H = numpy.hstack(H).reshape((s[0], s[2], s[1])).transpose((0, 2, 1)) - L = numpy.hstack(L).reshape((s[0], s[2], s[1])).transpose((0, 2, 1)) - return H, L - - -def _decompose_k(data, wavelet): - # process in k: - H, L = [], [] - k_arrays = chain.from_iterable(numpy.transpose(data, (2, 1, 0))) - for k_array in k_arrays: - cA, cD = pywt.swt(k_array, wavelet, level=1, start_level=0)[0] - H.append(cD) - L.append(cA) - H = numpy.asarray([slice for slice in numpy.split(numpy.vstack(H), data.shape[2])]).T - L = numpy.asarray([slice for slice in numpy.split(numpy.vstack(L), data.shape[2])]).T - return H, L - - def getSquareImage(inputImage, **kwargs): r""" Computes the square of the image intensities.