mirror of https://github.com/tensorflow/addons.git
remove internal modules (#1812)
* remove internal modules * clean up * refact * remove six * remove losses * refact get_config
This commit is contained in:
parent
37dbb92e3b
commit
40d8e59cbf
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.keras.losses import LossFunctionWrapper
|
||||
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
|
||||
from tensorflow_addons.utils.types import TensorLike, Number
|
||||
from typeguard import typechecked
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
import tensorflow as tf
|
||||
import tensorflow.keras.backend as K
|
||||
|
||||
from tensorflow.python.keras.losses import LossFunctionWrapper
|
||||
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
|
||||
from tensorflow_addons.utils.types import FloatTensorLike, TensorLike
|
||||
from typeguard import typechecked
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
"""Implements GIoU loss."""
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras.losses import LossFunctionWrapper
|
||||
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
|
||||
from tensorflow_addons.utils.types import TensorLike
|
||||
from typing import Optional
|
||||
from typeguard import typechecked
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow_addons.losses import metric_learning
|
||||
|
||||
from tensorflow.python.keras.losses import LossFunctionWrapper
|
||||
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
|
||||
from tensorflow_addons.utils.types import FloatTensorLike, TensorLike
|
||||
from typeguard import typechecked
|
||||
from typing import Optional
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
import tensorflow as tf
|
||||
from typeguard import typechecked
|
||||
from tensorflow.python.keras.losses import LossFunctionWrapper
|
||||
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
|
||||
from tensorflow_addons.utils.types import TensorLike, FloatTensorLike
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
import tensorflow as tf
|
||||
from tensorflow_addons.losses import metric_learning
|
||||
from tensorflow.python.keras.losses import LossFunctionWrapper
|
||||
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
|
||||
from tensorflow_addons.utils.types import FloatTensorLike, TensorLike
|
||||
from typeguard import typechecked
|
||||
from typing import Optional, Union, Callable
|
||||
|
|
|
|||
|
|
@ -17,6 +17,57 @@
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def is_tensor_or_variable(x):
|
||||
return tf.is_tensor(x) or isinstance(x, tf.Variable)
|
||||
|
||||
|
||||
class LossFunctionWrapper(tf.keras.losses.Loss):
|
||||
"""Wraps a loss function in the `Loss` class."""
|
||||
|
||||
def __init__(
|
||||
self, fn, reduction=tf.keras.losses.Reduction.AUTO, name=None, **kwargs
|
||||
):
|
||||
"""Initializes `LossFunctionWrapper` class.
|
||||
|
||||
Args:
|
||||
fn: The loss function to wrap, with signature `fn(y_true, y_pred,
|
||||
**kwargs)`.
|
||||
reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
|
||||
loss. Default value is `AUTO`. `AUTO` indicates that the reduction
|
||||
option will be determined by the usage context. For almost all cases
|
||||
this defaults to `SUM_OVER_BATCH_SIZE`. When used with
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
name: (Optional) name for the loss.
|
||||
**kwargs: The keyword arguments that are passed on to `fn`.
|
||||
"""
|
||||
super().__init__(reduction=reduction, name=name)
|
||||
self.fn = fn
|
||||
self._fn_kwargs = kwargs
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
"""Invokes the `LossFunctionWrapper` instance.
|
||||
|
||||
Args:
|
||||
y_true: Ground truth values.
|
||||
y_pred: The predicted values.
|
||||
|
||||
Returns:
|
||||
Loss values per sample.
|
||||
"""
|
||||
return self.fn(y_true, y_pred, **self._fn_kwargs)
|
||||
|
||||
def get_config(self):
|
||||
config = {}
|
||||
for k, v in iter(self._fn_kwargs.items()):
|
||||
config[k] = tf.keras.backend.eval(v) if is_tensor_or_variable(v) else v
|
||||
base_config = super().get_config()
|
||||
return {**base_config, **config}
|
||||
|
||||
|
||||
def normalize_data_format(value):
|
||||
if value is None:
|
||||
value = tf.keras.backend.image_data_format()
|
||||
|
|
|
|||
|
|
@ -92,12 +92,6 @@ def test_no_private_tf_api():
|
|||
"tensorflow_addons/optimizers/moving_average.py",
|
||||
"tensorflow_addons/metrics/r_square.py",
|
||||
"tensorflow_addons/utils/test_utils.py",
|
||||
"tensorflow_addons/losses/contrastive.py",
|
||||
"tensorflow_addons/losses/focal_loss.py",
|
||||
"tensorflow_addons/losses/lifted.py",
|
||||
"tensorflow_addons/losses/quantiles.py",
|
||||
"tensorflow_addons/losses/triplet.py",
|
||||
"tensorflow_addons/losses/giou_loss.py",
|
||||
"tensorflow_addons/seq2seq/decoder.py",
|
||||
"tensorflow_addons/seq2seq/attention_wrapper.py",
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue