Skip to content

Commit

Permalink
Merge pull request #323 from JoostJM/wavelet-update
Browse files Browse the repository at this point in the history
Simplify Wavelet filter function
  • Loading branch information
JoostJM authored Nov 2, 2017
2 parents b395904 + 71ed68d commit 6ac3bb8
Showing 1 changed file with 13 additions and 68 deletions.
81 changes: 13 additions & 68 deletions radiomics/imageoperations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import print_function

from itertools import chain
import logging

import numpy
Expand Down Expand Up @@ -714,7 +713,11 @@ def getWaveletImage(inputImage, **kwargs):

logger.debug('Generating Wavelet images')

approx, ret = _swt3(inputImage, kwargs.get('wavelet', 'coif1'), kwargs.get('level', 1), kwargs.get('start_level', 0))
axes = {2, 1, 0} # set
if kwargs.get('force2D', False):
axes -= {kwargs.get('force2Ddimension', 0)} # set

approx, ret = _swt3(inputImage, kwargs.get('wavelet', 'coif1'), kwargs.get('level', 1), kwargs.get('start_level', 0), axes=tuple(axes))

for idx, wl in enumerate(ret, start=1):
for decompositionName, decompositionImage in wl.items():
Expand All @@ -735,7 +738,7 @@ def getWaveletImage(inputImage, **kwargs):
yield approx, inputImageName, kwargs


def _swt3(inputImage, wavelet='coif1', level=1, start_level=0):
def _swt3(inputImage, wavelet='coif1', level=1, start_level=0, axes=(2, 1, 0)):
matrix = sitk.GetArrayFromImage(inputImage)
matrix = numpy.asarray(matrix)
if matrix.ndim != 3:
Expand All @@ -750,41 +753,23 @@ def _swt3(inputImage, wavelet='coif1', level=1, start_level=0):
wavelet = pywt.Wavelet(wavelet)

for i in range(0, start_level):
H, L = _decompose_i(data, wavelet)
LH, LL = _decompose_j(L, wavelet)
LLH, LLL = _decompose_k(LL, wavelet)

data = LLL.copy()
dec = pywt.swtn(data, wavelet, level=1, start_level=0, axes=axes)[0]
data = dec['a' * len(axes)].copy()

ret = []
for i in range(start_level, start_level + level):
H, L = _decompose_i(data, wavelet)

HH, HL = _decompose_j(H, wavelet)
LH, LL = _decompose_j(L, wavelet)

HHH, HHL = _decompose_k(HH, wavelet)
HLH, HLL = _decompose_k(HL, wavelet)
LHH, LHL = _decompose_k(LH, wavelet)
LLH, LLL = _decompose_k(LL, wavelet)
dec = pywt.swtn(data, wavelet, level=1, start_level=0, axes=axes)[0]
data = dec['a' * len(axes)].copy()

data = LLL.copy()

dec = {'HHH': HHH,
'HHL': HHL,
'HLH': HLH,
'HLL': HLL,
'LHH': LHH,
'LHL': LHL,
'LLH': LLH}
dec_im = {}
for decName, decImage in six.iteritems(dec):
decTemp = decImage.copy()
decTemp = numpy.resize(decTemp, original_shape)
sitkImage = sitk.GetImageFromArray(decTemp)
sitkImage.CopyInformation(inputImage)
dec[decName] = sitkImage
dec_im[str(decName).replace('a', 'L').replace('d', 'H')] = sitkImage

ret.append(dec)
ret.append(dec_im)

data = numpy.resize(data, original_shape)
approximation = sitk.GetImageFromArray(data)
Expand All @@ -793,46 +778,6 @@ def _swt3(inputImage, wavelet='coif1', level=1, start_level=0):
return approximation, ret


def _decompose_i(data, wavelet):
# process in i:
H, L = [], []
i_arrays = chain.from_iterable(data)
for i_array in i_arrays:
cA, cD = pywt.swt(i_array, wavelet, level=1, start_level=0)[0]
H.append(cD)
L.append(cA)
H = numpy.hstack(H).reshape(data.shape)
L = numpy.hstack(L).reshape(data.shape)
return H, L


def _decompose_j(data, wavelet):
# process in j:
s = data.shape
H, L = [], []
j_arrays = chain.from_iterable(numpy.transpose(data, (0, 2, 1)))
for j_array in j_arrays:
cA, cD = pywt.swt(j_array, wavelet, level=1, start_level=0)[0]
H.append(cD)
L.append(cA)
H = numpy.hstack(H).reshape((s[0], s[2], s[1])).transpose((0, 2, 1))
L = numpy.hstack(L).reshape((s[0], s[2], s[1])).transpose((0, 2, 1))
return H, L


def _decompose_k(data, wavelet):
# process in k:
H, L = [], []
k_arrays = chain.from_iterable(numpy.transpose(data, (2, 1, 0)))
for k_array in k_arrays:
cA, cD = pywt.swt(k_array, wavelet, level=1, start_level=0)[0]
H.append(cD)
L.append(cA)
H = numpy.asarray([slice for slice in numpy.split(numpy.vstack(H), data.shape[2])]).T
L = numpy.asarray([slice for slice in numpy.split(numpy.vstack(L), data.shape[2])]).T
return H, L


def getSquareImage(inputImage, **kwargs):
r"""
Computes the square of the image intensities.
Expand Down

0 comments on commit 6ac3bb8

Please sign in to comment.