examples/github_issue_summarization/notebooks/train.py

288 lines
9.4 KiB
Python

"""Train the github-issue-summarization model
train.py trains the github-issue-summarization model.
It reads the input data from GCS in a zip file format.
--input_data_gcs_bucket and --input_data_gcs_path specify
the location of input data.
It write the model back to GCS.
--output_model_gcs_bucket and --output_model_gcs_path specify
the location of output.
It also has parameters which control the training like
--learning_rate and --sample_size
"""
import argparse
import logging
import os
import re
import zipfile
from google.cloud import storage # pylint: disable=no-name-in-module
import dill as dpickle
import numpy as np
import pandas as pd
from keras import optimizers
from keras.layers import GRU, BatchNormalization, Dense, Embedding, Input
from keras.models import Model
from sklearn.model_selection import train_test_split
from ktext.preprocess import processor
from seq2seq_utils import load_encoder_inputs, load_text_processor
GCS_REGEX = re.compile("gs://([^/]*)(/.*)?")
def split_gcs_uri(gcs_uri):
"""Split a GCS URI into bucket and path."""
m = GCS_REGEX.match(gcs_uri)
bucket = m.group(1)
path = ""
if m.group(2):
path = m.group(2).lstrip("/")
return bucket, path
def main(): # pylint: disable=too-many-statements
# Parsing flags.
parser = argparse.ArgumentParser()
parser.add_argument("--sample_size", type=int, default=2000000)
parser.add_argument("--learning_rate", default="0.001")
parser.add_argument(
"--input_data",
type=str,
default="",
help="The input location. Can be a GCS or local file path.")
# TODO(jlewi): The following arguments are deprecated; just
# use input_data. We should remove them as soon as all call sites
# are updated.
parser.add_argument(
"--input_data_gcs_bucket", type=str, default="kubeflow-examples")
parser.add_argument(
"--input_data_gcs_path",
type=str,
default="github-issue-summarization-data/github-issues.zip")
parser.add_argument(
"--output_model",
type=str,
default="",
help="The output location for the model GCS or local file path.")
parser.add_argument("--output_model_gcs_bucket", type=str, default="")
parser.add_argument(
"--output_model_gcs_path",
type=str,
default="github-issue-summarization-data")
parser.add_argument(
"--output_body_preprocessor_dpkl",
type=str,
default="body_pp.dpkl")
parser.add_argument(
"--output_title_preprocessor_dpkl",
type=str,
default="title_pp.dpkl")
parser.add_argument(
"--output_train_title_vecs_npy", type=str, default="train_title_vecs.npy")
parser.add_argument(
"--output_train_body_vecs_npy", type=str, default="train_body_vecs.npy")
parser.add_argument(
"--output_model_h5", type=str, default="seq2seq_model_tutorial.h5")
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format=('%(levelname)s|%(asctime)s'
'|%(pathname)s|%(lineno)d| %(message)s'),
datefmt='%Y-%m-%dT%H:%M:%S',
)
logging.getLogger().setLevel(logging.INFO)
logging.info(args)
learning_rate = float(args.learning_rate)
pd.set_option('display.max_colwidth', 500)
# For backwords compatibility
input_data_gcs_bucket = None
input_data_gcs_path = None
if not args.input_data:
# Since input_data isn't set fall back on old arguments.
input_data_gcs_bucket = args.input_data_gcs_bucket
input_data_gcs_path = args.input_data_gcs_path
else:
if args.input_data.startswith('gs://'):
input_data_gcs_bucket, input_data_gcs_path = split_gcs_uri(
args.input_data)
if input_data_gcs_bucket:
logging.info("Download bucket %s object %s.", input_data_gcs_bucket,
input_data_gcs_path)
bucket = storage.Bucket(storage.Client(), input_data_gcs_bucket)
args.input_data = 'github-issues.zip'
storage.Blob(input_data_gcs_path, bucket).download_to_filename(
args.input_data)
ext = os.path.splitext(args.input_data)[-1]
if ext.lower() == '.zip':
zip_ref = zipfile.ZipFile(args.input_data, 'r')
zip_ref.extractall('.')
zip_ref.close()
# TODO(jlewi): Hardcoding the file in the Archive to use is brittle.
# We should probably just require the input to be a CSV file.
csv_file = 'github_issues.csv'
else:
csv_file = args.input_data
# Read in data sample 2M rows (for speed of tutorial)
traindf, testdf = train_test_split(
pd.read_csv(csv_file).sample(n=args.sample_size), test_size=.10)
# Print stats about the shape of the data.
logging.info('Train: %d rows %d columns', traindf.shape[0], traindf.shape[1])
logging.info('Test: %d rows %d columns', testdf.shape[0], testdf.shape[1])
train_body_raw = traindf.body.tolist()
train_title_raw = traindf.issue_title.tolist()
# Clean, tokenize, and apply padding / truncating such that each document
# length = 70. Also, retain only the top 8,000 words in the vocabulary and set
# the remaining words to 1 which will become common index for rare words.
body_pp = processor(keep_n=8000, padding_maxlen=70)
train_body_vecs = body_pp.fit_transform(train_body_raw)
logging.info('Example original body: %s', train_body_raw[0])
logging.info('Example body after pre-processing: %s', train_body_vecs[0])
# Instantiate a text processor for the titles, with some different parameters.
title_pp = processor(
append_indicators=True, keep_n=4500, padding_maxlen=12, padding='post')
# process the title data
train_title_vecs = title_pp.fit_transform(train_title_raw)
logging.info('Example original title: %s', train_title_raw[0])
logging.info('Example title after pre-processing: %s', train_title_vecs[0])
# Save the preprocessor.
with open(args.output_body_preprocessor_dpkl, 'wb') as f:
dpickle.dump(body_pp, f)
with open(args.output_title_preprocessor_dpkl, 'wb') as f:
dpickle.dump(title_pp, f)
# Save the processed data.
np.save(args.output_train_title_vecs_npy, train_title_vecs)
np.save(args.output_train_body_vecs_npy, train_body_vecs)
_, doc_length = load_encoder_inputs(args.output_train_body_vecs_npy)
num_encoder_tokens, body_pp = load_text_processor(
args.output_body_preprocessor_dpkl)
num_decoder_tokens, title_pp = load_text_processor(
args.output_title_preprocessor_dpkl)
# Arbitrarly set latent dimension for embedding and hidden units
latent_dim = 300
###############
# Encoder Model.
###############
encoder_inputs = Input(shape=(doc_length,), name='Encoder-Input')
# Word embeding for encoder (ex: Issue Body)
x = Embedding(
num_encoder_tokens, latent_dim, name='Body-Word-Embedding',
mask_zero=False)(encoder_inputs)
x = BatchNormalization(name='Encoder-Batchnorm-1')(x)
# We do not need the `encoder_output` just the hidden state.
_, state_h = GRU(latent_dim, return_state=True, name='Encoder-Last-GRU')(x)
# Encapsulate the encoder as a separate entity so we can just
# encode without decoding if we want to.
encoder_model = Model(
inputs=encoder_inputs, outputs=state_h, name='Encoder-Model')
seq2seq_encoder_out = encoder_model(encoder_inputs)
################
# Decoder Model.
################
decoder_inputs = Input(
shape=(None,), name='Decoder-Input') # for teacher forcing
# Word Embedding For Decoder (ex: Issue Titles)
dec_emb = Embedding(
num_decoder_tokens,
latent_dim,
name='Decoder-Word-Embedding',
mask_zero=False)(decoder_inputs)
dec_bn = BatchNormalization(name='Decoder-Batchnorm-1')(dec_emb)
# Set up the decoder, using `decoder_state_input` as initial state.
decoder_gru = GRU(
latent_dim, return_state=True, return_sequences=True, name='Decoder-GRU')
decoder_gru_output, _ = decoder_gru(dec_bn, initial_state=seq2seq_encoder_out)
x = BatchNormalization(name='Decoder-Batchnorm-2')(decoder_gru_output)
# Dense layer for prediction
decoder_dense = Dense(
num_decoder_tokens, activation='softmax', name='Final-Output-Dense')
decoder_outputs = decoder_dense(x)
################
# Seq2Seq Model.
################
seq2seq_Model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
seq2seq_Model.compile(
optimizer=optimizers.Nadam(lr=learning_rate),
loss='sparse_categorical_crossentropy')
seq2seq_Model.summary()
#############
# Save model.
#############
seq2seq_Model.save(args.output_model_h5)
######################
# Upload model to GCS.
######################
# For backwords compatibility
output_model_gcs_bucket = None
output_model_gcs_path = None
if not args.output_model:
# Since input_data isn't set fall back on old arguments.
output_model_gcs_bucket = args.output_model_gcs_bucket
output_model_gcs_path = args.output_model_gcs_path
else:
if args.output_model.startswith('gs://'):
output_model_gcs_bucket, output_model_gcs_path = split_gcs_uri(
args.output_model)
if output_model_gcs_bucket:
logging.info("Uploading model files to bucket %s path %s.",
output_model_gcs_bucket, output_model_gcs_path)
bucket = storage.Bucket(storage.Client(), output_model_gcs_bucket)
storage.Blob(
output_model_gcs_path + "/" + args.output_model_h5, bucket).upload_from_filename(
args.output_model_h5)
storage.Blob(output_model_gcs_path + "/" + args.output_body_preprocessor_dpkl,
bucket).upload_from_filename(args.output_body_preprocessor_dpkl)
storage.Blob(output_model_gcs_path + "/" + args.output_title_preprocessor_dpkl,
bucket).upload_from_filename(args.output_title_preprocessor_dpkl)
if __name__ == '__main__':
main()