examples/github_issue_summarization/notebooks/train.py

220 lines
6.4 KiB
Python

"""Train the github-issue-summarization model
train.py trains the github-issue-summarization model.
It reads the input data from GCS in a zip file format.
--input_data_gcs_bucket and --input_data_gcs_path specify
the location of input data.
It write the model back to GCS.
--output_model_gcs_bucket and --output_model_gcs_path specify
the location of output.
It also has parameters which control the training like
--learning_rate and --sample_size
"""
import argparse
import logging
import os
import re
import shutil
import time
import zipfile
import tempfile
from google.cloud import storage # pylint: disable=no-name-in-module
import trainer
GCS_REGEX = re.compile("gs://([^/]*)(/.*)?")
def split_gcs_uri(gcs_uri):
"""Split a GCS URI into bucket and path."""
m = GCS_REGEX.match(gcs_uri)
bucket = m.group(1)
path = ""
if m.group(2):
path = m.group(2).lstrip("/")
return bucket, path
def is_gcs_path(gcs_uri):
return GCS_REGEX.match(gcs_uri)
def process_input_file(remote_file):
"""Process the input file.
If its a GCS file we download it to a temporary local file. We do this
because Keras text preprocessing doesn't work with GCS.
If its a zip file we unpack it.
Args:
remote_file: The input
Returns:
csv_file: The local csv file to process
"""
if is_gcs_path(remote_file):
# Download the input to a local
with tempfile.NamedTemporaryFile() as hf:
input_data = hf.name
logging.info("Copying %s to %s", remote_file, input_data)
input_data_gcs_bucket, input_data_gcs_path = split_gcs_uri(
remote_file)
logging.info("Download bucket %s object %s.", input_data_gcs_bucket,
input_data_gcs_path)
bucket = storage.Bucket(storage.Client(), input_data_gcs_bucket)
storage.Blob(input_data_gcs_path, bucket).download_to_filename(
input_data)
else:
input_data = remote_file
ext = os.path.splitext(input_data)[-1]
if ext.lower() == '.zip':
zip_ref = zipfile.ZipFile(input_data, 'r')
zip_ref.extractall('.')
zip_ref.close()
# TODO(jlewi): Hardcoding the file in the Archive to use is brittle.
# We should probably just require the input to be a CSV file.:
csv_file = 'github_issues.csv'
else:
csv_file = input_data
return csv_file
def wait_for_preprocessing(preprocessed_file):
"""Wait for preprocessing.
In the case of distributed training the workers need to wait for the
preprocessing to be completed. But only the master runs preprocessing.
"""
# TODO(jlewi): Why do we need to block waiting for the file?
# I think this is because only the master produces the npy
# files so the other workers need to wait for the files to arrive.
# It might be better to make preprocessing a separate job.
# We should move this code since its only needed when using
# TF.Estimator
while True:
if os.path.isfile(preprocessed_file):
break
logging.info("Waiting for dataset")
time.sleep(2)
def main(unparsed_args=None): # pylint: disable=too-many-statements
# Parsing flags.
parser = argparse.ArgumentParser()
parser.add_argument("--sample_size", type=int, default=2000000)
parser.add_argument("--num_epochs", type=int, default=7,
help="Number of training epochs.")
parser.add_argument("--learning_rate", default=0.001, type=float)
parser.add_argument(
"--input_data",
type=str,
default="",
help="The input location. Can be a GCS or local file path.")
parser.add_argument(
"--output_model",
type=str,
default="",
help="The output location for the model GCS or local file path.")
parser.add_argument(
"--output_body_preprocessor_dpkl",
type=str,
default="body_pp.dpkl")
parser.add_argument(
"--output_title_preprocessor_dpkl",
type=str,
default="title_pp.dpkl")
parser.add_argument(
"--output_train_title_vecs_npy", type=str, default="train_title_vecs.npy")
parser.add_argument(
"--output_train_body_vecs_npy", type=str, default="train_body_vecs.npy")
parser.add_argument(
"--mode",
type=str,
default="keras",
help="Whether to train using TF.estimator or Keras.")
args = parser.parse_args(unparsed_args)
logging.basicConfig(
level=logging.INFO,
format=('%(levelname)s|%(asctime)s'
'|%(pathname)s|%(lineno)d| %(message)s'),
datefmt='%Y-%m-%dT%H:%M:%S',
)
logging.getLogger().setLevel(logging.INFO)
logging.info(args)
mode = args.mode.lower()
if not mode in ["estimator", "keras"]:
raise ValueError("Unrecognized mode %s; must be keras or estimator" % mode)
csv_file = process_input_file(args.input_data)
# Use a temporary directory for all the outputs.
# We will then copy the files to the final directory.
output_dir = tempfile.mkdtemp()
model_trainer = trainer.Trainer(output_dir)
model_trainer.preprocess(csv_file, args.sample_size)
if mode == "estimator":
wait_for_preprocessing(model_trainer.preprocessed_bodies)
model_trainer.build_model(args.learning_rate)
# Tuples of (temporary, final) paths
pairs = []
if mode == "keras":
local_model_output = args.output_model
if is_gcs_path(args.output_model):
local_model_output = os.path.join(output_dir, "model.h5")
model_trainer.train_keras(local_model_output,
base_name=os.path.join(output_dir, "model-checkpoint"),
epochs=args.num_epochs)
model_trainer.evaluate_keras()
# With Keras we might need to write to a local directory and then
# copy to GCS.
pairs.append((local_model_output, args.output_model))
elif mode == "estimator":
# With TF.Estimator we should be able to write directly to GCS.
model_trainer.train_estimator()
pairs.extend([
(model_trainer.body_pp_file, args.output_body_preprocessor_dpkl),
(model_trainer.title_pp_file, args.output_title_preprocessor_dpkl),
(model_trainer.preprocessed_titles, args.output_train_title_vecs_npy),
(model_trainer.preprocessed_bodies, args.output_train_body_vecs_npy),])
# Copy outputs
for p in pairs:
local = p[0]
remote = p[1]
if local == remote:
continue
logging.info("Copying %s to %s", local, remote)
if is_gcs_path(remote):
bucket_name, path = split_gcs_uri(remote)
bucket = storage.Bucket(storage.Client(), bucket_name)
blob = storage.Blob(path, bucket)
blob.upload_from_filename(local)
else:
shutil.move(local, remote)
if __name__ == '__main__':
main()