Skip to content

Commit

Permalink
Impl scalar switch case op with condition op (#8184)
Browse files Browse the repository at this point in the history
Impl scalar switch case op with condition op
  • Loading branch information
jacquesqiao authored Feb 7, 2018
1 parent e583201 commit 20c4a4c
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 10 deletions.
3 changes: 1 addition & 2 deletions doc/design/switch.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ The following example shows the usage of `fluid.switch`.
a = fluid.Var(10)
b = fluid.Var(0)

switch = fluid.switch()
with switch.block():
with switch() as switch:
with switch.case(fluid.less_equal(a, 10)):
fluid.print("Case 1")
with switch.case(fluid.larger(a, 0)):
Expand Down
44 changes: 38 additions & 6 deletions paddle/operators/conditional_block_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,21 @@ class ConditionalOp : public framework::OperatorBase {
});
return retv;
}

bool ScalarCondition(
const std::vector<const framework::LoDTensor *> &ips) const {
if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
PADDLE_THROW("should have one initialized input as condition");
}
if (!(ips[0]->type().hash_code() == typeid(bool).hash_code() &&
ips[0]->numel() == 1)) {
PADDLE_THROW(
"condition input's data type should be bool, "
"numel should be 1, actual numel is %d",
ips[0]->numel());
}
return ips[0]->data<bool>()[0];
}
};

class ConditionalBlockOp : public ConditionalOp {
Expand All @@ -53,9 +68,15 @@ class ConditionalBlockOp : public ConditionalOp {
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto xs = InputTensors(scope);
bool need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });

bool need_run;
if (Attr<bool>("is_scalar_condition")) {
need_run = ScalarCondition(xs);
} else {
need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
}

if (need_run) {
auto *scope_var = scope.FindVar(Output("Scope"));
Expand Down Expand Up @@ -88,6 +109,10 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"scope is std::vector<Scope*>");
AddAttr<framework::BlockDesc *>(
"sub_block", "The step block of conditional block operator");
AddAttr<bool>("is_scalar_condition",
"the input X is used as scalar "
"condition")
.SetDefault(false);
AddComment(R"DOC(Conditional block operator
Run the sub-block if X is not empty. Params is the other inputs and Out is the
Expand All @@ -106,9 +131,15 @@ class ConditionalBlockGradOp : public ConditionalOp {
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto xs = this->InputTensors(scope);
bool need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });

bool need_run;
if (Attr<bool>("is_scalar_condition")) {
need_run = ScalarCondition(xs);
} else {
need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
}

if (need_run) {
auto *scope_var = scope.FindVar(Input("Scope"));
Expand Down Expand Up @@ -182,6 +213,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
grad_op->SetOutput(framework::GradVarName("Params"),
InputGrad("Params", false));
grad_op->SetBlockAttr("sub_block", *this->grad_block_[0]);
grad_op->SetAttr("is_scalar_condition", GetAttr("is_scalar_condition"));
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
Expand Down
66 changes: 64 additions & 2 deletions python/paddle/v2/fluid/layers/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .. import core
from ..framework import Program, Variable, Operator
from ..layer_helper import LayerHelper, unique_name
from ops import logical_and, logical_not, logical_or

__all__ = [
'split_lod_tensor',
Expand All @@ -27,6 +28,7 @@
'StaticRNNMemoryLink',
'WhileGuard',
'While',
'Switch',
'lod_rank_table',
'max_sequence_len',
'topk',
Expand Down Expand Up @@ -1063,11 +1065,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):


class ConditionalBlock(object):
def __init__(self, inputs, name=None):
def __init__(self, inputs, is_scalar_condition=False, name=None):
for each_input in inputs:
if not isinstance(each_input, Variable):
raise TypeError("Each input should be variable")
self.inputs = inputs
self.is_scalar_condition = is_scalar_condition
self.helper = LayerHelper('conditional_block', name=name)

def block(self):
Expand Down Expand Up @@ -1112,7 +1115,66 @@ def complete(self):
},
outputs={'Out': out_list,
'Scope': [step_scope]},
attrs={'sub_block': inside_block})
attrs={
'sub_block': inside_block,
'is_scalar_condition': self.is_scalar_condition
})


class Switch(object):
def __init__(self, name=None):
self.helper = LayerHelper('switch', name=name)
self.inside_scope = False
self.pre_not_conditions = []

def case(self, condition):
"""create a new block for this condition
"""
if not self.inside_scope:
raise ValueError("case should be called inside with")

if len(self.pre_not_conditions) == 0:
cond_block = ConditionalBlock([condition], is_scalar_condition=True)
not_cond = logical_not(x=condition)
self.pre_not_conditions.append(not_cond)
else:
pre_cond_num = len(self.pre_not_conditions)
pre_not_cond = self.pre_not_conditions[pre_cond_num - 1]
new_not_cond = logical_and(
x=pre_not_cond, y=logical_not(x=condition))
self.pre_not_conditions.append(new_not_cond)
cond_block = ConditionalBlock(
[logical_and(
x=pre_not_cond, y=condition)],
is_scalar_condition=True)

return ConditionalBlockGuard(cond_block)

def default(self):
"""create a default case for this switch
"""
pre_cond_num = len(self.pre_not_conditions)
if pre_cond_num == 0:
raise ValueError("there should be at least one condition")
cond_block = ConditionalBlock(
[self.pre_not_conditions[pre_cond_num - 1]],
is_scalar_condition=True)
return ConditionalBlockGuard(cond_block)

def __enter__(self):
"""
set flag that now is inside switch.block {}
:return:
"""
self.inside_scope = True
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.inside_scope = False
if exc_type is not None:
return False # re-raise exception

return True


class IfElseBlockGuard(object):
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/v2/fluid/layers/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@
'clip_by_norm',
'softmax',
'sequence_softmax',
'logical_and',
'logical_or',
'logical_xor',
'logical_not',
] + __activations__

for _OP in set(__all__):
Expand Down
64 changes: 64 additions & 0 deletions python/paddle/v2/fluid/tests/test_switch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle.v2.fluid.core as core
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.framework as framework
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.framework import default_startup_program


class TestSwitch(unittest.TestCase):
def check_switch(self, value):
x = layers.fill_constant(shape=[1], dtype='float32', value=value)

zero_var = layers.fill_constant(shape=[1], dtype='float32', value=0.0)
one_var = layers.fill_constant(shape=[1], dtype='float32', value=1.0)
two_var = layers.fill_constant(shape=[1], dtype='float32', value=2.0)
three_var = layers.fill_constant(shape=[1], dtype='float32', value=3.0)

result = layers.create_global_var(
shape=[1], value=-1.0, dtype='float32', persistable=True)

with layers.Switch() as switch:
with switch.case(layers.less_than(x, zero_var)):
layers.assign(zero_var, result)
with switch.case(layers.less_than(x, one_var)):
layers.assign(one_var, result)
with switch.case(layers.less_than(x, two_var)):
layers.assign(two_var, result)
with switch.default():
layers.assign(three_var, result)

cpu = core.CPUPlace()
exe = Executor(cpu)
exe.run(default_startup_program())

out = exe.run(feed={}, fetch_list=[result])[0][0]
return out

def test_switch(self):
test_data = {(-0.1, 0), (0.1, 1), (1.1, 2), (2.1, 3)}
for x, expected_result in test_data:
main_program = framework.Program()
startup_program = framework.Program()
with framework.program_guard(main_program, startup_program):
result = self.check_switch(x)
self.assertEqual(result, expected_result)


if __name__ == '__main__':
unittest.main()

0 comments on commit 20c4a4c

Please sign in to comment.