Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Select to concurrency.py; incorporate outputs #9136

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/operators/select_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace operators {

static constexpr char kX[] = "X";
static constexpr char kCaseToExecute[] = "case_to_execute";
static constexpr char kOutputs[] = "Out";

static constexpr char kCases[] = "cases";
static constexpr char kCasesBlock[] = "sub_block";
Expand Down Expand Up @@ -388,6 +389,10 @@ class SelectOpMaker : public framework::OpProtoAndCheckerMaker {
"(Int) The variable the sets the index of the case to execute, "
"after evaluating the channels being sent to and received from")
.AsDuplicable();
AddOutput(kOutputs,
"A set of variables, which will be assigned with values "
"generated by the operators inside the cases of Select Op.")
.AsDuplicable();
AddAttr<std::vector<std::string>>(kCases,
"(String vector) Serialized list of"
"all cases in the select op. Each"
Expand Down
182 changes: 181 additions & 1 deletion python/paddle/fluid/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from layers.control_flow import BlockGuard, Select
from layers.control_flow import BlockGuard, equal
from .framework import Operator
from layer_helper import LayerHelper, unique_name
from layers import fill_constant
import core
Expand Down Expand Up @@ -75,6 +76,185 @@ def construct_go_op(self):
attrs={'sub_block': go_block})


class SelectCase(object):
DEFAULT = 0
SEND = 1
RECEIVE = 2

def __init__(self,
case_idx,
case_to_execute,
channel_action_fn=None,
channel=None,
value=None):
self.helper = LayerHelper('conditional_block')
self.main_program = self.helper.main_program
self.is_scalar_condition = True

self.case_to_execute = case_to_execute
self.idx = case_idx

# Since we aren't going to use the `channel_send` or `channel_recv`
# functions directly, we just need to capture the name.
self.action = (self.SEND
if channel_action_fn.__name__ == ('channel_send') else
self.RECEIVE) if channel_action_fn else self.DEFAULT
self.value = value
self.channel = channel

def __enter__(self):
self.block = self.main_program.create_block()

def construct_op(self):
main_program = self.helper.main_program
cases_block = main_program.current_block()

inner_outputs = set()
input_set = set()
params = set()

for op in self.block.ops:
# Iterate over all operators, get all the inputs
# and add as input to the SelectCase operator.
for iname in op.input_names:
for in_var_name in op.input(iname):
if in_var_name not in inner_outputs:
input_set.add(in_var_name)

for oname in op.output_names:
for out_var_name in op.output(oname):
inner_outputs.add(out_var_name)

param_list = [
cases_block.var(each_name) for each_name in params
if each_name not in input_set
]

# Iterate over all operators, get all the outputs
# add to the output list of SelectCase operator only if
# they exist in the parent block.
out_vars = []
for inner_out_name in inner_outputs:
if inner_out_name in cases_block.vars:
out_vars.append(cases_block.var(inner_out_name))

# First, create an op that will determine whether or not this is the
# conditional variable to execute.
should_execute_block = equal(
fill_constant(
shape=[1], dtype=core.VarDesc.VarType.INT32, value=self.idx),
self.case_to_execute)

step_scope = cases_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES)

cases_block.append_op(
type='conditional_block',
inputs={'X': [should_execute_block],
'Params': param_list},
outputs={'Out': out_vars,
'Scope': [step_scope]},
attrs={
'sub_block': self.block,
'is_scalar_condition': self.is_scalar_condition
})

return '%s,%s,%s,%s' % (self.idx, self.action, self.channel.name
if self.channel else '', self.value.name
if self.value else '')

def __exit__(self, exc_type, exc_val, exc_tb):
self.main_program.rollback()
if exc_type is not None:
return False # re-raise exception
return True


class Select(BlockGuard):
def __init__(self, name=None):
self.helper = LayerHelper('select', name=name)
self.cases = []

super(Select, self).__init__(self.helper.main_program)
self.case_to_execute = fill_constant(
shape=[1], dtype=core.VarDesc.VarType.INT32, value=-1)

def __enter__(self):
super(Select, self).__enter__()
return self

def case(self, channel_action_fn, channel, value):
"""Create a new block for this condition.
"""
select_case = SelectCase(
len(self.cases), self.case_to_execute, channel_action_fn, channel,
value)

self.cases.append(select_case)

return select_case

def default(self):
"""Create a default case block for this condition.
"""
default_case = SelectCase(len(self.cases), self.case_to_execute)

self.cases.append(default_case)

return default_case

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
return False

# Create a select op and another block to wrap its
# case blocks.
select_block = self.helper.main_program.current_block()
parent_block = self.helper.main_program.block(select_block.parent_idx)

# Construct each case op, inside the newly created select block.
serialized_cases = []
for case in self.cases:
serialized_cases.append(case.construct_op())

intermediate = set()
params = set()

for case_block in select_block.ops:
if case_block.attrs and 'sub_block' in case_block.attrs:
for each_op in case_block.attrs['sub_block'].ops:
assert isinstance(each_op, Operator)
for iname in each_op.input_names:
for in_var_name in each_op.input(iname):
if in_var_name not in intermediate:
params.add(in_var_name)

for oname in each_op.output_names:
for out_var_name in each_op.output(oname):
intermediate.add(out_var_name)

out_list = [
parent_block.var(var_name) for var_name in parent_block.vars
if var_name in intermediate
]

X = [select_block.var_recursive(x_name) for x_name in params]

# Needs to be used by `equal` inside the cases block.
X.append(self.case_to_execute)

# Construct the select op.
parent_block.append_op(
type='select',
inputs={'X': X,
'case_to_execute': self.case_to_execute},
attrs={'sub_block': select_block,
'cases': serialized_cases},
outputs={'Out': out_list})

return super(Select, self).__exit__(exc_type, exc_val, exc_tb)


def make_channel(dtype, capacity=0):
"""
Helps implementation of a concurrent program by creating a "channel" of
Expand Down
Loading