diff --git a/src/main/python/systemds/operator/nn/affine.py b/src/main/python/systemds/operator/nn/affine.py index 44c67d1eda6..35935871aa5 100644 --- a/src/main/python/systemds/operator/nn/affine.py +++ b/src/main/python/systemds/operator/nn/affine.py @@ -18,21 +18,15 @@ # under the License. # # ------------------------------------------------------------- -import os - from systemds.context import SystemDSContext -from systemds.operator import Matrix, Source, MultiReturn -from systemds.utils.helpers import get_path_to_script_layers +from systemds.operator import Matrix, MultiReturn +from systemds.operator.nn.layer import Layer -class Affine: - _source: Source = None +class Affine(Layer): weight: Matrix bias: Matrix - def __new__(cls, *args, **kwargs): - return super().__new__(cls) - def __init__(self, sds_context: SystemDSContext, d, m, seed=-1): """ sds_context: The systemdsContext to construct the layer inside of @@ -40,11 +34,8 @@ def __init__(self, sds_context: SystemDSContext, d, m, seed=-1): m: The number of neurons that are contained in the layer, and the number of features output """ - Affine._create_source(sds_context) - - # bypassing overload limitation in python - self.forward = self._instance_forward - self.backward = self._instance_backward + super().__init__(sds_context, 'affine.dml') + self._X = None # init weight and bias self.weight = Matrix(sds_context, '') @@ -64,7 +55,7 @@ def forward(X: Matrix, W: Matrix, b: Matrix): b: The bias added in the output. return out: An output matrix. """ - Affine._create_source(X.sds_context) + Affine._create_source(X.sds_context, "affine.dml") return Affine._source.forward(X, W, b) @staticmethod @@ -77,7 +68,7 @@ def backward(dout:Matrix, X: Matrix, W: Matrix, b: Matrix): return dX, dW, db: The gradients of: input X, weights and bias. """ sds = X.sds_context - Affine._create_source(sds) + Affine._create_source(sds, "affine.dml") params_dict = {'dout': dout, 'X': X, 'W': W, 'b': b} dX = Matrix(sds, '') dW = Matrix(sds, '') @@ -104,11 +95,6 @@ def _instance_backward(self, dout: Matrix, X: Matrix): X: The input to this layer. return dX, dW,db: gradient of input, weights and bias, respectively """ - return Affine.backward(dout, X, self.weight, self.bias) - - @staticmethod - def _create_source(sds: SystemDSContext): - if Affine._source is None or Affine._source.sds_context != sds: - path = get_path_to_script_layers() - path = os.path.join(path, "affine.dml") - Affine._source = sds.source(path, "affine") + gradients = Affine.backward(dout, X, self.weight, self.bias) + self._X = gradients[0] + return gradients diff --git a/src/main/python/systemds/operator/nn/layer.py b/src/main/python/systemds/operator/nn/layer.py new file mode 100644 index 00000000000..255fa2d4d15 --- /dev/null +++ b/src/main/python/systemds/operator/nn/layer.py @@ -0,0 +1,69 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- +import os + +from systemds.context import SystemDSContext +from systemds.operator import Source +from systemds.utils.helpers import get_path_to_script_layers + + +class Layer: + """ + Interface for neural network layers + """ + + _source: Source = None + + def __init__(self, sds_context: SystemDSContext = None, dml_script: str = None): + if sds_context is not None and dml_script is not None: + self.__class__._create_source(sds_context, dml_script) + + # bypassing overload limitation in python + self.forward = self._instance_forward + self.backward = self._instance_backward + + @classmethod + def _create_source(cls, sds_context: SystemDSContext, dml_script: str): + """ + Create SystemDS source + :param sds_context: SystemDS context + :param dml_script: DML script inside /scripts/nn/layers/ + :return: + """ + if cls._source is None or cls._source.sds_context != sds_context: + script_path = get_path_to_script_layers() + path = os.path.join(script_path, dml_script) + name = dml_script.split(".")[0] + cls._source = sds_context.source(path, name) + + def _instance_forward(self, *args): + raise NotImplementedError + + def _instance_backward(self, *args): + raise NotImplementedError + + @staticmethod + def forward(*args): + raise NotImplementedError + + @staticmethod + def backward(*args): + raise NotImplementedError diff --git a/src/main/python/systemds/operator/nn/relu.py b/src/main/python/systemds/operator/nn/relu.py index 99833e6d86d..e124e350d99 100644 --- a/src/main/python/systemds/operator/nn/relu.py +++ b/src/main/python/systemds/operator/nn/relu.py @@ -18,20 +18,16 @@ # under the License. # # ------------------------------------------------------------- -import os.path - from systemds.context import SystemDSContext from systemds.operator import Matrix, Source -from systemds.utils.helpers import get_path_to_script_layers +from systemds.operator.nn.layer import Layer -class ReLU: +class ReLU(Layer): _source: Source = None - def __init__(self, sds: SystemDSContext): - ReLU._create_source(sds) - self.forward = self._instance_forward - self.backward = self._instance_backward + def __init__(self, sds_context: SystemDSContext): + super().__init__(sds_context, "relu.dml") @staticmethod def forward(X: Matrix): @@ -39,7 +35,7 @@ def forward(X: Matrix): X: input matrix return out: output matrix """ - ReLU._create_source(X.sds_context) + ReLU._create_source(X.sds_context, "relu.dml") return ReLU._source.forward(X) @staticmethod @@ -49,7 +45,7 @@ def backward(dout: Matrix, X: Matrix): X: input matrix return dX: gradient of input """ - ReLU._create_source(dout.sds_context) + ReLU._create_source(dout.sds_context, "relu.dml") return ReLU._source.backward(dout, X) def _instance_forward(self, X: Matrix): @@ -58,11 +54,3 @@ def _instance_forward(self, X: Matrix): def _instance_backward(self, dout: Matrix, X: Matrix): return ReLU.backward(dout, X) - - @staticmethod - def _create_source(sds: SystemDSContext): - if ReLU._source is None or ReLU._source.sds_context != sds: - path = get_path_to_script_layers() - path = os.path.join(path, "relu.dml") - ReLU._source = sds.source(path, "relu") - diff --git a/src/main/python/systemds/operator/nn/sequential.py b/src/main/python/systemds/operator/nn/sequential.py new file mode 100644 index 00000000000..bf54ff5a650 --- /dev/null +++ b/src/main/python/systemds/operator/nn/sequential.py @@ -0,0 +1,97 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- +from systemds.operator import MultiReturn +from systemds.operator.nn.layer import Layer + + +class Sequential(Layer): + def __init__(self, *args): + super().__init__() + + self.layers = [] + if len(args) == 1 and isinstance(args[0], list): + self.layers = args[0] + else: + self.layers = list(args) + + def __len__(self): + return len(self.layers) + + def __getitem__(self, idx): + return self.layers[idx] + + def __setitem__(self, idx, value): + self.layers[idx] = value + + def __delitem__(self, idx): + del self.layers[idx] + + def __iter__(self): + return iter(self.layers) + + def __reversed__(self): + return reversed(self.layers) + + def push(self, layer: Layer): + """ + Add layer + :param layer: Layer + :return: + """ + self.layers.append(layer) + + def pop(self): + """ + Remove last layer + :return: Layer + """ + return self.layers.pop() + + def _instance_forward(self, X): + """ + Forward pass + :param X: Input matrix + :return: output matrix + """ + out = X + for layer in self: + out = layer.forward(out) + + # if MultiReturn, take only output matrix + if isinstance(out, MultiReturn): + out = out[0] + return out + + def _instance_backward(self, dout, X): + """ + Backward pass + :param dout: gradient of output, passed from the upstream + :param X: input matrix + :return: output matrix + """ + dx = dout + for layer in reversed(self): + dx = layer.backward(dx, X) + + # if MultiReturn, take only gradient of input + if isinstance(dx, MultiReturn): + dx = dx[0] + return dx diff --git a/src/main/python/systemds/operator/nodes/multi_return.py b/src/main/python/systemds/operator/nodes/multi_return.py index cb6b923d2c5..e2fa09b3dba 100644 --- a/src/main/python/systemds/operator/nodes/multi_return.py +++ b/src/main/python/systemds/operator/nodes/multi_return.py @@ -47,7 +47,7 @@ def __init__(self, sds_context, operation, named_input_nodes, OutputType.MULTI_RETURN, False) def __getitem__(self, key): - self._outputs[key] + return self._outputs[key] def code_line(self, var_name: str, unnamed_input_vars: Sequence[str], named_input_vars: Dict[str, str]) -> str: diff --git a/src/main/python/tests/nn/test_affine.py b/src/main/python/tests/nn/test_affine.py index 955945b29c1..a7de2c383d6 100644 --- a/src/main/python/tests/nn/test_affine.py +++ b/src/main/python/tests/nn/test_affine.py @@ -77,6 +77,7 @@ def test_forward(self): out = affine.forward(Xm).compute() self.assertEqual(len(out), 5) self.assertEqual(len(out[0]), 6) + assert_almost_equal(affine._X.compute(), Xm.compute()) # test static method out = Affine.forward(Xm, Wm, bm).compute() @@ -91,10 +92,13 @@ def test_backward(self): # test class method affine = Affine(self.sds, dim, m, 10) - [dx, dw, db] = affine.backward(doutm, Xm).compute() + gradients = affine.backward(doutm, Xm) + intermediate = affine._X.compute() + [dx, dw, db] = gradients.compute() assert len(dx) == 5 and len(dx[0]) == 6 assert len(dw) == 6 and len(dx[0]) == 6 assert len(db) == 1 and len(db[0]) == 6 + assert_almost_equal(intermediate, dx) # test static method res = Affine.backward(doutm, Xm, Wm, bm).compute() diff --git a/src/main/python/tests/nn/test_layer.py b/src/main/python/tests/nn/test_layer.py new file mode 100644 index 00000000000..0b6a0eb2e1d --- /dev/null +++ b/src/main/python/tests/nn/test_layer.py @@ -0,0 +1,80 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- + +import unittest + +from systemds.context import SystemDSContext +from systemds.operator.nn.layer import Layer + + +class TestLayer(unittest.TestCase): + sds: SystemDSContext = None + + @classmethod + def setUpClass(cls): + cls.sds = SystemDSContext() + + @classmethod + def tearDownClass(cls): + cls.sds.close() + + def test_init(self): + """ + Test that the source is created correctly from dml_script param when layer is initialized + """ + _ = Layer(self.sds, "relu.dml") + self.assertIsNotNone(Layer._source) + self.assertTrue(Layer._source.operation.endswith('relu.dml"')) + self.assertEqual(Layer._source._Source__name, "relu") + + def test_notimplemented(self): + """ + Test that NotImplementedError is raised + """ + + class TestLayerImpl(Layer): + pass + + layer = TestLayerImpl(self.sds, "relu.dml") + with self.assertRaises(NotImplementedError): + layer.forward(None) + with self.assertRaises(NotImplementedError): + layer.backward(None) + with self.assertRaises(NotImplementedError): + TestLayerImpl.forward(None) + with self.assertRaises(NotImplementedError): + TestLayerImpl.backward(None) + + def test_class_source_assignments(self): + """ + Test that the source is not shared between interface and implementation class + """ + + class TestLayerImpl(Layer): + @classmethod + def _create_source(cls, sds_context: SystemDSContext, dml_script: str): + cls._source = "test" + + _ = Layer(self.sds, "relu.dml") + _ = TestLayerImpl(self.sds, "relu.dml") + + self.assertNotEqual(Layer._source, "test") + self.assertEqual(TestLayerImpl._source, "test") diff --git a/src/main/python/tests/nn/test_sequential.py b/src/main/python/tests/nn/test_sequential.py new file mode 100644 index 00000000000..a7a361e40fb --- /dev/null +++ b/src/main/python/tests/nn/test_sequential.py @@ -0,0 +1,304 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- + +import unittest + +import numpy as np +from numpy.testing import assert_almost_equal + +from systemds.operator.nn.affine import Affine +from systemds.operator.nn.relu import ReLU +from systemds.operator.nn.sequential import Sequential +from systemds.operator import Matrix, MultiReturn +from systemds.operator.nn.layer import Layer +from systemds.context import SystemDSContext + + +class TestLayerImpl(Layer): + def __init__(self, test_id): + super().__init__() + self.test_id = test_id + + def _instance_forward(self, X: Matrix): + return X + self.test_id + + def _instance_backward(self, dout: Matrix, X: Matrix): + return dout - self.test_id + + +class MultiReturnImpl(Layer): + def __init__(self, sds): + super().__init__() + self.sds = sds + + def _instance_forward(self, X: Matrix): + return MultiReturn(self.sds, "test.dml", output_nodes=[X, 'some_random_return']) + + def _instance_backward(self, dout: Matrix, X: Matrix): + return MultiReturn(self.sds, "test.dml", output_nodes=[dout, X, 'some_random_return']) + + +class TestSequential(unittest.TestCase): + sds: SystemDSContext = None + + @classmethod + def setUpClass(cls): + cls.sds = SystemDSContext() + + @classmethod + def tearDownClass(cls): + cls.sds.close() + + def test_init_with_multiple_args(self): + """ + Test that Sequential is correctly initialized if multiple layers are passed as arguments + """ + model = Sequential(TestLayerImpl(1), TestLayerImpl(2), TestLayerImpl(3)) + self.assertEqual(len(model.layers), 3) + self.assertEqual(model.layers[0].test_id, 1) + self.assertEqual(model.layers[1].test_id, 2) + self.assertEqual(model.layers[2].test_id, 3) + + def test_init_with_list(self): + """ + Test that Sequential is correctly initialized if list of layers is passed as argument + """ + model = Sequential([TestLayerImpl(1), TestLayerImpl(2), TestLayerImpl(3)]) + self.assertEqual(len(model.layers), 3) + self.assertEqual(model.layers[0].test_id, 1) + self.assertEqual(model.layers[1].test_id, 2) + self.assertEqual(model.layers[2].test_id, 3) + + def test_len(self): + """ + Test that len() returns the number of layers + """ + model = Sequential([TestLayerImpl(1), TestLayerImpl(2), TestLayerImpl(3)]) + self.assertEqual(len(model), 3) + + def test_getitem(self): + """ + Test that Sequential[index] returns the layer at the given index + """ + model = Sequential([TestLayerImpl(1), TestLayerImpl(2), TestLayerImpl(3)]) + self.assertEqual(model[1].test_id, 2) + + def test_setitem(self): + """ + Test that Sequential[index] = layer sets the layer at the given index + """ + model = Sequential([TestLayerImpl(1), TestLayerImpl(2), TestLayerImpl(3)]) + model[1] = TestLayerImpl(4) + self.assertEqual(model[1].test_id, 4) + + def test_delitem(self): + """ + Test that del Sequential[index] removes the layer at the given index + """ + model = Sequential([TestLayerImpl(1), TestLayerImpl(2), TestLayerImpl(3)]) + del model[1] + self.assertEqual(len(model.layers), 2) + self.assertEqual(model[1].test_id, 3) + + def test_iter(self): + """ + Test that iter() returns an iterator over the layers + """ + model = Sequential([TestLayerImpl(1), TestLayerImpl(2), TestLayerImpl(3)]) + for i, layer in enumerate(model): + self.assertEqual(layer.test_id, i + 1) + + def test_push(self): + """ + Test that push() adds a layer + """ + model = Sequential() + model.push(TestLayerImpl(1)) + self.assertEqual(len(model.layers), 1) + self.assertEqual(model.layers[0].test_id, 1) + + def test_pop(self): + """ + Test that pop() removes the last layer + """ + model = Sequential([TestLayerImpl(1), TestLayerImpl(2), TestLayerImpl(3)]) + layer = model.pop() + self.assertEqual(len(model.layers), 2) + self.assertEqual(layer.test_id, 3) + + def test_reversed(self): + """ + Test that reversed() returns an iterator over the layers in reverse order + """ + model = Sequential([TestLayerImpl(1), TestLayerImpl(2), TestLayerImpl(3)]) + for i, layer in enumerate(reversed(model)): + self.assertEqual(layer.test_id, 3 - i) + + def test_forward(self): + """ + Test that forward() calls forward() on all layers + """ + model = Sequential([TestLayerImpl(1), TestLayerImpl(2), TestLayerImpl(3)]) + in_matrix = self.sds.from_numpy(np.array([[1, 2], [3, 4]])) + out_matrix = model.forward(in_matrix).compute() + self.assertEqual(out_matrix.tolist(), [[7, 8], [9, 10]]) + + def test_forward_actual_layers(self): + """ + Test forward() with actual layers + """ + params = [ + np.array([[0.5, -0.5], [-0.5, 0.5]]), + np.array([[0.1, -0.1]]), + np.array([[0.4, -0.4], [-0.4, 0.4]]), + np.array([[0.2, -0.2]]), + np.array([[0.3, -0.3], [-0.3, 0.3]]), + np.array([[0.3, -0.3]]), + ] + + model = Sequential( + [ + Affine(self.sds, 2, 2), + ReLU(self.sds), + Affine(self.sds, 2, 2), + ReLU(self.sds), + Affine(self.sds, 2, 2), + ] + ) + + for i, layer in enumerate(model): + if isinstance(layer, Affine): + layer.weight = self.sds.from_numpy(params[i]) + layer.bias = self.sds.from_numpy(params[i + 1]) + + in_matrix = self.sds.from_numpy(np.array([[1.0, 2.0], [3.0, 4.0]])) + out_matrix = model.forward(in_matrix).compute() + expected = np.array([[0.3120, -0.3120], [0.3120, -0.3120]]) + assert_almost_equal(out_matrix, expected) + + def test_backward_actual_layers(self): + """ + Test backward() with actual layers + """ + params = [ + np.array([[0.5, -0.5], [-0.5, 0.5]]), + np.array([[0.1, -0.1]]), + np.array([[0.4, -0.4], [-0.4, 0.4]]), + np.array([[0.2, -0.2]]), + np.array([[0.3, -0.3], [-0.3, 0.3]]), + np.array([[0.3, -0.3]]), + ] + + model = Sequential( + [ + Affine(self.sds, 2, 2), + ReLU(self.sds), + Affine(self.sds, 2, 2), + ReLU(self.sds), + Affine(self.sds, 2, 2), + ] + ) + + for i, layer in enumerate(model): + if isinstance(layer, Affine): + layer.weight = self.sds.from_numpy(params[i]) + layer.bias = self.sds.from_numpy(params[i + 1]) + + in_matrix = self.sds.from_numpy(np.array([[1.0, 2.0], [3.0, 4.0]])) + out_matrix = model.forward(in_matrix) + gradient = model.backward(out_matrix, in_matrix).compute() + + # Test returned gradient + expected = np.array([[0.14976, -0.14976], [0.14976, -0.14976]]) + assert_almost_equal(gradient, expected) + + # Test if layers have been updated correctly + expected_gradients = [ + np.array([[0.14976, -0.14976], [0.14976, -0.14976]]), + np.array([[0.14976, -0.14976], [0.14976, -0.14976]]), + np.array([[0.1872, -0.1872], [0.1872, -0.1872]]), + ] + for i, layer in enumerate(model): + if isinstance(layer, Affine): + assert_almost_equal(layer._X.compute(), expected_gradients[int(i / 2)]) + + def test_multireturn_forward_pass(self): + """ + Test that forward() handles MultiReturn correctly + """ + model = Sequential(MultiReturnImpl(self.sds), TestLayerImpl(1)) + in_matrix = self.sds.from_numpy(np.array([[1, 2], [3, 4]])) + out_matrix = model.forward(in_matrix).compute() + self.assertEqual(out_matrix.tolist(), [[2, 3], [4, 5]]) + + def test_multireturn_backward_pass(self): + """ + Test that backward() handles MultiReturn correctly + """ + model = Sequential(TestLayerImpl(1), MultiReturnImpl(self.sds)) + in_matrix = self.sds.from_numpy(np.array([[1, 2], [3, 4]])) + out_matrix = self.sds.from_numpy(np.array([[2, 3], [4, 5]])) + gradient = model.backward(out_matrix, in_matrix).compute() + self.assertEqual(gradient.tolist(), [[1, 2], [3, 4]]) + + def test_multireturn_variation_multiple(self): + """ + Test that multiple MultiReturn after each other are handled correctly + """ + model = Sequential(MultiReturnImpl(self.sds), MultiReturnImpl(self.sds)) + in_matrix = self.sds.from_numpy(np.array([[1, 2], [3, 4]])) + out_matrix = model.forward(in_matrix).compute() + self.assertEqual(out_matrix.tolist(), [[1, 2], [3, 4]]) + gradient = model.backward(self.sds.from_numpy(out_matrix), in_matrix).compute() + self.assertEqual(gradient.tolist(), [[1, 2], [3, 4]]) + + def test_multireturn_variation_single_to_multiple(self): + """ + Test that a single return into multiple MultiReturn are handled correctly + """ + model = Sequential(TestLayerImpl(1), MultiReturnImpl(self.sds), MultiReturnImpl(self.sds)) + in_matrix = self.sds.from_numpy(np.array([[1, 2], [3, 4]])) + out_matrix = model.forward(in_matrix).compute() + self.assertEqual(out_matrix.tolist(), [[2, 3], [4, 5]]) + gradient = model.backward(self.sds.from_numpy(out_matrix), in_matrix).compute() + self.assertEqual(gradient.tolist(), [[1, 2], [3, 4]]) + + def test_multireturn_variation_multiple_to_single(self): + """ + Test that multiple MultiReturn into a single return are handled correctly + """ + model = Sequential(MultiReturnImpl(self.sds), MultiReturnImpl(self.sds), TestLayerImpl(1)) + in_matrix = self.sds.from_numpy(np.array([[1, 2], [3, 4]])) + out_matrix = model.forward(in_matrix).compute() + self.assertEqual(out_matrix.tolist(), [[2, 3], [4, 5]]) + gradient = model.backward(self.sds.from_numpy(out_matrix), in_matrix).compute() + self.assertEqual(gradient.tolist(), [[1, 2], [3, 4]]) + + def test_multireturn_variation_sandwich(self): + """ + Test that a single return between two MultiReturn are handled correctly + """ + model = Sequential(MultiReturnImpl(self.sds), TestLayerImpl(1), MultiReturnImpl(self.sds)) + in_matrix = self.sds.from_numpy(np.array([[1, 2], [3, 4]])) + out_matrix = model.forward(in_matrix).compute() + self.assertEqual(out_matrix.tolist(), [[2, 3], [4, 5]]) + gradient = model.backward(self.sds.from_numpy(out_matrix), in_matrix).compute() + self.assertEqual(gradient.tolist(), [[1, 2], [3, 4]])