From 6bdc3b2c03b3a036a6f8dd3771211c8c0aa9b167 Mon Sep 17 00:00:00 2001 From: Alex Eftimiades Date: Wed, 2 Oct 2019 17:16:10 -0400 Subject: [PATCH 1/3] ResBlock combinator in stax --- examples/resnet50.py | 10 +++++----- jax/experimental/stax.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/examples/resnet50.py b/examples/resnet50.py index 211176fd2e23..3aeb2962a539 100644 --- a/examples/resnet50.py +++ b/examples/resnet50.py @@ -30,9 +30,9 @@ from jax import jit, grad, random from jax.experimental import optimizers from jax.experimental import stax -from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum, - FanOut, Flatten, GeneralConv, Identity, - MaxPool, Relu, LogSoftmax) +from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, + Flatten, GeneralConv, Identity, + MaxPool, Relu, ResBlock, LogSoftmax) # ResNet blocks compose other layers @@ -45,7 +45,7 @@ def ConvBlock(kernel_size, filters, strides=(2, 2)): Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(filters3, (1, 1)), BatchNorm()) Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm()) - return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu) + return ResBlock(Main, Shortcut) def IdentityBlock(kernel_size, filters): @@ -58,7 +58,7 @@ def make_main(input_shape): Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(input_shape[3], (1, 1)), BatchNorm()) Main = stax.shape_dependent(make_main) - return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu) + return ResBlock(Main, Identity) # ResNet architectures compose layers and ResNet blocks diff --git a/jax/experimental/stax.py b/jax/experimental/stax.py index 8c745fb728ab..ad864ccfcf71 100644 --- a/jax/experimental/stax.py +++ b/jax/experimental/stax.py @@ -260,6 +260,22 @@ def apply_fun(params, inputs, **kwargs): return init_fun, apply_fun +def ResBlock(*layers, fan_in=FanInSum, tail=Relu): + """Split input, feed it through one or more layers in parallel, + recombine them with a fan-in, apply a trailing layer (i.e. an activation) + + Args: + *layers: a sequence of layers, each an (init_fun, apply_fun) pair. + fan_in, optional: a fan-in to recombine the outputs of each layer + tail, optional: a final layer to apply after recombination + + Returns: + A new layer, meaning an (init_fun, apply_fun) pair, representing the + parallel composition of the given sequence of layers fed into fan_in and then tail. + """ + return serial(FanOut(len(layers)), parallel(*layers), fan_in, tail) + + # Composing layers via combinators From 05b891e7ce79c958f5109634af0c42975ae50543 Mon Sep 17 00:00:00 2001 From: Alex Eftimiades Date: Mon, 28 Oct 2019 09:28:16 -0400 Subject: [PATCH 2/3] change default tail to Identity layer --- examples/resnet50.py | 4 ++-- jax/experimental/stax.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/resnet50.py b/examples/resnet50.py index 3aeb2962a539..ef3361651684 100644 --- a/examples/resnet50.py +++ b/examples/resnet50.py @@ -45,7 +45,7 @@ def ConvBlock(kernel_size, filters, strides=(2, 2)): Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(filters3, (1, 1)), BatchNorm()) Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm()) - return ResBlock(Main, Shortcut) + return ResBlock(Main, Shortcut, tail=Relu) def IdentityBlock(kernel_size, filters): @@ -58,7 +58,7 @@ def make_main(input_shape): Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(input_shape[3], (1, 1)), BatchNorm()) Main = stax.shape_dependent(make_main) - return ResBlock(Main, Identity) + return ResBlock(Main, Identity, tail=Relu) # ResNet architectures compose layers and ResNet blocks diff --git a/jax/experimental/stax.py b/jax/experimental/stax.py index ad864ccfcf71..a2b9a1f0e093 100644 --- a/jax/experimental/stax.py +++ b/jax/experimental/stax.py @@ -260,7 +260,7 @@ def apply_fun(params, inputs, **kwargs): return init_fun, apply_fun -def ResBlock(*layers, fan_in=FanInSum, tail=Relu): +def ResBlock(*layers, fan_in=FanInSum, tail=Identity): """Split input, feed it through one or more layers in parallel, recombine them with a fan-in, apply a trailing layer (i.e. an activation) From 9094716916e6a767e1314a2ba272c467ade33cdc Mon Sep 17 00:00:00 2001 From: Alex Eftimiades Date: Mon, 28 Oct 2019 09:30:54 -0400 Subject: [PATCH 3/3] python2 compatibility --- jax/experimental/stax.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/jax/experimental/stax.py b/jax/experimental/stax.py index a2b9a1f0e093..d5209f29f9c9 100644 --- a/jax/experimental/stax.py +++ b/jax/experimental/stax.py @@ -260,7 +260,7 @@ def apply_fun(params, inputs, **kwargs): return init_fun, apply_fun -def ResBlock(*layers, fan_in=FanInSum, tail=Identity): +def ResBlock(*layers, **kwargs): """Split input, feed it through one or more layers in parallel, recombine them with a fan-in, apply a trailing layer (i.e. an activation) @@ -273,6 +273,13 @@ def ResBlock(*layers, fan_in=FanInSum, tail=Identity): A new layer, meaning an (init_fun, apply_fun) pair, representing the parallel composition of the given sequence of layers fed into fan_in and then tail. """ + # TODO(aeftimia): change signature to + # def ResBlock(*layers, fan_in=FanInSum, tail=Identity): + # when Python 2 support expires + default_args = {'fan_in': FanInSum, 'tail': Identity} + default_args.update(kwargs) + fan_in = default_args['fan_in'] + tail = default_args['tail'] return serial(FanOut(len(layers)), parallel(*layers), fan_in, tail)