Skip to content

Commit

Permalink
Merge reduce-window batching rules. Add batching rule for reduce_wind…
Browse files Browse the repository at this point in the history
…ow_min.
  • Loading branch information
hawkinsp committed Jun 26, 2019
1 parent b8bac19 commit 755d281
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 23 deletions.
35 changes: 14 additions & 21 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3453,8 +3453,9 @@ def _reduce_window_sum_transpose_rule(cotangent, window_dimensions,
assert result.shape == input_shape
return [result]

def _reduce_window_sum_batch_rule(
batched_args, bdims, window_dimensions, window_strides, padding, **kwargs):
def _reduce_window_batch_rule(
reduce_window, batched_args, bdims, window_dimensions, window_strides,
padding, input_shape=None):
operand, = batched_args
bdim, = bdims

Expand All @@ -3463,16 +3464,17 @@ def _reduce_window_sum_batch_rule(
window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]

oprand = _reduce_window_sum(
operand = reduce_window(
operand, window_dimensions, window_strides, padding)

return oprand, 0
return operand, 0

reduce_window_sum_p = standard_primitive(
_reduce_window_sum_shape_rule, _input_dtype, 'reduce_window_sum',
_reduce_window_sum_translation_rule)
ad.deflinear(reduce_window_sum_p, _reduce_window_sum_transpose_rule)
batching.primitive_batchers[reduce_window_sum_p] = _reduce_window_sum_batch_rule
batching.primitive_batchers[reduce_window_sum_p] = partial(
_reduce_window_batch_rule, _reduce_window_sum)

def _reduce_window_chooser_translation_rule(
prim, identity, c, operand, window_dimensions, window_strides, padding):
Expand Down Expand Up @@ -3514,28 +3516,14 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
onp.subtract(operand_padded, window_dimensions), window_strides) + 1
return tuple(t)

def _reduce_window_max_batch_rule(
batched_args, bdims, window_dimensions, window_strides, padding, **kwargs):
operand, = batched_args
bdim, = bdims

if bdim is not None:
window_dimensions = \
window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]

operand = _reduce_window_max(
operand, window_dimensions, window_strides, padding)

return operand, 0

_reduce_window_max_translation_rule = partial(
_reduce_window_chooser_translation_rule, max_p, _get_max_identity)
reduce_window_max_p = standard_primitive(
_common_reduce_window_shape_rule, _input_dtype, 'reduce_window_max',
_reduce_window_max_translation_rule)
ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule, max_p))
batching.primitive_batchers[reduce_window_max_p] = _reduce_window_max_batch_rule
batching.primitive_batchers[reduce_window_max_p] = partial(
_reduce_window_batch_rule, _reduce_window_max)

_reduce_window_min_translation_rule = partial(
_reduce_window_chooser_translation_rule, min_p, _get_min_identity)
Expand All @@ -3544,6 +3532,11 @@ def _reduce_window_max_batch_rule(
_reduce_window_min_translation_rule)
ad.defjvp(reduce_window_min_p, partial(_reduce_window_chooser_jvp_rule, min_p))

_reduce_window_min_batch_rule = partial(_reduce_window_batch_rule,
_reduce_window_min)
batching.primitive_batchers[reduce_window_min_p] = partial(
_reduce_window_batch_rule, _reduce_window_min)


def _select_and_scatter_shape_rule(
operand, source, init_value, select_jaxpr, select_consts, scatter_jaxpr,
Expand Down
7 changes: 5 additions & 2 deletions tests/batching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,10 @@ def f(params, x):
(5, 21, 5, 1)))
self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

def testMaxPool(self):
@parameterized.named_parameters(
{"testcase_name": "_op={}".format(name), "op": op, "unit": unit}
for name, op, unit in [("max", lax.max, -np.inf), ("min", lax.min, np.inf)])
def testMinMaxPool(self, op, unit):
W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)

Expand All @@ -495,7 +498,7 @@ def f(params, x):
y = lax.conv_general_dilated(
x, params, one, 'SAME', one, one, dimension_numbers)
y = lax.reduce_window(
y, -np.inf, lax.max, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME')
y, unit, op, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME')
return y
grad_loss = grad(lambda params, x: np.mean(f(params, x) ** 2))

Expand Down

0 comments on commit 755d281

Please sign in to comment.