Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polish memory optimization transpiler #9905

Merged
merged 1 commit into from
Apr 14, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 88 additions & 47 deletions python/paddle/fluid/memory_optimization_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,20 @@
core.VarDesc.VarType.BOOL: 1
}

sub_block_ops = [
SUB_BLOCK_OPS = [
"while", "while_grad", "parallel_do", "parallel_do_grad",
"conditional_block", "conditional_block_grad"
]

SUB_BLOCK_PAIR = [("while", "while_grad"), ("parallel_do", "parallel_do_grad"),
("conditional_block", "conditional_block_grad")]

PRINT_LOG = False


class ControlFlowGraph(object):
def __init__(self, Program, ops, forward_num, skip_opt):
self._program = Program
def __init__(self, program, ops, forward_num, skip_opt):
self._program = program
self._ops = ops
self._forward_num = forward_num
self._successors = defaultdict(set)
Expand All @@ -51,14 +54,19 @@ def __init__(self, Program, ops, forward_num, skip_opt):
self._skip_opt = skip_opt

def _add_connections(self, connections):
"""Populates _successors and _presuccessors for two neighbor nodes."""
for node1, node2 in connections:
self._add(node1, node2)

def _add(self, node1, node2):
self._successors[node1].add(node2)
self._presuccessors[node2].add(node1)

# TODO(panyx0718): We need to have a unified way of building intermediate
# representation.
def _build_graph(self):
"""Build a graph based on op sequence.
"""
self.op_size = len(self._ops)
op_node_connections = [(i, i + 1) for i in range(self.op_size - 1)]
self._add_connections(op_node_connections)
Expand All @@ -82,22 +90,23 @@ def _update_graph(self, old_name, new_name, begin_idx=0):
self._live_out[i].add(new_name)

def _reach_fixed_point(self, live_in, live_out):
"""Check if the liveness set has stablized."""
if len(live_in) != len(self._live_in):
return False
if len(live_out) != len(self._live_out):
return False
for i in range(self.op_size):
if live_in[i] != self._live_in[i]:
return False
for i in range(self.op_size):
if live_out[i] != self._live_out[i]:
if (live_in[i] != self._live_in[i] or
live_out[i] != self._live_out[i]):
return False
return True

def _dataflow_analyze(self):
self._build_graph()
live_in = defaultdict(set)
live_out = defaultdict(set)
# Repeatedly apply liveness updates until the algorithm stablize
# on a complete set live input vars and live output vars.
while True:
for i in range(self.op_size, 0, -1):
live_in[i] = set(self._live_in[i])
Expand Down Expand Up @@ -141,6 +150,8 @@ def _check_var_validity(self, block_desc, x, is_forward):
return False
return True

# TODO(panyx0718): This needs to be less hacky. It seems memory optimization
# doesn't consider vars copied between cpu and gpu.
def _update_skip_opt_set(self):
for i in range(self.op_size):
op = self._ops[i]
Expand All @@ -154,7 +165,7 @@ def release_memory(self):
bwd_id = 0
for i in range(self.op_size):
op = self._ops[i]
if op.type() in sub_block_ops:
if op.type() in SUB_BLOCK_OPS:
continue
block_desc = op.block()
is_forward = i < self._forward_num
Expand All @@ -177,24 +188,25 @@ def memory_optimize(self, level=0):
def compare_shape(x_shape, cache_shape, opt_level):
if opt_level == 0:
return x_shape == cache_shape
if opt_level == 1:
elif opt_level == 1:
if (x_shape[0] == -1) ^ (cache_shape[0] == -1):
return False
x_size = abs(reduce(lambda x, y: x * y, x_shape))
cache_size = abs(reduce(lambda x, y: x * y, cache_shape))
if x_size <= cache_size:
return True
else:
raise ValueError("only support opt_level 0 or 1.")
return False

self._dataflow_analyze()
self._update_skip_opt_set()
self.pool = []
for i in range(self.op_size):
op = self._ops[i]
if op.type() in sub_block_ops:
if op.type() in SUB_BLOCK_OPS:
continue
block_desc = op.block()
self.current_block_desc = block_desc
is_forward = i < self._forward_num
if self.pool:
defs_can_optimize = filter(
Expand All @@ -211,37 +223,40 @@ def compare_shape(x_shape, cache_shape, opt_level):
for index, cache_pair in enumerate(self.pool):
cache_var = cache_pair[0]
cache_shape = cache_pair[1]
if compare_shape(x_shape, cache_shape, level):
if self._has_var(block_desc, cache_var, is_forward):
x_dtype = self._find_var(block_desc, x,
is_forward).dtype()
cache_dtype = self._find_var(
block_desc, cache_var, is_forward).dtype()
# TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
# and dtype_to_size[cache_dtype]
if x_dtype == cache_dtype:
if PRINT_LOG:
print(
("Hit Cache !!!! cache pool index "
"is %d, var name is %s, "
"cached var name is %s, "
"var shape is %s ") %
(index, x, cache_var,
str(cache_shape)))
self.pool.pop(index)
if x == cache_var:
break
_rename_arg_(
self._ops, x, cache_var, begin_idx=i)
self._program.block(block_desc.id).var(
str(x)).desc = self._find_var(
block_desc, cache_var, is_forward)
self._update_graph(
x, cache_var, begin_idx=i)
break

in_diff, out_diff = self._get_diff(self._live_in[i],
self._live_out[i])
if not compare_shape(x_shape, cache_shape, level):
continue

if not self._has_var(block_desc, cache_var, is_forward):
continue

x_dtype = self._find_var(block_desc, x,
is_forward).dtype()
cache_dtype = self._find_var(block_desc, cache_var,
is_forward).dtype()
# TODO(qijun): actually, we should compare
# dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
if x_dtype != cache_dtype:
continue

if PRINT_LOG:
print(("Hit Cache !!!! cache pool index "
"is %d, var name is %s, "
"cached var name is %s, "
"var shape is %s ") % (index, x, cache_var,
str(cache_shape)))
self.pool.pop(index)
if x == cache_var:
break
# Rename the var to the cache var already with
# memory allocated in order to reuse the memory.
_rename_arg_(self._ops, x, cache_var, begin_idx=i)
self._program.block(block_desc.id).var(str(
x)).desc = self._find_var(block_desc, cache_var,
is_forward)
self._update_graph(x, cache_var, begin_idx=i)
break

in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i])
can_optimize = filter(
lambda x: self._check_var_validity(block_desc, x, is_forward),
in_diff)
Expand All @@ -252,6 +267,19 @@ def compare_shape(x_shape, cache_shape, opt_level):


def _process_sub_block_pair(pdesc, sub_block_pair):
"""Creates a list of tuple each of which tracks info of a subblock.

Note: this function doesn't handle nested subblocks yet.
TODO(panyx0718): assert if case nested subblocks happen.

:param pdesc: ProgramDesc.
:param sub_block_pair: A list op pairs. Each op pair is the forward
op and backward op. The ops in the list are special that they contain
a subblock of ops.
:return: A list of tuples, each tuple is (all ops in a subblock pair
including forward and backward, number of forward ops,
all output args names of the ops in the subblock pairs).
"""
ops_list = []
block_desc = pdesc.block(0)
op_size = block_desc.op_size()
Expand Down Expand Up @@ -308,6 +336,11 @@ def _process_sub_block_pair(pdesc, sub_block_pair):


def _get_cfgs(input_program):
"""Process each block and create ControlFlowGraph for each of them.

:param input_program: Program object.
:return: A list of ControlFlowGraph, each corresponds to a block.
"""
ops_list = []
pdesc = input_program.get_desc()
block_desc = pdesc.block(0)
Expand All @@ -316,11 +349,8 @@ def _get_cfgs(input_program):
ops_list.append(
([block_desc.op(i) for i in range(op_size)], op_size, set()))

sub_block_pair = [("while", "while_grad"), ("parallel_do",
"parallel_do_grad"),
("conditional_block", "conditional_block_grad")]

ops_list.extend(_process_sub_block_pair(pdesc, sub_block_pair))
# Only process one level of nested subblock.
ops_list.extend(_process_sub_block_pair(pdesc, SUB_BLOCK_PAIR))

cfgs = [
ControlFlowGraph(input_program, ops, forward_num, skip_opt)
Expand All @@ -330,6 +360,17 @@ def _get_cfgs(input_program):


def memory_optimize(input_program, print_log=False, level=0):
"""Optimize memory by reusing var memory.

Note: it doesn't not support subblock nested in subblock.

:param input_program: Input Program
:param print_log: whether to print debug log.
:param level: If level=0, reuse if the shape is completely equal, o
:return:
"""
if level != 0 and level != 1:
raise ValueError("only support opt_level 0 or 1.")
global PRINT_LOG
PRINT_LOG = print_log
cfgs = _get_cfgs(input_program)
Expand Down