Skip to content

Commit

Permalink
Merge pull request #207 from Aakanksha-Rana/3D_LSTM_Transfomers
Browse files Browse the repository at this point in the history
Create unet_lstm.py
  • Loading branch information
satra authored Aug 12, 2022
2 parents 88b5e92 + 9d6bc88 commit e3e7113
Show file tree
Hide file tree
Showing 3 changed files with 423 additions and 0 deletions.
80 changes: 80 additions & 0 deletions nobrainer/layers/max_pool4d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# import tensorflow as tf
# from tensorflow.python.ops import gen_nn_ops

# def _get_sequence(value, n, channel_index, name):
# """Formats a value input for gen_nn_ops."""
# # Performance is fast-pathed for common cases:
# # `None`, `list`, `tuple` and `int`.
# if value is None:
# return [1] * (n + 2)

# # Always convert `value` to a `list`.
# if isinstance(value, list):
# pass
# elif isinstance(value, tuple):
# value = list(value)
# elif isinstance(value, int):
# value = [value]
# elif not isinstance(value, collections_abc.Sized):
# value = [value]
# else:
# value = list(value) # Try casting to a list.

# len_value = len(value)

# # Fully specified, including batch and channel dims.
# if len_value == n + 2:
# return value

# # Apply value to spatial dims only.
# if len_value == 1:
# value = value * n # Broadcast to spatial dimensions.
# elif len_value != n:
# raise ValueError(f"{name} should be of length 1, {n} or {n + 2}. "
# f"Received: {name}={value} of length {len_value}")

# # Add batch and channel dims (always 1).
# if channel_index == 1:
# return [1, 1] + value
# else:
# return [1] + value + [1]

# @tf_export("nn.max_pool4d")
# @dispatch.add_dispatch_support
# def max_pool4d(input, ksize, strides, padding, data_format="NVDHWC", name=None):
# """Performs the max pooling on the input.
# Args:
# input: A 6-D `Tensor` of the format specified by `data_format`.
# ksize: An int or list of `ints` that has length `1`, `3` or `5`. The size of
# the window for each dimension of the input tensor.
# strides: An int or list of `ints` that has length `1`, `3` or `5`. The
# stride of the sliding window for each dimension of the input tensor.
# padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
# [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
# for more information.
# data_format: An optional string from: "NVDHWC", "NCVDHW". Defaults to "NVDHWC".
# The data format of the input and output data. With the default format
# "NVDHWC", the data is stored in the order of: [batch, in_depth, in_height,
# in_width, in_channels]. Alternatively, the format could be "NCVDHW", the
# data storage order is: [batch, in_channels, in_volumes, in_depth, in_height,
# in_width].
# name: A name for the operation (optional).
# Returns:
# A `Tensor` of format specified by `data_format`.
# The max pooled output tensor.
# """
# with ops.name_scope(name, "MaxPool4D", [input]) as name:
# if data_format is None:
# data_format = "NVDHWC"
# channel_index = 1 if data_format.startswith("NC") else 5

# ksize = _get_sequence(ksize, 3, channel_index, "ksize")
# strides = _get_sequence(strides, 3, channel_index, "strides")

# return gen_nn_ops.max_pool4d(
# input,
# ksize=ksize,
# strides=strides,
# padding=padding,
# data_format=data_format,
# name=name)
11 changes: 11 additions & 0 deletions nobrainer/models/tests/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..meshnet import meshnet
from ..progressivegan import progressivegan
from ..unet import unet
from ..unet_lstm import unet_lstm
from ..vnet import vnet
from ..vox2vox import Vox_ensembler, vox_gan

Expand Down Expand Up @@ -208,6 +209,16 @@ def test_bayesian_vnet():
)


def test_unet_lstm():
input_shape = (1, 32, 32, 32, 32)
n_classes = 1
x = 10 * np.random.random(input_shape)
y = 10 * np.random.random(input_shape)
model = unet_lstm(input_shape=(32, 32, 32, 32, 1), n_classes=1)
actual_output = model.predict(x)
assert actual_output.shape == y.shape[:-1] + (n_classes,)


def test_vox2vox():
input_shape = (1, 32, 32, 32, 1)
n_classes = 1
Expand Down
Loading

0 comments on commit e3e7113

Please sign in to comment.