gan/tensorflow_gan/python/tpu/normalization_ops.py

376 lines
15 KiB
Python

# coding=utf-8
# Copyright 2024 The TensorFlow GAN Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides ops for supporting TPU operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
from six.moves import range
import tensorflow as tf
from tensorflow_gan.python.tpu.cross_replica_ops import cross_replica_moments
from tensorflow.python.tpu import tpu_function # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.training import moving_averages # pylint: disable=g-direct-tensorflow-import
__all__ = [
'batch_norm',
'standardize_batch',
]
def batch_norm(inputs,
is_training,
conditional_class_labels=None,
axis=-1,
variance_epsilon=1e-3,
center=True,
scale=True,
beta_initializer=tf.compat.v1.initializers.zeros(),
gamma_initializer=tf.compat.v1.initializers.ones(),
batch_axis=0,
name='batch_norm'):
"""Adds Batch Norm or Conditional Batch Norm.
Args:
inputs: Tensor of inputs (e.g. images).
is_training: Whether or not the layer is in training mode. In training
mode it would accumulate the statistics of the moments into the
`moving_mean` and `moving_variance` using an exponential moving average
with the given `decay`. When is_training=False, these variables are not
updated, and the precomputed values are used verbatim.
conditional_class_labels: If `None`, this layer is vanilla Batch
Normalization. If not, it is a tensor of one-hot labels - same first
dimension as inputs, and the layer is Conditional Batch Normalization
with normalization constants determined by the class (see
https://arxiv.org/pdf/1610.07629.pdf for more detail).
axis: Integer, the axis that should be normalized (typically the features
axis). For instance, after a `Convolution2D` layer with
`data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
variance_epsilon: A small float number to avoid dividing by 0.
center: If True, add offset of `beta` to normalized tensor. If False,
`beta` is ignored.
scale: If True, multiply by `gamma`. If False, `gamma` is
not used. When the next layer is linear (also e.g. `nn.relu`), this can
be disabled since the scaling can be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
batch_axis: The axis of the batch dimension.
name: name: String name to be used for scoping.
Returns:
Output tensor.
"""
with tf.compat.v1.variable_scope(
name, values=[inputs], reuse=tf.compat.v1.AUTO_REUSE):
# Determine the variable shape.
var_shape = [1] * inputs.shape.rank
var_shape[axis] = tf.compat.dimension_value(inputs.shape[axis])
# Allocate parameters for the trainable variables.
if conditional_class_labels is not None:
num_categories = tf.compat.dimension_value(
conditional_class_labels.shape[-1])
var_shape[batch_axis] = num_categories
labels = tf.math.argmax(
input=conditional_class_labels, axis=1) # to integer
if center:
beta = tf.compat.v1.get_variable(
'beta', var_shape, initializer=beta_initializer)
beta = tf.gather(beta, labels)
if scale:
gamma = tf.compat.v1.get_variable(
'gamma', var_shape, initializer=gamma_initializer)
gamma = tf.gather(gamma, labels)
else:
if center:
beta = tf.compat.v1.get_variable(
'beta', var_shape, initializer=beta_initializer)
if scale:
gamma = tf.compat.v1.get_variable(
'gamma', var_shape, initializer=gamma_initializer)
outputs = standardize_batch(
inputs, is_training=is_training, epsilon=variance_epsilon, offset=beta,
scale=gamma)
outputs.set_shape(inputs.shape)
return outputs
def standardize_batch(inputs,
is_training,
offset=None,
scale=None,
decay=0.999,
epsilon=1e-3,
data_format='NHWC',
use_moving_averages=True,
use_cross_replica_mean=None):
"""Adds TPU-enabled batch normalization layer.
Details on Batch Normalization can be found in 'Batch Normalization:
Accelerating Deep Network Training by Reducing Internal Covariate Shift',
Ioffe S. and Szegedy C. 2015 [http://arxiv.org/abs/1502.03167].
Note #1: This method computes the batch statistic across all TPU replicas,
thus simulating the true batch norm in the distributed setting. If one wants
to avoid the cross-replica communication set use_cross_replica_mean=False.
Note #2: When is_training is True the moving_mean and moving_variance need
to be updated in each training step. By default, the update_ops are placed
in `tf.GraphKeys.UPDATE_OPS` and they need to be added as a dependency to
the `train_op`. For example:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
updates = tf.group(*update_ops)
total_loss = control_flow_ops.with_dependencies([updates], total_loss)
Note #3: Reasonable values for `decay` are close to 1.0, typically in the
multiple-nines range: 0.999, 0.99, 0.9, etc. Lower the `decay` value (trying
`decay`=0.9) if model experiences reasonably good training performance but
poor validation and/or test performance.
Args:
inputs: A tensor with 2 or 4 dimensions, where the first dimension is
`batch_size`. The normalization is over all but the last dimension if
`data_format` is `NHWC`, and the second dimension if `data_format` is
`NCHW`.
is_training: Whether or not the layer is in training mode. In training
mode it would accumulate the statistics of the moments into the
`moving_mean` and `moving_variance` using an exponential moving average
with the given `decay`. When is_training=False, these variables are not
updated, and the precomputed values are used verbatim.
offset: An offset `Tensor`, often denoted `beta` in equations, or
None. If present, will be added to the normalized tensor.
scale: A scale `Tensor`, often denoted `gamma` in equations, or
`None`. If present, the scale is applied to the normalized tensor.
decay: Decay for the moving averages. See notes above for reasonable
values.
epsilon: Small float added to variance to avoid dividing by zero.
data_format: Input data format. NHWC or NCHW.
use_moving_averages: If True keep moving averages of mean and variance that
are used during inference. Otherwise use accumlators.
use_cross_replica_mean: If True add operations to do computes batch norm
statistics across all TPU cores. These ops are not compatible with other
platforms. The default (None) will only add the operations if running
on TPU.
Returns:
The normalized tensor with the same type and shape as `inputs`.
"""
if data_format not in {'NCHW', 'NHWC'}:
raise ValueError(
'Invalid data_format {}. Allowed: NCHW, NHWC.'.format(data_format))
if use_cross_replica_mean is None:
# Default to global batch norm only on TPUs.
use_cross_replica_mean = (
tpu_function.get_tpu_context().number_of_shards is not None)
logging.debug('Automatically determined use_cross_replica_mean=%s.',
use_cross_replica_mean)
inputs = tf.convert_to_tensor(value=inputs)
inputs_dtype = inputs.dtype
inputs_shape = inputs.get_shape()
num_channels = tf.compat.dimension_value(inputs.shape[-1])
if num_channels is None:
raise ValueError('`C` dimension must be known but is None')
inputs_rank = inputs_shape.ndims
if inputs_rank is None:
raise ValueError('Inputs %s has undefined rank' % inputs.name)
elif inputs_rank not in [2, 4]:
raise ValueError(
'Inputs %s has unsupported rank.'
' Expected 2 or 4 but got %d' % (inputs.name, inputs_rank))
# Bring 2-D inputs into 4-D format.
if inputs_rank == 2:
new_shape = [-1, 1, 1, num_channels]
if data_format == 'NCHW':
new_shape = [-1, num_channels, 1, 1]
inputs = tf.reshape(inputs, new_shape)
if offset is not None:
offset = tf.reshape(offset, new_shape)
if scale is not None:
scale = tf.reshape(scale, new_shape)
# Execute a distributed batch normalization
axis = 1 if data_format == 'NCHW' else 3
inputs = tf.cast(inputs, tf.float32)
reduction_axes = [i for i in range(4) if i != axis]
if use_cross_replica_mean:
mean, variance = cross_replica_moments(inputs, reduction_axes)
else:
counts, mean_ss, variance_ss, _ = tf.nn.sufficient_statistics(
inputs, reduction_axes, keepdims=False)
mean, variance = tf.nn.normalize_moments(
counts, mean_ss, variance_ss, shift=None)
if use_moving_averages:
mean, variance = moving_moments_for_inference(
mean=mean, variance=variance, is_training=is_training, decay=decay)
else:
mean, variance = accumulated_moments_for_inference(
mean=mean, variance=variance, is_training=is_training)
outputs = tf.nn.batch_normalization(
inputs,
mean=mean,
variance=variance,
offset=offset,
scale=scale,
variance_epsilon=epsilon)
outputs = tf.cast(outputs, inputs_dtype)
# Bring 2-D inputs back into 2-D format.
if inputs_rank == 2:
outputs = tf.reshape(outputs, [-1] + inputs_shape[1:].as_list())
outputs.set_shape(inputs_shape)
return outputs
def moving_moments_for_inference(mean, variance, is_training, decay):
"""Use moving averages of moments during inference.
Args:
mean: Tensor of shape [num_channels] with the mean of the current batch.
variance: Tensor of shape [num_channels] with the variance of the current
batch.
is_training: Boolean, wheather to construct ops for training or inference
graph.
decay: Decay rate to use for moving averages.
Returns:
Tuple of (mean, variance) to use. This can the same as the inputs.
"""
# Create the moving average variables and add them to the appropriate
# collections.
variable_collections = [
tf.compat.v1.GraphKeys.MOVING_AVERAGE_VARIABLES,
tf.compat.v1.GraphKeys.MODEL_VARIABLES,
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES,
]
# Disable partition setting for moving_mean and moving_variance
# as assign_moving_average op below doesn"t support partitioned variable.
moving_mean = tf.compat.v1.get_variable(
'moving_mean',
shape=mean.shape,
initializer=tf.compat.v1.zeros_initializer(),
trainable=False,
partitioner=None,
collections=variable_collections)
moving_variance = tf.compat.v1.get_variable(
'moving_variance',
shape=variance.shape,
initializer=tf.compat.v1.ones_initializer(),
trainable=False,
partitioner=None,
collections=variable_collections)
if is_training:
logging.debug('Adding update ops for moving averages of mean and variance.')
# Update variables for mean and variance during training.
update_moving_mean = moving_averages.assign_moving_average(
moving_mean,
tf.cast(mean, moving_mean.dtype),
decay,
zero_debias=False)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance,
tf.cast(variance, moving_variance.dtype),
decay,
zero_debias=False)
tf.compat.v1.add_to_collection(
tf.compat.v1.GraphKeys.UPDATE_OPS,
tf.group(
update_moving_mean, update_moving_variance, name='ema_update_ops'))
return mean, variance
logging.debug('Using moving mean and variance.')
return moving_mean, moving_variance
def accumulated_moments_for_inference(mean, variance, is_training):
"""Use accumulated statistics for moments during inference.
After training the user is responsible for filling the accumulators with the
actual values.
Args:
mean: Tensor of shape [num_channels] with the mean of the current batch.
variance: Tensor of shape [num_channels] with the variance of the current
batch.
is_training: Boolean, wheather to construct ops for training or inference
graph.
Returns:
Tuple of (mean, variance) to use. This can the same as the inputs.
"""
variable_collections = [
tf.compat.v1.GraphKeys.MODEL_VARIABLES,
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES,
]
with tf.compat.v1.variable_scope('accu', values=[mean, variance]):
# Create variables for accumulating batch statistic and use them during
# inference. The ops for filling the accumulators must be created and run
# before eval. See docstring above.
accu_mean = tf.compat.v1.get_variable(
'accu_mean',
shape=mean.shape,
initializer=tf.compat.v1.zeros_initializer(),
trainable=False,
collections=variable_collections)
accu_variance = tf.compat.v1.get_variable(
'accu_variance',
shape=variance.shape,
initializer=tf.compat.v1.zeros_initializer(),
trainable=False,
collections=variable_collections)
accu_counter = tf.compat.v1.get_variable(
'accu_counter',
shape=[],
initializer=tf.compat.v1.initializers.constant(1e-12),
trainable=False,
collections=variable_collections)
update_accus = tf.compat.v1.get_variable(
'update_accus',
shape=[],
dtype=tf.int32,
initializer=tf.compat.v1.zeros_initializer(),
trainable=False,
collections=variable_collections)
mean = tf.identity(mean, 'mean')
variance = tf.identity(variance, 'variance')
if is_training:
return mean, variance
logging.debug('Using accumulated moments.')
# Return the accumulated batch statistics and add current batch statistics
# to accumulators if update_accus variables equals 1.
def update_accus_fn():
return tf.group([
tf.compat.v1.assign_add(accu_mean, mean),
tf.compat.v1.assign_add(accu_variance, variance),
tf.compat.v1.assign_add(accu_counter, 1),
])
dep = tf.cond(
pred=tf.equal(update_accus, 1),
true_fn=update_accus_fn,
false_fn=tf.no_op)
with tf.control_dependencies([dep]):
return accu_mean / accu_counter, accu_variance / accu_counter