Add similarity transformer body (#159)

* Add similarity transformer body

* Update pipeline to Write a single CSV file

* Fix lint errors

* Use CSV writer to handle formatting rows

* Use direct transformer encoding methods with variable scopes

* Complete end-to-end training with new model and problem

* Read from mutliple csv files
This commit is contained in:
Sanyam Kapoor 2018-07-03 11:14:19 -07:00 committed by k8s-ci-robot
parent 836ad70421
commit 5a9748bf8f
7 changed files with 137 additions and 55 deletions

View File

@ -63,10 +63,9 @@ Results are saved back into a BigQuery table.
* Execute the `Dataflow` job * Execute the `Dataflow` job
``` ```
$ python preprocess/scripts/process_github_archive.py -i files/select_github_archive.sql \ $ python preprocess/scripts/process_github_archive.py -p kubeflow-dev -j process-github-archive \
-o code_search:function_docstrings -p kubeflow-dev -j process-github-archive \ --storage-bucket gs://kubeflow-examples/t2t-code-search -o code_search:function_docstrings \
--storage-bucket gs://kubeflow-dev --machine-type n1-highcpu-32 --num-workers 16 \ --machine-type n1-highcpu-32 --num-workers 16 --max-num-workers 16
--max-num-workers 16
``` ```
## 2. Model Training ## 2. Model Training
@ -102,23 +101,13 @@ See [GCR Pushing and Pulling Images](https://cloud.google.com/container-registry
#### 2.2.1 Function Summarizer #### 2.2.1 Function Summarizer
This part generates a model to summarize functions into docstrings using the data generated in previous
step. It uses `tensor2tensor`.
* Generate `TFRecords` for training
```
$ export MOUNT_DATA_DIR=/path/to/data/folder
$ docker run --rm -it -v ${MOUNT_DATA_DIR}:/data ${BUILD_IMAGE_TAG} \
t2t-datagen --problem=github_function_summarizer --data_dir=/data
```
* Train transduction model using `Tranformer Networks` and a base hyper-parameters set * Train transduction model using `Tranformer Networks` and a base hyper-parameters set
``` ```
$ export MOUNT_DATA_DIR=/path/to/data/folder $ export MOUNT_DATA_DIR=/path/to/data/folder
$ export MOUNT_OUTPUT_DIR=/path/to/output/folder $ export MOUNT_OUTPUT_DIR=/path/to/output/folder
$ docker run --rm -it -v ${MOUNT_DATA_DIR}:/data -v ${MOUNT_OUTPUT_DIR}:/output ${BUILD_IMAGE_TAG} \ $ docker run --rm -it -v ${MOUNT_DATA_DIR}:/data -v ${MOUNT_OUTPUT_DIR}:/output ${BUILD_IMAGE_TAG} \
t2t-trainer --problem=github_function_summarizer --data_dir=/data --output_dir=/output \ --generate_data --problem=github_function_docstring --data_dir=/data --output_dir=/output \
--model=transformer --hparams_set=transformer_base --model=similarity_transformer --hparams_set=transformer_tiny
``` ```
### 2.2 Train on Kubeflow ### 2.2 Train on Kubeflow

View File

@ -1 +1,2 @@
import code_search.t2t.similarity_transformer from . import function_docstring
from . import similarity_transformer

View File

@ -0,0 +1,40 @@
"""Github function/text similatrity problems."""
import csv
import glob
import os
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import metrics
from tensor2tensor.utils import registry
@registry.register_problem
class GithubFunctionDocstring(text_problems.Text2TextProblem):
# pylint: disable=abstract-method
"""This class defines the problem of finding similarity between Python
function and docstring"""
@property
def is_generate_per_split(self):
return False
@property
def approx_vocab_size(self):
return 2**13
def generate_samples(self, data_dir, tmp_dir, dataset_split): # pylint: disable=no-self-use,unused-argument
"""Returns a generator to return {"inputs": [text], "targets": [text]}."""
# TODO(sanyamkapoor): separate train/eval data set.
pair_files_glob = os.path.join(data_dir, 'pairs-*.csv')
for pairs_file_path in glob.glob(pair_files_glob):
with open(pairs_file_path, 'r') as csv_file:
pairs_reader = csv.reader(csv_file)
for row in pairs_reader:
function_tokens, docstring_tokens = row[-2:]
yield {'inputs': docstring_tokens, 'targets': function_tokens}
def eval_metrics(self): # pylint: disable=no-self-use
return [
metrics.Metrics.ACC
]

View File

@ -1,7 +1,10 @@
import os """Using Transformer Networks for String similarities."""
from tensor2tensor.utils import t2t_model from tensor2tensor.data_generators import problem
from tensor2tensor.layers import common_layers
from tensor2tensor.models import transformer
from tensor2tensor.utils import registry from tensor2tensor.utils import registry
from tensor2tensor.data_generators import text_problems from tensor2tensor.utils import t2t_model
import tensorflow as tf
@registry.register_model @registry.register_model
@ -9,31 +12,56 @@ class SimilarityTransformer(t2t_model.T2TModel):
# pylint: disable=abstract-method # pylint: disable=abstract-method
""" """
This class defines the model to compute similarity scores between functions and This class defines the model to compute similarity scores between functions
docstrings and docstrings
""" """
def body(self, features): def body(self, features):
# TODO: need to fill this with Transformer encoder/decoder """Body of the Similarity Transformer Network."""
# and loss calculation
raise NotImplementedError
with tf.variable_scope('string_embedding'):
string_embedding = self.encode(features, 'inputs')
@registry.register_problem loss = None
class GithubFunctionDocstring(text_problems.Text2TextProblem): if 'targets' in features:
# pylint: disable=abstract-method with tf.variable_scope('code_embedding'):
code_embedding = self.encode(features, 'targets')
"""This class defines the problem of finding similarity between Python function cosine_dist = tf.losses.cosine_distance(
and docstring""" tf.nn.l2_normalize(string_embedding, axis=1),
tf.nn.l2_normalize(code_embedding, axis=1),
axis=1, reduction=tf.losses.Reduction.NONE)
@property # TODO(sanyamkapoor): need negative sampling, won't be all ones anymore.
def is_generate_per_split(self): labels = tf.one_hot(tf.ones(
return False tf.shape(features['targets'])[0], tf.int32), 2)
logits = tf.concat([cosine_dist, 1 - cosine_dist], axis=1)
def generate_samples(self, data_dir, _tmp_dir, dataset_split): #pylint: disable=no-self-use loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
"""This method returns the generator to return {"inputs": [text], "targets": [text]} dict""" logits=logits)
functions_file_path = os.path.join(data_dir, '{}.function'.format(dataset_split)) if loss is not None:
docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split)) return string_embedding, loss
return text_problems.text2text_txt_iterator(functions_file_path, docstrings_file_path) return string_embedding
def encode(self, features, input_key):
hparams = self._hparams
inputs = common_layers.flatten4d3d(features[input_key])
(encoder_input, encoder_self_attention_bias, _) = (
transformer.transformer_prepare_encoder(inputs, problem.SpaceID.EN_TOK,
self._hparams))
encoder_input = tf.nn.dropout(encoder_input,
1.0 - hparams.layer_prepostprocess_dropout)
encoder_output = transformer.transformer_encoder(
encoder_input,
encoder_self_attention_bias,
self._hparams,
nonpadding=transformer.features_to_nonpadding(features, input_key))
encoder_output = tf.expand_dims(encoder_output, 2)
encoder_output = tf.reduce_mean(tf.squeeze(encoder_output, axis=2), axis=1)
return encoder_output

View File

@ -1,8 +1,9 @@
import os import os
import logging import logging
import time import time
import csv
import io
import apache_beam as beam import apache_beam as beam
import apache_beam.io as io
from apache_beam import pvalue from apache_beam import pvalue
from apache_beam.metrics import Metrics from apache_beam.metrics import Metrics
from apache_beam.options.pipeline_options import StandardOptions, PipelineOptions, \ from apache_beam.options.pipeline_options import StandardOptions, PipelineOptions, \
@ -117,11 +118,11 @@ class ProcessGithubFiles(beam.PTransform):
'function_tokens', 'docstring_tokens'] 'function_tokens', 'docstring_tokens']
self.data_types = ['STRING', 'STRING', 'STRING', 'INTEGER', 'STRING', 'STRING', 'STRING'] self.data_types = ['STRING', 'STRING', 'STRING', 'INTEGER', 'STRING', 'STRING', 'STRING']
self.num_shards = 1 self.num_shards = 10
def expand(self, input_or_inputs): def expand(self, input_or_inputs):
tokenize_result = (input_or_inputs tokenize_result = (input_or_inputs
| "Read Github Dataset" >> io.Read(io.BigQuerySource(query=self.query_string, | "Read Github Dataset" >> beam.io.Read(beam.io.BigQuerySource(query=self.query_string,
use_standard_sql=True)) use_standard_sql=True))
| "Split 'repo_path'" >> beam.ParDo(SplitRepoPath()) | "Split 'repo_path'" >> beam.ParDo(SplitRepoPath())
| "Tokenize Code/Docstring Pairs" >> beam.ParDo(TokenizeCodeDocstring()) | "Tokenize Code/Docstring Pairs" >> beam.ParDo(TokenizeCodeDocstring())
@ -130,7 +131,7 @@ class ProcessGithubFiles(beam.PTransform):
#pylint: disable=expression-not-assigned #pylint: disable=expression-not-assigned
(tokenize_result.err_rows (tokenize_result.err_rows
| "Failed Row Tokenization" >> io.WriteToBigQuery(project=self.project, | "Failed Row Tokenization" >> beam.io.WriteToBigQuery(project=self.project,
dataset=self.output_dataset, dataset=self.output_dataset,
table=self.output_table + '_failed', table=self.output_table + '_failed',
schema=self.create_failed_output_schema()) schema=self.create_failed_output_schema())
@ -145,7 +146,7 @@ class ProcessGithubFiles(beam.PTransform):
#pylint: disable=expression-not-assigned #pylint: disable=expression-not-assigned
(info_result.err_rows (info_result.err_rows
| "Failed Function Info" >> io.WriteToBigQuery(project=self.project, | "Failed Function Info" >> beam.io.WriteToBigQuery(project=self.project,
dataset=self.output_dataset, dataset=self.output_dataset,
table=self.output_table + '_failed', table=self.output_table + '_failed',
schema=self.create_failed_output_schema()) schema=self.create_failed_output_schema())
@ -156,24 +157,43 @@ class ProcessGithubFiles(beam.PTransform):
# pylint: disable=expression-not-assigned # pylint: disable=expression-not-assigned
(processed_rows (processed_rows
| "Filter Function tokens" >> beam.Map(lambda x: x['function_tokens']) | "Filter Tiny Docstrings" >> beam.Filter(
| "Write Function tokens" >> io.WriteToText('{}/raw_data/data'.format(self.storage_bucket), lambda row: len(row['docstring_tokens'].split(' ')) > 5)
file_name_suffix='.function', | "Format For Write" >> beam.Map(self.format_for_write)
num_shards=self.num_shards)) | "Write To File" >> beam.io.WriteToText('{}/data/pairs'.format(self.storage_bucket),
(processed_rows file_name_suffix='.csv',
| "Filter Docstring tokens" >> beam.Map(lambda x: x['docstring_tokens'])
| "Write Docstring tokens" >> io.WriteToText('{}/raw_data/data'.format(self.storage_bucket),
file_name_suffix='.docstring',
num_shards=self.num_shards)) num_shards=self.num_shards))
# pylint: enable=expression-not-assigned # pylint: enable=expression-not-assigned
return (processed_rows return (processed_rows
| "Save Tokens" >> io.WriteToBigQuery(project=self.project, | "Save Tokens" >> beam.io.WriteToBigQuery(project=self.project,
dataset=self.output_dataset, dataset=self.output_dataset,
table=self.output_table, table=self.output_table,
schema=self.create_output_schema()) schema=self.create_output_schema())
) )
def format_for_write(self, row):
"""This method filters keys that we don't need in the
final CSV. It must ensure that there are no multi-line
column fields. For instance, 'original_function' is a
multi-line string and makes CSV parsing hard for any
derived Dataflow steps. This uses the CSV Writer
to handle all edge cases like quote escaping."""
filter_keys = [
'original_function',
'lineno',
]
target_keys = [col for col in self.data_columns if col not in filter_keys]
target_values = [row[key].encode('utf-8') for key in target_keys]
with io.BytesIO() as fs:
cw = csv.writer(fs)
cw.writerow(target_values)
result_str = fs.getvalue().strip('\r\n')
return result_str
def create_output_schema(self): def create_output_schema(self):
table_schema = bigquery.TableSchema() table_schema = bigquery.TableSchema()

View File

@ -1,4 +1,4 @@
astor~=0.6.0 astor~=0.6.0
apache-beam[gcp]~=2.4.0 apache-beam[gcp]~=2.5.0
nltk~=3.3.0 nltk~=3.3.0
spacy~=2.0.0 spacy~=2.0.0

View File

@ -1,5 +1,6 @@
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os
import apache_beam as beam import apache_beam as beam
from preprocess.pipeline import create_pipeline_opts, ProcessGithubFiles from preprocess.pipeline import create_pipeline_opts, ProcessGithubFiles
@ -7,7 +8,10 @@ from preprocess.pipeline import create_pipeline_opts, ProcessGithubFiles
def parse_arguments(args): def parse_arguments(args):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-i', '--input', metavar='', type=str, help='Path to BigQuery SQL script')
default_script_file = os.path.abspath('{}/../../files/select_github_archive.sql'.format(__file__))
parser.add_argument('-i', '--input', metavar='', type=str, default=default_script_file,
help='Path to BigQuery SQL script')
parser.add_argument('-o', '--output', metavar='', type=str, parser.add_argument('-o', '--output', metavar='', type=str,
help='Output string of the format <dataset>:<table>') help='Output string of the format <dataset>:<table>')
parser.add_argument('-p', '--project', metavar='', type=str, default='Project', help='Project ID') parser.add_argument('-p', '--project', metavar='', type=str, default='Project', help='Project ID')