From db358557ddba0bd02faade27e9a4e2f512b1f47c Mon Sep 17 00:00:00 2001 From: Pascal Vicaire <25757005+vicaire@users.noreply.github.com> Date: Thu, 8 Mar 2018 16:03:10 -0800 Subject: [PATCH] Example workflow for Github issue summarization. (#35) * Example workflow for Github issue summarization. * Fixing quotes in README.md * Fixing typo in README.md --- .../workflow/Dockerfile | 24 ++ github_issue_summarization/workflow/README.md | 72 ++++ .../workflow/github_issues_summarization.yaml | 183 ++++++++ .../workflow/workspace/src/prediction.py | 33 ++ .../src/preprocess_data_for_deep_learning.py | 50 +++ .../workflow/workspace/src/process_data.py | 28 ++ .../workflow/workspace/src/recommend.py | 38 ++ .../workflow/workspace/src/seq2seq_utils.py | 393 ++++++++++++++++++ .../workflow/workspace/src/train.py | 93 +++++ 9 files changed, 914 insertions(+) create mode 100644 github_issue_summarization/workflow/Dockerfile create mode 100644 github_issue_summarization/workflow/README.md create mode 100644 github_issue_summarization/workflow/github_issues_summarization.yaml create mode 100644 github_issue_summarization/workflow/workspace/src/prediction.py create mode 100644 github_issue_summarization/workflow/workspace/src/preprocess_data_for_deep_learning.py create mode 100644 github_issue_summarization/workflow/workspace/src/process_data.py create mode 100644 github_issue_summarization/workflow/workspace/src/recommend.py create mode 100644 github_issue_summarization/workflow/workspace/src/seq2seq_utils.py create mode 100644 github_issue_summarization/workflow/workspace/src/train.py diff --git a/github_issue_summarization/workflow/Dockerfile b/github_issue_summarization/workflow/Dockerfile new file mode 100644 index 00000000..64e7c1e5 --- /dev/null +++ b/github_issue_summarization/workflow/Dockerfile @@ -0,0 +1,24 @@ +FROM python:3.6 +COPY ./workspace/src /workspace/src/ +RUN apt-get update && apt-get install -y --no-install-recommends \ + python-pandas \ + && pip3 install -U scikit-learn \ + && pip3 install -U ktext \ + && pip3 install -U IPython \ + && pip3 install -U annoy \ + && pip3 install -U tqdm \ + && pip3 install -U nltk \ + && pip3 install -U matplotlib \ + && pip3 install -U tensorflow \ + && pip3 install -U bernoulli \ + && pip3 install -U h5py \ + && git clone https://github.com/google/seq2seq.git \ + && pip3 install -e ./seq2seq/ \ + && apt-get clean \ + && rm -rf \ + /var/lib/apt/lists/* \ + /tmp/* \ + /var/tmp/* \ + /usr/share/man \ + /usr/share/doc \ + /usr/share/doc-base diff --git a/github_issue_summarization/workflow/README.md b/github_issue_summarization/workflow/README.md new file mode 100644 index 00000000..890b9ed1 --- /dev/null +++ b/github_issue_summarization/workflow/README.md @@ -0,0 +1,72 @@ +# [WIP] Github summarization workflow. + +# Prerequisites. + +* [Create a GKE cluster and configure kubectl.](https://cloud.google.com/kubernetes-engine/docs/how-to/creating-a-container-cluster) +* [Install Argo.](https://github.com/argoproj/argo/blob/master/demo.md) +* [Configure the defautl artifact repository.](https://github.com/argoproj/argo/blob/master/ARTIFACT_REPO.md#configure-the-default-artifact-repository) + +# Get the input data and upload it to GCS. + +Get the input data from this [location](https://towardsdatascience.com/how-to-create-data-products-that-are-magical-using-sequence-to-sequence-models-703f86a231f8). In the following, we assume that the file path is ./github-issues.zip + +Decompress the input data: + +``` +unzip ./github-issues.zip +``` + +For debugging purposes, consider reducing the size of the input data. The workflow will execute much faster: + +``` +cat ./github-issues.csv | head -n 10000 > ./github-issues-medium.csv +``` + +Compress the data using gzip (this format is the one assumed by the workflow): + +``` +gzip ./github-issues-medium.csv +``` + +Upload the data to GCS: + +``` +gsutil cp ./github-issues-medium.csv.gz gs:// +``` + +# Building the container. + +Build the container and tag it so that it can be pushed to a [GCP container registry](https://cloud.google.com/container-registry/) + +``` +docker build -f Dockerfile -t gcr.io//github_issue_summarization:v1 . +``` + +Push the container to the GCP container registry: + +``` +gcloud docker -- push gcr.io//github_issue_summarization:v1 +``` + +# Running the workflow. + +Run the workflow: + +``` +argo submit github_issues_summarization.yaml + -p bucket= + -p bucket-key= + -p container-image=gcr.io//github_issue_summarization:v1 +``` + +Where: + +* is the name of a GCS bucket where the input data is stored (e.g.: "my_bucket_1234"). +* is the path to the input data in csv.gz format (e.g.: "data/github_issues.csv.gz"). +* is the name of the GCP project where the container was pushed. + +The data generated by the workflow will be stored in the default artifact +repository specified in the previous section. + +The logs can be read by using the argo get and argo logs commands ([link](https://github.com/argoproj/argo/blob/master/demo.md#4-run-simple-example-workflows)) + diff --git a/github_issue_summarization/workflow/github_issues_summarization.yaml b/github_issue_summarization/workflow/github_issues_summarization.yaml new file mode 100644 index 00000000..a0c10a79 --- /dev/null +++ b/github_issue_summarization/workflow/github_issues_summarization.yaml @@ -0,0 +1,183 @@ +apiVersion: argoproj.io/v1alpha1 +kind: Workflow +metadata: + generateName: github-issue-summarization- +spec: + entrypoint: default + + # Create a volume for containers to store their output data. + volumeClaimTemplates: + - metadata: + name: workdir + spec: + accessModes: [ "ReadWriteOnce" ] + resources: + requests: + storage: 20Gi + + # Arguments of the workflows + arguments: + parameters: + + # The name of the GCS bucket where the data is stored. + - name: bucket + + # The path to the input data in the GCS bucket, in csv.gz format. + - name: bucket-key + + # The number of data points to use in the workflows. + # The default ensures that the workflow executes quickly but does + # not lead to good results. + - name: sample-size + value: 200 + + # The number of issues on which to provide recommendations. + - name: input-topic-number + value: 10 + + # The number of issues to summarize. + - name: input-prediction-count + value: 50 + + # The learning rate for training. + - name: learning-rate + value: "0.001" + + # The container image to use in the workflow. + - name: container-image + + templates: + + ################################## + # Define the steps of the workflow + ################################## + - name: default + steps: + - - name: import-data + template: import-data + - - name: process-data + template: process-data + - - name: preprocess-deep-learning + template: preprocess-deep-learning + - - name: training + template: training + - - name: prediction + template: prediction + - - name: feature-extraction + template: feature-extraction + + ################################################# + # Import / Unzip + # Imports the data on the volume and unzips it. + ################################################# + - name: import-data + container: + image: alpine:latest + command: [sh, -c] + args: ["cp /mnt/workspace/data/issues.csv.gz /mnt/workspace/data/issues-copy.csv.gz; gzip -d /mnt/workspace/data/issues-copy.csv.gz; cp /mnt/workspace/data/issues-copy.csv /mnt/workspace/data/github_issues_medium.csv;"] + volumeMounts: + - name: workdir + mountPath: /mnt/workspace/data/ + inputs: + artifacts: + - name: input + path: /mnt/workspace/data/issues.csv.gz + s3: + endpoint: storage.googleapis.com + bucket: "{{workflow.parameters.bucket}}" + key: "{{workflow.parameters.bucket-key}}" + accessKeySecret: + name: gcs-accesskey + key: accesskey + secretKeySecret: + name: gcs-accesskey + key: secretkey + outputs: + artifacts: + - name: output + path: "/mnt/workspace/data/issues-copy.csv" + + ######################################################################### + # Process Data + # Generates the training and test set. Only processes "sample-size" rows. + ######################################################################### + - name: process-data + container: + image: "{{workflow.parameters.container-image}}" + command: [sh, -c] + args: ["cd ./workspace/src; python process_data.py --input_csv=/mnt/workspace/data/github_issues_medium.csv --sample_size={{workflow.parameters.sample-size}} --output_traindf_csv=/mnt/workspace/data/github_issues_medium_train.csv --output_testdf_csv=/mnt/workspace/data/github_issues_medium_test.csv"] + volumeMounts: + - name: workdir + mountPath: /mnt/workspace/data/ + outputs: + artifacts: + - name: output-traindf-csv + path: /mnt/workspace/data/github_issues_medium_train.csv + - name: output-testdf-csv + path: /mnt/workspace/data/github_issues_medium_test.csv + + ####################################### + # Preprocess for deep learning + ####################################### + - name: preprocess-deep-learning + container: + image: "{{workflow.parameters.container-image}}" + command: [sh, -c] + args: ["cd ./workspace/src; python ./preprocess_data_for_deep_learning.py --input_traindf_csv=/mnt/workspace/data/github_issues_medium_train.csv --output_body_preprocessor_dpkl=/mnt/workspace/data/body_preprocessor.dpkl --output_title_preprocessor_dpkl=/mnt/workspace/data/title_preprocessor.dpkl --output_train_title_vecs_npy=/mnt/workspace/data/train_title_vecs.npy --output_train_body_vecs_npy=/mnt/workspace/data/train_body_vecs.npy"] + volumeMounts: + - name: workdir + mountPath: /mnt/workspace/data/ + outputs: + artifacts: + - name: output-body-preprocessor-dpkl + path: /mnt/workspace/data/body_preprocessor.dpkl + - name: output-title-preprocessor-dpkl + path: /mnt/workspace/data/title_preprocessor.dpkl + - name: output-train-title-vecs-npy + path: /mnt/workspace/data/train_title_vecs.npy + - name: output-train-body-vecs-npy + path: /mnt/workspace/data/train_body_vecs.npy + + ####################################### + # Training + ####################################### + - name: training + container: + image: "{{workflow.parameters.container-image}}" + command: [sh, -c] + args: ["cd ./workspace/src; python train.py --input_body_preprocessor_dpkl=/mnt/workspace/data/body_preprocessor.dpkl --input_title_preprocessor_dpkl=/mnt/workspace/data/title_preprocessor.dpkl --input_train_title_vecs_npy=/mnt/workspace/data/train_title_vecs.npy --input_train_body_vecs_npy=/mnt/workspace/data/train_body_vecs.npy --output_model_h5=/mnt/workspace/data/output_model.h5 --learning_rate={{workflow.parameters.learning-rate}}"] + volumeMounts: + - name: workdir + mountPath: /mnt/workspace/data/ + outputs: + artifacts: + - name: output-model-h5 + path: /mnt/workspace/data/output_model.h5 + + ########################################################################### + # Prediction + # For now, this step simply summarizes "input-prediction-count" issues and + # prints the results in the logs. + ########################################################################### + - name: prediction + container: + image: "{{workflow.parameters.container-image}}" + command: [sh, -c] + args: ["cd ./workspace/src; python prediction.py --input_body_preprocessor_dpkl=/mnt/workspace/data/body_preprocessor.dpkl --input_title_preprocessor_dpkl=/mnt/workspace/data/title_preprocessor.dpkl --input_model_h5=/mnt/workspace/data/output_model.h5 --input_testdf_csv=/mnt/workspace/data/github_issues_medium_test.csv --input_prediction_count={{workflow.parameters.input-prediction-count}}"] + volumeMounts: + - name: workdir + mountPath: /mnt/workspace/data/ + + ########################################################################### + # Feature Extraction + # For now, this step simply provides recommendations about + # "input-topic-number" issues and prints the results in the logs. + ########################################################################### + - name: feature-extraction + container: + image: "{{workflow.parameters.container-image}}" + command: [sh, -c] + args: ["cd ./workspace/src; python recommend.py --input_csv=/mnt/workspace/data/github_issues_medium.csv --input_body_preprocessor_dpkl=/mnt/workspace/data/body_preprocessor.dpkl --input_title_preprocessor_dpkl=/mnt/workspace/data/title_preprocessor.dpkl --input_model_h5=/mnt/workspace/data/output_model.h5 --input_testdf_csv=/mnt/workspace/data/github_issues_medium_test.csv --input_topic_number={{workflow.parameters.input-topic-number}}"] + volumeMounts: + - name: workdir + mountPath: /mnt/workspace/data/ diff --git a/github_issue_summarization/workflow/workspace/src/prediction.py b/github_issue_summarization/workflow/workspace/src/prediction.py new file mode 100644 index 00000000..5539f6c0 --- /dev/null +++ b/github_issue_summarization/workflow/workspace/src/prediction.py @@ -0,0 +1,33 @@ +import argparse +import keras +import pandas as pd +from seq2seq_utils import load_decoder_inputs +from seq2seq_utils import load_encoder_inputs +from seq2seq_utils import load_text_processor +from seq2seq_utils import Seq2Seq_Inference + +# Parsing flags. +parser = argparse.ArgumentParser() +parser.add_argument("--input_model_h5") +parser.add_argument("--input_body_preprocessor_dpkl") +parser.add_argument("--input_title_preprocessor_dpkl") +parser.add_argument("--input_testdf_csv") +parser.add_argument("--input_prediction_count", type=int, default=50) +args = parser.parse_args() +print(args) + +# Read data. +testdf = pd.read_csv(args.input_testdf_csv) + +# Load model, preprocessors. +seq2seq_Model = keras.models.load_model(args.input_model_h5) +num_encoder_tokens, body_pp = load_text_processor(args.input_body_preprocessor_dpkl) +num_decoder_tokens, title_pp = load_text_processor(args.input_title_preprocessor_dpkl) + +# Prepare inference. +seq2seq_inf = Seq2Seq_Inference(encoder_preprocessor=body_pp, + decoder_preprocessor=title_pp, + seq2seq_model=seq2seq_Model) + +# Output predictions for n random rows in the test set. +seq2seq_inf.demo_model_predictions(n=args.input_prediction_count, issue_df=testdf) diff --git a/github_issue_summarization/workflow/workspace/src/preprocess_data_for_deep_learning.py b/github_issue_summarization/workflow/workspace/src/preprocess_data_for_deep_learning.py new file mode 100644 index 00000000..125717e0 --- /dev/null +++ b/github_issue_summarization/workflow/workspace/src/preprocess_data_for_deep_learning.py @@ -0,0 +1,50 @@ +import argparse +import dill as dpickle +from ktext.preprocess import processor +import numpy as np +import pandas as pd + +# Parsing flags. +parser = argparse.ArgumentParser() +parser.add_argument("--input_traindf_csv") +parser.add_argument("--output_body_preprocessor_dpkl") +parser.add_argument("--output_title_preprocessor_dpkl") +parser.add_argument("--output_train_title_vecs_npy") +parser.add_argument("--output_train_body_vecs_npy") +args = parser.parse_args() +print(args) + +# Read data. +traindf = pd.read_csv(args.input_traindf_csv) +train_body_raw = traindf.body.tolist() +train_title_raw = traindf.issue_title.tolist() + +# Clean, tokenize, and apply padding / truncating such that each document +# length = 70. Also, retain only the top 8,000 words in the vocabulary and set +# the remaining words to 1 which will become common index for rare words. +body_pp = processor(keep_n=8000, padding_maxlen=70) +train_body_vecs = body_pp.fit_transform(train_body_raw) + +print('Example original body:', train_body_raw[0]) +print('Example body after pre-processing:', train_body_vecs[0]) + +# Instantiate a text processor for the titles, with some different parameters. +title_pp = processor(append_indicators=True, keep_n=4500, + padding_maxlen=12, padding ='post') + +# process the title data +train_title_vecs = title_pp.fit_transform(train_title_raw) + +print('Example original title:', train_title_raw[0]) +print('Example title after pre-processing:', train_title_vecs[0]) + +# Save the preprocessor. +with open(args.output_body_preprocessor_dpkl, 'wb') as f: + dpickle.dump(body_pp, f) + +with open(args.output_title_preprocessor_dpkl, 'wb') as f: + dpickle.dump(title_pp, f) + +# Save the processed data. +np.save(args.output_train_title_vecs_npy, train_title_vecs) +np.save(args.output_train_body_vecs_npy, train_body_vecs) diff --git a/github_issue_summarization/workflow/workspace/src/process_data.py b/github_issue_summarization/workflow/workspace/src/process_data.py new file mode 100644 index 00000000..d6b27cf4 --- /dev/null +++ b/github_issue_summarization/workflow/workspace/src/process_data.py @@ -0,0 +1,28 @@ +import argparse +import glob +import logging +import pandas as pd +from sklearn.model_selection import train_test_split + +# Parsing flags. +parser = argparse.ArgumentParser() +parser.add_argument("--input_csv") +parser.add_argument("--sample_size", type=int, default=2000000) +parser.add_argument("--output_traindf_csv") +parser.add_argument("--output_testdf_csv") +args = parser.parse_args() +print(args) + +pd.set_option('display.max_colwidth', 500) + +# Read in data sample 2M rows (for speed of tutorial) +traindf, testdf = train_test_split(pd.read_csv(args.input_csv).sample(n=args.sample_size), + test_size=.10) + +# Print stats about the shape of the data. +print(f'Train: {traindf.shape[0]:,} rows {traindf.shape[1]:,} columns') +print(f'Test: {testdf.shape[0]:,} rows {testdf.shape[1]:,} columns') + +# Store output as CSV. +traindf.to_csv(args.output_traindf_csv) +testdf.to_csv(args.output_testdf_csv) diff --git a/github_issue_summarization/workflow/workspace/src/recommend.py b/github_issue_summarization/workflow/workspace/src/recommend.py new file mode 100644 index 00000000..f755bb4f --- /dev/null +++ b/github_issue_summarization/workflow/workspace/src/recommend.py @@ -0,0 +1,38 @@ +import argparse +import keras +import pandas as pd +from seq2seq_utils import load_decoder_inputs +from seq2seq_utils import load_encoder_inputs +from seq2seq_utils import load_text_processor +from seq2seq_utils import Seq2Seq_Inference + +# Parsing flags. +parser = argparse.ArgumentParser() +parser.add_argument("--input_csv") +parser.add_argument("--input_model_h5") +parser.add_argument("--input_body_preprocessor_dpkl") +parser.add_argument("--input_title_preprocessor_dpkl") +parser.add_argument("--input_testdf_csv") +parser.add_argument("--input_topic_number", type=int, default=1) +args = parser.parse_args() +print(args) + +# Read data. +all_data_df = pd.read_csv(args.input_csv) +testdf = pd.read_csv(args.input_testdf_csv) + +# Load model, preprocessors. +num_encoder_tokens, body_pp = load_text_processor(args.input_body_preprocessor_dpkl) +num_decoder_tokens, title_pp = load_text_processor(args.input_title_preprocessor_dpkl) +seq2seq_Model = keras.models.load_model(args.input_model_h5) + +# Prepare the recommender. +all_data_bodies = all_data_df['body'].tolist() +all_data_vectorized = body_pp.transform_parallel(all_data_bodies) +seq2seq_inf_rec = Seq2Seq_Inference(encoder_preprocessor=body_pp, + decoder_preprocessor=title_pp, + seq2seq_model=seq2seq_Model) +recsys_annoyobj = seq2seq_inf_rec.prepare_recommender(all_data_vectorized, all_data_df) + +# Output recommendations for n topics. +seq2seq_inf_rec.demo_model_predictions(n=args.input_topic_number, issue_df=testdf, threshold=1) diff --git a/github_issue_summarization/workflow/workspace/src/seq2seq_utils.py b/github_issue_summarization/workflow/workspace/src/seq2seq_utils.py new file mode 100644 index 00000000..c278dfdb --- /dev/null +++ b/github_issue_summarization/workflow/workspace/src/seq2seq_utils.py @@ -0,0 +1,393 @@ +from matplotlib import pyplot as plt +import tensorflow as tf +from keras import backend as K +from keras.layers import Input +from keras.models import Model +from IPython.display import SVG, display +from keras.utils.vis_utils import model_to_dot +import logging +import numpy as np +import dill as dpickle +from annoy import AnnoyIndex +from tqdm import tqdm, tqdm_notebook +from random import random +from nltk.translate.bleu_score import corpus_bleu + +def load_text_processor(fname='title_pp.dpkl'): + """ + Load preprocessors from disk. + Parameters + ---------- + fname: str + file name of ktext.proccessor object + Returns + ------- + num_tokens : int + size of vocabulary loaded into ktext.processor + pp : ktext.processor + the processor you are trying to load + Typical Usage: + ------------- + num_decoder_tokens, title_pp = load_text_processor(fname='title_pp.dpkl') + num_encoder_tokens, body_pp = load_text_processor(fname='body_pp.dpkl') + """ + # Load files from disk + with open(fname, 'rb') as f: + pp = dpickle.load(f) + + num_tokens = max(pp.id2token.keys()) + 1 + print(f'Size of vocabulary for {fname}: {num_tokens:,}') + return num_tokens, pp + + +def load_decoder_inputs(decoder_np_vecs='train_title_vecs.npy'): + """ + Load decoder inputs. + Parameters + ---------- + decoder_np_vecs : str + filename of serialized numpy.array of decoder input (issue title) + Returns + ------- + decoder_input_data : numpy.array + The data fed to the decoder as input during training for teacher forcing. + This is the same as `decoder_np_vecs` except the last position. + decoder_target_data : numpy.array + The data that the decoder data is trained to generate (issue title). + Calculated by sliding `decoder_np_vecs` one position forward. + """ + vectorized_title = np.load(decoder_np_vecs) + # For Decoder Input, you don't need the last word as that is only for prediction + # when we are training using Teacher Forcing. + decoder_input_data = vectorized_title[:, :-1] + + # Decoder Target Data Is Ahead By 1 Time Step From Decoder Input Data (Teacher Forcing) + decoder_target_data = vectorized_title[:, 1:] + + print(f'Shape of decoder input: {decoder_input_data.shape}') + print(f'Shape of decoder target: {decoder_target_data.shape}') + return decoder_input_data, decoder_target_data + + +def load_encoder_inputs(encoder_np_vecs='train_body_vecs.npy'): + """ + Load variables & data that are inputs to encoder. + Parameters + ---------- + encoder_np_vecs : str + filename of serialized numpy.array of encoder input (issue title) + Returns + ------- + encoder_input_data : numpy.array + The issue body + doc_length : int + The standard document length of the input for the encoder after padding + the shape of this array will be (num_examples, doc_length) + """ + vectorized_body = np.load(encoder_np_vecs) + # Encoder input is simply the body of the issue text + encoder_input_data = vectorized_body + doc_length = encoder_input_data.shape[1] + print(f'Shape of encoder input: {encoder_input_data.shape}') + return encoder_input_data, doc_length + + +def viz_model_architecture(model): + """Visualize model architecture in Jupyter notebook.""" + display(SVG(model_to_dot(model).create(prog='dot', format='svg'))) + + +def free_gpu_mem(): + """Attempt to free gpu memory.""" + K.get_session().close() + cfg = K.tf.ConfigProto() + cfg.gpu_options.allow_growth = True + K.set_session(K.tf.Session(config=cfg)) + + +def test_gpu(): + """Run a toy computation task in tensorflow to test GPU.""" + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + session = tf.Session(config=config) + hello = tf.constant('Hello, TensorFlow!') + print(session.run(hello)) + + +def plot_model_training_history(history_object): + """Plots model train vs. validation loss.""" + plt.title('model accuracy') + plt.ylabel('accuracy') + plt.xlabel('epoch') + plt.plot(history_object.history['loss']) + plt.plot(history_object.history['val_loss']) + plt.legend(['train', 'test'], loc='upper left') + plt.show() + + +def extract_encoder_model(model): + """ + Extract the encoder from the original Sequence to Sequence Model. + Returns a keras model object that has one input (body of issue) and one + output (encoding of issue, which is the last hidden state). + Input: + ----- + model: keras model object + Returns: + ----- + keras model object + """ + encoder_model = model.get_layer('Encoder-Model') + return encoder_model + + +def extract_decoder_model(model): + """ + Extract the decoder from the original model. + Inputs: + ------ + model: keras model object + Returns: + ------- + A Keras model object with the following inputs and outputs: + Inputs of Keras Model That Is Returned: + 1: the embedding index for the last predicted word or the indicator + 2: the last hidden state, or in the case of the first word the hidden state from the encoder + Outputs of Keras Model That Is Returned: + 1. Prediction (class probabilities) for the next word + 2. The hidden state of the decoder, to be fed back into the decoder at the next time step + Implementation Notes: + ---------------------- + Must extract relevant layers and reconstruct part of the computation graph + to allow for different inputs as we are not going to use teacher forcing at + inference time. + """ + # the latent dimension is the same throughout the architecture so we are going to + # cheat and grab the latent dimension of the embedding because that is the same as what is + # output from the decoder + latent_dim = model.get_layer('Decoder-Word-Embedding').output_shape[-1] + + # Reconstruct the input into the decoder + decoder_inputs = model.get_layer('Decoder-Input').input + dec_emb = model.get_layer('Decoder-Word-Embedding')(decoder_inputs) + dec_bn = model.get_layer('Decoder-Batchnorm-1')(dec_emb) + + # Instead of setting the intial state from the encoder and forgetting about it, during inference + # we are not doing teacher forcing, so we will have to have a feedback loop from predictions back into + # the GRU, thus we define this input layer for the state so we can add this capability + gru_inference_state_input = Input(shape=(latent_dim,), name='hidden_state_input') + + # we need to reuse the weights that is why we are getting this + # If you inspect the decoder GRU that we created for training, it will take as input + # 2 tensors -> (1) is the embedding layer output for the teacher forcing + # (which will now be the last step's prediction, and will be _start_ on the first time step) + # (2) is the state, which we will initialize with the encoder on the first time step, but then + # grab the state after the first prediction and feed that back in again. + gru_out, gru_state_out = model.get_layer('Decoder-GRU')([dec_bn, gru_inference_state_input]) + + # Reconstruct dense layers + dec_bn2 = model.get_layer('Decoder-Batchnorm-2')(gru_out) + dense_out = model.get_layer('Final-Output-Dense')(dec_bn2) + decoder_model = Model([decoder_inputs, gru_inference_state_input], + [dense_out, gru_state_out]) + return decoder_model + + +class Seq2Seq_Inference(object): + def __init__(self, + encoder_preprocessor, + decoder_preprocessor, + seq2seq_model): + + self.pp_body = encoder_preprocessor + self.pp_title = decoder_preprocessor + self.seq2seq_model = seq2seq_model + self.encoder_model = extract_encoder_model(seq2seq_model) + self.decoder_model = extract_decoder_model(seq2seq_model) + self.default_max_len_title = self.pp_title.padding_maxlen + self.nn = None + self.rec_df = None + + def generate_issue_title(self, + raw_input_text, + max_len_title=None): + """ + Use the seq2seq model to generate a title given the body of an issue. + Inputs + ------ + raw_input: str + The body of the issue text as an input string + max_len_title: int (optional) + The maximum length of the title the model will generate + """ + if max_len_title is None: + max_len_title = self.default_max_len_title + # get the encoder's features for the decoder + raw_tokenized = self.pp_body.transform([raw_input_text]) + body_encoding = self.encoder_model.predict(raw_tokenized) + # we want to save the encoder's embedding before its updated by decoder + # because we can use that as an embedding for other tasks. + original_body_encoding = body_encoding + state_value = np.array(self.pp_title.token2id['_start_']).reshape(1, 1) + + decoded_sentence = [] + stop_condition = False + while not stop_condition: + preds, st = self.decoder_model.predict([state_value, body_encoding]) + + # We are going to ignore indices 0 (padding) and indices 1 (unknown) + # Argmax will return the integer index corresponding to the + # prediction + 2 b/c we chopped off first two + pred_idx = np.argmax(preds[:, :, 2:]) + 2 + + # retrieve word from index prediction + pred_word_str = self.pp_title.id2token[pred_idx] + + if pred_word_str == '_end_' or len(decoded_sentence) >= max_len_title: + stop_condition = True + break + decoded_sentence.append(pred_word_str) + + # update the decoder for the next word + body_encoding = st + state_value = np.array(pred_idx).reshape(1, 1) + + return original_body_encoding, ' '.join(decoded_sentence) + + + def print_example(self, + i, + body_text, + title_text, + url, + threshold): + """ + Prints an example of the model's prediction for manual inspection. + """ + if i: + print('\n\n==============================================') + print(f'============== Example # {i} =================\n') + + if url: + print(url) + + print(f"Issue Body:\n {body_text} \n") + + if title_text: + print(f"Original Title:\n {title_text}") + + emb, gen_title = self.generate_issue_title(body_text) + print(f"\n****** Machine Generated Title (Prediction) ******:\n {gen_title}") + + if self.nn: + # return neighbors and distances + n, d = self.nn.get_nns_by_vector(emb.flatten(), n=4, + include_distances=True) + neighbors = n[1:] + dist = d[1:] + + if min(dist) <= threshold: + cols = ['issue_url', 'issue_title', 'body'] + dfcopy = self.rec_df.iloc[neighbors][cols].copy(deep=True) + dfcopy['dist'] = dist + similar_issues_df = dfcopy.query(f'dist <= {threshold}') + + print("\n**** Similar Issues (using encoder embedding) ****:\n") + display(similar_issues_df) + + + def demo_model_predictions(self, + n, + issue_df, + threshold=1): + """ + Pick n random Issues and display predictions. + Input: + ------ + n : int + Number of issues to display from issue_df + issue_df : pandas DataFrame + DataFrame that contains two columns: `body` and `issue_title`. + threshold : float + distance threshold for recommendation of similar issues. + Returns: + -------- + None + Prints the original issue body and the model's prediction. + """ + # Extract body and title from DF + body_text = issue_df.body.tolist() + title_text = issue_df.issue_title.tolist() + url = issue_df.issue_url.tolist() + + demo_list = np.random.randint(low=1, high=len(body_text), size=n) + for i in demo_list: + self.print_example(i, + body_text=body_text[i], + title_text=title_text[i], + url=url[i], + threshold=threshold) + + def prepare_recommender(self, vectorized_array, original_df): + """ + Use the annoy library to build recommender + Parameters + ---------- + vectorized_array : List[List[int]] + This is the list of list of integers that represents your corpus + that is fed into the seq2seq model for training. + original_df : pandas.DataFrame + This is the original dataframe that has the columns + ['issue_url', 'issue_title', 'body'] + Returns + ------- + annoy.AnnoyIndex object (see https://github.com/spotify/annoy) + """ + self.rec_df = original_df + emb = self.encoder_model.predict(x=vectorized_array, + batch_size=vectorized_array.shape[0]//200) + + f = emb.shape[1] + self.nn = AnnoyIndex(f) + logging.warning('Adding embeddings') + for i in tqdm(range(len(emb))): + self.nn.add_item(i, emb[i]) + logging.warning('Building trees for similarity lookup.') + self.nn.build(50) + return self.nn + + def set_recsys_data(self, original_df): + self.rec_df = original_df + + def set_recsys_annoyobj(self, annoyobj): + self.nn = annoyobj + + def evaluate_model(self, holdout_bodies, holdout_titles): + """ + Method for calculating BLEU Score. + Parameters + ---------- + holdout_bodies : List[str] + These are the issue bodies that we want to summarize + holdout_titles : List[str] + This is the ground truth we are trying to predict --> issue titles + Returns + ------- + bleu : float + The BLEU Score + """ + actual, predicted = list(), list() + assert len(holdout_bodies) == len(holdout_titles) + num_examples = len(holdout_bodies) + + logging.warning('Generating predictions.') + # step over the whole set TODO: parallelize this + for i in tqdm_notebook(range(num_examples)): + _, yhat = self.generate_issue_title(holdout_bodies[i]) + + actual.append(self.pp_title.process_text([holdout_titles[i]])[0]) + predicted.append(self.pp_title.process_text([yhat])[0]) + # calculate BLEU score + logging.warning('Calculating BLEU.') + bleu = corpus_bleu(actual, predicted) + return bleu diff --git a/github_issue_summarization/workflow/workspace/src/train.py b/github_issue_summarization/workflow/workspace/src/train.py new file mode 100644 index 00000000..0969019a --- /dev/null +++ b/github_issue_summarization/workflow/workspace/src/train.py @@ -0,0 +1,93 @@ +import argparse +from keras.callbacks import CSVLogger, ModelCheckpoint +from keras.layers import Input, LSTM, GRU, Dense, Embedding, Bidirectional, BatchNormalization +from keras.models import Model +from keras import optimizers +import numpy as np +from seq2seq_utils import load_decoder_inputs, load_encoder_inputs, load_text_processor +from seq2seq_utils import viz_model_architecture + +# Parsing flags. +parser = argparse.ArgumentParser() +parser.add_argument("--input_body_preprocessor_dpkl") +parser.add_argument("--input_title_preprocessor_dpkl") +parser.add_argument("--input_train_title_vecs_npy") +parser.add_argument("--input_train_body_vecs_npy") +parser.add_argument("--output_model_h5") +parser.add_argument("--learning_rate", default="0.001") +args = parser.parse_args() +print(args) + +learning_rate=float(args.learning_rate) + +encoder_input_data, doc_length = load_encoder_inputs(args.input_train_body_vecs_npy) +decoder_input_data, decoder_target_data = load_decoder_inputs(args.input_train_title_vecs_npy) + +num_encoder_tokens, body_pp = load_text_processor(args.input_body_preprocessor_dpkl) +num_decoder_tokens, title_pp = load_text_processor(args.input_title_preprocessor_dpkl) + +# Arbitrarly set latent dimension for embedding and hidden units +latent_dim = 300 + +############### +# Encoder Model. +############### +encoder_inputs = Input(shape=(doc_length,), name='Encoder-Input') + +# Word embeding for encoder (ex: Issue Body) +x = Embedding(num_encoder_tokens, latent_dim, name='Body-Word-Embedding', mask_zero=False)(encoder_inputs) +x = BatchNormalization(name='Encoder-Batchnorm-1')(x) + +# We do not need the `encoder_output` just the hidden state. +_, state_h = GRU(latent_dim, return_state=True, name='Encoder-Last-GRU')(x) + +# Encapsulate the encoder as a separate entity so we can just +# encode without decoding if we want to. +encoder_model = Model(inputs=encoder_inputs, outputs=state_h, name='Encoder-Model') + +seq2seq_encoder_out = encoder_model(encoder_inputs) + +################ +# Decoder Model. +################ +decoder_inputs = Input(shape=(None,), name='Decoder-Input') # for teacher forcing + +# Word Embedding For Decoder (ex: Issue Titles) +dec_emb = Embedding(num_decoder_tokens, latent_dim, name='Decoder-Word-Embedding', mask_zero=False)(decoder_inputs) +dec_bn = BatchNormalization(name='Decoder-Batchnorm-1')(dec_emb) + +# Set up the decoder, using `decoder_state_input` as initial state. +decoder_gru = GRU(latent_dim, return_state=True, return_sequences=True, name='Decoder-GRU') +decoder_gru_output, _ = decoder_gru(dec_bn, initial_state=seq2seq_encoder_out) +x = BatchNormalization(name='Decoder-Batchnorm-2')(decoder_gru_output) + +# Dense layer for prediction +decoder_dense = Dense(num_decoder_tokens, activation='softmax', name='Final-Output-Dense') +decoder_outputs = decoder_dense(x) + +################ +# Seq2Seq Model. +################ + +seq2seq_Model = Model([encoder_inputs, decoder_inputs], decoder_outputs) + +seq2seq_Model.compile(optimizer=optimizers.Nadam(lr=learning_rate), loss='sparse_categorical_crossentropy') + +seq2seq_Model.summary() + +script_name_base = 'tutorial_seq2seq' +csv_logger = CSVLogger('{:}.log'.format(script_name_base)) +model_checkpoint = ModelCheckpoint('{:}.epoch{{epoch:02d}}-val{{val_loss:.5f}}.hdf5'.format(script_name_base), + save_best_only=True) + +batch_size = 1200 +epochs = 7 +history = seq2seq_Model.fit([encoder_input_data, decoder_input_data], np.expand_dims(decoder_target_data, -1), + batch_size=batch_size, + epochs=epochs, + validation_split=0.12, callbacks=[csv_logger, model_checkpoint]) + +############# +# Save model. +############# +seq2seq_Model.save(args.output_model_h5)