addons/tensorflow_addons/seq2seq/attention_wrapper.py

2080 lines
86 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A dynamic attention wrapper for RNN cells."""
import collections
import functools
import math
from packaging.version import Version
import numpy as np
import tensorflow as tf
from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
from tensorflow_addons.utils import keras_utils
from tensorflow_addons.utils.types import (
AcceptableDTypes,
FloatTensorLike,
TensorLike,
Initializer,
Number,
)
from typeguard import typechecked
from typing import Optional, Callable, Union, List
if Version(tf.__version__) < Version("2.13"):
SERIALIZATION_ARGS = {}
else:
SERIALIZATION_ARGS = {"use_legacy_format": True}
class AttentionMechanism(tf.keras.layers.Layer):
"""Base class for attention mechanisms.
Common functionality includes:
1. Storing the query and memory layers.
2. Preprocessing and storing the memory.
Note that this layer takes memory as its init parameter, which is an
anti-pattern of Keras API, we have to keep the memory as init parameter for
performance and dependency reason. Under the hood, during `__init__()`, it
will invoke `base_layer.__call__(memory, setup_memory=True)`. This will let
keras to keep track of the memory tensor as the input of this layer. Once
the `__init__()` is done, then user can query the attention by
`score = att_obj([query, state])`, and use it as a normal keras layer.
Special attention is needed when adding using this class as the base layer
for new attention:
1. Build() could be invoked at least twice. So please make sure weights
are not duplicated.
2. Layer.get_weights() might return different set of weights if the
instance has `query_layer`. The query_layer weights is not initialized
until the memory is configured.
Also note that this layer does not work with Keras model when
`model.compile(run_eagerly=True)` due to the fact that this layer is
stateful. The support for that will be added in a future version.
"""
@typechecked
def __init__(
self,
memory: Union[TensorLike, None],
probability_fn: callable,
query_layer: Optional[tf.keras.layers.Layer] = None,
memory_layer: Optional[tf.keras.layers.Layer] = None,
memory_sequence_length: Optional[TensorLike] = None,
**kwargs,
):
"""Construct base AttentionMechanism class.
Args:
memory: The memory to query; usually the output of an RNN encoder.
This tensor should be shaped `[batch_size, max_time, ...]`.
probability_fn: A `callable`. Converts the score and previous
alignments to probabilities. Its signature should be:
`probabilities = probability_fn(score, state)`.
query_layer: Optional `tf.keras.layers.Layer` instance. The layer's
depth must match the depth of `memory_layer`. If `query_layer` is
not provided, the shape of `query` must match that of
`memory_layer`.
memory_layer: Optional `tf.keras.layers.Layer` instance. The layer's
depth must match the depth of `query_layer`.
If `memory_layer` is not provided, the shape of `memory` must match
that of `query_layer`.
memory_sequence_length: (optional) Sequence lengths for the batch
entries in memory. If provided, the memory tensor rows are masked
with zeros for values past the respective sequence lengths.
**kwargs: Dictionary that contains other common arguments for layer
creation.
"""
self.query_layer = query_layer
self.memory_layer = memory_layer
super().__init__(**kwargs)
self.default_probability_fn = probability_fn
self.probability_fn = probability_fn
self.keys = None
self.values = None
self.batch_size = None
self._memory_initialized = False
self._check_inner_dims_defined = True
self.supports_masking = True
if memory is not None:
# Setup the memory by self.__call__() with memory and
# memory_seq_length. This will make the attention follow the keras
# convention which takes all the tensor inputs via __call__().
if memory_sequence_length is None:
inputs = memory
else:
inputs = [memory, memory_sequence_length]
self.values = super().__call__(inputs, setup_memory=True)
@property
def memory_initialized(self):
"""Returns `True` if this attention mechanism has been initialized with
a memory."""
return self._memory_initialized
def build(self, input_shape):
if not self._memory_initialized:
# This is for setting up the memory, which contains memory and
# optional memory_sequence_length. Build the memory_layer with
# memory shape.
if self.memory_layer is not None and not self.memory_layer.built:
if isinstance(input_shape, list):
self.memory_layer.build(input_shape[0])
else:
self.memory_layer.build(input_shape)
else:
# The input_shape should be query.shape and state.shape. Use the
# query to init the query layer.
if self.query_layer is not None and not self.query_layer.built:
self.query_layer.build(input_shape[0])
def __call__(self, inputs, **kwargs):
"""Preprocess the inputs before calling `base_layer.__call__()`.
Note that there are situation here, one for setup memory, and one with
actual query and state.
1. When the memory has not been configured, we just pass all the param
to `base_layer.__call__()`, which will then invoke `self.call()` with
proper inputs, which allows this class to setup memory.
2. When the memory has already been setup, the input should contain
query and state, and optionally processed memory. If the processed
memory is not included in the input, we will have to append it to
the inputs and give it to the `base_layer.__call__()`. The processed
memory is the output of first invocation of `self.__call__()`. If we
don't add it here, then from keras perspective, the graph is
disconnected since the output from previous call is never used.
Args:
inputs: the inputs tensors.
**kwargs: dict, other keyeword arguments for the `__call__()`
"""
# Allow manual memory reset
if kwargs.get("setup_memory", False):
self._memory_initialized = False
if self._memory_initialized:
if len(inputs) not in (2, 3):
raise ValueError(
"Expect the inputs to have 2 or 3 tensors, got %d" % len(inputs)
)
if len(inputs) == 2:
# We append the calculated memory here so that the graph will be
# connected.
inputs.append(self.values)
return super().__call__(inputs, **kwargs)
def call(self, inputs, mask=None, setup_memory=False, **kwargs):
"""Setup the memory or query the attention.
There are two case here, one for setup memory, and the second is query
the attention score. `setup_memory` is the flag to indicate which mode
it is. The input list will be treated differently based on that flag.
Args:
inputs: a list of tensor that could either be `query` and `state`, or
`memory` and `memory_sequence_length`.
`query` is the tensor of dtype matching `memory` and shape
`[batch_size, query_depth]`.
`state` is the tensor of dtype matching `memory` and shape
`[batch_size, alignments_size]`. (`alignments_size` is memory's
`max_time`).
`memory` is the memory to query; usually the output of an RNN
encoder. The tensor should be shaped `[batch_size, max_time, ...]`.
`memory_sequence_length` (optional) is the sequence lengths for the
batch entries in memory. If provided, the memory tensor rows are
masked with zeros for values past the respective sequence lengths.
mask: optional bool tensor with shape `[batch, max_time]` for the
mask of memory. If it is not None, the corresponding item of the
memory should be filtered out during calculation.
setup_memory: boolean, whether the input is for setting up memory, or
query attention.
**kwargs: Dict, other keyword arguments for the call method.
Returns:
Either processed memory or attention score, based on `setup_memory`.
"""
if setup_memory:
if isinstance(inputs, list):
if len(inputs) not in (1, 2):
raise ValueError(
"Expect inputs to have 1 or 2 tensors, got %d" % len(inputs)
)
memory = inputs[0]
memory_sequence_length = inputs[1] if len(inputs) == 2 else None
memory_mask = mask
else:
memory, memory_sequence_length = inputs, None
memory_mask = mask
self.setup_memory(memory, memory_sequence_length, memory_mask)
# We force the self.built to false here since only memory is,
# initialized but the real query/state has not been call() yet. The
# layer should be build and call again.
self.built = False
# Return the processed memory in order to create the Keras
# connectivity data for it.
return self.values
else:
if not self._memory_initialized:
raise ValueError(
"Cannot query the attention before the setup of memory"
)
if len(inputs) not in (2, 3):
raise ValueError(
"Expect the inputs to have query, state, and optional "
"processed memory, got %d items" % len(inputs)
)
# Ignore the rest of the inputs and only care about the query and
# state
query, state = inputs[0], inputs[1]
return self._calculate_attention(query, state)
def setup_memory(self, memory, memory_sequence_length=None, memory_mask=None):
"""Pre-process the memory before actually query the memory.
This should only be called once at the first invocation of `call()`.
Args:
memory: The memory to query; usually the output of an RNN encoder.
This tensor should be shaped `[batch_size, max_time, ...]`.
memory_sequence_length (optional): Sequence lengths for the batch
entries in memory. If provided, the memory tensor rows are masked
with zeros for values past the respective sequence lengths.
memory_mask: (Optional) The boolean tensor with shape `[batch_size,
max_time]`. For any value equal to False, the corresponding value
in memory should be ignored.
"""
if memory_sequence_length is not None and memory_mask is not None:
raise ValueError(
"memory_sequence_length and memory_mask cannot be "
"used at same time for attention."
)
with tf.name_scope(self.name or "BaseAttentionMechanismInit"):
self.values = _prepare_memory(
memory,
memory_sequence_length=memory_sequence_length,
memory_mask=memory_mask,
check_inner_dims_defined=self._check_inner_dims_defined,
)
# Mark the value as check since the memory and memory mask might not
# passed from __call__(), which does not have proper keras metadata.
# TODO(omalleyt12): Remove this hack once the mask the has proper
# keras history.
def _mark_checked(tensor):
tensor._keras_history_checked = True # pylint: disable=protected-access
tf.nest.map_structure(_mark_checked, self.values)
if self.memory_layer is not None:
self.keys = self.memory_layer(self.values)
else:
self.keys = self.values
self.batch_size = self.keys.shape[0] or tf.shape(self.keys)[0]
self._alignments_size = self.keys.shape[1] or tf.shape(self.keys)[1]
if memory_mask is not None or memory_sequence_length is not None:
unwrapped_probability_fn = self.default_probability_fn
def _mask_probability_fn(score, prev):
return unwrapped_probability_fn(
_maybe_mask_score(
score,
memory_mask=memory_mask,
memory_sequence_length=memory_sequence_length,
score_mask_value=score.dtype.min,
),
prev,
)
self.probability_fn = _mask_probability_fn
self._memory_initialized = True
def _calculate_attention(self, query, state):
raise NotImplementedError(
"_calculate_attention need to be implemented by subclasses."
)
def compute_mask(self, inputs, mask=None):
# There real input of the attention is query and state, and the memory
# layer mask shouldn't be pass down. Returning None for all output mask
# here.
return None, None
def get_config(self):
config = {}
# Since the probability_fn is likely to be a wrapped function, the child
# class should preserve the original function and how its wrapped.
if self.query_layer is not None:
config["query_layer"] = {
"class_name": self.query_layer.__class__.__name__,
"config": self.query_layer.get_config(),
}
if self.memory_layer is not None:
config["memory_layer"] = {
"class_name": self.memory_layer.__class__.__name__,
"config": self.memory_layer.get_config(),
}
# memory is a required init parameter and its a tensor. It cannot be
# serialized to config, so we put a placeholder for it.
config["memory"] = None
base_config = super().get_config()
return {**base_config, **config}
def _process_probability_fn(self, func_name):
"""Helper method to retrieve the probably function by string input."""
valid_probability_fns = {
"softmax": tf.nn.softmax,
"hardmax": hardmax,
}
if func_name not in valid_probability_fns.keys():
raise ValueError(
"Invalid probability function: %s, options are %s"
% (func_name, valid_probability_fns.keys())
)
return valid_probability_fns[func_name]
@classmethod
def deserialize_inner_layer_from_config(cls, config, custom_objects):
"""Helper method that reconstruct the query and memory from the config.
In the get_config() method, the query and memory layer configs are
serialized into dict for persistence, this method perform the reverse
action to reconstruct the layer from the config.
Args:
config: dict, the configs that will be used to reconstruct the
object.
custom_objects: dict mapping class names (or function names) of
custom (non-Keras) objects to class/functions.
Returns:
config: dict, the config with layer instance created, which is ready
to be used as init parameters.
"""
# Reconstruct the query and memory layer for parent class.
# Instead of updating the input, create a copy and use that.
config = config.copy()
query_layer_config = config.pop("query_layer", None)
if query_layer_config:
query_layer = tf.keras.layers.deserialize(
query_layer_config,
custom_objects=custom_objects,
**SERIALIZATION_ARGS,
)
config["query_layer"] = query_layer
memory_layer_config = config.pop("memory_layer", None)
if memory_layer_config:
memory_layer = tf.keras.layers.deserialize(
memory_layer_config,
custom_objects=custom_objects,
**SERIALIZATION_ARGS,
)
config["memory_layer"] = memory_layer
return config
@property
def alignments_size(self):
if isinstance(self._alignments_size, int):
return self._alignments_size
else:
return tf.TensorShape([None])
@property
def state_size(self):
return self.alignments_size
def initial_alignments(self, batch_size, dtype):
"""Creates the initial alignment values for the `tfa.seq2seq.AttentionWrapper`
class.
This is important for attention mechanisms that use the previous
alignment to calculate the alignment at the next time step
(e.g. monotonic attention).
The default behavior is to return a tensor of all zeros.
Args:
batch_size: `int32` scalar, the batch_size.
dtype: The `dtype`.
Returns:
A `dtype` tensor shaped `[batch_size, alignments_size]`
(`alignments_size` is the values' `max_time`).
"""
return tf.zeros([batch_size, self._alignments_size], dtype=dtype)
def initial_state(self, batch_size, dtype):
"""Creates the initial state values for the `tfa.seq2seq.AttentionWrapper` class.
This is important for attention mechanisms that use the previous
alignment to calculate the alignment at the next time step
(e.g. monotonic attention).
The default behavior is to return the same output as
`initial_alignments`.
Args:
batch_size: `int32` scalar, the batch_size.
dtype: The `dtype`.
Returns:
A structure of all-zero tensors with shapes as described by
`state_size`.
"""
return self.initial_alignments(batch_size, dtype)
def _luong_score(query, keys, scale):
"""Implements Luong-style (multiplicative) scoring function.
This attention has two forms. The first is standard Luong attention,
as described in:
Minh-Thang Luong, Hieu Pham, Christopher D. Manning.
"Effective Approaches to Attention-based Neural Machine Translation."
EMNLP 2015. https://arxiv.org/abs/1508.04025
The second is the scaled form inspired partly by the normalized form of
Bahdanau attention.
To enable the second form, call this function with `scale=True`.
Args:
query: Tensor, shape `[batch_size, num_units]` to compare to keys.
keys: Processed memory, shape `[batch_size, max_time, num_units]`.
scale: the optional tensor to scale the attention score.
Returns:
A `[batch_size, max_time]` tensor of unnormalized score values.
Raises:
ValueError: If `key` and `query` depths do not match.
"""
depth = query.shape[-1]
key_units = keys.shape[-1]
if depth != key_units:
raise ValueError(
"Incompatible or unknown inner dimensions between query and keys. "
"Query (%s) has units: %s. Keys (%s) have units: %s. "
"Perhaps you need to set num_units to the keys' dimension (%s)?"
% (query, depth, keys, key_units, key_units)
)
# Reshape from [batch_size, depth] to [batch_size, 1, depth]
# for matmul.
query = tf.expand_dims(query, 1)
# Inner product along the query units dimension.
# matmul shapes: query is [batch_size, 1, depth] and
# keys is [batch_size, max_time, depth].
# the inner product is asked to **transpose keys' inner shape** to get a
# batched matmul on:
# [batch_size, 1, depth] . [batch_size, depth, max_time]
# resulting in an output shape of:
# [batch_size, 1, max_time].
# we then squeeze out the center singleton dimension.
score = tf.matmul(query, keys, transpose_b=True)
score = tf.squeeze(score, [1])
if scale is not None:
score = scale * score
return score
class LuongAttention(AttentionMechanism):
"""Implements Luong-style (multiplicative) attention scoring.
This attention has two forms. The first is standard Luong attention,
as described in:
Minh-Thang Luong, Hieu Pham, Christopher D. Manning.
[Effective Approaches to Attention-based Neural Machine Translation.
EMNLP 2015.](https://arxiv.org/abs/1508.04025)
The second is the scaled form inspired partly by the normalized form of
Bahdanau attention.
To enable the second form, construct the object with parameter
`scale=True`.
"""
@typechecked
def __init__(
self,
units: TensorLike,
memory: Optional[TensorLike] = None,
memory_sequence_length: Optional[TensorLike] = None,
scale: bool = False,
probability_fn: str = "softmax",
dtype: AcceptableDTypes = None,
name: str = "LuongAttention",
**kwargs,
):
"""Construct the AttentionMechanism mechanism.
Args:
units: The depth of the attention mechanism.
memory: The memory to query; usually the output of an RNN encoder.
This tensor should be shaped `[batch_size, max_time, ...]`.
memory_sequence_length: (optional): Sequence lengths for the batch
entries in memory. If provided, the memory tensor rows are masked
with zeros for values past the respective sequence lengths.
scale: Python boolean. Whether to scale the energy term.
probability_fn: (optional) string, the name of function to convert
the attention score to probabilities. The default is `softmax`
which is `tf.nn.softmax`. Other options is `hardmax`, which is
hardmax() within this module. Any other value will result
intovalidation error. Default to use `softmax`.
dtype: The data type for the memory layer of the attention mechanism.
name: Name to use when creating ops.
**kwargs: Dictionary that contains other common arguments for layer
creation.
"""
# For LuongAttention, we only transform the memory layer; thus
# num_units **must** match expected the query depth.
self.probability_fn_name = probability_fn
probability_fn = self._process_probability_fn(self.probability_fn_name)
def wrapped_probability_fn(score, _):
return probability_fn(score)
memory_layer = kwargs.pop("memory_layer", None)
if not memory_layer:
memory_layer = tf.keras.layers.Dense(
units, name="memory_layer", use_bias=False, dtype=dtype
)
self.units = units
self.scale = scale
self.scale_weight = None
super().__init__(
memory=memory,
memory_sequence_length=memory_sequence_length,
query_layer=None,
memory_layer=memory_layer,
probability_fn=wrapped_probability_fn,
name=name,
dtype=dtype,
**kwargs,
)
def build(self, input_shape):
super().build(input_shape)
if self.scale and self.scale_weight is None:
self.scale_weight = self.add_weight(
"attention_g", initializer=tf.ones_initializer, shape=()
)
self.built = True
def _calculate_attention(self, query, state):
"""Score the query based on the keys and values.
Args:
query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`.
state: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
Returns:
alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]` (`alignments_size` is memory's
`max_time`).
next_state: Same as the alignments.
"""
score = _luong_score(query, self.keys, self.scale_weight)
alignments = self.probability_fn(score, state)
next_state = alignments
return alignments, next_state
def get_config(self):
config = {
"units": self.units,
"scale": self.scale,
"probability_fn": self.probability_fn_name,
}
base_config = super().get_config()
return {**base_config, **config}
@classmethod
def from_config(cls, config, custom_objects=None):
config = AttentionMechanism.deserialize_inner_layer_from_config(
config, custom_objects=custom_objects
)
return cls(**config)
def _bahdanau_score(
processed_query, keys, attention_v, attention_g=None, attention_b=None
):
"""Implements Bahdanau-style (additive) scoring function.
This attention has two forms. The first is Bahdanau attention,
as described in:
Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
"Neural Machine Translation by Jointly Learning to Align and Translate."
ICLR 2015. https://arxiv.org/abs/1409.0473
The second is the normalized form. This form is inspired by the
weight normalization article:
Tim Salimans, Diederik P. Kingma.
"Weight Normalization: A Simple Reparameterization to Accelerate
Training of Deep Neural Networks."
https://arxiv.org/abs/1602.07868
To enable the second form, set please pass in attention_g and attention_b.
Args:
processed_query: Tensor, shape `[batch_size, num_units]` to compare to
keys.
keys: Processed memory, shape `[batch_size, max_time, num_units]`.
attention_v: Tensor, shape `[num_units]`.
attention_g: Optional scalar tensor for normalization.
attention_b: Optional tensor with shape `[num_units]` for normalization.
Returns:
A `[batch_size, max_time]` tensor of unnormalized score values.
"""
# Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
processed_query = tf.expand_dims(processed_query, 1)
if attention_g is not None and attention_b is not None:
normed_v = (
attention_g
* attention_v
* tf.math.rsqrt(tf.reduce_sum(tf.square(attention_v)))
)
return tf.reduce_sum(
normed_v * tf.tanh(keys + processed_query + attention_b), [2]
)
else:
return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query), [2])
class BahdanauAttention(AttentionMechanism):
"""Implements Bahdanau-style (additive) attention.
This attention has two forms. The first is Bahdanau attention,
as described in:
Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
"Neural Machine Translation by Jointly Learning to Align and Translate."
ICLR 2015. https://arxiv.org/abs/1409.0473
The second is the normalized form. This form is inspired by the
weight normalization article:
Tim Salimans, Diederik P. Kingma.
"Weight Normalization: A Simple Reparameterization to Accelerate
Training of Deep Neural Networks."
https://arxiv.org/abs/1602.07868
To enable the second form, construct the object with parameter
`normalize=True`.
"""
@typechecked
def __init__(
self,
units: TensorLike,
memory: Optional[TensorLike] = None,
memory_sequence_length: Optional[TensorLike] = None,
normalize: bool = False,
probability_fn: str = "softmax",
kernel_initializer: Initializer = "glorot_uniform",
dtype: AcceptableDTypes = None,
name: str = "BahdanauAttention",
**kwargs,
):
"""Construct the Attention mechanism.
Args:
units: The depth of the query mechanism.
memory: The memory to query; usually the output of an RNN encoder.
This tensor should be shaped `[batch_size, max_time, ...]`.
memory_sequence_length: (optional): Sequence lengths for the batch
entries in memory. If provided, the memory tensor rows are masked
with zeros for values past the respective sequence lengths.
normalize: Python boolean. Whether to normalize the energy term.
probability_fn: (optional) string, the name of function to convert
the attention score to probabilities. The default is `softmax`
which is `tf.nn.softmax`. Other options is `hardmax`, which is
hardmax() within this module. Any other value will result into
validation error. Default to use `softmax`.
kernel_initializer: (optional), the name of the initializer for the
attention kernel.
dtype: The data type for the query and memory layers of the attention
mechanism.
name: Name to use when creating ops.
**kwargs: Dictionary that contains other common arguments for layer
creation.
"""
self.probability_fn_name = probability_fn
probability_fn = self._process_probability_fn(self.probability_fn_name)
def wrapped_probability_fn(score, _):
return probability_fn(score)
query_layer = kwargs.pop("query_layer", None)
if not query_layer:
query_layer = tf.keras.layers.Dense(
units, name="query_layer", use_bias=False, dtype=dtype
)
memory_layer = kwargs.pop("memory_layer", None)
if not memory_layer:
memory_layer = tf.keras.layers.Dense(
units, name="memory_layer", use_bias=False, dtype=dtype
)
self.units = units
self.normalize = normalize
self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self.attention_v = None
self.attention_g = None
self.attention_b = None
super().__init__(
memory=memory,
memory_sequence_length=memory_sequence_length,
query_layer=query_layer,
memory_layer=memory_layer,
probability_fn=wrapped_probability_fn,
name=name,
dtype=dtype,
**kwargs,
)
def build(self, input_shape):
super().build(input_shape)
if self.attention_v is None:
self.attention_v = self.add_weight(
"attention_v",
[self.units],
dtype=self.dtype,
initializer=self.kernel_initializer,
)
if self.normalize and self.attention_g is None and self.attention_b is None:
self.attention_g = self.add_weight(
"attention_g",
initializer=tf.constant_initializer(math.sqrt(1.0 / self.units)),
shape=(),
)
self.attention_b = self.add_weight(
"attention_b", shape=[self.units], initializer=tf.zeros_initializer()
)
self.built = True
def _calculate_attention(self, query, state):
"""Score the query based on the keys and values.
Args:
query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`.
state: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
Returns:
alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]` (`alignments_size` is memory's
`max_time`).
next_state: same as alignments.
"""
processed_query = self.query_layer(query) if self.query_layer else query
score = _bahdanau_score(
processed_query,
self.keys,
self.attention_v,
attention_g=self.attention_g,
attention_b=self.attention_b,
)
alignments = self.probability_fn(score, state)
next_state = alignments
return alignments, next_state
def get_config(self):
# yapf: disable
config = {
"units": self.units,
"normalize": self.normalize,
"probability_fn": self.probability_fn_name,
"kernel_initializer": tf.keras.initializers.serialize(
self.kernel_initializer,
**SERIALIZATION_ARGS,
)
}
# yapf: enable
base_config = super().get_config()
return {**base_config, **config}
@classmethod
def from_config(cls, config, custom_objects=None):
config = AttentionMechanism.deserialize_inner_layer_from_config(
config,
custom_objects=custom_objects,
)
return cls(**config)
def safe_cumprod(x: TensorLike, *args, **kwargs) -> tf.Tensor:
"""Computes cumprod of x in logspace using cumsum to avoid underflow.
The cumprod function and its gradient can result in numerical instabilities
when its argument has very small and/or zero values. As long as the
argument is all positive, we can instead compute the cumulative product as
exp(cumsum(log(x))). This function can be called identically to
tf.cumprod.
Args:
x: Tensor to take the cumulative product of.
*args: Passed on to cumsum; these are identical to those in cumprod.
**kwargs: Passed on to cumsum; these are identical to those in cumprod.
Returns:
Cumulative product of x.
"""
with tf.name_scope("SafeCumprod"):
x = tf.convert_to_tensor(x, name="x")
tiny = np.finfo(x.dtype.as_numpy_dtype).tiny
return tf.exp(
tf.cumsum(tf.math.log(tf.clip_by_value(x, tiny, 1)), *args, **kwargs)
)
def monotonic_attention(
p_choose_i: FloatTensorLike, previous_attention: FloatTensorLike, mode: str
) -> tf.Tensor:
"""Computes monotonic attention distribution from choosing probabilities.
Monotonic attention implies that the input sequence is processed in an
explicitly left-to-right manner when generating the output sequence. In
addition, once an input sequence element is attended to at a given output
timestep, elements occurring before it cannot be attended to at subsequent
output timesteps. This function generates attention distributions
according to these assumptions. For more information, see `Online and
Linear-Time Attention by Enforcing Monotonic Alignments`.
Args:
p_choose_i: Probability of choosing input sequence/memory element i.
Should be of shape (batch_size, input_sequence_length), and should all
be in the range [0, 1].
previous_attention: The attention distribution from the previous output
timestep. Should be of shape (batch_size, input_sequence_length). For
the first output timestep, preevious_attention[n] should be
[1, 0, 0, ..., 0] for all n in [0, ... batch_size - 1].
mode: How to compute the attention distribution. Must be one of
'recursive', 'parallel', or 'hard'.
* 'recursive' uses tf.scan to recursively compute the distribution.
This is slowest but is exact, general, and does not suffer from
numerical instabilities.
* 'parallel' uses parallelized cumulative-sum and cumulative-product
operations to compute a closed-form solution to the recurrence
relation defining the attention distribution. This makes it more
efficient than 'recursive', but it requires numerical checks which
make the distribution non-exact. This can be a problem in
particular when input_sequence_length is long and/or p_choose_i has
entries very close to 0 or 1.
* 'hard' requires that the probabilities in p_choose_i are all either
0 or 1, and subsequently uses a more efficient and exact solution.
Returns:
A tensor of shape (batch_size, input_sequence_length) representing the
attention distributions for each sequence in the batch.
Raises:
ValueError: mode is not one of 'recursive', 'parallel', 'hard'.
"""
# Force things to be tensors
p_choose_i = tf.convert_to_tensor(p_choose_i, name="p_choose_i")
previous_attention = tf.convert_to_tensor(
previous_attention, name="previous_attention"
)
if mode == "recursive":
# Use .shape[0] when it's not None, or fall back on symbolic shape
batch_size = p_choose_i.shape[0] or tf.shape(p_choose_i)[0]
# Compute [1, 1 - p_choose_i[0], 1 - p_choose_i[1], ..., 1 - p_choose_
# i[-2]]
shifted_1mp_choose_i = tf.concat(
[tf.ones((batch_size, 1)), 1 - p_choose_i[:, :-1]], 1
)
# Compute attention distribution recursively as
# q[i] = (1 - p_choose_i[i - 1])*q[i - 1] + previous_attention[i]
# attention[i] = p_choose_i[i]*q[i]
attention = p_choose_i * tf.transpose(
tf.scan(
# Need to use reshape to remind TF of the shape between loop
# iterations
lambda x, yz: tf.reshape(yz[0] * x + yz[1], (batch_size,)),
# Loop variables yz[0] and yz[1]
[tf.transpose(shifted_1mp_choose_i), tf.transpose(previous_attention)],
# Initial value of x is just zeros
tf.zeros((batch_size,)),
)
)
elif mode == "parallel":
# safe_cumprod computes cumprod in logspace with numeric checks
cumprod_1mp_choose_i = safe_cumprod(1 - p_choose_i, axis=1, exclusive=True)
# Compute recurrence relation solution
attention = (
p_choose_i
* cumprod_1mp_choose_i
* tf.cumsum(
previous_attention /
# Clip cumprod_1mp to avoid divide-by-zero
tf.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.0),
axis=1,
)
)
elif mode == "hard":
# Remove any probabilities before the index chosen last time step
p_choose_i *= tf.cumsum(previous_attention, axis=1)
# Now, use exclusive cumprod to remove probabilities after the first
# chosen index, like so:
# p_choose_i = [0, 0, 0, 1, 1, 0, 1, 1]
# cumprod(1 - p_choose_i, exclusive=True) = [1, 1, 1, 1, 0, 0, 0, 0]
# Product of above: [0, 0, 0, 1, 0, 0, 0, 0]
attention = p_choose_i * tf.math.cumprod(1 - p_choose_i, axis=1, exclusive=True)
else:
raise ValueError("mode must be 'recursive', 'parallel', or 'hard'.")
return attention
def _monotonic_probability_fn(
score, previous_alignments, sigmoid_noise, mode, seed=None
):
"""Attention probability function for monotonic attention.
Takes in unnormalized attention scores, adds pre-sigmoid noise to encourage
the model to make discrete attention decisions, passes them through a
sigmoid to obtain "choosing" probabilities, and then calls
monotonic_attention to obtain the attention distribution. For more
information, see
Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
"Online and Linear-Time Attention by Enforcing Monotonic Alignments."
ICML 2017. https://arxiv.org/abs/1704.00784
Args:
score: Unnormalized attention scores, shape
`[batch_size, alignments_size]`
previous_alignments: Previous attention distribution, shape
`[batch_size, alignments_size]`
sigmoid_noise: Standard deviation of pre-sigmoid noise. Setting this
larger than 0 will encourage the model to produce large attention
scores, effectively making the choosing probabilities discrete and the
resulting attention distribution one-hot. It should be set to 0 at
test-time, and when hard attention is not desired.
mode: How to compute the attention distribution. Must be one of
'recursive', 'parallel', or 'hard'. See the docstring for
`tfa.seq2seq.monotonic_attention` for more information.
seed: (optional) Random seed for pre-sigmoid noise.
Returns:
A `[batch_size, alignments_size]`-shape tensor corresponding to the
resulting attention distribution.
"""
# Optionally add pre-sigmoid noise to the scores
if sigmoid_noise > 0:
noise = tf.random.normal(tf.shape(score), dtype=score.dtype, seed=seed)
score += sigmoid_noise * noise
# Compute "choosing" probabilities from the attention scores
if mode == "hard":
# When mode is hard, use a hard sigmoid
p_choose_i = tf.cast(score > 0, score.dtype)
else:
p_choose_i = tf.sigmoid(score)
# Convert from choosing probabilities to attention distribution
return monotonic_attention(p_choose_i, previous_alignments, mode)
class _BaseMonotonicAttentionMechanism(AttentionMechanism):
"""Base attention mechanism for monotonic attention.
Simply overrides the initial_alignments function to provide a dirac
distribution, which is needed in order for the monotonic attention
distributions to have the correct behavior.
"""
def initial_alignments(self, batch_size, dtype):
"""Creates the initial alignment values for the monotonic attentions.
Initializes to dirac distributions, i.e.
[1, 0, 0, ...memory length..., 0] for all entries in the batch.
Args:
batch_size: `int32` scalar, the batch_size.
dtype: The `dtype`.
Returns:
A `dtype` tensor shaped `[batch_size, alignments_size]`
(`alignments_size` is the values' `max_time`).
"""
max_time = self._alignments_size
return tf.one_hot(
tf.zeros((batch_size,), dtype=tf.int32), max_time, dtype=dtype
)
class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism):
"""Monotonic attention mechanism with Bahdanau-style energy function.
This type of attention enforces a monotonic constraint on the attention
distributions; that is once the model attends to a given point in the
memory it can't attend to any prior points at subsequence output timesteps.
It achieves this by using the `_monotonic_probability_fn` instead of `softmax`
to construct its attention distributions. Since the attention scores are
passed through a sigmoid, a learnable scalar bias parameter is applied
after the score function and before the sigmoid. Otherwise, it is
equivalent to `tfa.seq2seq.BahdanauAttention`. This approach is proposed in
Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
"Online and Linear-Time Attention by Enforcing Monotonic Alignments."
ICML 2017. https://arxiv.org/abs/1704.00784
"""
@typechecked
def __init__(
self,
units: TensorLike,
memory: Optional[TensorLike] = None,
memory_sequence_length: Optional[TensorLike] = None,
normalize: bool = False,
sigmoid_noise: FloatTensorLike = 0.0,
sigmoid_noise_seed: Optional[FloatTensorLike] = None,
score_bias_init: FloatTensorLike = 0.0,
mode: str = "parallel",
kernel_initializer: Initializer = "glorot_uniform",
dtype: AcceptableDTypes = None,
name: str = "BahdanauMonotonicAttention",
**kwargs,
):
"""Construct the attention mechanism.
Args:
units: The depth of the query mechanism.
memory: The memory to query; usually the output of an RNN encoder.
This tensor should be shaped `[batch_size, max_time, ...]`.
memory_sequence_length: (optional): Sequence lengths for the batch
entries in memory. If provided, the memory tensor rows are masked
with zeros for values past the respective sequence lengths.
normalize: Python boolean. Whether to normalize the energy term.
sigmoid_noise: Standard deviation of pre-sigmoid noise. See the
docstring for `_monotonic_probability_fn` for more information.
sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise.
score_bias_init: Initial value for score bias scalar. It's
recommended to initialize this to a negative value when the length
of the memory is large.
mode: How to compute the attention distribution. Must be one of
'recursive', 'parallel', or 'hard'. See the docstring for
`tfa.seq2seq.monotonic_attention` for more information.
kernel_initializer: (optional), the name of the initializer for the
attention kernel.
dtype: The data type for the query and memory layers of the attention
mechanism.
name: Name to use when creating ops.
**kwargs: Dictionary that contains other common arguments for layer
creation.
"""
# Set up the monotonic probability fn with supplied parameters
wrapped_probability_fn = functools.partial(
_monotonic_probability_fn,
sigmoid_noise=sigmoid_noise,
mode=mode,
seed=sigmoid_noise_seed,
)
query_layer = kwargs.pop("query_layer", None)
if not query_layer:
query_layer = tf.keras.layers.Dense(
units, name="query_layer", use_bias=False, dtype=dtype
)
memory_layer = kwargs.pop("memory_layer", None)
if not memory_layer:
memory_layer = tf.keras.layers.Dense(
units, name="memory_layer", use_bias=False, dtype=dtype
)
self.units = units
self.normalize = normalize
self.sigmoid_noise = sigmoid_noise
self.sigmoid_noise_seed = sigmoid_noise_seed
self.score_bias_init = score_bias_init
self.mode = mode
self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self.attention_v = None
self.attention_score_bias = None
self.attention_g = None
self.attention_b = None
super().__init__(
memory=memory,
memory_sequence_length=memory_sequence_length,
query_layer=query_layer,
memory_layer=memory_layer,
probability_fn=wrapped_probability_fn,
name=name,
dtype=dtype,
**kwargs,
)
def build(self, input_shape):
super().build(input_shape)
if self.attention_v is None:
self.attention_v = self.add_weight(
"attention_v",
[self.units],
dtype=self.dtype,
initializer=self.kernel_initializer,
)
if self.attention_score_bias is None:
self.attention_score_bias = self.add_weight(
"attention_score_bias",
shape=(),
dtype=self.dtype,
initializer=tf.constant_initializer(self.score_bias_init),
)
if self.normalize and self.attention_g is None and self.attention_b is None:
self.attention_g = self.add_weight(
"attention_g",
dtype=self.dtype,
initializer=tf.constant_initializer(math.sqrt(1.0 / self.units)),
shape=(),
)
self.attention_b = self.add_weight(
"attention_b",
[self.units],
dtype=self.dtype,
initializer=tf.zeros_initializer(),
)
self.built = True
def _calculate_attention(self, query, state):
"""Score the query based on the keys and values.
Args:
query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`.
state: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
Returns:
alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]` (`alignments_size` is memory's
`max_time`).
"""
processed_query = self.query_layer(query) if self.query_layer else query
score = _bahdanau_score(
processed_query,
self.keys,
self.attention_v,
attention_g=self.attention_g,
attention_b=self.attention_b,
)
score += self.attention_score_bias
alignments = self.probability_fn(score, state)
next_state = alignments
return alignments, next_state
def get_config(self):
# yapf: disable
config = {
"units": self.units,
"normalize": self.normalize,
"sigmoid_noise": self.sigmoid_noise,
"sigmoid_noise_seed": self.sigmoid_noise_seed,
"score_bias_init": self.score_bias_init,
"mode": self.mode,
"kernel_initializer": tf.keras.initializers.serialize(
self.kernel_initializer,
**SERIALIZATION_ARGS,
),
}
# yapf: enable
base_config = super().get_config()
return {**base_config, **config}
@classmethod
def from_config(cls, config, custom_objects=None):
config = AttentionMechanism.deserialize_inner_layer_from_config(
config, custom_objects=custom_objects
)
return cls(**config)
class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):
"""Monotonic attention mechanism with Luong-style energy function.
This type of attention enforces a monotonic constraint on the attention
distributions; that is once the model attends to a given point in the
memory it can't attend to any prior points at subsequence output timesteps.
It achieves this by using the `_monotonic_probability_fn` instead of `softmax`
to construct its attention distributions. Otherwise, it is equivalent to
`tfa.seq2seq.LuongAttention`. This approach is proposed in
[Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
"Online and Linear-Time Attention by Enforcing Monotonic Alignments."
ICML 2017.](https://arxiv.org/abs/1704.00784)
"""
@typechecked
def __init__(
self,
units: TensorLike,
memory: Optional[TensorLike] = None,
memory_sequence_length: Optional[TensorLike] = None,
scale: bool = False,
sigmoid_noise: FloatTensorLike = 0.0,
sigmoid_noise_seed: Optional[FloatTensorLike] = None,
score_bias_init: FloatTensorLike = 0.0,
mode: str = "parallel",
dtype: AcceptableDTypes = None,
name: str = "LuongMonotonicAttention",
**kwargs,
):
"""Construct the attention mechanism.
Args:
units: The depth of the query mechanism.
memory: The memory to query; usually the output of an RNN encoder.
This tensor should be shaped `[batch_size, max_time, ...]`.
memory_sequence_length: (optional): Sequence lengths for the batch
entries in memory. If provided, the memory tensor rows are masked
with zeros for values past the respective sequence lengths.
scale: Python boolean. Whether to scale the energy term.
sigmoid_noise: Standard deviation of pre-sigmoid noise. See the
docstring for `_monotonic_probability_fn` for more information.
sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise.
score_bias_init: Initial value for score bias scalar. It's
recommended to initialize this to a negative value when the length
of the memory is large.
mode: How to compute the attention distribution. Must be one of
'recursive', 'parallel', or 'hard'. See the docstring for
`tfa.seq2seq.monotonic_attention` for more information.
dtype: The data type for the query and memory layers of the attention
mechanism.
name: Name to use when creating ops.
**kwargs: Dictionary that contains other common arguments for layer
creation.
"""
# Set up the monotonic probability fn with supplied parameters
wrapped_probability_fn = functools.partial(
_monotonic_probability_fn,
sigmoid_noise=sigmoid_noise,
mode=mode,
seed=sigmoid_noise_seed,
)
memory_layer = kwargs.pop("memory_layer", None)
if not memory_layer:
memory_layer = tf.keras.layers.Dense(
units, name="memory_layer", use_bias=False, dtype=dtype
)
self.units = units
self.scale = scale
self.sigmoid_noise = sigmoid_noise
self.sigmoid_noise_seed = sigmoid_noise_seed
self.score_bias_init = score_bias_init
self.mode = mode
self.attention_g = None
self.attention_score_bias = None
super().__init__(
memory=memory,
memory_sequence_length=memory_sequence_length,
query_layer=None,
memory_layer=memory_layer,
probability_fn=wrapped_probability_fn,
name=name,
dtype=dtype,
**kwargs,
)
def build(self, input_shape):
super().build(input_shape)
if self.scale and self.attention_g is None:
self.attention_g = self.add_weight(
"attention_g", initializer=tf.ones_initializer, shape=()
)
if self.attention_score_bias is None:
self.attention_score_bias = self.add_weight(
"attention_score_bias",
shape=(),
initializer=tf.constant_initializer(self.score_bias_init),
)
self.built = True
def _calculate_attention(self, query, state):
"""Score the query based on the keys and values.
Args:
query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`.
state: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
Returns:
alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]` (`alignments_size` is memory's
`max_time`).
next_state: Same as alignments
"""
score = _luong_score(query, self.keys, self.attention_g)
score += self.attention_score_bias
alignments = self.probability_fn(score, state)
next_state = alignments
return alignments, next_state
def get_config(self):
config = {
"units": self.units,
"scale": self.scale,
"sigmoid_noise": self.sigmoid_noise,
"sigmoid_noise_seed": self.sigmoid_noise_seed,
"score_bias_init": self.score_bias_init,
"mode": self.mode,
}
base_config = super().get_config()
return {**base_config, **config}
@classmethod
def from_config(cls, config, custom_objects=None):
config = AttentionMechanism.deserialize_inner_layer_from_config(
config, custom_objects=custom_objects
)
return cls(**config)
class AttentionWrapperState(
collections.namedtuple(
"AttentionWrapperState",
(
"cell_state",
"attention",
"alignments",
"alignment_history",
"attention_state",
),
)
):
"""State of a `tfa.seq2seq.AttentionWrapper`.
Attributes:
cell_state: The state of the wrapped RNN cell at the previous time
step.
attention: The attention emitted at the previous time step.
alignments: A single or tuple of `Tensor`(s) containing the
alignments emitted at the previous time step for each attention
mechanism.
alignment_history: (if enabled) a single or tuple of `TensorArray`(s)
containing alignment matrices from all time steps for each attention
mechanism. Call `stack()` on each to convert to a `Tensor`.
attention_state: A single or tuple of nested objects
containing attention mechanism state for each attention mechanism.
The objects may contain Tensors or TensorArrays.
"""
def clone(self, **kwargs):
"""Clone this object, overriding components provided by kwargs.
The new state fields' shape must match original state fields' shape.
This will be validated, and original fields' shape will be propagated
to new fields.
Example:
>>> batch_size = 1
>>> memory = tf.random.normal(shape=[batch_size, 3, 100])
>>> encoder_state = [tf.zeros((batch_size, 100)), tf.zeros((batch_size, 100))]
>>> attention_mechanism = tfa.seq2seq.LuongAttention(100, memory=memory, memory_sequence_length=[3] * batch_size)
>>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(100), attention_mechanism, attention_layer_size=10)
>>> decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
>>> decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
Args:
**kwargs: Any properties of the state object to replace in the
returned `AttentionWrapperState`.
Returns:
A new `AttentionWrapperState` whose properties are the same as
this one, except any overridden properties as provided in `kwargs`.
"""
def with_same_shape(old, new):
"""Check and set new tensor's shape."""
if isinstance(old, tf.Tensor) and isinstance(new, tf.Tensor):
if not tf.executing_eagerly():
new_shape = tf.shape(new)
old_shape = tf.shape(old)
assert_equal = tf.debugging.assert_equal(new_shape, old_shape)
with tf.control_dependencies([assert_equal]):
# Add an identity op so that control deps can kick in.
return tf.identity(new)
else:
if old.shape.as_list() != new.shape.as_list():
raise ValueError(
"The shape of the AttentionWrapperState is "
"expected to be same as the one to clone. "
"self.shape: %s, input.shape: %s" % (old.shape, new.shape)
)
return new
return new
return tf.nest.map_structure(with_same_shape, self, super()._replace(**kwargs))
def _prepare_memory(
memory, memory_sequence_length=None, memory_mask=None, check_inner_dims_defined=True
):
"""Convert to tensor and possibly mask `memory`.
Args:
memory: `Tensor`, shaped `[batch_size, max_time, ...]`.
memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`.
memory_mask: `boolean` tensor with shape [batch_size, max_time]. The
memory should be skipped when the corresponding mask is False.
check_inner_dims_defined: Python boolean. If `True`, the `memory`
argument's shape is checked to ensure all but the two outermost
dimensions are fully defined.
Returns:
A (possibly masked), checked, new `memory`.
Raises:
ValueError: If `check_inner_dims_defined` is `True` and not
`memory.shape[2:].is_fully_defined()`.
"""
memory = tf.nest.map_structure(
lambda m: tf.convert_to_tensor(m, name="memory"), memory
)
if memory_sequence_length is not None and memory_mask is not None:
raise ValueError(
"memory_sequence_length and memory_mask can't be provided at same time."
)
if memory_sequence_length is not None:
memory_sequence_length = tf.convert_to_tensor(
memory_sequence_length, name="memory_sequence_length"
)
if check_inner_dims_defined:
def _check_dims(m):
if not m.shape[2:].is_fully_defined():
raise ValueError(
"Expected memory %s to have fully defined inner dims, "
"but saw shape: %s" % (m.name, m.shape)
)
tf.nest.map_structure(_check_dims, memory)
if memory_sequence_length is None and memory_mask is None:
return memory
elif memory_sequence_length is not None:
seq_len_mask = tf.sequence_mask(
memory_sequence_length,
maxlen=tf.shape(tf.nest.flatten(memory)[0])[1],
dtype=tf.nest.flatten(memory)[0].dtype,
)
else:
# For memory_mask is not None
seq_len_mask = tf.cast(memory_mask, dtype=tf.nest.flatten(memory)[0].dtype)
def _maybe_mask(m, seq_len_mask):
"""Mask the memory based on the memory mask."""
rank = m.shape.ndims
rank = rank if rank is not None else tf.rank(m)
extra_ones = tf.ones(rank - 2, dtype=tf.int32)
seq_len_mask = tf.reshape(
seq_len_mask, tf.concat((tf.shape(seq_len_mask), extra_ones), 0)
)
return m * seq_len_mask
return tf.nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory)
def _maybe_mask_score(
score, memory_sequence_length=None, memory_mask=None, score_mask_value=None
):
"""Mask the attention score based on the masks."""
if memory_sequence_length is None and memory_mask is None:
return score
if memory_sequence_length is not None and memory_mask is not None:
raise ValueError(
"memory_sequence_length and memory_mask can't be provided at same time."
)
if memory_sequence_length is not None:
message = "All values in memory_sequence_length must greater than zero."
with tf.control_dependencies(
[
tf.debugging.assert_positive( # pylint: disable=bad-continuation
memory_sequence_length, message=message
)
]
):
memory_mask = tf.sequence_mask(
memory_sequence_length, maxlen=tf.shape(score)[1]
)
score_mask_values = score_mask_value * tf.ones_like(score)
return tf.where(memory_mask, score, score_mask_values)
def hardmax(logits: TensorLike, name: Optional[str] = None) -> tf.Tensor:
"""Returns batched one-hot vectors.
The depth index containing the `1` is that of the maximum logit value.
Args:
logits: A batch tensor of logit values.
name: Name to use when creating ops.
Returns:
A batched one-hot tensor.
"""
with tf.name_scope(name or "Hardmax"):
logits = tf.convert_to_tensor(logits, name="logits")
depth = logits.shape[-1] or tf.shape(logits)[-1]
return tf.one_hot(tf.argmax(logits, -1), depth, dtype=logits.dtype)
def _compute_attention(
attention_mechanism, cell_output, attention_state, attention_layer
):
"""Computes the attention and alignments for a given
attention_mechanism."""
alignments, next_attention_state = attention_mechanism(
[cell_output, attention_state]
)
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
expanded_alignments = tf.expand_dims(alignments, 1)
# Context is the inner product of alignments and values along the
# memory time dimension.
# alignments shape is
# [batch_size, 1, memory_time]
# attention_mechanism.values shape is
# [batch_size, memory_time, memory_size]
# the batched matmul is over memory_time, so the output shape is
# [batch_size, 1, memory_size].
# we then squeeze out the singleton dim.
context_ = tf.matmul(expanded_alignments, attention_mechanism.values)
context_ = tf.squeeze(context_, [1])
if attention_layer is not None:
attention = attention_layer(tf.concat([cell_output, context_], 1))
else:
attention = context_
return attention, alignments, next_attention_state
class AttentionWrapper(AbstractRNNCell):
"""Wraps another RNN cell with attention.
Example:
>>> batch_size = 4
>>> max_time = 7
>>> hidden_size = 32
>>>
>>> memory = tf.random.uniform([batch_size, max_time, hidden_size])
>>> memory_sequence_length = tf.fill([batch_size], max_time)
>>>
>>> attention_mechanism = tfa.seq2seq.LuongAttention(hidden_size)
>>> attention_mechanism.setup_memory(memory, memory_sequence_length)
>>>
>>> cell = tf.keras.layers.LSTMCell(hidden_size)
>>> cell = tfa.seq2seq.AttentionWrapper(
... cell, attention_mechanism, attention_layer_size=hidden_size)
>>>
>>> inputs = tf.random.uniform([batch_size, hidden_size])
>>> state = cell.get_initial_state(inputs)
>>>
>>> outputs, state = cell(inputs, state)
>>> outputs.shape
TensorShape([4, 32])
"""
@typechecked
def __init__(
self,
cell: tf.keras.layers.Layer,
attention_mechanism: Union[AttentionMechanism, List[AttentionMechanism]],
attention_layer_size: Optional[Union[Number, List[Number]]] = None,
alignment_history: bool = False,
cell_input_fn: Optional[Callable] = None,
output_attention: bool = True,
initial_cell_state: Optional[TensorLike] = None,
name: Optional[str] = None,
attention_layer: Optional[
Union[tf.keras.layers.Layer, List[tf.keras.layers.Layer]]
] = None,
attention_fn: Optional[Callable] = None,
**kwargs,
):
"""Construct the `AttentionWrapper`.
**NOTE** If you are using the `tfa.seq2seq.BeamSearchDecoder` with a cell wrapped
in `AttentionWrapper`, then you must ensure that:
- The encoder output has been tiled to `beam_width` via
`tfa.seq2seq.tile_batch` (NOT `tf.tile`).
- The `batch_size` argument passed to the `get_initial_state` method of
this wrapper is equal to `true_batch_size * beam_width`.
- The initial state created with `get_initial_state` above contains a
`cell_state` value containing properly tiled final state from the
encoder.
An example:
>>> batch_size = 1
>>> beam_width = 5
>>> sequence_length = tf.convert_to_tensor([5])
>>> encoder_outputs = tf.random.uniform(shape=(batch_size, 5, 10))
>>> encoder_final_state = [tf.zeros((batch_size, 10)), tf.zeros((batch_size, 10))]
>>> tiled_encoder_outputs = tfa.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width)
>>> tiled_encoder_final_state = tfa.seq2seq.tile_batch(encoder_final_state, multiplier=beam_width)
>>> tiled_sequence_length = tfa.seq2seq.tile_batch(sequence_length, multiplier=beam_width)
>>> attention_mechanism = tfa.seq2seq.BahdanauAttention(10, memory=tiled_encoder_outputs, memory_sequence_length=tiled_sequence_length)
>>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(10), attention_mechanism)
>>> decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size * beam_width, dtype=tf.float32)
>>> decoder_initial_state = decoder_initial_state.clone(cell_state=tiled_encoder_final_state)
Args:
cell: A layer that implements the `tf.keras.layers.AbstractRNNCell`
interface.
attention_mechanism: A list of `tfa.seq2seq.AttentionMechanism`
instances single instance.
attention_layer_size: A list of Python integers or a single Python
integer, the depth of the attention (output) layer(s). If `None`
(default), use the context as attention at each time step.
Otherwise, feed the context and cell output into the attention
layer to generate attention at each time step. If
`attention_mechanism` is a list, `attention_layer_size` must be a list
of the same length. If `attention_layer` is set, this must be `None`.
If `attention_fn` is set, it must guaranteed that the outputs of
`attention_fn` also meet the above requirements.
alignment_history: Python boolean, whether to store alignment history
from all time steps in the final output state (currently stored as
a time major `TensorArray` on which you must call `stack()`).
cell_input_fn: (optional) A `callable`. The default is:
`lambda inputs, attention:
tf.concat([inputs, attention], -1)`.
output_attention: Python bool. If `True` (default), the output at
each time step is the attention value. This is the behavior of
Luong-style attention mechanisms. If `False`, the output at each
time step is the output of `cell`. This is the behavior of
Bahdanau-style attention mechanisms. In both cases, the
`attention` tensor is propagated to the next time step via the
state and is used there. This flag only controls whether the
attention mechanism is propagated up to the next cell in an RNN
stack or to the top RNN output.
initial_cell_state: The initial state value to use for the cell when
the user calls `get_initial_state()`. Note that if this value is
provided now, and the user uses a `batch_size` argument of
`get_initial_state` which does not match the batch size of
`initial_cell_state`, proper behavior is not guaranteed.
name: Name to use when creating ops.
attention_layer: A list of `tf.keras.layers.Layer` instances or a
single `tf.keras.layers.Layer` instance taking the context
and cell output as inputs to generate attention at each time step.
If `None` (default), use the context as attention at each time step.
If `attention_mechanism` is a list, `attention_layer` must be a list of
the same length. If `attention_layer_size` is set, this must be
`None`.
attention_fn: An optional callable function that allows users to
provide their own customized attention function, which takes input
`(attention_mechanism, cell_output, attention_state,
attention_layer)` and outputs `(attention, alignments,
next_attention_state)`. If provided, the `attention_layer_size` should
be the size of the outputs of `attention_fn`.
**kwargs: Other keyword arguments for layer creation.
Raises:
TypeError: `attention_layer_size` is not `None` and
(`attention_mechanism` is a list but `attention_layer_size` is not;
or vice versa).
ValueError: if `attention_layer_size` is not `None`,
`attention_mechanism` is a list, and its length does not match that
of `attention_layer_size`; if `attention_layer_size` and
`attention_layer` are set simultaneously.
"""
super().__init__(name=name, **kwargs)
keras_utils.assert_like_rnncell("cell", cell)
if isinstance(attention_mechanism, (list, tuple)):
self._is_multi = True
attention_mechanisms = list(attention_mechanism)
else:
self._is_multi = False
attention_mechanisms = [attention_mechanism]
if cell_input_fn is None:
def cell_input_fn(inputs, attention):
return tf.concat([inputs, attention], -1)
if attention_layer_size is not None and attention_layer is not None:
raise ValueError(
"Only one of attention_layer_size and attention_layer should be set"
)
if attention_layer_size is not None:
attention_layer_sizes = tuple(
attention_layer_size
if isinstance(attention_layer_size, (list, tuple))
else (attention_layer_size,)
)
if len(attention_layer_sizes) != len(attention_mechanisms):
raise ValueError(
"If provided, attention_layer_size must contain exactly "
"one integer per attention_mechanism, saw: %d vs %d"
% (len(attention_layer_sizes), len(attention_mechanisms))
)
dtype = kwargs.get("dtype", None)
self._attention_layers = list(
tf.keras.layers.Dense(
attention_layer_size,
name="attention_layer",
use_bias=False,
dtype=dtype,
)
for i, attention_layer_size in enumerate(attention_layer_sizes)
)
elif attention_layer is not None:
self._attention_layers = list(
attention_layer
if isinstance(attention_layer, (list, tuple))
else (attention_layer,)
)
if len(self._attention_layers) != len(attention_mechanisms):
raise ValueError(
"If provided, attention_layer must contain exactly one "
"layer per attention_mechanism, saw: %d vs %d"
% (len(self._attention_layers), len(attention_mechanisms))
)
else:
self._attention_layers = None
if attention_fn is None:
attention_fn = _compute_attention
self._attention_fn = attention_fn
self._attention_layer_size = None
self._cell = cell
self._attention_mechanisms = attention_mechanisms
self._cell_input_fn = cell_input_fn
self._output_attention = output_attention
self._alignment_history = alignment_history
with tf.name_scope(name or "AttentionWrapperInit"):
if initial_cell_state is None:
self._initial_cell_state = None
else:
final_state_tensor = tf.nest.flatten(initial_cell_state)[-1]
state_batch_size = (
final_state_tensor.shape[0] or tf.shape(final_state_tensor)[0]
)
error_message = (
"When constructing AttentionWrapper %s: " % self.name
+ "Non-matching batch sizes between the memory "
"(encoder output) and initial_cell_state. Are you using "
"the BeamSearchDecoder? You may need to tile your "
"initial state via the tfa.seq2seq.tile_batch "
"function with argument multiple=beam_width."
)
with tf.control_dependencies(
self._batch_size_checks( # pylint: disable=bad-continuation
state_batch_size, error_message
)
):
self._initial_cell_state = tf.nest.map_structure(
lambda s: tf.identity(s, name="check_initial_cell_state"),
initial_cell_state,
)
def _attention_mechanisms_checks(self):
for attention_mechanism in self._attention_mechanisms:
if not attention_mechanism.memory_initialized:
raise ValueError(
"The AttentionMechanism instances passed to "
"this AttentionWrapper should be initialized "
"with a memory first, either by passing it "
"to the AttentionMechanism constructor or "
"calling attention_mechanism.setup_memory()"
)
def _batch_size_checks(self, batch_size, error_message):
self._attention_mechanisms_checks()
return [
tf.debugging.assert_equal(
batch_size, attention_mechanism.batch_size, message=error_message
)
for attention_mechanism in self._attention_mechanisms
]
def _get_attention_layer_size(self):
if self._attention_layer_size is not None:
return self._attention_layer_size
self._attention_mechanisms_checks()
attention_output_sizes = (
attention_mechanism.values.shape[-1]
for attention_mechanism in self._attention_mechanisms
)
if self._attention_layers is None:
self._attention_layer_size = sum(attention_output_sizes)
else:
# Compute the layer output size from its input which is the
# concatenation of the cell output and the attention mechanism
# output.
self._attention_layer_size = sum(
layer.compute_output_shape(
[None, self._cell.output_size + attention_output_size]
)[-1]
for layer, attention_output_size in zip(
self._attention_layers, attention_output_sizes
)
)
return self._attention_layer_size
def _item_or_tuple(self, seq):
"""Returns `seq` as tuple or the singular element.
Which is returned is determined by how the AttentionMechanism(s) were
passed to the constructor.
Args:
seq: A non-empty sequence of items or generator.
Returns:
Either the values in the sequence as a tuple if
AttentionMechanism(s) were passed to the constructor as a sequence
or the singular element.
"""
t = tuple(seq)
if self._is_multi:
return t
else:
return t[0]
@property
def output_size(self):
if self._output_attention:
return self._get_attention_layer_size()
else:
return self._cell.output_size
@property
def state_size(self):
"""The `state_size` property of `tfa.seq2seq.AttentionWrapper`.
Returns:
A `tfa.seq2seq.AttentionWrapperState` tuple containing shapes used
by this object.
"""
return AttentionWrapperState(
cell_state=self._cell.state_size,
attention=self._get_attention_layer_size(),
alignments=self._item_or_tuple(
a.alignments_size for a in self._attention_mechanisms
),
attention_state=self._item_or_tuple(
a.state_size for a in self._attention_mechanisms
),
alignment_history=self._item_or_tuple(
a.alignments_size if self._alignment_history else ()
for a in self._attention_mechanisms
),
) # sometimes a TensorArray
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
"""Return an initial (zero) state tuple for this `tfa.seq2seq.AttentionWrapper`.
**NOTE** Please see the initializer documentation for details of how
to call `get_initial_state` if using a `tfa.seq2seq.AttentionWrapper`
with a `tfa.seq2seq.BeamSearchDecoder`.
Args:
inputs: The inputs that will be fed to this cell.
batch_size: `0D` integer tensor: the batch size.
dtype: The internal state data type.
Returns:
An `tfa.seq2seq.AttentionWrapperState` tuple containing zeroed out tensors and,
possibly, empty `TensorArray` objects.
Raises:
ValueError: (or, possibly at runtime, `InvalidArgument`), if
`batch_size` does not match the output size of the encoder passed
to the wrapper object at initialization time.
"""
if inputs is not None:
batch_size = tf.shape(inputs)[0]
dtype = inputs.dtype
with tf.name_scope(
type(self).__name__ + "ZeroState"
): # pylint: disable=bad-continuation
if self._initial_cell_state is not None:
cell_state = self._initial_cell_state
else:
cell_state = self._cell.get_initial_state(
batch_size=batch_size, dtype=dtype
)
error_message = (
"When calling get_initial_state of AttentionWrapper %s: " % self.name
+ "Non-matching batch sizes between the memory "
"(encoder output) and the requested batch size. Are you using "
"the BeamSearchDecoder? If so, make sure your encoder output "
"has been tiled to beam_width via "
"tfa.seq2seq.tile_batch, and the batch_size= argument "
"passed to get_initial_state is batch_size * beam_width."
)
with tf.control_dependencies(
self._batch_size_checks(batch_size, error_message)
): # pylint: disable=bad-continuation
cell_state = tf.nest.map_structure(
lambda s: tf.identity(s, name="checked_cell_state"), cell_state
)
initial_alignments = [
attention_mechanism.initial_alignments(batch_size, dtype)
for attention_mechanism in self._attention_mechanisms
]
return AttentionWrapperState(
cell_state=cell_state,
attention=tf.zeros(
[batch_size, self._get_attention_layer_size()], dtype=dtype
),
alignments=self._item_or_tuple(initial_alignments),
attention_state=self._item_or_tuple(
attention_mechanism.initial_state(batch_size, dtype)
for attention_mechanism in self._attention_mechanisms
),
alignment_history=self._item_or_tuple(
tf.TensorArray(
dtype, size=0, dynamic_size=True, element_shape=alignment.shape
)
if self._alignment_history
else ()
for alignment in initial_alignments
),
)
def call(self, inputs, state, **kwargs):
"""Perform a step of attention-wrapped RNN.
- Step 1: Mix the `inputs` and previous step's `attention` output via
`cell_input_fn`.
- Step 2: Call the wrapped `cell` with this input and its previous
state.
- Step 3: Score the cell's output with `attention_mechanism`.
- Step 4: Calculate the alignments by passing the score through the
`normalizer`.
- Step 5: Calculate the context vector as the inner product between the
alignments and the attention_mechanism's values (memory).
- Step 6: Calculate the attention output by concatenating the cell
output and context through the attention layer (a linear layer with
`attention_layer_size` outputs).
Args:
inputs: (Possibly nested tuple of) Tensor, the input at this time
step.
state: An instance of `tfa.seq2seq.AttentionWrapperState` containing
tensors from the previous time step.
**kwargs: Dict, other keyword arguments for the cell call method.
Returns:
A tuple `(attention_or_cell_output, next_state)`, where:
- `attention_or_cell_output` depending on `output_attention`.
- `next_state` is an instance of `tfa.seq2seq.AttentionWrapperState`
containing the state calculated at this time step.
Raises:
TypeError: If `state` is not an instance of `tfa.seq2seq.AttentionWrapperState`.
"""
if not isinstance(state, AttentionWrapperState):
try:
state = AttentionWrapperState(*state)
except TypeError:
raise TypeError(
"Expected state to be instance of AttentionWrapperState or "
"values that can construct AttentionWrapperState. "
"Received type %s instead." % type(state)
)
# Step 1: Calculate the true inputs to the cell based on the
# previous attention value.
cell_inputs = self._cell_input_fn(inputs, state.attention)
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state, **kwargs)
next_cell_state = tf.nest.pack_sequence_as(
cell_state, tf.nest.flatten(next_cell_state)
)
cell_batch_size = cell_output.shape[0] or tf.shape(cell_output)[0]
error_message = (
"When applying AttentionWrapper %s: " % self.name
+ "Non-matching batch sizes between the memory "
"(encoder output) and the query (decoder output). Are you using "
"the BeamSearchDecoder? You may need to tile your memory input "
"via the tfa.seq2seq.tile_batch function with argument "
"multiple=beam_width."
)
with tf.control_dependencies(
self._batch_size_checks(cell_batch_size, error_message)
): # pylint: disable=bad-continuation
cell_output = tf.identity(cell_output, name="checked_cell_output")
if self._is_multi:
previous_attention_state = state.attention_state
previous_alignment_history = state.alignment_history
else:
previous_attention_state = [state.attention_state]
previous_alignment_history = [state.alignment_history]
all_alignments = []
all_attentions = []
all_attention_states = []
maybe_all_histories = []
for i, attention_mechanism in enumerate(self._attention_mechanisms):
attention, alignments, next_attention_state = self._attention_fn(
attention_mechanism,
cell_output,
previous_attention_state[i],
self._attention_layers[i] if self._attention_layers else None,
)
alignment_history = (
previous_alignment_history[i].write(
previous_alignment_history[i].size(), alignments
)
if self._alignment_history
else ()
)
all_attention_states.append(next_attention_state)
all_alignments.append(alignments)
all_attentions.append(attention)
maybe_all_histories.append(alignment_history)
attention = tf.concat(all_attentions, 1)
next_state = AttentionWrapperState(
cell_state=next_cell_state,
attention=attention,
attention_state=self._item_or_tuple(all_attention_states),
alignments=self._item_or_tuple(all_alignments),
alignment_history=self._item_or_tuple(maybe_all_histories),
)
if self._output_attention:
return attention, next_state
else:
return cell_output, next_state