Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create unet_lstm.py #207

Merged
merged 21 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
75048ca
Create unet_lstm.py
Aakanksha-Rana Jan 19, 2022
d8c1f9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2022
7df74ec
Update unet_lstm.py
Aakanksha-Rana Jan 20, 2022
80f7817
Merge branch 'master' into 3D_LSTM_Transfomers
Aakanksha-Rana Jan 24, 2022
dea32f6
Merge branch 'master' into 3D_LSTM_Transfomers
satra Feb 28, 2022
2b3a217
Merge branch 'master' into 3D_LSTM_Transfomers
Aakanksha-Rana Mar 9, 2022
dd5310c
Merge branch 'neuronets:master' into 3D_LSTM_Transfomers
Aakanksha-Rana Mar 13, 2022
49396b7
Merge branch 'master' into 3D_LSTM_Transfomers
Aakanksha-Rana May 10, 2022
aba42ce
Merge branch 'master' into 3D_LSTM_Transfomers
satra May 11, 2022
194ea7c
Update unet_lstm.py
Aakanksha-Rana May 11, 2022
f1cf8bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2022
aff4172
updated unet_lstm test
Aakanksha-Rana May 11, 2022
1cd8082
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2022
449ba7f
Update unet_lstm.py
Aakanksha-Rana May 11, 2022
96b1e44
docstrings
Aakanksha-Rana May 16, 2022
f0a46bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2022
afeace7
Create max_pool4d.py
Aakanksha-Rana May 16, 2022
2d683b7
Update max_pool4d.py
Aakanksha-Rana May 16, 2022
497095e
Merge branch 'master' into 3D_LSTM_Transfomers
Aakanksha-Rana Jul 25, 2022
295d33f
Merge branch 'master' into 3D_LSTM_Transfomers
Hoda1394 Aug 3, 2022
9d6bc88
Merge branch 'master' into 3D_LSTM_Transfomers
satra Aug 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions nobrainer/layers/max_pool4d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# import tensorflow as tf
# from tensorflow.python.ops import gen_nn_ops

# def _get_sequence(value, n, channel_index, name):
# """Formats a value input for gen_nn_ops."""
# # Performance is fast-pathed for common cases:
# # `None`, `list`, `tuple` and `int`.
# if value is None:
# return [1] * (n + 2)

# # Always convert `value` to a `list`.
# if isinstance(value, list):
# pass
# elif isinstance(value, tuple):
# value = list(value)
# elif isinstance(value, int):
# value = [value]
# elif not isinstance(value, collections_abc.Sized):
# value = [value]
# else:
# value = list(value) # Try casting to a list.

# len_value = len(value)

# # Fully specified, including batch and channel dims.
# if len_value == n + 2:
# return value

# # Apply value to spatial dims only.
# if len_value == 1:
# value = value * n # Broadcast to spatial dimensions.
# elif len_value != n:
# raise ValueError(f"{name} should be of length 1, {n} or {n + 2}. "
# f"Received: {name}={value} of length {len_value}")

# # Add batch and channel dims (always 1).
# if channel_index == 1:
# return [1, 1] + value
# else:
# return [1] + value + [1]

# @tf_export("nn.max_pool4d")
# @dispatch.add_dispatch_support
# def max_pool4d(input, ksize, strides, padding, data_format="NVDHWC", name=None):
# """Performs the max pooling on the input.
# Args:
# input: A 6-D `Tensor` of the format specified by `data_format`.
# ksize: An int or list of `ints` that has length `1`, `3` or `5`. The size of
# the window for each dimension of the input tensor.
# strides: An int or list of `ints` that has length `1`, `3` or `5`. The
# stride of the sliding window for each dimension of the input tensor.
# padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
# [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
# for more information.
# data_format: An optional string from: "NVDHWC", "NCVDHW". Defaults to "NVDHWC".
# The data format of the input and output data. With the default format
# "NVDHWC", the data is stored in the order of: [batch, in_depth, in_height,
# in_width, in_channels]. Alternatively, the format could be "NCVDHW", the
# data storage order is: [batch, in_channels, in_volumes, in_depth, in_height,
# in_width].
# name: A name for the operation (optional).
# Returns:
# A `Tensor` of format specified by `data_format`.
# The max pooled output tensor.
# """
# with ops.name_scope(name, "MaxPool4D", [input]) as name:
# if data_format is None:
# data_format = "NVDHWC"
# channel_index = 1 if data_format.startswith("NC") else 5

# ksize = _get_sequence(ksize, 3, channel_index, "ksize")
# strides = _get_sequence(strides, 3, channel_index, "strides")

# return gen_nn_ops.max_pool4d(
# input,
# ksize=ksize,
# strides=strides,
# padding=padding,
# data_format=data_format,
# name=name)
11 changes: 11 additions & 0 deletions nobrainer/models/tests/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..meshnet import meshnet
from ..progressivegan import progressivegan
from ..unet import unet
from ..unet_lstm import unet_lstm
from ..vnet import vnet
from ..vox2vox import Vox_ensembler, vox_gan

Expand Down Expand Up @@ -208,6 +209,16 @@ def test_bayesian_vnet():
)


def test_unet_lstm():
input_shape = (1, 32, 32, 32, 32)
n_classes = 1
x = 10 * np.random.random(input_shape)
y = 10 * np.random.random(input_shape)
model = unet_lstm(input_shape=(32, 32, 32, 32, 1), n_classes=1)
actual_output = model.predict(x)
assert actual_output.shape == y.shape[:-1] + (n_classes,)


def test_vox2vox():
input_shape = (1, 32, 32, 32, 1)
n_classes = 1
Expand Down
Loading