mirror of https://github.com/tensorflow/models.git
468 lines
16 KiB
Python
468 lines
16 KiB
Python
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Context for Universal Value Function agents.
|
|
|
|
A context specifies a list of contextual variables, each with
|
|
own sampling and reward computation methods.
|
|
|
|
Examples of contextual variables include
|
|
goal states, reward combination vectors, etc.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from tf_agents import specs
|
|
import gin.tf
|
|
from utils import utils as uvf_utils
|
|
|
|
|
|
@gin.configurable
|
|
class Context(object):
|
|
"""Base context."""
|
|
VAR_NAME = 'action'
|
|
|
|
def __init__(self,
|
|
tf_env,
|
|
context_ranges=None,
|
|
context_shapes=None,
|
|
state_indices=None,
|
|
variable_indices=None,
|
|
gamma_index=None,
|
|
settable_context=False,
|
|
timers=None,
|
|
samplers=None,
|
|
reward_weights=None,
|
|
reward_fn=None,
|
|
random_sampler_mode='random',
|
|
normalizers=None,
|
|
context_transition_fn=None,
|
|
context_multi_transition_fn=None,
|
|
meta_action_every_n=None):
|
|
self._tf_env = tf_env
|
|
self.variable_indices = variable_indices
|
|
self.gamma_index = gamma_index
|
|
self._settable_context = settable_context
|
|
self.timers = timers
|
|
self._context_transition_fn = context_transition_fn
|
|
self._context_multi_transition_fn = context_multi_transition_fn
|
|
self._random_sampler_mode = random_sampler_mode
|
|
|
|
# assign specs
|
|
self._obs_spec = self._tf_env.observation_spec()
|
|
self._context_shapes = tuple([
|
|
shape if shape is not None else self._obs_spec.shape
|
|
for shape in context_shapes
|
|
])
|
|
self.context_specs = tuple([
|
|
specs.TensorSpec(dtype=self._obs_spec.dtype, shape=shape)
|
|
for shape in self._context_shapes
|
|
])
|
|
if context_ranges is not None:
|
|
self.context_ranges = context_ranges
|
|
else:
|
|
self.context_ranges = [None] * len(self._context_shapes)
|
|
|
|
self.context_as_action_specs = tuple([
|
|
specs.BoundedTensorSpec(
|
|
shape=shape,
|
|
dtype=(tf.float32 if self._obs_spec.dtype in
|
|
[tf.float32, tf.float64] else self._obs_spec.dtype),
|
|
minimum=context_range[0],
|
|
maximum=context_range[-1])
|
|
for shape, context_range in zip(self._context_shapes, self.context_ranges)
|
|
])
|
|
|
|
if state_indices is not None:
|
|
self.state_indices = state_indices
|
|
else:
|
|
self.state_indices = [None] * len(self._context_shapes)
|
|
if self.variable_indices is not None and self.n != len(
|
|
self.variable_indices):
|
|
raise ValueError(
|
|
'variable_indices (%s) must have the same length as contexts (%s).' %
|
|
(self.variable_indices, self.context_specs))
|
|
assert self.n == len(self.context_ranges)
|
|
assert self.n == len(self.state_indices)
|
|
|
|
# assign reward/sampler fns
|
|
self._sampler_fns = dict()
|
|
self._samplers = dict()
|
|
self._reward_fns = dict()
|
|
|
|
# assign reward fns
|
|
self._add_custom_reward_fns()
|
|
reward_weights = reward_weights or None
|
|
self._reward_fn = self._make_reward_fn(reward_fn, reward_weights)
|
|
|
|
# assign samplers
|
|
self._add_custom_sampler_fns()
|
|
for mode, sampler_fns in samplers.items():
|
|
self._make_sampler_fn(sampler_fns, mode)
|
|
|
|
# create normalizers
|
|
if normalizers is None:
|
|
self._normalizers = [None] * len(self.context_specs)
|
|
else:
|
|
self._normalizers = [
|
|
normalizer(tf.zeros(shape=spec.shape, dtype=spec.dtype))
|
|
if normalizer is not None else None
|
|
for normalizer, spec in zip(normalizers, self.context_specs)
|
|
]
|
|
assert self.n == len(self._normalizers)
|
|
|
|
self.meta_action_every_n = meta_action_every_n
|
|
|
|
# create vars
|
|
self.context_vars = {}
|
|
self.timer_vars = {}
|
|
self.create_vars(self.VAR_NAME)
|
|
self.t = tf.Variable(
|
|
tf.zeros(shape=(), dtype=tf.int32), name='num_timer_steps')
|
|
|
|
def _add_custom_reward_fns(self):
|
|
pass
|
|
|
|
def _add_custom_sampler_fns(self):
|
|
pass
|
|
|
|
def sample_random_contexts(self, batch_size):
|
|
"""Sample random batch contexts."""
|
|
assert self._random_sampler_mode is not None
|
|
return self.sample_contexts(self._random_sampler_mode, batch_size)[0]
|
|
|
|
def sample_contexts(self, mode, batch_size, state=None, next_state=None,
|
|
**kwargs):
|
|
"""Sample a batch of contexts.
|
|
|
|
Args:
|
|
mode: A string representing the mode [`train`, `explore`, `eval`].
|
|
batch_size: Batch size.
|
|
Returns:
|
|
Two lists of [batch_size, num_context_dims] contexts.
|
|
"""
|
|
contexts, next_contexts = self._sampler_fns[mode](
|
|
batch_size, state=state, next_state=next_state,
|
|
**kwargs)
|
|
self._validate_contexts(contexts)
|
|
self._validate_contexts(next_contexts)
|
|
return contexts, next_contexts
|
|
|
|
def compute_rewards(self, mode, states, actions, rewards, next_states,
|
|
contexts):
|
|
"""Compute context-based rewards.
|
|
|
|
Args:
|
|
mode: A string representing the mode ['uvf', 'task'].
|
|
states: A [batch_size, num_state_dims] tensor.
|
|
actions: A [batch_size, num_action_dims] tensor.
|
|
rewards: A [batch_size] tensor representing unmodified rewards.
|
|
next_states: A [batch_size, num_state_dims] tensor.
|
|
contexts: A list of [batch_size, num_context_dims] tensors.
|
|
Returns:
|
|
A [batch_size] tensor representing rewards.
|
|
"""
|
|
return self._reward_fn(states, actions, rewards, next_states,
|
|
contexts)
|
|
|
|
def _make_reward_fn(self, reward_fns_list, reward_weights):
|
|
"""Returns a fn that computes rewards.
|
|
|
|
Args:
|
|
reward_fns_list: A fn or a list of reward fns.
|
|
mode: A string representing the operating mode.
|
|
reward_weights: A list of reward weights.
|
|
"""
|
|
if not isinstance(reward_fns_list, (list, tuple)):
|
|
reward_fns_list = [reward_fns_list]
|
|
if reward_weights is None:
|
|
reward_weights = [1.0] * len(reward_fns_list)
|
|
assert len(reward_fns_list) == len(reward_weights)
|
|
|
|
reward_fns_list = [
|
|
self._custom_reward_fns[fn] if isinstance(fn, (str,)) else fn
|
|
for fn in reward_fns_list
|
|
]
|
|
|
|
def reward_fn(*args, **kwargs):
|
|
"""Returns rewards, discounts."""
|
|
reward_tuples = [
|
|
reward_fn(*args, **kwargs) for reward_fn in reward_fns_list
|
|
]
|
|
rewards_list = [reward_tuple[0] for reward_tuple in reward_tuples]
|
|
discounts_list = [reward_tuple[1] for reward_tuple in reward_tuples]
|
|
ndims = max([r.shape.ndims for r in rewards_list])
|
|
if ndims > 1: # expand reward shapes to allow broadcasting
|
|
for i in range(len(rewards_list)):
|
|
for _ in range(rewards_list[i].shape.ndims - ndims):
|
|
rewards_list[i] = tf.expand_dims(rewards_list[i], axis=-1)
|
|
for _ in range(discounts_list[i].shape.ndims - ndims):
|
|
discounts_list[i] = tf.expand_dims(discounts_list[i], axis=-1)
|
|
rewards = tf.add_n(
|
|
[r * tf.to_float(w) for r, w in zip(rewards_list, reward_weights)])
|
|
discounts = discounts_list[0]
|
|
for d in discounts_list[1:]:
|
|
discounts *= d
|
|
|
|
return rewards, discounts
|
|
|
|
return reward_fn
|
|
|
|
def _make_sampler_fn(self, sampler_cls_list, mode):
|
|
"""Returns a fn that samples a list of context vars.
|
|
|
|
Args:
|
|
sampler_cls_list: A list of sampler classes.
|
|
mode: A string representing the operating mode.
|
|
"""
|
|
if not isinstance(sampler_cls_list, (list, tuple)):
|
|
sampler_cls_list = [sampler_cls_list]
|
|
|
|
self._samplers[mode] = []
|
|
sampler_fns = []
|
|
for spec, sampler in zip(self.context_specs, sampler_cls_list):
|
|
if isinstance(sampler, (str,)):
|
|
sampler_fn = self._custom_sampler_fns[sampler]
|
|
else:
|
|
sampler_fn = sampler(context_spec=spec)
|
|
self._samplers[mode].append(sampler_fn)
|
|
sampler_fns.append(sampler_fn)
|
|
|
|
def batch_sampler_fn(batch_size, state=None, next_state=None, **kwargs):
|
|
"""Sampler fn."""
|
|
contexts_tuples = [
|
|
sampler(batch_size, state=state, next_state=next_state, **kwargs)
|
|
for sampler in sampler_fns]
|
|
contexts = [c[0] for c in contexts_tuples]
|
|
next_contexts = [c[1] for c in contexts_tuples]
|
|
contexts = [
|
|
normalizer.update_apply(c) if normalizer is not None else c
|
|
for normalizer, c in zip(self._normalizers, contexts)
|
|
]
|
|
next_contexts = [
|
|
normalizer.apply(c) if normalizer is not None else c
|
|
for normalizer, c in zip(self._normalizers, next_contexts)
|
|
]
|
|
return contexts, next_contexts
|
|
|
|
self._sampler_fns[mode] = batch_sampler_fn
|
|
|
|
def set_env_context_op(self, context, disable_unnormalizer=False):
|
|
"""Returns a TensorFlow op that sets the environment context.
|
|
|
|
Args:
|
|
context: A list of context Tensor variables.
|
|
disable_unnormalizer: Disable unnormalization.
|
|
Returns:
|
|
A TensorFlow op that sets the environment context.
|
|
"""
|
|
ret_val = np.array(1.0, dtype=np.float32)
|
|
if not self._settable_context:
|
|
return tf.identity(ret_val)
|
|
|
|
if not disable_unnormalizer:
|
|
context = [
|
|
normalizer.unapply(tf.expand_dims(c, 0))[0]
|
|
if normalizer is not None else c
|
|
for normalizer, c in zip(self._normalizers, context)
|
|
]
|
|
|
|
def set_context_func(*env_context_values):
|
|
tf.logging.info('[set_env_context_op] Setting gym environment context.')
|
|
# pylint: disable=protected-access
|
|
self.gym_env.set_context(*env_context_values)
|
|
return ret_val
|
|
# pylint: enable=protected-access
|
|
|
|
with tf.name_scope('set_env_context'):
|
|
set_op = tf.py_func(set_context_func, context, tf.float32,
|
|
name='set_env_context_py_func')
|
|
set_op.set_shape([])
|
|
return set_op
|
|
|
|
def set_replay(self, replay):
|
|
"""Set replay buffer for samplers.
|
|
|
|
Args:
|
|
replay: A replay buffer.
|
|
"""
|
|
for _, samplers in self._samplers.items():
|
|
for sampler in samplers:
|
|
sampler.set_replay(replay)
|
|
|
|
def get_clip_fns(self):
|
|
"""Returns a list of clip fns for contexts.
|
|
|
|
Returns:
|
|
A list of fns that clip context tensors.
|
|
"""
|
|
clip_fns = []
|
|
for context_range in self.context_ranges:
|
|
def clip_fn(var_, range_=context_range):
|
|
"""Clip a tensor."""
|
|
if range_ is None:
|
|
clipped_var = tf.identity(var_)
|
|
elif isinstance(range_[0], (int, long, float, list, np.ndarray)):
|
|
clipped_var = tf.clip_by_value(
|
|
var_,
|
|
range_[0],
|
|
range_[1],)
|
|
else: raise NotImplementedError(range_)
|
|
return clipped_var
|
|
clip_fns.append(clip_fn)
|
|
return clip_fns
|
|
|
|
def _validate_contexts(self, contexts):
|
|
"""Validate if contexts have right specs.
|
|
|
|
Args:
|
|
contexts: A list of [batch_size, num_context_dim] tensors.
|
|
Raises:
|
|
ValueError: If shape or dtype mismatches that of spec.
|
|
"""
|
|
for i, (context, spec) in enumerate(zip(contexts, self.context_specs)):
|
|
if context[0].shape != spec.shape:
|
|
raise ValueError('contexts[%d] has invalid shape %s wrt spec shape %s' %
|
|
(i, context[0].shape, spec.shape))
|
|
if context.dtype != spec.dtype:
|
|
raise ValueError('contexts[%d] has invalid dtype %s wrt spec dtype %s' %
|
|
(i, context.dtype, spec.dtype))
|
|
|
|
def context_multi_transition_fn(self, contexts, **kwargs):
|
|
"""Returns multiple future contexts starting from a batch."""
|
|
assert self._context_multi_transition_fn
|
|
return self._context_multi_transition_fn(contexts, None, None, **kwargs)
|
|
|
|
def step(self, mode, agent=None, action_fn=None, **kwargs):
|
|
"""Returns [next_contexts..., next_timer] list of ops.
|
|
|
|
Args:
|
|
mode: a string representing the mode=[train, explore, eval].
|
|
**kwargs: kwargs for context_transition_fn.
|
|
Returns:
|
|
a list of ops that set the context.
|
|
"""
|
|
if agent is None:
|
|
ops = []
|
|
if self._context_transition_fn is not None:
|
|
def sampler_fn():
|
|
samples = self.sample_contexts(mode, 1)[0]
|
|
return [s[0] for s in samples]
|
|
values = self._context_transition_fn(self.vars, self.t, sampler_fn, **kwargs)
|
|
ops += [tf.assign(var, value) for var, value in zip(self.vars, values)]
|
|
ops.append(tf.assign_add(self.t, 1)) # increment timer
|
|
return ops
|
|
else:
|
|
ops = agent.tf_context.step(mode, **kwargs)
|
|
state = kwargs['state']
|
|
next_state = kwargs['next_state']
|
|
state_repr = kwargs['state_repr']
|
|
next_state_repr = kwargs['next_state_repr']
|
|
with tf.control_dependencies(ops): # Step high level context before computing low level one.
|
|
# Get the context transition function output.
|
|
values = self._context_transition_fn(self.vars, self.t, None,
|
|
state=state_repr,
|
|
next_state=next_state_repr)
|
|
# Select a new goal every C steps, otherwise use context transition.
|
|
low_level_context = [
|
|
tf.cond(tf.equal(self.t % self.meta_action_every_n, 0),
|
|
lambda: tf.cast(action_fn(next_state, context=None), tf.float32),
|
|
lambda: values)]
|
|
ops = [tf.assign(var, value)
|
|
for var, value in zip(self.vars, low_level_context)]
|
|
with tf.control_dependencies(ops):
|
|
return [tf.assign_add(self.t, 1)] # increment timer
|
|
return ops
|
|
|
|
def reset(self, mode, agent=None, action_fn=None, state=None):
|
|
"""Returns ops that reset the context.
|
|
|
|
Args:
|
|
mode: a string representing the mode=[train, explore, eval].
|
|
Returns:
|
|
a list of ops that reset the context.
|
|
"""
|
|
if agent is None:
|
|
values = self.sample_contexts(mode=mode, batch_size=1)[0]
|
|
if values is None:
|
|
return []
|
|
values = [value[0] for value in values]
|
|
values[0] = uvf_utils.tf_print(
|
|
values[0],
|
|
values,
|
|
message='context:reset, mode=%s' % mode,
|
|
first_n=10,
|
|
name='context:reset:%s' % mode)
|
|
all_ops = []
|
|
for _, context_vars in sorted(self.context_vars.items()):
|
|
ops = [tf.assign(var, value) for var, value in zip(context_vars, values)]
|
|
all_ops += ops
|
|
all_ops.append(self.set_env_context_op(values))
|
|
all_ops.append(tf.assign(self.t, 0)) # reset timer
|
|
return all_ops
|
|
else:
|
|
ops = agent.tf_context.reset(mode)
|
|
# NOTE: The code is currently written in such a way that the higher level
|
|
# policy does not provide a low-level context until the second
|
|
# observation. Insead, we just zero-out low-level contexts.
|
|
for key, context_vars in sorted(self.context_vars.items()):
|
|
ops += [tf.assign(var, tf.zeros_like(var)) for var, meta_var in
|
|
zip(context_vars, agent.tf_context.context_vars[key])]
|
|
|
|
ops.append(tf.assign(self.t, 0)) # reset timer
|
|
return ops
|
|
|
|
def create_vars(self, name, agent=None):
|
|
"""Create tf variables for contexts.
|
|
|
|
Args:
|
|
name: Name of the variables.
|
|
Returns:
|
|
A list of [num_context_dims] tensors.
|
|
"""
|
|
if agent is not None:
|
|
meta_vars = agent.create_vars(name)
|
|
else:
|
|
meta_vars = {}
|
|
assert name not in self.context_vars, ('Conflict! %s is already '
|
|
'initialized.') % name
|
|
self.context_vars[name] = tuple([
|
|
tf.Variable(
|
|
tf.zeros(shape=spec.shape, dtype=spec.dtype),
|
|
name='%s_context_%d' % (name, i))
|
|
for i, spec in enumerate(self.context_specs)
|
|
])
|
|
return self.context_vars[name], meta_vars
|
|
|
|
@property
|
|
def n(self):
|
|
return len(self.context_specs)
|
|
|
|
@property
|
|
def vars(self):
|
|
return self.context_vars[self.VAR_NAME]
|
|
|
|
# pylint: disable=protected-access
|
|
@property
|
|
def gym_env(self):
|
|
return self._tf_env.pyenv._gym_env
|
|
|
|
@property
|
|
def tf_env(self):
|
|
return self._tf_env
|
|
# pylint: enable=protected-access
|