remove internal modules (#1812)

* remove internal modules

* clean up

* refact

* remove six

* remove losses

* refact get_config
This commit is contained in:
who who who 2020-05-20 01:18:34 +08:00 committed by GitHub
parent 37dbb92e3b
commit 40d8e59cbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 57 additions and 12 deletions

View File

@ -16,7 +16,7 @@
import tensorflow as tf
from tensorflow.python.keras.losses import LossFunctionWrapper
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
from tensorflow_addons.utils.types import TensorLike, Number
from typeguard import typechecked

View File

@ -17,7 +17,7 @@
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.python.keras.losses import LossFunctionWrapper
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
from tensorflow_addons.utils.types import FloatTensorLike, TensorLike
from typeguard import typechecked

View File

@ -15,7 +15,7 @@
"""Implements GIoU loss."""
import tensorflow as tf
from tensorflow.python.keras.losses import LossFunctionWrapper
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
from tensorflow_addons.utils.types import TensorLike
from typing import Optional
from typeguard import typechecked

View File

@ -17,7 +17,7 @@
import tensorflow as tf
from tensorflow_addons.losses import metric_learning
from tensorflow.python.keras.losses import LossFunctionWrapper
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
from tensorflow_addons.utils.types import FloatTensorLike, TensorLike
from typeguard import typechecked
from typing import Optional

View File

@ -16,7 +16,7 @@
import tensorflow as tf
from typeguard import typechecked
from tensorflow.python.keras.losses import LossFunctionWrapper
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
from tensorflow_addons.utils.types import TensorLike, FloatTensorLike

View File

@ -16,7 +16,7 @@
import tensorflow as tf
from tensorflow_addons.losses import metric_learning
from tensorflow.python.keras.losses import LossFunctionWrapper
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
from tensorflow_addons.utils.types import FloatTensorLike, TensorLike
from typeguard import typechecked
from typing import Optional, Union, Callable

View File

@ -17,6 +17,57 @@
import tensorflow as tf
def is_tensor_or_variable(x):
return tf.is_tensor(x) or isinstance(x, tf.Variable)
class LossFunctionWrapper(tf.keras.losses.Loss):
"""Wraps a loss function in the `Loss` class."""
def __init__(
self, fn, reduction=tf.keras.losses.Reduction.AUTO, name=None, **kwargs
):
"""Initializes `LossFunctionWrapper` class.
Args:
fn: The loss function to wrap, with signature `fn(y_true, y_pred,
**kwargs)`.
reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
loss. Default value is `AUTO`. `AUTO` indicates that the reduction
option will be determined by the usage context. For almost all cases
this defaults to `SUM_OVER_BATCH_SIZE`. When used with
`tf.distribute.Strategy`, outside of built-in training loops such as
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
will raise an error. Please see this custom training [tutorial](
https://www.tensorflow.org/tutorials/distribute/custom_training)
for more details.
name: (Optional) name for the loss.
**kwargs: The keyword arguments that are passed on to `fn`.
"""
super().__init__(reduction=reduction, name=name)
self.fn = fn
self._fn_kwargs = kwargs
def call(self, y_true, y_pred):
"""Invokes the `LossFunctionWrapper` instance.
Args:
y_true: Ground truth values.
y_pred: The predicted values.
Returns:
Loss values per sample.
"""
return self.fn(y_true, y_pred, **self._fn_kwargs)
def get_config(self):
config = {}
for k, v in iter(self._fn_kwargs.items()):
config[k] = tf.keras.backend.eval(v) if is_tensor_or_variable(v) else v
base_config = super().get_config()
return {**base_config, **config}
def normalize_data_format(value):
if value is None:
value = tf.keras.backend.image_data_format()

View File

@ -92,12 +92,6 @@ def test_no_private_tf_api():
"tensorflow_addons/optimizers/moving_average.py",
"tensorflow_addons/metrics/r_square.py",
"tensorflow_addons/utils/test_utils.py",
"tensorflow_addons/losses/contrastive.py",
"tensorflow_addons/losses/focal_loss.py",
"tensorflow_addons/losses/lifted.py",
"tensorflow_addons/losses/quantiles.py",
"tensorflow_addons/losses/triplet.py",
"tensorflow_addons/losses/giou_loss.py",
"tensorflow_addons/seq2seq/decoder.py",
"tensorflow_addons/seq2seq/attention_wrapper.py",
]