addons/tensorflow_addons/optimizers/average_wrapper.py

187 lines
6.6 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import abc
import warnings
import tensorflow as tf
from tensorflow_addons.optimizers import KerasLegacyOptimizer
from tensorflow_addons.utils import types
from typeguard import typechecked
class AveragedOptimizerWrapper(KerasLegacyOptimizer, metaclass=abc.ABCMeta):
@typechecked
def __init__(
self,
optimizer: types.Optimizer,
name: str = "AverageOptimizer",
**kwargs,
):
super().__init__(name, **kwargs)
if isinstance(optimizer, str):
if (
hasattr(tf.keras.optimizers, "legacy")
and KerasLegacyOptimizer == tf.keras.optimizers.legacy.Optimizer
):
optimizer = tf.keras.optimizers.get(
optimizer, use_legacy_optimizer=True
)
else:
optimizer = tf.keras.optimizers.get(optimizer)
if not isinstance(optimizer, KerasLegacyOptimizer):
raise TypeError(
"optimizer is not an object of tf.keras.optimizers.legacy.Optimizer "
)
self._optimizer = optimizer
self._track_trackable(self._optimizer, "awg_optimizer")
def _create_slots(self, var_list):
self._optimizer._create_slots(var_list=var_list)
for var in var_list:
self.add_slot(var, "average")
def _create_hypers(self):
self._optimizer._create_hypers()
def _prepare_local(self, var_device, var_dtype, apply_state):
return self._optimizer._prepare_local(var_device, var_dtype, apply_state)
def apply_gradients(self, grads_and_vars, name=None, **kwargs):
self._optimizer._iterations = self.iterations
return super().apply_gradients(grads_and_vars, name, **kwargs)
@abc.abstractmethod
def average_op(self, var, average_var, local_apply_state):
raise NotImplementedError
def _apply_average_op(self, train_op, var, apply_state):
apply_state = apply_state or {}
local_apply_state = apply_state.get((var.device, var.dtype.base_dtype))
if local_apply_state is None:
local_apply_state = self._fallback_apply_state(
var.device, var.dtype.base_dtype
)
average_var = self.get_slot(var, "average")
return self.average_op(var, average_var, local_apply_state)
def _resource_apply_dense(self, grad, var, apply_state=None):
if "apply_state" in self._optimizer._dense_apply_args:
train_op = self._optimizer._resource_apply_dense(
grad, var, apply_state=apply_state
)
else:
train_op = self._optimizer._resource_apply_dense(grad, var)
average_op = self._apply_average_op(train_op, var, apply_state)
return tf.group(train_op, average_op)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
if "apply_state" in self._optimizer._sparse_apply_args:
train_op = self._optimizer._resource_apply_sparse(
grad, var, indices, apply_state=apply_state
)
else:
train_op = self._optimizer._resource_apply_sparse(grad, var, indices)
average_op = self._apply_average_op(train_op, var, apply_state)
return tf.group(train_op, average_op)
def _resource_apply_sparse_duplicate_indices(
self, grad, var, indices, apply_state=None
):
if "apply_state" in self._optimizer._sparse_apply_args:
train_op = self._optimizer._resource_apply_sparse_duplicate_indices(
grad, var, indices, apply_state=apply_state
)
else:
train_op = self._optimizer._resource_apply_sparse_duplicate_indices(
grad, var, indices
)
average_op = self._apply_average_op(train_op, var, apply_state)
return tf.group(train_op, average_op)
def assign_average_vars(self, var_list):
"""Assign variables in var_list with their respective averages.
Args:
var_list: List of model variables to be assigned to their average.
Returns:
assign_op: The op corresponding to the assignment operation of
variables to their average.
Example:
```python
model = tf.Sequential([...])
opt = tfa.optimizers.SWA(
tf.keras.optimizers.SGD(lr=2.0), 100, 10)
model.compile(opt, ...)
model.fit(x, y, ...)
# Update the weights to their mean before saving
opt.assign_average_vars(model.variables)
model.save('model.h5')
```
"""
assign_ops = []
for var in var_list:
try:
assign_ops.append(
var.assign(
self.get_slot(var, "average"),
use_locking=self._use_locking,
)
)
except Exception as e:
warnings.warn("Unable to assign average slot to {} : {}".format(var, e))
return tf.group(assign_ops)
def get_config(self):
config = {
"optimizer": tf.keras.optimizers.serialize(self._optimizer),
}
base_config = super().get_config()
return {**base_config, **config}
@classmethod
def from_config(cls, config, custom_objects=None):
optimizer = tf.keras.optimizers.deserialize(
config.pop("optimizer"), custom_objects=custom_objects
)
return cls(optimizer, **config)
@property
def weights(self):
return self._weights + self._optimizer.weights
@property
def lr(self):
return self._optimizer._get_hyper("learning_rate")
@lr.setter
def lr(self, lr):
self._optimizer._set_hyper("learning_rate", lr) #
@property
def learning_rate(self):
return self._optimizer._get_hyper("learning_rate")
@learning_rate.setter
def learning_rate(self, learning_rate):
self._optimizer._set_hyper("learning_rate", learning_rate)