mirror of https://github.com/kubeflow/examples.git
278 lines
9.0 KiB
Python
278 lines
9.0 KiB
Python
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Provides an entrypoint for the training task."""
|
|
|
|
#pylint: disable=unused-import
|
|
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
import datetime
|
|
import logging
|
|
import os
|
|
import pprint
|
|
import uuid
|
|
|
|
from google.cloud import storage
|
|
import tensorflow as tf
|
|
|
|
import agents
|
|
import pybullet_envs # To make AntBulletEnv-v0 available.
|
|
|
|
flags = tf.app.flags
|
|
|
|
flags.DEFINE_string("run_mode", "train",
|
|
"Run mode, one of [train, render, train_and_render].")
|
|
flags.DEFINE_string("logdir", '/tmp/test',
|
|
"The base directory in which to write logs and "
|
|
"checkpoints.")
|
|
flags.DEFINE_string("hparam_set_id", "pybullet_kuka_ff",
|
|
"The name of the config object to be used to parameterize "
|
|
"the run.")
|
|
flags.DEFINE_string("run_base_tag",
|
|
datetime.datetime.now().strftime('%Y%m%dT%H%M%S'),
|
|
"Base tag to prepend to logs dir folder name. Defaults "
|
|
"to timestamp.")
|
|
flags.DEFINE_boolean("env_processes", True,
|
|
"Step environments in separate processes to circumvent "
|
|
"the GIL.")
|
|
flags.DEFINE_integer("num_gpus", 0,
|
|
"Total number of gpus for each machine."
|
|
"If you don't use GPU, please set it to '0'")
|
|
flags.DEFINE_integer("save_checkpoint_secs", 600,
|
|
"Number of seconds between checkpoint save.")
|
|
flags.DEFINE_boolean("log_device_placement", False,
|
|
"Whether to output logs listing the devices on which "
|
|
"variables are placed.")
|
|
flags.DEFINE_boolean("debug", True,
|
|
"Run in debug mode.")
|
|
|
|
# Render
|
|
flags.DEFINE_integer("render_secs", 600,
|
|
"Number of seconds between triggering render jobs.")
|
|
flags.DEFINE_string("render_out_dir", None,
|
|
"The path to which to copy generated renders.")
|
|
|
|
# Algorithm
|
|
flags.DEFINE_string("algorithm", "agents.ppo.PPOAlgorithm",
|
|
"The name of the algorithm to use.")
|
|
flags.DEFINE_integer("num_agents", 30,
|
|
"The number of agents to use.")
|
|
flags.DEFINE_integer("eval_episodes", 25,
|
|
"The number of eval episodes to use.")
|
|
flags.DEFINE_string("env", "AntBulletEnv-v0",
|
|
"The gym / bullet simulation environment to use.")
|
|
flags.DEFINE_integer("max_length", 1000,
|
|
"The maximum length of an episode.")
|
|
flags.DEFINE_integer("steps", 10000000,
|
|
"The number of steps.")
|
|
|
|
# Network
|
|
flags.DEFINE_string("network", "agents.scripts.networks.feed_forward_gaussian",
|
|
"The registered network name to use for policy and value.")
|
|
flags.DEFINE_float("init_mean_factor", 0.1,
|
|
"")
|
|
flags.DEFINE_float("init_std", 0.35,
|
|
"")
|
|
|
|
# Optimization
|
|
flags.DEFINE_float("learning_rate", 1e-4,
|
|
"The learning rate of the optimizer.")
|
|
flags.DEFINE_string("optimizer", "tensorflow.train.AdamOptimizer",
|
|
"The import path of the optimizer to use.")
|
|
flags.DEFINE_integer("update_epochs", 25,
|
|
"The number of update epochs.")
|
|
flags.DEFINE_integer("update_every", 60,
|
|
"The update frequency.")
|
|
|
|
# Losses
|
|
flags.DEFINE_float("discount", 0.995,
|
|
"The discount.")
|
|
flags.DEFINE_float("kl_target", 1e-2,
|
|
"the KL target.")
|
|
flags.DEFINE_integer("kl_cutoff_factor", 2,
|
|
"The KL cutoff factor.")
|
|
flags.DEFINE_integer("kl_cutoff_coef", 1000,
|
|
"The KL cutoff coefficient.")
|
|
flags.DEFINE_integer("kl_init_penalty", 1,
|
|
"The initial KL penalty?.")
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
def hparams_base():
|
|
"""Base hparams tf/Agents PPO """
|
|
|
|
# General
|
|
# algorithm = agents.ppo.PPOAlgorithm
|
|
# num_agents = 30
|
|
# eval_episodes = 30
|
|
# use_gpu = False
|
|
|
|
# Environment
|
|
# env = 'KukaBulletEnv-v0'
|
|
# normalize_ranges = True
|
|
# max_length = 1000
|
|
|
|
# Network
|
|
# network = agents.scripts.networks.feed_forward_gaussian
|
|
# weight_summaries = dict(
|
|
# all=r'.*', policy=r'.*/policy/.*', value=r'.*/value/.*')
|
|
# policy_layers = 200, 100
|
|
# value_layers = 200, 100
|
|
# init_output_factor = 0.1
|
|
# init_logstd = -1
|
|
# init_std = 0.35
|
|
|
|
# Optimization
|
|
# update_every = 60
|
|
# update_epochs = 25
|
|
# optimizer = tf.train.AdamOptimizer
|
|
# learning_rate = 1e-4
|
|
# steps = 3e7 # 30M
|
|
|
|
# Losses
|
|
# discount = 0.995
|
|
# kl_target = 1e-2
|
|
# kl_cutoff_factor = 2
|
|
# kl_cutoff_coef = 1000
|
|
# kl_init_penalty = 1
|
|
|
|
return locals()
|
|
|
|
|
|
def _object_import_from_string(name):
|
|
components = name.split('.')
|
|
mod = __import__(components[0])
|
|
for comp in components[1:]:
|
|
mod = getattr(mod, comp)
|
|
return mod
|
|
|
|
|
|
def _realize_import_attrs(d, hparam_filter):
|
|
for k, v in d.items():
|
|
if k in hparam_filter:
|
|
imported = _object_import_from_string(v)
|
|
# TODO: Provide an appropriately informative error if the import fails
|
|
# except ImportError as e:
|
|
# msg = ("Failed to realize import path %s." % v)
|
|
# raise e
|
|
d[k] = imported
|
|
return d
|
|
|
|
|
|
def _get_agents_configuration(log_dir=None):
|
|
"""Load hyperparameter config."""
|
|
try:
|
|
# Try to resume training.
|
|
hparams = agents.scripts.utility.load_config(log_dir)
|
|
except IOError:
|
|
|
|
hparams = hparams_base()
|
|
|
|
# --------
|
|
# Experiment extending base hparams with FLAGS and dynamic import of
|
|
# network and algorithm.
|
|
for k, v in FLAGS.__dict__['__flags'].items():
|
|
hparams[k] = v
|
|
hparams = _realize_import_attrs(
|
|
hparams, ["network", "algorithm", "optimizer"])
|
|
# --------
|
|
|
|
hparams = agents.tools.AttrDict(hparams)
|
|
hparams = agents.scripts.utility.save_config(hparams, log_dir)
|
|
|
|
pprint.pprint(hparams)
|
|
return hparams
|
|
|
|
|
|
def gcs_upload(local_dir, gcs_out_dir):
|
|
"""Upload the contents of a local directory to a specific GCS path.
|
|
|
|
Args:
|
|
local_dir (str): The local directory containing files to upload.
|
|
gcs_out_dir (str): The target Google Cloud Storage directory path.
|
|
|
|
Raises:
|
|
ValueError: If `gcs_out_dir` does not start with "gs://".
|
|
|
|
"""
|
|
|
|
# Get a list of all files in the local_dir
|
|
local_files = [f for f in os.listdir(
|
|
local_dir) if os.path.isfile(os.path.join(local_dir, f))]
|
|
tf.logging.info("Preparing local files for upload:\n %s" % local_files)
|
|
|
|
# Initialize the GCS API client
|
|
storage_client = storage.Client()
|
|
|
|
# Raise an error if the target directory cannot be a GCS path
|
|
if not gcs_out_dir.startswith("gs://"):
|
|
raise ValueError(
|
|
"gcs_upload expected gcs_out_dir argument to start with gs://, saw %s" % gcs_out_dir)
|
|
|
|
# TODO: Detect and handle case where a GCS path has been provdied
|
|
# corresponding to a bucket that does not exist or for which the user does
|
|
# not have permissions.
|
|
|
|
# Obtain the bucket path from the total path
|
|
bucket_path = gcs_out_dir.split('/')[2]
|
|
bucket = storage_client.get_bucket(bucket_path)
|
|
|
|
# Construct a target upload path that excludes the initial gs://bucket-name
|
|
blob_base_path = '/'.join(gcs_out_dir.split('/')[3:])
|
|
|
|
# For each local file *name* in the list of local file names
|
|
for local_filename in local_files:
|
|
|
|
# Construct the target and local *paths*
|
|
blob_path = os.path.join(blob_base_path, local_filename)
|
|
blob = bucket.blob(blob_path)
|
|
local_file_path = os.path.join(local_dir, local_filename)
|
|
|
|
# Perform the upload operation
|
|
blob.upload_from_filename(local_file_path)
|
|
|
|
|
|
def main(_):
|
|
"""Run training."""
|
|
tf.logging.set_verbosity(tf.logging.INFO)
|
|
|
|
if FLAGS.debug:
|
|
tf.logging.set_verbosity(tf.logging.DEBUG)
|
|
|
|
log_dir = FLAGS.logdir
|
|
|
|
agents_config = _get_agents_configuration(log_dir)
|
|
|
|
if FLAGS.run_mode == 'train':
|
|
for score in agents.scripts.train.train(agents_config, env_processes=True):
|
|
logging.info('Score %s.', score)
|
|
if FLAGS.run_mode == 'render':
|
|
now = datetime.datetime.now()
|
|
subdir = now.strftime("%m%d-%H%M") + "-" + uuid.uuid4().hex[0:4]
|
|
render_tmp_dir = "/tmp/agents-render/"
|
|
os.system('mkdir -p %s' % render_tmp_dir)
|
|
agents.scripts.visualize.visualize(
|
|
logdir=FLAGS.logdir, outdir=render_tmp_dir, num_agents=1, num_episodes=1,
|
|
checkpoint=None, env_processes=True)
|
|
render_out_dir = FLAGS.render_out_dir
|
|
# Unless a render out dir is specified explicitly upload to a unique subdir
|
|
# of the log dir with the parent render/
|
|
if render_out_dir is None:
|
|
render_out_dir = os.path.join(FLAGS.logdir, "render", subdir)
|
|
gcs_upload(render_tmp_dir, render_out_dir)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
tf.app.run()
|