mirror of https://github.com/kubeflow/examples.git
Distributed training using tensor2tensor (#86)
* Distributed training using tensor2tensor * Use a transformer model to train the github issue summarization problem * Dockerfile for building training image * ksonnet component for deploying tfjob Fixes https://github.com/kubeflow/examples/issues/43 * Fix lint issues
This commit is contained in:
parent
12b00f2921
commit
6cf382f597
|
|
@ -30,6 +30,7 @@ By the end of this tutorial, you should learn how to:
|
|||
1. Training the model. You can train the model either using Jupyter Notebook or using TFJob.
|
||||
1. [Training the model using a Jupyter Notebook](training_the_model.md)
|
||||
1. [Training the model using TFJob](training_the_model_tfjob.md)
|
||||
1. [Distributed Training using tensor2tensor and TFJob](tensor2tensor_training.md)
|
||||
1. [Serving the model](serving_the_model.md)
|
||||
1. [Querying the model](querying_the_model.md)
|
||||
1. [Teardown](teardown.md)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,9 @@
|
|||
// Each object below should correspond to a component in the components/ directory
|
||||
tfjob: {
|
||||
|
||||
},
|
||||
tensor2tensor: {
|
||||
|
||||
},
|
||||
ui: {
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
local env = std.extVar("__ksonnet/environments");
|
||||
local params = std.extVar("__ksonnet/params").components.tensor2tensor;
|
||||
local k = import "k.libsonnet";
|
||||
|
||||
local tensor2tensor = import "tensor2tensor.libsonnet";
|
||||
|
||||
std.prune(k.core.v1.list.new([tensor2tensor.parts(params)]))
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
{
|
||||
parts(params):: {
|
||||
apiVersion: "kubeflow.org/v1alpha1",
|
||||
kind: "TFJob",
|
||||
metadata: {
|
||||
name: "tensor2tensor",
|
||||
namespace: params.namespace,
|
||||
},
|
||||
spec: {
|
||||
replicaSpecs: [
|
||||
{
|
||||
replicas: 1,
|
||||
template: {
|
||||
spec: {
|
||||
containers: [
|
||||
{
|
||||
image: params.image,
|
||||
name: "tensorflow",
|
||||
command: [
|
||||
"bash",
|
||||
],
|
||||
args: [
|
||||
"/home/jovyan/train_dist_launcher.sh",
|
||||
"1",
|
||||
params.workers,
|
||||
"0",
|
||||
params.train_steps,
|
||||
"/job:master",
|
||||
"False",
|
||||
],
|
||||
},
|
||||
],
|
||||
restartPolicy: "OnFailure",
|
||||
},
|
||||
},
|
||||
tfReplicaType: "MASTER",
|
||||
},
|
||||
{
|
||||
replicas: params.workers,
|
||||
template: {
|
||||
spec: {
|
||||
containers: [
|
||||
{
|
||||
image: params.image,
|
||||
name: "tensorflow",
|
||||
command: [
|
||||
"bash",
|
||||
],
|
||||
args: [
|
||||
"/home/jovyan/train_dist_launcher.sh",
|
||||
"1",
|
||||
params.workers,
|
||||
"0",
|
||||
params.train_steps,
|
||||
"/job:master",
|
||||
"False",
|
||||
],
|
||||
},
|
||||
],
|
||||
restartPolicy: "OnFailure",
|
||||
},
|
||||
},
|
||||
tfReplicaType: "WORKER",
|
||||
},
|
||||
{
|
||||
replicas: 1,
|
||||
template: {
|
||||
spec: {
|
||||
containers: [
|
||||
{
|
||||
image: params.image,
|
||||
name: "tensorflow",
|
||||
command: [
|
||||
"bash",
|
||||
],
|
||||
args: [
|
||||
"/home/jovyan/ps_dist_launcher.sh",
|
||||
],
|
||||
},
|
||||
],
|
||||
restartPolicy: "OnFailure",
|
||||
},
|
||||
},
|
||||
tfReplicaType: "PS",
|
||||
},
|
||||
],
|
||||
terminationPolicy: {
|
||||
chief: {
|
||||
replicaIndex: 0,
|
||||
replicaName: "MASTER",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
FROM gcr.io/kubeflow-images-staging/tensorflow-1.7.0-notebook-gpu:latest
|
||||
|
||||
USER root
|
||||
|
||||
RUN pip install tensor2tensor && \
|
||||
apt-get install -y jq
|
||||
|
||||
COPY __init__.py github/__init__.py
|
||||
COPY github_problem.py github/github_problem.py
|
||||
COPY ps_dist_launcher.sh github/ps_dist_launcher.sh
|
||||
COPY train_dist_launcher.sh github/train_dist_launcher.sh
|
||||
|
||||
RUN chown -R jovyan:users /home/jovyan/github
|
||||
|
||||
USER jovyan
|
||||
|
|
@ -0,0 +1 @@
|
|||
from . import github_problem
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
import pandas as pd
|
||||
from tensor2tensor.utils import registry
|
||||
from tensor2tensor.models import transformer
|
||||
from tensor2tensor.data_generators import problem
|
||||
from tensor2tensor.data_generators import text_problems
|
||||
|
||||
|
||||
@registry.register_problem
|
||||
class GithubIssueSummarizationProblem(text_problems.Text2TextProblem):
|
||||
"""Predict issue summary from issue body. Using Github issue data."""
|
||||
|
||||
@property
|
||||
def approx_vocab_size(self):
|
||||
return 2**12 # ~4k
|
||||
|
||||
@property
|
||||
def is_generate_per_split(self):
|
||||
# generate_data will NOT shard the data into TRAIN and EVAL for us.
|
||||
return False
|
||||
|
||||
@property
|
||||
def dataset_splits(self):
|
||||
"""Splits of data to produce and number of output shards for each."""
|
||||
# 10% evaluation data
|
||||
return [{
|
||||
"split": problem.DatasetSplit.TRAIN,
|
||||
"shards": 90,
|
||||
}, {
|
||||
"split": problem.DatasetSplit.EVAL,
|
||||
"shards": 10,
|
||||
}]
|
||||
|
||||
def generate_samples(self, data_dir, tmp_dir, dataset_split): # pylint: disable=unused-argument, no-self-use
|
||||
chunksize = 200000
|
||||
for issue_data in pd.read_csv(
|
||||
'csv_data/github_issues_10000.csv', chunksize=chunksize):
|
||||
issue_body_data = issue_data.body.tolist()
|
||||
issue_title_data = issue_data.issue_title.tolist()
|
||||
n = len(issue_title_data)
|
||||
for i in range(n):
|
||||
yield {"inputs": issue_body_data[i], "targets": issue_title_data[i]}
|
||||
|
||||
|
||||
# Smaller than the typical translate model, and with more regularization
|
||||
@registry.register_hparams
|
||||
def transformer_github_issues():
|
||||
hparams = transformer.transformer_base()
|
||||
hparams.num_hidden_layers = 2
|
||||
hparams.hidden_size = 128
|
||||
hparams.filter_size = 512
|
||||
hparams.num_heads = 4
|
||||
hparams.attention_dropout = 0.6
|
||||
hparams.layer_prepostprocess_dropout = 0.6
|
||||
hparams.learning_rate = 0.05
|
||||
return hparams
|
||||
|
||||
|
||||
# hyperparameter tuning ranges
|
||||
@registry.register_ranged_hparams
|
||||
def transformer_github_issues_range(rhp):
|
||||
rhp.set_float("learning_rate", 0.05, 0.25, scale=rhp.LOG_SCALE)
|
||||
rhp.set_int("num_hidden_layers", 2, 4)
|
||||
rhp.set_discrete("hidden_size", [128, 256, 512])
|
||||
rhp.set_float("attention_dropout", 0.4, 0.7)
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
# TODO(ankushagarwal): Convert this to a python launcher script
|
||||
set -x
|
||||
export TF_CONFIG=${TF_CONFIG}
|
||||
echo "TF_CONFIG = ${TF_CONFIG}"
|
||||
OUTDIR=./out
|
||||
DATA_DIR=gs://kubeflow-examples/tensor2tensor/data
|
||||
TMP_DIR=./tmp
|
||||
PROBLEM=github_issue_summarization_problem
|
||||
USR_DIR=./github
|
||||
HPARAMS_SET=transformer_github_issues
|
||||
WORKER_ID=$(echo ${TF_CONFIG} | jq ".task.index")
|
||||
WORKER_TYPE=$(echo ${TF_CONFIG} | jq -r ".task.type")
|
||||
MASTER_INSTANCE=$(echo ${TF_CONFIG} | jq -r ".cluster.${WORKER_TYPE}[${WORKER_ID}]")
|
||||
rm -rf "${OUTDIR}" "${TMP_DIR}"
|
||||
mkdir -p "${OUTDIR}"
|
||||
mkdir -p "${TMP_DIR}"
|
||||
t2t-trainer \
|
||||
--data_dir=${DATA_DIR} \
|
||||
--t2t_usr_dir=${USR_DIR} \
|
||||
--problems=${PROBLEM} \
|
||||
--model=transformer \
|
||||
--hparams_set=${HPARAMS_SET} \
|
||||
--output_dir=$OUTDIR --job-dir=$OUTDIR --train_steps=1000 \
|
||||
--master=grpc://${MASTER_INSTANCE} \
|
||||
--schedule=run_std_server
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
#!/bin/bash
|
||||
# TODO(ankushagarwal): Convert this to a python launcher script
|
||||
set -x
|
||||
PS_REPLICAS="${1}"
|
||||
WORKER_REPLICAS="${2}"
|
||||
WORKER_GPU="${3}"
|
||||
TRAIN_STEPS="${4}"
|
||||
WORKER_JOB="${5}"
|
||||
SYNC="${6}"
|
||||
export TF_CONFIG=$(echo ${TF_CONFIG} | sed 's/"worker"/"master"/g')
|
||||
echo "TF_CONFIG = ${TF_CONFIG}"
|
||||
OUTDIR=./out
|
||||
DATA_DIR=gs://kubeflow-examples/tensor2tensor/data
|
||||
TMP_DIR=./tmp
|
||||
PROBLEM=github_issue_summarization_problem
|
||||
USR_DIR=./github
|
||||
HPARAMS_SET=transformer_github_issues
|
||||
WORKER_ID=$(echo ${TF_CONFIG} | jq ".task.index")
|
||||
WORKER_TYPE=$(echo ${TF_CONFIG} | jq -r ".task.type")
|
||||
MASTER_INSTANCE=$(echo ${TF_CONFIG} | jq -r ".cluster.${WORKER_TYPE}[${WORKER_ID}]")
|
||||
rm -rf "${OUTDIR}" "${TMP_DIR}"
|
||||
mkdir -p "${OUTDIR}"
|
||||
mkdir -p "${TMP_DIR}"
|
||||
t2t-trainer \
|
||||
--data_dir=${DATA_DIR} \
|
||||
--t2t_usr_dir=${USR_DIR} \
|
||||
--problems=${PROBLEM} \
|
||||
--model=transformer \
|
||||
--hparams_set=${HPARAMS_SET} \
|
||||
--output_dir=$OUTDIR --job-dir=$OUTDIR --train_steps=${TRAIN_STEPS} \
|
||||
--master=grpc://${MASTER_INSTANCE} \
|
||||
--ps_replicas=${PS_REPLICAS} \
|
||||
--worker_replicas=${WORKER_REPLICAS} \
|
||||
--worker_gpu=${WORKER_GPU} \
|
||||
--worker_id=${WORKER_ID} \
|
||||
--worker_job=${WORKER_JOB} \
|
||||
--ps_gpu=0 \
|
||||
--schedule=train \
|
||||
--sync=${SYNC}
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
# Distributed Training using tensor2tensor
|
||||
|
||||
[Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), or
|
||||
[T2T](https://github.com/tensorflow/tensor2tensor) for short, is a library
|
||||
of deep learning models and datasets designed to make deep learning more
|
||||
accessible and [accelerate ML
|
||||
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html). To get started, follow the instructions on the tensor2tensor [README](https://github.com/tensorflow/tensor2tensor) and install it locally.
|
||||
|
||||
We are going to use the packaged [transformer](https://research.googleblog.com/2017/08/transformer-novel-neural-network.html) model to train our github issue summarization model.
|
||||
|
||||
## Defining a Problem
|
||||
A key concept in the T2T library is that of a Problem, which ties together all the pieces needed to train a machine learning model. It is easiest to inherit from the appropriate base class in the T2T library and then change only the pieces that are different for your model. We are going to define a problem in [github_problem.py](tensor2tensor/github/github_problem.py) which will extend the inbuilt `text_problems.Text2TextProblem`. `github_problem.py` overrides some properties such as approx_vocab_size, generate_samples, etc.
|
||||
|
||||
## Generate training data
|
||||
|
||||
For training a model using tensor2tensor, the input data must be in a particular format. tensor2tensor comes with a data generator which transforms your input data into a format which can be consumed by the training process.
|
||||
|
||||
```
|
||||
cd tensor2tensor/
|
||||
mkdir csv_data
|
||||
cd csv_data
|
||||
wget https://storage.googleapis.com/kubeflow-examples/github-issue-summarization-data/github-issues.zip
|
||||
unzip github-issues.zip
|
||||
cd ..
|
||||
DATA_DIR=data
|
||||
TMP_DIR=tmp
|
||||
mkdir -p $DATA_DIR $TMP_DIR
|
||||
PROBLEM=github_issue_summarization_problem
|
||||
USR_DIR=./github
|
||||
rm -rf $DATA_DIR/*
|
||||
# Generate data
|
||||
# This can take a while depending on the size of the data
|
||||
t2t-datagen \
|
||||
--t2t_usr_dir=$USR_DIR \
|
||||
--problem=$PROBLEM \
|
||||
--data_dir=$DATA_DIR \
|
||||
--tmp_dir=$TMP_DIR
|
||||
|
||||
# Copy to GCS where it can be used by distributed training
|
||||
gsutil cp -r ${DATA_DIR} gs://${BUCKET_NAME}/${DATA_DIR}
|
||||
```
|
||||
|
||||
## Build and push docker image for distributed training
|
||||
|
||||
The [github](tensor2tensor/github) directory contains a Dockerfile to build the docker image
|
||||
required for distributed training.
|
||||
|
||||
```
|
||||
cd tensor2tensor/github
|
||||
docker build . -t gcr.io/${GCR_REGISTRY}/tensor2tensor-training:latest
|
||||
gcloud docker -- push gcr.io/${GCR_REGISTRY}/tensor2tensor-training:latest
|
||||
```
|
||||
|
||||
## Launch distributed training
|
||||
|
||||
[notebooks](notebooks) contains a ksonnet app([ks-app](notebooks/ks-app)) to deploy the TFJob.
|
||||
|
||||
|
||||
Set the appropriate params for the tfjob component
|
||||
|
||||
```commandline
|
||||
ks param set tensor2tensor namespace ${NAMESPACE}
|
||||
|
||||
# The image pushed in the previous step
|
||||
ks param set tensor2tensor image "gcr.io/${GCR_REGISTRY}/tensor2tensor-training:latest"
|
||||
ks param set tensor2tensor workers 3
|
||||
ks param set tensor2tensor train_steps 5000
|
||||
|
||||
```
|
||||
|
||||
Deploy the app:
|
||||
|
||||
```commandline
|
||||
ks apply tensor2tensor -c tfjob
|
||||
```
|
||||
|
||||
You can view the logs of the tf-job operator using
|
||||
|
||||
```commandline
|
||||
kubectl logs -f $(kubectl get pods -n=${NAMESPACE} -lname=tf-job-operator -o=jsonpath='{.items[0].metadata.name}')
|
||||
```
|
||||
Loading…
Reference in New Issue