mirror of https://github.com/tensorflow/gan.git
1457 lines
58 KiB
Python
1457 lines
58 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 The TensorFlow GAN Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""The TF-GAN project provides a lightweight GAN training/testing framework.
|
|
|
|
This file contains the core helper functions to create and train a GAN model.
|
|
See the README or examples in `tensorflow_models` for details on how to use.
|
|
|
|
TF-GAN training occurs in four steps:
|
|
1) Create a model
|
|
2) Add a loss
|
|
3) Create train ops
|
|
4) Run the train ops
|
|
|
|
The functions in this file are organized around these four steps. Each function
|
|
corresponds to one of the steps.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import inspect
|
|
import os
|
|
import time
|
|
|
|
import tensorflow as tf
|
|
|
|
from tensorflow import estimator as tf_estimator
|
|
from tensorflow_gan.python import contrib_utils as contrib
|
|
from tensorflow_gan.python import namedtuples
|
|
from tensorflow_gan.python.losses import losses_wargs
|
|
from tensorflow_gan.python.losses import tuple_losses
|
|
|
|
__all__ = [
|
|
'gan_model',
|
|
'infogan_model',
|
|
'acgan_model',
|
|
'cyclegan_model',
|
|
'stargan_model',
|
|
'gan_loss',
|
|
'cyclegan_loss',
|
|
'stargan_loss',
|
|
'gan_train_ops',
|
|
'gan_train',
|
|
'get_sequential_train_hooks',
|
|
'get_joint_train_hooks',
|
|
'get_sequential_train_steps',
|
|
'RunTrainOpsHook',
|
|
]
|
|
|
|
|
|
def gan_model(
|
|
# Lambdas defining models.
|
|
generator_fn,
|
|
discriminator_fn,
|
|
# Real data and conditioning.
|
|
real_data,
|
|
generator_inputs,
|
|
# Optional scopes.
|
|
generator_scope='Generator',
|
|
discriminator_scope='Discriminator',
|
|
# Options.
|
|
check_shapes=True):
|
|
"""Returns GAN model outputs and variables.
|
|
|
|
Args:
|
|
generator_fn: A python lambda that takes `generator_inputs` as inputs and
|
|
returns the outputs of the GAN generator.
|
|
discriminator_fn: A python lambda that takes `real_data`/`generated data`
|
|
and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
|
|
real_data: A Tensor representing the real data.
|
|
generator_inputs: A Tensor or list of Tensors to the generator. In the
|
|
vanilla GAN case, this might be a single noise Tensor. In the conditional
|
|
GAN case, this might be the generator's conditioning.
|
|
generator_scope: Optional generator variable scope. Useful if you want to
|
|
reuse a subgraph that has already been created.
|
|
discriminator_scope: Optional discriminator variable scope. Useful if you
|
|
want to reuse a subgraph that has already been created.
|
|
check_shapes: If `True`, check that generator produces Tensors that are the
|
|
same shape as real data. Otherwise, skip this check.
|
|
|
|
Returns:
|
|
A GANModel namedtuple.
|
|
|
|
Raises:
|
|
ValueError: If the generator outputs a Tensor that isn't the same shape as
|
|
`real_data`.
|
|
ValueError: If TF is executing eagerly.
|
|
"""
|
|
if tf.executing_eagerly():
|
|
raise ValueError('`tfgan.gan_model` doesn\'t work when executing eagerly.')
|
|
# Create models
|
|
with tf.compat.v1.variable_scope(
|
|
generator_scope, reuse=tf.compat.v1.AUTO_REUSE) as gen_scope:
|
|
generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
|
|
generated_data = generator_fn(generator_inputs)
|
|
with tf.compat.v1.variable_scope(
|
|
discriminator_scope, reuse=tf.compat.v1.AUTO_REUSE) as dis_scope:
|
|
discriminator_gen_outputs = discriminator_fn(generated_data,
|
|
generator_inputs)
|
|
with tf.compat.v1.variable_scope(dis_scope, reuse=True):
|
|
real_data = _convert_tensor_or_l_or_d(real_data)
|
|
discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)
|
|
|
|
if check_shapes:
|
|
if not generated_data.shape.is_compatible_with(real_data.shape):
|
|
raise ValueError(
|
|
'Generator output shape (%s) must be the same shape as real data '
|
|
'(%s).' % (generated_data.shape, real_data.shape))
|
|
|
|
# Get model-specific variables.
|
|
generator_variables = contrib.get_trainable_variables(gen_scope)
|
|
discriminator_variables = contrib.get_trainable_variables(dis_scope)
|
|
|
|
return namedtuples.GANModel(
|
|
generator_inputs, generated_data, generator_variables, gen_scope,
|
|
generator_fn, real_data, discriminator_real_outputs,
|
|
discriminator_gen_outputs, discriminator_variables, dis_scope,
|
|
discriminator_fn)
|
|
|
|
|
|
def infogan_model(
|
|
# Lambdas defining models.
|
|
generator_fn,
|
|
discriminator_fn,
|
|
# Real data and conditioning.
|
|
real_data,
|
|
unstructured_generator_inputs,
|
|
structured_generator_inputs,
|
|
# Optional scopes.
|
|
generator_scope='Generator',
|
|
discriminator_scope='Discriminator'):
|
|
"""Returns an InfoGAN model outputs and variables.
|
|
|
|
See https://arxiv.org/abs/1606.03657 for more details.
|
|
|
|
Args:
|
|
generator_fn: A python lambda that takes a list of Tensors as inputs and
|
|
returns the outputs of the GAN generator.
|
|
discriminator_fn: A python lambda that takes `real_data`/`generated data`
|
|
and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list).
|
|
`logits` are in the range [-inf, inf], and `distribution_list` is a list
|
|
of Tensorflow distributions representing the predicted noise distribution
|
|
of the ith structure noise.
|
|
real_data: A Tensor representing the real data.
|
|
unstructured_generator_inputs: A list of Tensors to the generator.
|
|
These tensors represent the unstructured noise or conditioning.
|
|
structured_generator_inputs: A list of Tensors to the generator.
|
|
These tensors must have high mutual information with the recognizer.
|
|
generator_scope: Optional generator variable scope. Useful if you want to
|
|
reuse a subgraph that has already been created.
|
|
discriminator_scope: Optional discriminator variable scope. Useful if you
|
|
want to reuse a subgraph that has already been created.
|
|
|
|
Returns:
|
|
An InfoGANModel namedtuple.
|
|
|
|
Raises:
|
|
ValueError: If the generator outputs a Tensor that isn't the same shape as
|
|
`real_data`.
|
|
ValueError: If the discriminator output is malformed.
|
|
ValueError: If TF is executing eagerly.
|
|
"""
|
|
if tf.executing_eagerly():
|
|
raise ValueError('`tfgan.infogan_model` doesn\'t work when executing '
|
|
'eagerly.')
|
|
|
|
# Create models
|
|
with tf.compat.v1.variable_scope(generator_scope) as gen_scope:
|
|
unstructured_generator_inputs = _convert_tensor_or_l_or_d(
|
|
unstructured_generator_inputs)
|
|
structured_generator_inputs = _convert_tensor_or_l_or_d(
|
|
structured_generator_inputs)
|
|
generator_inputs = (
|
|
unstructured_generator_inputs + structured_generator_inputs)
|
|
generated_data = generator_fn(generator_inputs)
|
|
with tf.compat.v1.variable_scope(discriminator_scope) as disc_scope:
|
|
dis_gen_outputs, predicted_distributions = discriminator_fn(
|
|
generated_data, generator_inputs)
|
|
_validate_distributions(predicted_distributions, structured_generator_inputs)
|
|
with tf.compat.v1.variable_scope(disc_scope, reuse=True):
|
|
real_data = tf.convert_to_tensor(value=real_data)
|
|
dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)
|
|
|
|
if not generated_data.get_shape().is_compatible_with(real_data.get_shape()):
|
|
raise ValueError(
|
|
'Generator output shape (%s) must be the same shape as real data '
|
|
'(%s).' % (generated_data.get_shape(), real_data.get_shape()))
|
|
|
|
# Get model-specific variables.
|
|
generator_variables = contrib.get_trainable_variables(
|
|
gen_scope)
|
|
discriminator_variables = contrib.get_trainable_variables(
|
|
disc_scope)
|
|
|
|
return namedtuples.InfoGANModel(
|
|
generator_inputs,
|
|
generated_data,
|
|
generator_variables,
|
|
gen_scope,
|
|
generator_fn,
|
|
real_data,
|
|
dis_real_outputs,
|
|
dis_gen_outputs,
|
|
discriminator_variables,
|
|
disc_scope,
|
|
lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API
|
|
structured_generator_inputs,
|
|
predicted_distributions,
|
|
discriminator_fn)
|
|
|
|
|
|
def acgan_model(
|
|
# Lambdas defining models.
|
|
generator_fn,
|
|
discriminator_fn,
|
|
# Real data and conditioning.
|
|
real_data,
|
|
generator_inputs,
|
|
one_hot_labels,
|
|
# Optional scopes.
|
|
generator_scope='Generator',
|
|
discriminator_scope='Discriminator',
|
|
# Options.
|
|
check_shapes=True):
|
|
"""Returns an ACGANModel contains all the pieces needed for ACGAN training.
|
|
|
|
The `acgan_model` is the same as the `gan_model` with the only difference
|
|
being that the discriminator additionally outputs logits to classify the input
|
|
(real or generated).
|
|
Therefore, an explicit field holding one_hot_labels is necessary, as well as a
|
|
discriminator_fn that outputs a 2-tuple holding the logits for real/fake and
|
|
classification.
|
|
|
|
See https://arxiv.org/abs/1610.09585 for more details.
|
|
|
|
Args:
|
|
generator_fn: A python lambda that takes `generator_inputs` as inputs and
|
|
returns the outputs of the GAN generator.
|
|
discriminator_fn: A python lambda that takes `real_data`/`generated data`
|
|
and `generator_inputs`. Outputs a tuple consisting of two Tensors:
|
|
(1) real/fake logits in the range [-inf, inf]
|
|
(2) classification logits in the range [-inf, inf]
|
|
real_data: A Tensor representing the real data.
|
|
generator_inputs: A Tensor or list of Tensors to the generator. In the
|
|
vanilla GAN case, this might be a single noise Tensor. In the conditional
|
|
GAN case, this might be the generator's conditioning.
|
|
one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by
|
|
acgan_loss.
|
|
generator_scope: Optional generator variable scope. Useful if you want to
|
|
reuse a subgraph that has already been created.
|
|
discriminator_scope: Optional discriminator variable scope. Useful if you
|
|
want to reuse a subgraph that has already been created.
|
|
check_shapes: If `True`, check that generator produces Tensors that are the
|
|
same shape as real data. Otherwise, skip this check.
|
|
|
|
Returns:
|
|
A ACGANModel namedtuple.
|
|
|
|
Raises:
|
|
ValueError: If the generator outputs a Tensor that isn't the same shape as
|
|
`real_data`.
|
|
TypeError: If the discriminator does not output a tuple consisting of
|
|
(discrimination logits, classification logits).
|
|
ValueError: If TF is executing eagerly.
|
|
"""
|
|
if tf.executing_eagerly():
|
|
raise ValueError('`tfgan.acgan_model` doesn\'t work when executing '
|
|
'eagerly.')
|
|
|
|
# Create models
|
|
with tf.compat.v1.variable_scope(generator_scope) as gen_scope:
|
|
generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
|
|
generated_data = generator_fn(generator_inputs)
|
|
with tf.compat.v1.variable_scope(discriminator_scope) as dis_scope:
|
|
with tf.compat.v1.name_scope(dis_scope.name + '/generated/'):
|
|
(discriminator_gen_outputs, discriminator_gen_classification_logits
|
|
) = _validate_acgan_discriminator_outputs(
|
|
discriminator_fn(generated_data, generator_inputs))
|
|
with tf.compat.v1.variable_scope(dis_scope, reuse=True):
|
|
with tf.compat.v1.name_scope(dis_scope.name + '/real/'):
|
|
real_data = tf.convert_to_tensor(value=real_data)
|
|
(discriminator_real_outputs, discriminator_real_classification_logits
|
|
) = _validate_acgan_discriminator_outputs(
|
|
discriminator_fn(real_data, generator_inputs))
|
|
if check_shapes:
|
|
if not generated_data.shape.is_compatible_with(real_data.shape):
|
|
raise ValueError(
|
|
'Generator output shape (%s) must be the same shape as real data '
|
|
'(%s).' % (generated_data.shape, real_data.shape))
|
|
|
|
# Get model-specific variables.
|
|
generator_variables = contrib.get_trainable_variables(
|
|
gen_scope)
|
|
discriminator_variables = contrib.get_trainable_variables(
|
|
dis_scope)
|
|
|
|
return namedtuples.ACGANModel(
|
|
generator_inputs, generated_data, generator_variables, gen_scope,
|
|
generator_fn, real_data, discriminator_real_outputs,
|
|
discriminator_gen_outputs, discriminator_variables, dis_scope,
|
|
discriminator_fn, one_hot_labels,
|
|
discriminator_real_classification_logits,
|
|
discriminator_gen_classification_logits)
|
|
|
|
|
|
def cyclegan_model(
|
|
# Lambdas defining models.
|
|
generator_fn,
|
|
discriminator_fn,
|
|
# data X and Y.
|
|
data_x,
|
|
data_y,
|
|
# Optional scopes.
|
|
generator_scope='Generator',
|
|
discriminator_scope='Discriminator',
|
|
model_x2y_scope='ModelX2Y',
|
|
model_y2x_scope='ModelY2X',
|
|
# Options.
|
|
check_shapes=True):
|
|
"""Returns a CycleGAN model outputs and variables.
|
|
|
|
See https://arxiv.org/abs/1703.10593 for more details.
|
|
|
|
Args:
|
|
generator_fn: A python lambda that takes `data_x` or `data_y` as inputs and
|
|
returns the outputs of the GAN generator.
|
|
discriminator_fn: A python lambda that takes `real_data`/`generated data`
|
|
and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
|
|
data_x: A `Tensor` of dataset X. Must be the same shape as `data_y`.
|
|
data_y: A `Tensor` of dataset Y. Must be the same shape as `data_x`.
|
|
generator_scope: Optional generator variable scope. Useful if you want to
|
|
reuse a subgraph that has already been created. Defaults to 'Generator'.
|
|
discriminator_scope: Optional discriminator variable scope. Useful if you
|
|
want to reuse a subgraph that has already been created. Defaults to
|
|
'Discriminator'.
|
|
model_x2y_scope: Optional variable scope for model x2y variables. Defaults
|
|
to 'ModelX2Y'.
|
|
model_y2x_scope: Optional variable scope for model y2x variables. Defaults
|
|
to 'ModelY2X'.
|
|
check_shapes: If `True`, check that generator produces Tensors that are the
|
|
same shape as `data_x` (`data_y`). Otherwise, skip this check.
|
|
|
|
Returns:
|
|
A `CycleGANModel` namedtuple.
|
|
|
|
Raises:
|
|
ValueError: If `check_shapes` is True and `data_x` or the generator output
|
|
does not have the same shape as `data_y`.
|
|
ValueError: If TF is executing eagerly.
|
|
"""
|
|
if tf.executing_eagerly():
|
|
raise ValueError('`tfgan.cyclegan_model` doesn\'t work when executing '
|
|
'eagerly.')
|
|
|
|
# Create models.
|
|
def _define_partial_model(input_data, output_data):
|
|
return gan_model(
|
|
generator_fn=generator_fn,
|
|
discriminator_fn=discriminator_fn,
|
|
real_data=output_data,
|
|
generator_inputs=input_data,
|
|
generator_scope=generator_scope,
|
|
discriminator_scope=discriminator_scope,
|
|
check_shapes=check_shapes)
|
|
|
|
with tf.compat.v1.variable_scope(model_x2y_scope):
|
|
model_x2y = _define_partial_model(data_x, data_y)
|
|
with tf.compat.v1.variable_scope(model_y2x_scope):
|
|
model_y2x = _define_partial_model(data_y, data_x)
|
|
|
|
with tf.compat.v1.variable_scope(model_y2x.generator_scope, reuse=True):
|
|
reconstructed_x = model_y2x.generator_fn(model_x2y.generated_data)
|
|
with tf.compat.v1.variable_scope(model_x2y.generator_scope, reuse=True):
|
|
reconstructed_y = model_x2y.generator_fn(model_y2x.generated_data)
|
|
|
|
return namedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x,
|
|
reconstructed_y)
|
|
|
|
|
|
def stargan_model(generator_fn,
|
|
discriminator_fn,
|
|
input_data,
|
|
input_data_domain_label,
|
|
generator_scope='Generator',
|
|
discriminator_scope='Discriminator'):
|
|
"""Returns a StarGAN model outputs and variables.
|
|
|
|
See https://arxiv.org/abs/1711.09020 for more details.
|
|
|
|
Args:
|
|
generator_fn: A python lambda that takes `inputs` and `targets` as inputs
|
|
and returns 'generated_data' as the transformed version of `input` based
|
|
on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n,
|
|
num_domains), and `generated_data` has the same shape as `input`.
|
|
discriminator_fn: A python lambda that takes `inputs` and `num_domains` as
|
|
inputs and returns a tuple (`source_prediction`, `domain_prediction`).
|
|
`source_prediction` represents the source(real/generated) prediction by
|
|
the discriminator, and `domain_prediction` represents the domain
|
|
prediction/classification by the discriminator. `source_prediction` has
|
|
shape (n) and `domain_prediction` has shape (n, num_domains).
|
|
input_data: Tensor or a list of tensor of shape (n, h, w, c) representing
|
|
the real input images.
|
|
input_data_domain_label: Tensor or a list of tensor of shape (batch_size,
|
|
num_domains) representing the domain label associated with the real
|
|
images.
|
|
generator_scope: Optional generator variable scope. Useful if you want to
|
|
reuse a subgraph that has already been created.
|
|
discriminator_scope: Optional discriminator variable scope. Useful if you
|
|
want to reuse a subgraph that has already been created.
|
|
|
|
Returns:
|
|
StarGANModel nametuple return the tensor that are needed to compute the
|
|
loss.
|
|
|
|
Raises:
|
|
ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully
|
|
defined in every dimensions.
|
|
ValueError: If TF is executing eagerly.
|
|
"""
|
|
if tf.executing_eagerly():
|
|
raise ValueError('`tfgan.stargan_model` doesn\'t work when executing '
|
|
'eagerly.')
|
|
|
|
# Convert to tensor.
|
|
input_data = _convert_tensor_or_l_or_d(input_data)
|
|
input_data_domain_label = _convert_tensor_or_l_or_d(input_data_domain_label)
|
|
|
|
# Convert list of tensor to a single tensor if applicable.
|
|
if isinstance(input_data, (list, tuple)):
|
|
input_data = tf.concat([tf.convert_to_tensor(value=x) for x in input_data],
|
|
0)
|
|
if isinstance(input_data_domain_label, (list, tuple)):
|
|
input_data_domain_label = tf.concat(
|
|
[tf.convert_to_tensor(value=x) for x in input_data_domain_label], 0)
|
|
|
|
# Get batch_size, num_domains from the labels.
|
|
input_data_domain_label.shape.assert_has_rank(2)
|
|
input_data_domain_label.shape.assert_is_fully_defined()
|
|
batch_size, num_domains = input_data_domain_label.shape.as_list()
|
|
|
|
# Transform input_data to random target domains.
|
|
with tf.compat.v1.variable_scope(generator_scope) as generator_scope:
|
|
generated_data_domain_target = generate_stargan_random_domain_target(
|
|
batch_size, num_domains)
|
|
generated_data = generator_fn(input_data, generated_data_domain_target)
|
|
|
|
# Transform generated_data back to the original input_data domain.
|
|
with tf.compat.v1.variable_scope(generator_scope, reuse=True):
|
|
reconstructed_data = generator_fn(generated_data, input_data_domain_label)
|
|
|
|
# Predict source and domain for the generated_data using the discriminator.
|
|
with tf.compat.v1.variable_scope(discriminator_scope) as discriminator_scope:
|
|
disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn(
|
|
generated_data, num_domains)
|
|
|
|
# Predict source and domain for the input_data using the discriminator.
|
|
with tf.compat.v1.variable_scope(discriminator_scope, reuse=True):
|
|
disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn(
|
|
input_data, num_domains)
|
|
|
|
# Collect trainable variables from the neural networks.
|
|
generator_variables = contrib.get_trainable_variables(
|
|
generator_scope)
|
|
discriminator_variables = contrib.get_trainable_variables(
|
|
discriminator_scope)
|
|
|
|
# Create the StarGANModel namedtuple.
|
|
return namedtuples.StarGANModel(
|
|
input_data=input_data,
|
|
input_data_domain_label=input_data_domain_label,
|
|
generated_data=generated_data,
|
|
generated_data_domain_target=generated_data_domain_target,
|
|
reconstructed_data=reconstructed_data,
|
|
discriminator_input_data_source_predication=disc_input_data_source_pred,
|
|
discriminator_generated_data_source_predication=disc_gen_data_source_pred,
|
|
discriminator_input_data_domain_predication=disc_input_data_domain_pred,
|
|
discriminator_generated_data_domain_predication=disc_gen_data_domain_pred,
|
|
generator_variables=generator_variables,
|
|
generator_scope=generator_scope,
|
|
generator_fn=generator_fn,
|
|
discriminator_variables=discriminator_variables,
|
|
discriminator_scope=discriminator_scope,
|
|
discriminator_fn=discriminator_fn)
|
|
|
|
|
|
def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'):
|
|
if isinstance(aux_loss_weight, tf.Tensor):
|
|
aux_loss_weight.shape.assert_is_compatible_with([])
|
|
with tf.control_dependencies(
|
|
[tf.compat.v1.debugging.assert_greater_equal(aux_loss_weight, 0.0)]):
|
|
aux_loss_weight = tf.identity(aux_loss_weight)
|
|
elif aux_loss_weight is not None and aux_loss_weight < 0:
|
|
raise ValueError('`%s` must be greater than 0. Instead, was %s' %
|
|
(name, aux_loss_weight))
|
|
return aux_loss_weight
|
|
|
|
|
|
def _use_aux_loss(aux_loss_weight):
|
|
if aux_loss_weight is not None:
|
|
if not isinstance(aux_loss_weight, tf.Tensor):
|
|
return aux_loss_weight > 0
|
|
else:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def tensor_pool_adjusted_model(model, tensor_pool_fn):
|
|
"""Adjusts model using `tensor_pool_fn`.
|
|
|
|
Args:
|
|
model: A GANModel tuple.
|
|
tensor_pool_fn: A function that takes (generated_data, generator_inputs),
|
|
stores them in an internal pool and returns a previously stored
|
|
(generated_data, generator_inputs) with some probability. For example
|
|
tfgan.features.tensor_pool.
|
|
|
|
Returns:
|
|
A new GANModel tuple where discriminator outputs are adjusted by taking
|
|
pooled generator outputs as inputs. Returns the original model if
|
|
`tensor_pool_fn` is None.
|
|
|
|
Raises:
|
|
ValueError: If tensor pool does not support the `model`.
|
|
"""
|
|
if isinstance(model, namedtuples.GANModel):
|
|
pooled_generator_inputs, pooled_generated_data = tensor_pool_fn(
|
|
(model.generator_inputs, model.generated_data))
|
|
with tf.compat.v1.variable_scope(model.discriminator_scope, reuse=True):
|
|
dis_gen_outputs = model.discriminator_fn(pooled_generated_data,
|
|
pooled_generator_inputs)
|
|
return model._replace(
|
|
generator_inputs=pooled_generator_inputs,
|
|
generated_data=pooled_generated_data,
|
|
discriminator_gen_outputs=dis_gen_outputs)
|
|
elif isinstance(model, namedtuples.ACGANModel):
|
|
pooled_generator_inputs, pooled_generated_data = tensor_pool_fn(
|
|
(model.generator_inputs, model.generated_data))
|
|
with tf.compat.v1.variable_scope(model.discriminator_scope, reuse=True):
|
|
(pooled_discriminator_gen_outputs,
|
|
pooled_discriminator_gen_classification_logits) = model.discriminator_fn(
|
|
pooled_generated_data, pooled_generator_inputs)
|
|
return model._replace(
|
|
generator_inputs=pooled_generator_inputs,
|
|
generated_data=pooled_generated_data,
|
|
discriminator_gen_outputs=pooled_discriminator_gen_outputs,
|
|
discriminator_gen_classification_logits=
|
|
pooled_discriminator_gen_classification_logits)
|
|
elif isinstance(model, namedtuples.InfoGANModel):
|
|
pooled_generator_inputs, pooled_generated_data, pooled_structured_input = (
|
|
tensor_pool_fn((model.generator_inputs, model.generated_data,
|
|
model.structured_generator_inputs)))
|
|
with tf.compat.v1.variable_scope(model.discriminator_scope, reuse=True):
|
|
(pooled_discriminator_gen_outputs,
|
|
pooled_predicted_distributions) = model.discriminator_and_aux_fn(
|
|
pooled_generated_data, pooled_generator_inputs)
|
|
return model._replace(
|
|
generator_inputs=pooled_generator_inputs,
|
|
generated_data=pooled_generated_data,
|
|
structured_generator_inputs=pooled_structured_input,
|
|
discriminator_gen_outputs=pooled_discriminator_gen_outputs,
|
|
predicted_distributions=pooled_predicted_distributions)
|
|
else:
|
|
raise ValueError('Tensor pool does not support `model`: %s.' % type(model))
|
|
|
|
|
|
def gan_loss(
|
|
# GANModel.
|
|
model,
|
|
# Loss functions.
|
|
generator_loss_fn=tuple_losses.wasserstein_generator_loss,
|
|
discriminator_loss_fn=tuple_losses.wasserstein_discriminator_loss,
|
|
# Auxiliary losses.
|
|
gradient_penalty_weight=None,
|
|
gradient_penalty_epsilon=1e-10,
|
|
gradient_penalty_target=1.0,
|
|
gradient_penalty_one_sided=False,
|
|
mutual_information_penalty_weight=None,
|
|
aux_cond_generator_weight=None,
|
|
aux_cond_discriminator_weight=None,
|
|
tensor_pool_fn=None,
|
|
# Options.
|
|
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
|
add_summaries=True):
|
|
"""Returns losses necessary to train generator and discriminator.
|
|
|
|
Args:
|
|
model: A GANModel tuple.
|
|
generator_loss_fn: The loss function on the generator. Takes a GANModel
|
|
tuple. If it also takes `reduction` or `add_summaries`, it will be
|
|
passed those values as well. All TF-GAN loss functions have these
|
|
arguments.
|
|
discriminator_loss_fn: The loss function on the discriminator. Takes a
|
|
GANModel tuple. If it also takes `reduction` or `add_summaries`, it will
|
|
be passed those values as well. All TF-GAN loss functions have these
|
|
arguments.
|
|
gradient_penalty_weight: If not `None`, must be a non-negative Python number
|
|
or Tensor indicating how much to weight the gradient penalty. See
|
|
https://arxiv.org/pdf/1704.00028.pdf for more details.
|
|
gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the
|
|
small positive value used by the gradient penalty function for numerical
|
|
stability. Note some applications will need to increase this value to
|
|
avoid NaNs.
|
|
gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python
|
|
number or `Tensor` indicating the target value of gradient norm. See the
|
|
CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0.
|
|
gradient_penalty_one_sided: If `True`, penalty proposed in
|
|
https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
|
|
mutual_information_penalty_weight: If not `None`, must be a non-negative
|
|
Python number or Tensor indicating how much to weight the mutual
|
|
information penalty. See https://arxiv.org/abs/1606.03657 for more
|
|
details.
|
|
aux_cond_generator_weight: If not None: add a classification loss as in
|
|
https://arxiv.org/abs/1610.09585
|
|
aux_cond_discriminator_weight: If not None: add a classification loss as in
|
|
https://arxiv.org/abs/1610.09585
|
|
tensor_pool_fn: A function that takes (generated_data, generator_inputs),
|
|
stores them in an internal pool and returns previous stored
|
|
(generated_data, generator_inputs). For example
|
|
`tfgan.features.tensor_pool`. Defaults to None (not using tensor pool).
|
|
reduction: A `tf.losses.Reduction` to apply to loss, if the loss takes an
|
|
argument called `reduction`. Otherwise, this is ignored.
|
|
add_summaries: Whether or not to add summaries for the losses.
|
|
|
|
Returns:
|
|
A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes
|
|
regularization losses.
|
|
|
|
Raises:
|
|
ValueError: If any of the auxiliary loss weights is provided and negative.
|
|
ValueError: If `mutual_information_penalty_weight` is provided, but the
|
|
`model` isn't an `InfoGANModel`.
|
|
"""
|
|
# Validate arguments.
|
|
gradient_penalty_weight = _validate_aux_loss_weight(
|
|
gradient_penalty_weight, 'gradient_penalty_weight')
|
|
mutual_information_penalty_weight = _validate_aux_loss_weight(
|
|
mutual_information_penalty_weight, 'infogan_weight')
|
|
aux_cond_generator_weight = _validate_aux_loss_weight(
|
|
aux_cond_generator_weight, 'aux_cond_generator_weight')
|
|
aux_cond_discriminator_weight = _validate_aux_loss_weight(
|
|
aux_cond_discriminator_weight, 'aux_cond_discriminator_weight')
|
|
|
|
# Verify configuration for mutual information penalty
|
|
if (_use_aux_loss(mutual_information_penalty_weight) and
|
|
not isinstance(model, namedtuples.InfoGANModel)):
|
|
raise ValueError(
|
|
'When `mutual_information_penalty_weight` is provided, `model` must be '
|
|
'an `InfoGANModel`. Instead, was %s.' % type(model))
|
|
|
|
# Verify configuration for mutual auxiliary condition loss (ACGAN).
|
|
if ((_use_aux_loss(aux_cond_generator_weight) or
|
|
_use_aux_loss(aux_cond_discriminator_weight)) and
|
|
not isinstance(model, namedtuples.ACGANModel)):
|
|
raise ValueError(
|
|
'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` '
|
|
'is provided, `model` must be an `ACGANModel`. Instead, was %s.' %
|
|
type(model))
|
|
|
|
# Optionally create pooled model.
|
|
if tensor_pool_fn:
|
|
pooled_model = tensor_pool_adjusted_model(model, tensor_pool_fn)
|
|
else:
|
|
pooled_model = model
|
|
|
|
# Create standard losses with optional kwargs, if the loss functions accept
|
|
# them.
|
|
def _optional_kwargs(fn, possible_kwargs):
|
|
"""Returns a kwargs dictionary of valid kwargs for a given function."""
|
|
spec = inspect.getfullargspec(fn)
|
|
if spec.varkw is not None:
|
|
return possible_kwargs
|
|
actual_kwargs = {}
|
|
for k, v in possible_kwargs.items():
|
|
if k in spec.args or k in spec.kwonlyargs:
|
|
actual_kwargs[k] = v
|
|
return actual_kwargs
|
|
possible_kwargs = {'reduction': reduction, 'add_summaries': add_summaries}
|
|
gen_loss = generator_loss_fn(
|
|
model, **_optional_kwargs(generator_loss_fn, possible_kwargs))
|
|
dis_loss = discriminator_loss_fn(
|
|
pooled_model, **_optional_kwargs(discriminator_loss_fn, possible_kwargs))
|
|
|
|
# Add optional extra losses.
|
|
if _use_aux_loss(gradient_penalty_weight):
|
|
gp_loss = tuple_losses.wasserstein_gradient_penalty(
|
|
pooled_model,
|
|
epsilon=gradient_penalty_epsilon,
|
|
target=gradient_penalty_target,
|
|
one_sided=gradient_penalty_one_sided,
|
|
reduction=reduction,
|
|
add_summaries=add_summaries)
|
|
dis_loss += gradient_penalty_weight * gp_loss
|
|
if _use_aux_loss(mutual_information_penalty_weight):
|
|
gen_info_loss = tuple_losses.mutual_information_penalty(
|
|
model, reduction=reduction, add_summaries=add_summaries)
|
|
if tensor_pool_fn is None:
|
|
dis_info_loss = gen_info_loss
|
|
else:
|
|
dis_info_loss = tuple_losses.mutual_information_penalty(
|
|
pooled_model, reduction=reduction, add_summaries=add_summaries)
|
|
gen_loss += mutual_information_penalty_weight * gen_info_loss
|
|
dis_loss += mutual_information_penalty_weight * dis_info_loss
|
|
if _use_aux_loss(aux_cond_generator_weight):
|
|
ac_gen_loss = tuple_losses.acgan_generator_loss(
|
|
model, reduction=reduction, add_summaries=add_summaries)
|
|
gen_loss += aux_cond_generator_weight * ac_gen_loss
|
|
if _use_aux_loss(aux_cond_discriminator_weight):
|
|
ac_disc_loss = tuple_losses.acgan_discriminator_loss(
|
|
pooled_model, reduction=reduction, add_summaries=add_summaries)
|
|
dis_loss += aux_cond_discriminator_weight * ac_disc_loss
|
|
# Gathers auxiliary losses.
|
|
if model.generator_scope:
|
|
gen_reg_loss = tf.compat.v1.losses.get_regularization_loss(
|
|
model.generator_scope.name)
|
|
else:
|
|
gen_reg_loss = 0
|
|
if model.discriminator_scope:
|
|
dis_reg_loss = tf.compat.v1.losses.get_regularization_loss(
|
|
model.discriminator_scope.name)
|
|
else:
|
|
dis_reg_loss = 0
|
|
|
|
return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss)
|
|
|
|
|
|
def cyclegan_loss(
|
|
model,
|
|
# Loss functions.
|
|
generator_loss_fn=tuple_losses.least_squares_generator_loss,
|
|
discriminator_loss_fn=tuple_losses.least_squares_discriminator_loss,
|
|
# Auxiliary losses.
|
|
cycle_consistency_loss_fn=tuple_losses.cycle_consistency_loss,
|
|
cycle_consistency_loss_weight=10.0,
|
|
# Options
|
|
**kwargs):
|
|
"""Returns the losses for a `CycleGANModel`.
|
|
|
|
See https://arxiv.org/abs/1703.10593 for more details.
|
|
|
|
Args:
|
|
model: A `CycleGANModel` namedtuple.
|
|
generator_loss_fn: The loss function on the generator. Takes a `GANModel`
|
|
named tuple.
|
|
discriminator_loss_fn: The loss function on the discriminator. Takes a
|
|
`GANModel` namedtuple.
|
|
cycle_consistency_loss_fn: The cycle consistency loss function. Takes a
|
|
`CycleGANModel` namedtuple.
|
|
cycle_consistency_loss_weight: A non-negative Python number or a scalar
|
|
`Tensor` indicating how much to weigh the cycle consistency loss.
|
|
**kwargs: Keyword args to pass directly to `gan_loss` to construct the loss
|
|
for each partial model of `model`.
|
|
|
|
Returns:
|
|
A `CycleGANLoss` namedtuple.
|
|
|
|
Raises:
|
|
ValueError: If `model` is not a `CycleGANModel` namedtuple.
|
|
"""
|
|
# Sanity checks.
|
|
if not isinstance(model, namedtuples.CycleGANModel):
|
|
raise ValueError(
|
|
'`model` must be a `CycleGANModel`. Instead, was %s.' % type(model))
|
|
|
|
# Defines cycle consistency loss.
|
|
cycle_consistency_loss = cycle_consistency_loss_fn(
|
|
model, add_summaries=kwargs.get('add_summaries', True))
|
|
cycle_consistency_loss_weight = _validate_aux_loss_weight(
|
|
cycle_consistency_loss_weight, 'cycle_consistency_loss_weight')
|
|
aux_loss = cycle_consistency_loss_weight * cycle_consistency_loss
|
|
|
|
# Defines losses for each partial model.
|
|
def _partial_loss(partial_model):
|
|
partial_loss = gan_loss(
|
|
partial_model,
|
|
generator_loss_fn=generator_loss_fn,
|
|
discriminator_loss_fn=discriminator_loss_fn,
|
|
**kwargs)
|
|
return partial_loss._replace(generator_loss=partial_loss.generator_loss +
|
|
aux_loss)
|
|
|
|
with tf.compat.v1.name_scope('cyclegan_loss_x2y'):
|
|
loss_x2y = _partial_loss(model.model_x2y)
|
|
with tf.compat.v1.name_scope('cyclegan_loss_y2x'):
|
|
loss_y2x = _partial_loss(model.model_y2x)
|
|
|
|
return namedtuples.CycleGANLoss(loss_x2y, loss_y2x)
|
|
|
|
|
|
# Begin google-internal
|
|
# The four major parts can be found here: http://screen/tMRMBAohDYG.
|
|
# End google-internal
|
|
def stargan_loss(
|
|
model,
|
|
generator_loss_fn=tuple_losses.stargan_generator_loss_wrapper(
|
|
losses_wargs.wasserstein_generator_loss),
|
|
discriminator_loss_fn=tuple_losses.stargan_discriminator_loss_wrapper(
|
|
losses_wargs.wasserstein_discriminator_loss),
|
|
gradient_penalty_weight=10.0,
|
|
gradient_penalty_epsilon=1e-10,
|
|
gradient_penalty_target=1.0,
|
|
gradient_penalty_one_sided=False,
|
|
reconstruction_loss_fn=tf.compat.v1.losses.absolute_difference,
|
|
reconstruction_loss_weight=10.0,
|
|
classification_loss_fn=tf.compat.v1.losses.softmax_cross_entropy,
|
|
classification_loss_weight=1.0,
|
|
classification_one_hot=True,
|
|
add_summaries=True):
|
|
"""StarGAN Loss.
|
|
|
|
Args:
|
|
model: (StarGAN) Model output of the stargan_model() function call.
|
|
generator_loss_fn: The loss function on the generator. Takes a
|
|
`StarGANModel` named tuple.
|
|
discriminator_loss_fn: The loss function on the discriminator. Takes a
|
|
`StarGANModel` namedtuple.
|
|
gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per
|
|
the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to
|
|
turn off gradient penalty.
|
|
gradient_penalty_epsilon: (float) A small positive number added for
|
|
numerical stability when computing the gradient norm.
|
|
gradient_penalty_target: (float, or tf.float `Tensor`) The target value of
|
|
gradient norm. Defaults to 1.0.
|
|
gradient_penalty_one_sided: (bool) If `True`, penalty proposed in
|
|
https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
|
|
reconstruction_loss_fn: The reconstruction loss function. Default to L1-norm
|
|
and the function must conform to the `tf.losses` API.
|
|
reconstruction_loss_weight: Reconstruction loss weight. Default to 10.0.
|
|
classification_loss_fn: The loss function on the discriminator's ability to
|
|
classify domain of the input. Default to one-hot softmax cross entropy
|
|
loss, and the function must conform to the `tf.losses` API.
|
|
classification_loss_weight: (float) Classification loss weight. Default to
|
|
1.0.
|
|
classification_one_hot: (bool) If the label is one hot representation.
|
|
Default to True. If False, classification classification_loss_fn need to
|
|
be sigmoid cross entropy loss instead.
|
|
add_summaries: (bool) Add the loss to the summary
|
|
|
|
Returns:
|
|
GANLoss namedtuple where we have generator loss and discriminator loss.
|
|
|
|
Raises:
|
|
ValueError: If input StarGANModel.input_data_domain_label does not have rank
|
|
2, or dimension 2 is not defined.
|
|
"""
|
|
|
|
def _classification_loss_helper(true_labels, predict_logits, scope_name):
|
|
"""Classification Loss Function Helper.
|
|
|
|
Args:
|
|
true_labels: Tensor of shape [batch_size, num_domains] representing the
|
|
label where each row is an one-hot vector.
|
|
predict_logits: Tensor of shape [batch_size, num_domains] representing the
|
|
predicted label logit, which is UNSCALED output from the NN.
|
|
scope_name: (string) Name scope of the loss component.
|
|
|
|
Returns:
|
|
Single scalar tensor representing the classification loss.
|
|
"""
|
|
|
|
with tf.compat.v1.name_scope(
|
|
scope_name, values=(true_labels, predict_logits)):
|
|
|
|
loss = classification_loss_fn(
|
|
onehot_labels=true_labels, logits=predict_logits)
|
|
|
|
if not classification_one_hot:
|
|
loss = tf.reduce_sum(input_tensor=loss, axis=1)
|
|
loss = tf.reduce_mean(input_tensor=loss)
|
|
|
|
if add_summaries:
|
|
tf.compat.v1.summary.scalar(scope_name, loss)
|
|
|
|
return loss
|
|
|
|
# Check input shape.
|
|
model.input_data_domain_label.shape.assert_has_rank(2)
|
|
model.input_data_domain_label.shape[1:].assert_is_fully_defined()
|
|
|
|
# Adversarial Loss.
|
|
generator_loss = generator_loss_fn(model, add_summaries=add_summaries)
|
|
discriminator_loss = discriminator_loss_fn(model, add_summaries=add_summaries)
|
|
|
|
# Gradient Penalty.
|
|
if _use_aux_loss(gradient_penalty_weight):
|
|
gradient_penalty_fn = tuple_losses.stargan_gradient_penalty_wrapper(
|
|
losses_wargs.wasserstein_gradient_penalty)
|
|
discriminator_loss += gradient_penalty_fn(
|
|
model,
|
|
epsilon=gradient_penalty_epsilon,
|
|
target=gradient_penalty_target,
|
|
one_sided=gradient_penalty_one_sided,
|
|
add_summaries=add_summaries) * gradient_penalty_weight
|
|
|
|
# Reconstruction Loss.
|
|
reconstruction_loss = reconstruction_loss_fn(model.input_data,
|
|
model.reconstructed_data)
|
|
generator_loss += reconstruction_loss * reconstruction_loss_weight
|
|
if add_summaries:
|
|
tf.compat.v1.summary.scalar('reconstruction_loss', reconstruction_loss)
|
|
|
|
# Classification Loss.
|
|
generator_loss += _classification_loss_helper(
|
|
true_labels=model.generated_data_domain_target,
|
|
predict_logits=model.discriminator_generated_data_domain_predication,
|
|
scope_name='generator_classification_loss') * classification_loss_weight
|
|
discriminator_loss += _classification_loss_helper(
|
|
true_labels=model.input_data_domain_label,
|
|
predict_logits=model.discriminator_input_data_domain_predication,
|
|
scope_name='discriminator_classification_loss'
|
|
) * classification_loss_weight
|
|
|
|
return namedtuples.GANLoss(generator_loss, discriminator_loss)
|
|
|
|
|
|
def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
|
|
"""Gets generator and discriminator update ops.
|
|
|
|
Args:
|
|
kwargs: A dictionary of kwargs to be passed to `create_train_op`.
|
|
`update_ops` is removed, if present.
|
|
gen_scope: A scope for the generator.
|
|
dis_scope: A scope for the discriminator.
|
|
check_for_unused_ops: A Python bool. If `True`, throw Exception if there are
|
|
unused update ops.
|
|
|
|
Returns:
|
|
A 2-tuple of (generator update ops, discriminator train ops).
|
|
|
|
Raises:
|
|
ValueError: If there are update ops outside of the generator or
|
|
discriminator scopes.
|
|
"""
|
|
if 'update_ops' in kwargs:
|
|
update_ops = set(kwargs['update_ops'])
|
|
del kwargs['update_ops']
|
|
else:
|
|
update_ops = set(
|
|
tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS))
|
|
|
|
all_gen_ops = set(
|
|
tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, gen_scope))
|
|
all_dis_ops = set(
|
|
tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, dis_scope))
|
|
|
|
if check_for_unused_ops:
|
|
unused_ops = update_ops - all_gen_ops - all_dis_ops
|
|
if unused_ops:
|
|
raise ValueError('There are unused update ops: %s' % unused_ops)
|
|
|
|
gen_update_ops = list(all_gen_ops & update_ops)
|
|
dis_update_ops = list(all_dis_ops & update_ops)
|
|
|
|
return gen_update_ops, dis_update_ops
|
|
|
|
|
|
def gan_train_ops(
|
|
model,
|
|
loss,
|
|
generator_optimizer,
|
|
discriminator_optimizer,
|
|
check_for_unused_update_ops=True,
|
|
is_chief=True,
|
|
# Optional args to pass directly to the `create_train_op`.
|
|
**kwargs):
|
|
"""Returns GAN train ops.
|
|
|
|
The highest-level call in TF-GAN. It is composed of functions that can also
|
|
be called, should a user require more control over some part of the GAN
|
|
training process.
|
|
|
|
Args:
|
|
model: A GANModel.
|
|
loss: A GANLoss.
|
|
generator_optimizer: The optimizer for generator updates.
|
|
discriminator_optimizer: The optimizer for the discriminator updates.
|
|
check_for_unused_update_ops: If `True`, throws an exception if there are
|
|
update ops outside of the generator or discriminator scopes.
|
|
is_chief: Specifies whether or not the training is being run by the primary
|
|
replica during replica training.
|
|
**kwargs: Keyword args to pass directly to
|
|
`training.create_train_op` for both the generator and
|
|
discriminator train op.
|
|
|
|
Returns:
|
|
A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
|
|
be used to train a generator/discriminator pair.
|
|
"""
|
|
if isinstance(model, namedtuples.CycleGANModel):
|
|
# Get and store all arguments other than model and loss from locals.
|
|
# Contents of locals should not be modified, may not affect values. So make
|
|
# a copy. https://docs.python.org/2/library/functions.html#locals.
|
|
saved_params = dict(locals())
|
|
saved_params.pop('model', None)
|
|
saved_params.pop('loss', None)
|
|
kwargs = saved_params.pop('kwargs', {})
|
|
saved_params.update(kwargs)
|
|
with tf.compat.v1.name_scope('cyclegan_x2y_train'):
|
|
train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y,
|
|
**saved_params)
|
|
with tf.compat.v1.name_scope('cyclegan_y2x_train'):
|
|
train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x,
|
|
**saved_params)
|
|
return namedtuples.GANTrainOps(
|
|
(train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op),
|
|
(train_ops_x2y.discriminator_train_op,
|
|
train_ops_y2x.discriminator_train_op),
|
|
tf.compat.v1.train.get_or_create_global_step().assign_add(1))
|
|
|
|
# Create global step increment op.
|
|
global_step = tf.compat.v1.train.get_or_create_global_step()
|
|
global_step_inc = global_step.assign_add(1)
|
|
|
|
# Get generator and discriminator update ops. We split them so that update
|
|
# ops aren't accidentally run multiple times. For now, throw an error if
|
|
# there are update ops that aren't associated with either the generator or
|
|
# the discriminator. Might modify the `kwargs` dictionary.
|
|
gen_update_ops, dis_update_ops = _get_update_ops(
|
|
kwargs, model.generator_scope.name, model.discriminator_scope.name,
|
|
check_for_unused_update_ops)
|
|
|
|
# Get the sync hooks if these are needed.
|
|
sync_hooks = []
|
|
|
|
generator_global_step = None
|
|
if isinstance(generator_optimizer, tf.compat.v1.train.SyncReplicasOptimizer):
|
|
# TODO(joelshor): Figure out a way to get this work without including the
|
|
# dummy global step in the checkpoint.
|
|
# WARNING: Making this variable a local variable causes sync replicas to
|
|
# hang forever.
|
|
generator_global_step = tf.compat.v1.get_variable(
|
|
'dummy_global_step_generator',
|
|
shape=[],
|
|
dtype=global_step.dtype.base_dtype,
|
|
initializer=tf.compat.v1.initializers.zeros(),
|
|
trainable=False,
|
|
collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES])
|
|
gen_update_ops += [generator_global_step.assign(global_step)]
|
|
sync_hooks.append(generator_optimizer.make_session_run_hook(is_chief))
|
|
with tf.compat.v1.name_scope('generator_train'):
|
|
gen_train_op = contrib.create_train_op(
|
|
total_loss=loss.generator_loss,
|
|
optimizer=generator_optimizer,
|
|
variables_to_train=model.generator_variables,
|
|
global_step=generator_global_step,
|
|
update_ops=gen_update_ops,
|
|
**kwargs)
|
|
|
|
discriminator_global_step = None
|
|
if isinstance(discriminator_optimizer,
|
|
tf.compat.v1.train.SyncReplicasOptimizer):
|
|
# See comment above `generator_global_step`.
|
|
discriminator_global_step = tf.compat.v1.get_variable(
|
|
'dummy_global_step_discriminator',
|
|
shape=[],
|
|
dtype=global_step.dtype.base_dtype,
|
|
initializer=tf.compat.v1.initializers.zeros(),
|
|
trainable=False,
|
|
collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES])
|
|
dis_update_ops += [discriminator_global_step.assign(global_step)]
|
|
sync_hooks.append(discriminator_optimizer.make_session_run_hook(is_chief))
|
|
with tf.compat.v1.name_scope('discriminator_train'):
|
|
disc_train_op = contrib.create_train_op(
|
|
total_loss=loss.discriminator_loss,
|
|
optimizer=discriminator_optimizer,
|
|
variables_to_train=model.discriminator_variables,
|
|
global_step=discriminator_global_step,
|
|
update_ops=dis_update_ops,
|
|
**kwargs)
|
|
|
|
return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc,
|
|
sync_hooks)
|
|
|
|
|
|
# TODO(joelshor): Implement a dynamic GAN train loop, as in `Real-Time Adaptive
|
|
# Image Compression` (https://arxiv.org/abs/1705.05823)
|
|
class RunTrainOpsHook(tf_estimator.SessionRunHook):
|
|
"""A hook to run train ops a fixed number of times."""
|
|
|
|
def __init__(self, train_ops, train_steps):
|
|
"""Run train ops a certain number of times.
|
|
|
|
Args:
|
|
train_ops: A train op or iterable of train ops to run.
|
|
train_steps: The number of times to run the op(s).
|
|
"""
|
|
if not isinstance(train_ops, (list, tuple)):
|
|
train_ops = [train_ops]
|
|
self._train_ops = train_ops
|
|
self._train_steps = train_steps
|
|
|
|
def before_run(self, run_context):
|
|
for _ in range(self._train_steps):
|
|
run_context.session.run(self._train_ops)
|
|
|
|
|
|
def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
|
|
"""Returns a hooks function for sequential GAN training.
|
|
|
|
Args:
|
|
train_steps: A `GANTrainSteps` tuple that determines how many generator
|
|
and discriminator training steps to take.
|
|
|
|
Returns:
|
|
A function that takes a GANTrainOps tuple and returns a list of hooks.
|
|
"""
|
|
|
|
def get_hooks(train_ops):
|
|
generator_hook = RunTrainOpsHook(train_ops.generator_train_op,
|
|
train_steps.generator_train_steps)
|
|
discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op,
|
|
train_steps.discriminator_train_steps)
|
|
return [generator_hook, discriminator_hook] + list(train_ops.train_hooks)
|
|
|
|
return get_hooks
|
|
|
|
|
|
def _num_joint_steps(train_steps):
|
|
g_steps = train_steps.generator_train_steps
|
|
d_steps = train_steps.discriminator_train_steps
|
|
# Get the number of each type of step that should be run.
|
|
num_d_and_g_steps = min(g_steps, d_steps)
|
|
num_g_steps = g_steps - num_d_and_g_steps
|
|
num_d_steps = d_steps - num_d_and_g_steps
|
|
|
|
return num_d_and_g_steps, num_g_steps, num_d_steps
|
|
|
|
|
|
def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
|
|
"""Returns a hooks function for joint GAN training.
|
|
|
|
When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON
|
|
ALL OPTIMIZERS TO AVOID RACE CONDITIONS.
|
|
|
|
The order of steps taken is:
|
|
1) Combined generator and discriminator steps
|
|
2) Generator only steps, if any remain
|
|
3) Discriminator only steps, if any remain
|
|
|
|
**NOTE**: Unlike `get_sequential_train_hooks`, this method performs updates
|
|
for the generator and discriminator simultaneously whenever possible. This
|
|
reduces the number of `tf.Session` calls, and can also change the training
|
|
semantics.
|
|
|
|
To illustrate the difference look at the following example:
|
|
|
|
`train_steps=namedtuples.GANTrainSteps(3, 5)` will cause
|
|
`get_sequential_train_hooks` to make 8 session calls:
|
|
1) 3 generator steps
|
|
2) 5 discriminator steps
|
|
|
|
In contrast, `get_joint_train_steps` will make 5 session calls:
|
|
1) 3 generator + discriminator steps
|
|
2) 2 discriminator steps
|
|
|
|
Args:
|
|
train_steps: A `GANTrainSteps` tuple that determines how many generator
|
|
and discriminator training steps to take.
|
|
|
|
Returns:
|
|
A function that takes a GANTrainOps tuple and returns a list of hooks.
|
|
"""
|
|
num_d_and_g_steps, num_g_steps, num_d_steps = _num_joint_steps(train_steps)
|
|
|
|
def get_hooks(train_ops):
|
|
g_op = train_ops.generator_train_op
|
|
d_op = train_ops.discriminator_train_op
|
|
|
|
joint_hook = RunTrainOpsHook([g_op, d_op], num_d_and_g_steps)
|
|
g_hook = RunTrainOpsHook(g_op, num_g_steps)
|
|
d_hook = RunTrainOpsHook(d_op, num_d_steps)
|
|
|
|
return [joint_hook, g_hook, d_hook] + list(train_ops.train_hooks)
|
|
|
|
return get_hooks
|
|
|
|
|
|
# TODO(joelshor): This function currently returns the global step. Find a
|
|
# good way for it to return the generator, discriminator, and final losses.
|
|
def gan_train(train_ops,
|
|
logdir,
|
|
get_hooks_fn=get_sequential_train_hooks(),
|
|
master='',
|
|
is_chief=True,
|
|
scaffold=None,
|
|
hooks=None,
|
|
chief_only_hooks=None,
|
|
save_checkpoint_secs=600,
|
|
save_summaries_steps=100,
|
|
max_wait_secs=7200,
|
|
config=None):
|
|
"""A wrapper around `contrib.training.train` that uses GAN hooks.
|
|
|
|
Args:
|
|
train_ops: A GANTrainOps named tuple.
|
|
logdir: The directory where the graph and checkpoints are saved.
|
|
get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
|
|
of hooks.
|
|
master: The URL of the master.
|
|
is_chief: Specifies whether or not the training is being run by the primary
|
|
replica during replica training.
|
|
scaffold: An tf.train.Scaffold instance.
|
|
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
|
|
training loop.
|
|
chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
|
|
inside the training loop for the chief trainer only.
|
|
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
|
|
using a default checkpoint saver. If `save_checkpoint_secs` is set to
|
|
`None`, then the default checkpoint saver isn't used.
|
|
save_summaries_steps: The frequency, in number of global steps, that the
|
|
summaries are written to disk using a default summary saver. If
|
|
`save_summaries_steps` is set to `None`, then the default summary saver
|
|
isn't used.
|
|
max_wait_secs: Maximum time workers should wait for the session to
|
|
become available. This should be kept relatively short to help detect
|
|
incorrect code, but sometimes may need to be increased if the chief takes
|
|
a while to start up.
|
|
config: An instance of `tf.ConfigProto`.
|
|
|
|
Returns:
|
|
Output of the call to `training.train`.
|
|
"""
|
|
_validate_gan_train_inputs(logdir, is_chief, save_summaries_steps,
|
|
save_checkpoint_secs)
|
|
new_hooks = get_hooks_fn(train_ops)
|
|
if hooks is not None:
|
|
hooks = list(hooks) + list(new_hooks)
|
|
else:
|
|
hooks = new_hooks
|
|
|
|
with tf.compat.v1.train.MonitoredTrainingSession(
|
|
master=master,
|
|
is_chief=is_chief,
|
|
checkpoint_dir=logdir,
|
|
scaffold=scaffold,
|
|
hooks=hooks,
|
|
chief_only_hooks=chief_only_hooks,
|
|
save_checkpoint_secs=save_checkpoint_secs,
|
|
save_summaries_steps=save_summaries_steps,
|
|
config=config,
|
|
max_wait_secs=max_wait_secs) as session:
|
|
gstep = None
|
|
while not session.should_stop():
|
|
gstep = session.run(train_ops.global_step_inc_op)
|
|
return gstep
|
|
|
|
|
|
def _validate_gan_train_inputs(logdir, is_chief, save_summaries_steps,
|
|
save_checkpoint_secs):
|
|
if logdir is None and is_chief:
|
|
if save_summaries_steps:
|
|
raise ValueError(
|
|
'logdir cannot be None when save_summaries_steps is not None')
|
|
if save_checkpoint_secs:
|
|
raise ValueError(
|
|
'logdir cannot be None when save_checkpoint_secs is not None')
|
|
|
|
|
|
def get_sequential_train_steps(train_steps=namedtuples.GANTrainSteps(1, 1)):
|
|
"""Returns a thin wrapper around slim.learning.train_step, for GANs.
|
|
|
|
This function is to provide support for the Supervisor. For new code, please
|
|
use `MonitoredSession` and `get_sequential_train_hooks`.
|
|
|
|
Args:
|
|
train_steps: A `GANTrainSteps` tuple that determines how many generator
|
|
and discriminator training steps to take.
|
|
|
|
Returns:
|
|
A function that can be used for `train_step_fn` for GANs.
|
|
"""
|
|
|
|
def sequential_train_steps(sess, train_ops, global_step, train_step_kwargs):
|
|
"""A thin wrapper around slim.learning.train_step, for GANs.
|
|
|
|
Args:
|
|
sess: A Tensorflow session.
|
|
train_ops: A GANTrainOps tuple of train ops to run.
|
|
global_step: The global step.
|
|
train_step_kwargs: Dictionary controlling `train_step` behavior.
|
|
|
|
Returns:
|
|
A scalar final loss and a bool whether or not the train loop should stop.
|
|
"""
|
|
# Only run `should_stop` at the end, if required. Make a local copy of
|
|
# `train_step_kwargs`, if necessary, so as not to modify the caller's
|
|
# dictionary.
|
|
should_stop_op, train_kwargs = None, train_step_kwargs
|
|
if 'should_stop' in train_step_kwargs:
|
|
should_stop_op = train_step_kwargs['should_stop']
|
|
train_kwargs = train_step_kwargs.copy()
|
|
del train_kwargs['should_stop']
|
|
|
|
# Run generator training steps.
|
|
gen_loss = 0
|
|
for _ in range(train_steps.generator_train_steps):
|
|
cur_gen_loss, _ = train_step(
|
|
sess, train_ops.generator_train_op, global_step, train_kwargs)
|
|
gen_loss += cur_gen_loss
|
|
|
|
# Run discriminator training steps.
|
|
dis_loss = 0
|
|
for _ in range(train_steps.discriminator_train_steps):
|
|
cur_dis_loss, _ = train_step(
|
|
sess, train_ops.discriminator_train_op, global_step, train_kwargs)
|
|
dis_loss += cur_dis_loss
|
|
|
|
sess.run(train_ops.global_step_inc_op)
|
|
|
|
# Run the `should_stop` op after the global step has been incremented, so
|
|
# that the `should_stop` aligns with the proper `global_step` count.
|
|
if should_stop_op is not None:
|
|
should_stop = sess.run(should_stop_op)
|
|
else:
|
|
should_stop = False
|
|
|
|
return gen_loss + dis_loss, should_stop
|
|
|
|
return sequential_train_steps
|
|
|
|
|
|
# Helpers
|
|
|
|
|
|
def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
|
|
"""Convert input, list of inputs, or dictionary of inputs to Tensors."""
|
|
if isinstance(tensor_or_l_or_d, (list, tuple)):
|
|
return [tf.convert_to_tensor(value=x) for x in tensor_or_l_or_d]
|
|
elif isinstance(tensor_or_l_or_d, dict):
|
|
return {
|
|
k: tf.convert_to_tensor(value=v) for k, v in tensor_or_l_or_d.items()
|
|
}
|
|
else:
|
|
return tf.convert_to_tensor(value=tensor_or_l_or_d)
|
|
|
|
|
|
def _validate_distributions(distributions_l, noise_l):
|
|
if not isinstance(distributions_l, (tuple, list)):
|
|
raise ValueError('`predicted_distributions` must be a list. Instead, found '
|
|
'%s.' % type(distributions_l))
|
|
if len(distributions_l) != len(noise_l):
|
|
raise ValueError('Length of `predicted_distributions` %i must be the same '
|
|
'as the length of structured noise %i.' %
|
|
(len(distributions_l), len(noise_l)))
|
|
|
|
|
|
def _validate_acgan_discriminator_outputs(discriminator_output):
|
|
try:
|
|
a, b = discriminator_output
|
|
except (TypeError, ValueError):
|
|
raise TypeError(
|
|
'A discriminator function for ACGAN must output a tuple '
|
|
'consisting of (discrimination logits, classification logits).')
|
|
return a, b
|
|
|
|
|
|
def generate_stargan_random_domain_target(batch_size, num_domains):
|
|
"""Generate random domain label.
|
|
|
|
Args:
|
|
batch_size: (int) Number of random domain label.
|
|
num_domains: (int) Number of domains representing with the label.
|
|
|
|
Returns:
|
|
Tensor of shape (batch_size, num_domains) representing random label.
|
|
"""
|
|
domain_idx = tf.random.uniform([batch_size],
|
|
minval=0,
|
|
maxval=num_domains,
|
|
dtype=tf.int32)
|
|
|
|
return tf.one_hot(domain_idx, num_domains)
|
|
|
|
|
|
# Slightly modified from
|
|
# `third_party/tensorflow/contrib/slim/python/slim/learning.py`.
|
|
def train_step(sess, train_op, global_step, train_step_kwargs):
|
|
"""Function that takes a gradient step and specifies whether to stop.
|
|
|
|
Args:
|
|
sess: The current session.
|
|
train_op: An `Operation` that evaluates the gradients and returns the
|
|
total loss.
|
|
global_step: A `Tensor` representing the global training step.
|
|
train_step_kwargs: A dictionary of keyword arguments.
|
|
|
|
Returns:
|
|
The total loss and a boolean indicating whether or not to stop training.
|
|
|
|
Raises:
|
|
ValueError: If 'should_trace' is in `train_step_kwargs` but `logdir` is not.
|
|
"""
|
|
start_time = time.time()
|
|
|
|
trace_run_options = None
|
|
run_metadata = None
|
|
if 'should_trace' in train_step_kwargs:
|
|
if 'logdir' not in train_step_kwargs:
|
|
raise ValueError('logdir must be present in train_step_kwargs when '
|
|
'should_trace is present')
|
|
if sess.run(train_step_kwargs['should_trace']):
|
|
trace_run_options = tf.compat.v1.RunOptions(
|
|
trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
|
|
run_metadata = tf.compat.v1.RunMetadata()
|
|
|
|
total_loss, np_global_step = sess.run([train_op, global_step],
|
|
options=trace_run_options,
|
|
run_metadata=run_metadata)
|
|
time_elapsed = time.time() - start_time
|
|
|
|
if run_metadata is not None:
|
|
trace_filename = os.path.join(train_step_kwargs['logdir'],
|
|
'tf_trace-%d.json' % np_global_step)
|
|
tf.compat.v1.logging.info('Writing trace to %s', trace_filename)
|
|
if 'summary_writer' in train_step_kwargs:
|
|
train_step_kwargs['summary_writer'].add_run_metadata(run_metadata,
|
|
'run_metadata-%d' %
|
|
np_global_step)
|
|
|
|
if 'should_log' in train_step_kwargs:
|
|
if sess.run(train_step_kwargs['should_log']):
|
|
tf.compat.v1.logging.info('global step %d: loss = %.4f (%.3f sec/step)',
|
|
np_global_step, total_loss, time_elapsed)
|
|
|
|
# TODO(joelshor): Figure out why we can't put this into sess.run. The
|
|
# issue right now is that the stop check depends on the global step. The
|
|
# increment of global step often happens via the train op, which used
|
|
# created using optimizer.apply_gradients.
|
|
#
|
|
# Since running `train_op` causes the global step to be incremented, one
|
|
# would expected that using a control dependency would allow the
|
|
# should_stop check to be run in the same session.run call:
|
|
#
|
|
# with ops.control_dependencies([train_op]):
|
|
# should_stop_op = ...
|
|
#
|
|
# However, this actually seems not to work on certain platforms.
|
|
if 'should_stop' in train_step_kwargs:
|
|
should_stop = sess.run(train_step_kwargs['should_stop'])
|
|
else:
|
|
should_stop = False
|
|
|
|
return total_loss, should_stop
|