Skip to content

Commit

Permalink
Moved pooling OP from linen to core, added pooling documentation to NNX
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisSchaller committed Nov 28, 2024
1 parent 5d896bc commit f677937
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 149 deletions.
2 changes: 1 addition & 1 deletion docs_nnx/api_reference/flax.nnx/nn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/index.html>`__ for
linear
lora
normalization
pooling
recurrent
stochastic

10 changes: 10 additions & 0 deletions docs_nnx/api_reference/flax.nnx/nn/pooling.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Pooling
========================
.. currentmodule:: flax.nnx

Pooling function
------------------------

.. autofunction:: max_pool
.. autofunction:: avg_pool
.. autofunction:: pool
2 changes: 1 addition & 1 deletion flax/core/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
swish as swish,
tanh as tanh,
)
from flax.linen.pooling import (avg_pool as avg_pool, max_pool as max_pool)
from .pooling import (max_pool as max_pool, avg_pool as avg_pool,min_pool as min_pool, pool as pool)
from .attention import (
dot_product_attention as dot_product_attention,
multi_head_dot_product_attention as multi_head_dot_product_attention,
Expand Down
143 changes: 143 additions & 0 deletions flax/core/nn/pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pooling modules."""

import jax.numpy as jnp
from numpy import prod as np_prod
from jax import lax


def pool(inputs, init, reduce_fn, window_shape, strides, padding):
"""Helper function to define pooling functions.
Pooling functions are implemented using the ReduceWindow XLA op.
.. note::
Be aware that pooling is not generally differentiable.
That means providing a reduce_fn that is differentiable does not imply that
pool is differentiable.
Args:
inputs: input data with dimensions (batch, window dims..., features).
init: the initial value for the reduction
reduce_fn: a reduce function of the form ``(T, T) -> T``.
window_shape: a shape tuple defining the window to reduce over.
strides: a sequence of ``n`` integers, representing the inter-window
strides (default: ``(1, ..., 1)``).
padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence
of ``n`` ``(low, high)`` integer pairs that give the padding to apply before
and after each spatial dimension.
Returns:
The output of the reduction for each window slice.
"""
num_batch_dims = inputs.ndim - (len(window_shape) + 1)
strides = strides or (1,) * len(window_shape)
assert len(window_shape) == len(
strides
), f'len({window_shape}) must equal len({strides})'
strides = (1,) * num_batch_dims + strides + (1,)
dims = (1,) * num_batch_dims + window_shape + (1,)

is_single_input = False
if num_batch_dims == 0:
# add singleton batch dimension because lax.reduce_window always
# needs a batch dimension.
inputs = inputs[None]
strides = (1,) + strides
dims = (1,) + dims
is_single_input = True

assert inputs.ndim == len(dims), f'len({inputs.shape}) != len({dims})'
if not isinstance(padding, str):
padding = tuple(map(tuple, padding))
assert len(padding) == len(window_shape), (
f'padding {padding} must specify pads for same number of dims as '
f'window_shape {window_shape}'
)
assert all(
[len(x) == 2 for x in padding]
), f'each entry in padding {padding} must be length 2'
padding = ((0, 0),) + padding + ((0, 0),)
y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
if is_single_input:
y = jnp.squeeze(y, axis=0)
return y


def avg_pool(
inputs, window_shape, strides=None, padding='VALID', count_include_pad=True
):
"""Pools the input by taking the average over a window.
Args:
inputs: input data with dimensions (batch, window dims..., features).
window_shape: a shape tuple defining the window to reduce over.
strides: a sequence of ``n`` integers, representing the inter-window
strides (default: ``(1, ..., 1)``).
padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence
of ``n`` ``(low, high)`` integer pairs that give the padding to apply before
and after each spatial dimension (default: ``'VALID'``).
count_include_pad: a boolean whether to include padded tokens
in the average calculation (default: ``True``).
Returns:
The average for each window slice.
"""
y = pool(inputs, 0.0, lax.add, window_shape, strides, padding)
if count_include_pad:
y = y / np_prod(window_shape)
else:
div_shape = inputs.shape[:-1] + (1,)
if len(div_shape) - 2 == len(window_shape):
div_shape = (1,) + div_shape[1:]
y = y / pool(
jnp.ones(div_shape), 0.0, lax.add, window_shape, strides, padding
)
return y


def max_pool(inputs, window_shape, strides=None, padding='VALID'):
"""Pools the input by taking the maximum of a window slice.
Args:
inputs: input data with dimensions (batch, window dims..., features).
window_shape: a shape tuple defining the window to reduce over.
strides: a sequence of ``n`` integers, representing the inter-window
strides (default: ``(1, ..., 1)``).
padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence
of ``n`` ``(low, high)`` integer pairs that give the padding to apply before
and after each spatial dimension (default: ``'VALID'``).
Returns:
The maximum for each window slice.
"""
y = pool(inputs, -jnp.inf, lax.max, window_shape, strides, padding)
return y


def min_pool(inputs, window_shape, strides=None, padding='VALID'):
"""Pools the input by taking the minimum of a window slice.
Args:
inputs: Input data with dimensions (batch, window dims..., features).
window_shape: A shape tuple defining the window to reduce over.
strides: A sequence of ``n`` integers, representing the inter-window strides
(default: ``(1, ..., 1)``).
padding: Either the string ``'SAME'``, the string ``'VALID'``, or a sequence of
``n`` ``(low, high)`` integer pairs that give the padding to apply before and
after each spatial dimension (default: ``'VALID'``).
Returns:
The minimum for each window slice.
"""
return pool(inputs, jnp.inf, lax.min, window_shape, strides, padding)
152 changes: 9 additions & 143 deletions flax/linen/pooling.py
Original file line number Diff line number Diff line change
@@ -1,143 +1,9 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pooling modules."""

import jax.numpy as jnp
import numpy as np
from jax import lax


def pool(inputs, init, reduce_fn, window_shape, strides, padding):
"""Helper function to define pooling functions.
Pooling functions are implemented using the ReduceWindow XLA op.
.. note::
Be aware that pooling is not generally differentiable.
That means providing a reduce_fn that is differentiable does not imply that
pool is differentiable.
Args:
inputs: input data with dimensions (batch, window dims..., features).
init: the initial value for the reduction
reduce_fn: a reduce function of the form ``(T, T) -> T``.
window_shape: a shape tuple defining the window to reduce over.
strides: a sequence of ``n`` integers, representing the inter-window
strides (default: ``(1, ..., 1)``).
padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence
of ``n`` ``(low, high)`` integer pairs that give the padding to apply before
and after each spatial dimension.
Returns:
The output of the reduction for each window slice.
"""
num_batch_dims = inputs.ndim - (len(window_shape) + 1)
strides = strides or (1,) * len(window_shape)
assert len(window_shape) == len(
strides
), f'len({window_shape}) must equal len({strides})'
strides = (1,) * num_batch_dims + strides + (1,)
dims = (1,) * num_batch_dims + window_shape + (1,)

is_single_input = False
if num_batch_dims == 0:
# add singleton batch dimension because lax.reduce_window always
# needs a batch dimension.
inputs = inputs[None]
strides = (1,) + strides
dims = (1,) + dims
is_single_input = True

assert inputs.ndim == len(dims), f'len({inputs.shape}) != len({dims})'
if not isinstance(padding, str):
padding = tuple(map(tuple, padding))
assert len(padding) == len(window_shape), (
f'padding {padding} must specify pads for same number of dims as '
f'window_shape {window_shape}'
)
assert all(
[len(x) == 2 for x in padding]
), f'each entry in padding {padding} must be length 2'
padding = ((0, 0),) + padding + ((0, 0),)
y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
if is_single_input:
y = jnp.squeeze(y, axis=0)
return y


def avg_pool(
inputs, window_shape, strides=None, padding='VALID', count_include_pad=True
):
"""Pools the input by taking the average over a window.
Args:
inputs: input data with dimensions (batch, window dims..., features).
window_shape: a shape tuple defining the window to reduce over.
strides: a sequence of ``n`` integers, representing the inter-window
strides (default: ``(1, ..., 1)``).
padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence
of ``n`` ``(low, high)`` integer pairs that give the padding to apply before
and after each spatial dimension (default: ``'VALID'``).
count_include_pad: a boolean whether to include padded tokens
in the average calculation (default: ``True``).
Returns:
The average for each window slice.
"""
y = pool(inputs, 0.0, lax.add, window_shape, strides, padding)
if count_include_pad:
y = y / np.prod(window_shape)
else:
div_shape = inputs.shape[:-1] + (1,)
if len(div_shape) - 2 == len(window_shape):
div_shape = (1,) + div_shape[1:]
y = y / pool(
jnp.ones(div_shape), 0.0, lax.add, window_shape, strides, padding
)
return y


def max_pool(inputs, window_shape, strides=None, padding='VALID'):
"""Pools the input by taking the maximum of a window slice.
Args:
inputs: input data with dimensions (batch, window dims..., features).
window_shape: a shape tuple defining the window to reduce over.
strides: a sequence of ``n`` integers, representing the inter-window
strides (default: ``(1, ..., 1)``).
padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence
of ``n`` ``(low, high)`` integer pairs that give the padding to apply before
and after each spatial dimension (default: ``'VALID'``).
Returns:
The maximum for each window slice.
"""
y = pool(inputs, -jnp.inf, lax.max, window_shape, strides, padding)
return y


def min_pool(inputs, window_shape, strides=None, padding='VALID'):
"""Pools the input by taking the minimum of a window slice.
Args:
inputs: Input data with dimensions (batch, window dims..., features).
window_shape: A shape tuple defining the window to reduce over.
strides: A sequence of ``n`` integers, representing the inter-window strides
(default: ``(1, ..., 1)``).
padding: Either the string ``'SAME'``, the string ``'VALID'``, or a sequence of
``n`` ``(low, high)`` integer pairs that give the padding to apply before and
after each spatial dimension (default: ``'VALID'``).
Returns:
The minimum for each window slice.
"""
return pool(inputs, jnp.inf, lax.min, window_shape, strides, padding)
# Export pooling functions
from flax.core.nn.pooling import(
avg_pool as avg_pool,
max_pool as max_pool,
min_pool as min_pool,
pool as pool,
)

__all__ = ['avg_pool', 'max_pool', 'min_pool', 'pool']
8 changes: 4 additions & 4 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from flax.linen.pooling import avg_pool as avg_pool
from flax.linen.pooling import max_pool as max_pool
from flax.linen.pooling import min_pool as min_pool
from flax.linen.pooling import pool as pool
from flax.core.nn.pooling import (avg_pool as avg_pool,
max_pool as max_pool,
min_pool as min_pool,
pool as pool)
from flax.typing import Initializer as Initializer

from .bridge import wrappers as wrappers
Expand Down

0 comments on commit f677937

Please sign in to comment.