# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pooling modules."""
import jax.numpy as jnp
import numpy as np
from jax import lax
[docs]def pool(inputs, init, reduce_fn, window_shape, strides, padding):
"""Helper function to define pooling functions.
Pooling functions are implemented using the ReduceWindow XLA op.
.. note::
Be aware that pooling is not generally differentiable.
That means providing a reduce_fn that is differentiable does not imply that
pool is differentiable.
Args:
inputs: input data with dimensions (batch, window dims..., features).
init: the initial value for the reduction
reduce_fn: a reduce function of the form ``(T, T) -> T``.
window_shape: a shape tuple defining the window to reduce over.
strides: a sequence of ``n`` integers, representing the inter-window
strides (default: ``(1, ..., 1)``).
padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence
of ``n`` ``(low, high)`` integer pairs that give the padding to apply before
and after each spatial dimension.
Returns:
The output of the reduction for each window slice.
"""
num_batch_dims = inputs.ndim - (len(window_shape) + 1)
strides = strides or (1,) * len(window_shape)
assert len(window_shape) == len(
strides
), f'len({window_shape}) must equal len({strides})'
strides = (1,) * num_batch_dims + strides + (1,)
dims = (1,) * num_batch_dims + window_shape + (1,)
is_single_input = False
if num_batch_dims == 0:
# add singleton batch dimension because lax.reduce_window always
# needs a batch dimension.
inputs = inputs[None]
strides = (1,) + strides
dims = (1,) + dims
is_single_input = True
assert inputs.ndim == len(dims), f'len({inputs.shape}) != len({dims})'
if not isinstance(padding, str):
padding = tuple(map(tuple, padding))
assert len(padding) == len(window_shape), (
f'padding {padding} must specify pads for same number of dims as '
f'window_shape {window_shape}'
)
assert all(
[len(x) == 2 for x in padding]
), f'each entry in padding {padding} must be length 2'
padding = ((0, 0),) + padding + ((0, 0),)
y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
if is_single_input:
y = jnp.squeeze(y, axis=0)
return y
[docs]def avg_pool(
inputs, window_shape, strides=None, padding='VALID', count_include_pad=True
):
"""Pools the input by taking the average over a window.
Args:
inputs: input data with dimensions (batch, window dims..., features).
window_shape: a shape tuple defining the window to reduce over.
strides: a sequence of ``n`` integers, representing the inter-window
strides (default: ``(1, ..., 1)``).
padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence
of ``n`` ``(low, high)`` integer pairs that give the padding to apply before
and after each spatial dimension (default: ``'VALID'``).
count_include_pad: a boolean whether to include padded tokens
in the average calculation (default: ``True``).
Returns:
The average for each window slice.
"""
y = pool(inputs, 0.0, lax.add, window_shape, strides, padding)
if count_include_pad:
y = y / np.prod(window_shape)
else:
div_shape = inputs.shape[:-1] + (1,)
if len(div_shape) - 2 == len(window_shape):
div_shape = (1,) + div_shape[1:]
y = y / pool(
jnp.ones(div_shape), 0.0, lax.add, window_shape, strides, padding
)
return y
[docs]def max_pool(inputs, window_shape, strides=None, padding='VALID'):
"""Pools the input by taking the maximum of a window slice.
Args:
inputs: input data with dimensions (batch, window dims..., features).
window_shape: a shape tuple defining the window to reduce over.
strides: a sequence of ``n`` integers, representing the inter-window
strides (default: ``(1, ..., 1)``).
padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence
of ``n`` ``(low, high)`` integer pairs that give the padding to apply before
and after each spatial dimension (default: ``'VALID'``).
Returns:
The maximum for each window slice.
"""
y = pool(inputs, -jnp.inf, lax.max, window_shape, strides, padding)
return y
def min_pool(inputs, window_shape, strides=None, padding='VALID'):
"""Pools the input by taking the minimum of a window slice.
Args:
inputs: Input data with dimensions (batch, window dims..., features).
window_shape: A shape tuple defining the window to reduce over.
strides: A sequence of ``n`` integers, representing the inter-window strides
(default: ``(1, ..., 1)``).
padding: Either the string ``'SAME'``, the string ``'VALID'``, or a sequence of
``n`` ``(low, high)`` integer pairs that give the padding to apply before and
after each spatial dimension (default: ``'VALID'``).
Returns:
The minimum for each window slice.
"""
return pool(inputs, jnp.inf, lax.min, window_shape, strides, padding)