mirror of https://github.com/kubeflow/examples.git
Add similarity transformer body (#159)
* Add similarity transformer body * Update pipeline to Write a single CSV file * Fix lint errors * Use CSV writer to handle formatting rows * Use direct transformer encoding methods with variable scopes * Complete end-to-end training with new model and problem * Read from mutliple csv files
This commit is contained in:
parent
836ad70421
commit
5a9748bf8f
|
|
@ -63,10 +63,9 @@ Results are saved back into a BigQuery table.
|
|||
|
||||
* Execute the `Dataflow` job
|
||||
```
|
||||
$ python preprocess/scripts/process_github_archive.py -i files/select_github_archive.sql \
|
||||
-o code_search:function_docstrings -p kubeflow-dev -j process-github-archive \
|
||||
--storage-bucket gs://kubeflow-dev --machine-type n1-highcpu-32 --num-workers 16 \
|
||||
--max-num-workers 16
|
||||
$ python preprocess/scripts/process_github_archive.py -p kubeflow-dev -j process-github-archive \
|
||||
--storage-bucket gs://kubeflow-examples/t2t-code-search -o code_search:function_docstrings \
|
||||
--machine-type n1-highcpu-32 --num-workers 16 --max-num-workers 16
|
||||
```
|
||||
|
||||
## 2. Model Training
|
||||
|
|
@ -102,23 +101,13 @@ See [GCR Pushing and Pulling Images](https://cloud.google.com/container-registry
|
|||
|
||||
#### 2.2.1 Function Summarizer
|
||||
|
||||
This part generates a model to summarize functions into docstrings using the data generated in previous
|
||||
step. It uses `tensor2tensor`.
|
||||
|
||||
* Generate `TFRecords` for training
|
||||
```
|
||||
$ export MOUNT_DATA_DIR=/path/to/data/folder
|
||||
$ docker run --rm -it -v ${MOUNT_DATA_DIR}:/data ${BUILD_IMAGE_TAG} \
|
||||
t2t-datagen --problem=github_function_summarizer --data_dir=/data
|
||||
```
|
||||
|
||||
* Train transduction model using `Tranformer Networks` and a base hyper-parameters set
|
||||
```
|
||||
$ export MOUNT_DATA_DIR=/path/to/data/folder
|
||||
$ export MOUNT_OUTPUT_DIR=/path/to/output/folder
|
||||
$ docker run --rm -it -v ${MOUNT_DATA_DIR}:/data -v ${MOUNT_OUTPUT_DIR}:/output ${BUILD_IMAGE_TAG} \
|
||||
t2t-trainer --problem=github_function_summarizer --data_dir=/data --output_dir=/output \
|
||||
--model=transformer --hparams_set=transformer_base
|
||||
--generate_data --problem=github_function_docstring --data_dir=/data --output_dir=/output \
|
||||
--model=similarity_transformer --hparams_set=transformer_tiny
|
||||
```
|
||||
|
||||
### 2.2 Train on Kubeflow
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
import code_search.t2t.similarity_transformer
|
||||
from . import function_docstring
|
||||
from . import similarity_transformer
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
"""Github function/text similatrity problems."""
|
||||
import csv
|
||||
import glob
|
||||
import os
|
||||
from tensor2tensor.data_generators import text_problems
|
||||
from tensor2tensor.utils import metrics
|
||||
from tensor2tensor.utils import registry
|
||||
|
||||
|
||||
@registry.register_problem
|
||||
class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
||||
# pylint: disable=abstract-method
|
||||
|
||||
"""This class defines the problem of finding similarity between Python
|
||||
function and docstring"""
|
||||
|
||||
@property
|
||||
def is_generate_per_split(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def approx_vocab_size(self):
|
||||
return 2**13
|
||||
|
||||
def generate_samples(self, data_dir, tmp_dir, dataset_split): # pylint: disable=no-self-use,unused-argument
|
||||
"""Returns a generator to return {"inputs": [text], "targets": [text]}."""
|
||||
|
||||
# TODO(sanyamkapoor): separate train/eval data set.
|
||||
pair_files_glob = os.path.join(data_dir, 'pairs-*.csv')
|
||||
for pairs_file_path in glob.glob(pair_files_glob):
|
||||
with open(pairs_file_path, 'r') as csv_file:
|
||||
pairs_reader = csv.reader(csv_file)
|
||||
for row in pairs_reader:
|
||||
function_tokens, docstring_tokens = row[-2:]
|
||||
yield {'inputs': docstring_tokens, 'targets': function_tokens}
|
||||
|
||||
def eval_metrics(self): # pylint: disable=no-self-use
|
||||
return [
|
||||
metrics.Metrics.ACC
|
||||
]
|
||||
|
|
@ -1,7 +1,10 @@
|
|||
import os
|
||||
from tensor2tensor.utils import t2t_model
|
||||
"""Using Transformer Networks for String similarities."""
|
||||
from tensor2tensor.data_generators import problem
|
||||
from tensor2tensor.layers import common_layers
|
||||
from tensor2tensor.models import transformer
|
||||
from tensor2tensor.utils import registry
|
||||
from tensor2tensor.data_generators import text_problems
|
||||
from tensor2tensor.utils import t2t_model
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@registry.register_model
|
||||
|
|
@ -9,31 +12,56 @@ class SimilarityTransformer(t2t_model.T2TModel):
|
|||
# pylint: disable=abstract-method
|
||||
|
||||
"""
|
||||
This class defines the model to compute similarity scores between functions and
|
||||
docstrings
|
||||
This class defines the model to compute similarity scores between functions
|
||||
and docstrings
|
||||
"""
|
||||
|
||||
def body(self, features):
|
||||
# TODO: need to fill this with Transformer encoder/decoder
|
||||
# and loss calculation
|
||||
raise NotImplementedError
|
||||
"""Body of the Similarity Transformer Network."""
|
||||
|
||||
with tf.variable_scope('string_embedding'):
|
||||
string_embedding = self.encode(features, 'inputs')
|
||||
|
||||
@registry.register_problem
|
||||
class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
||||
# pylint: disable=abstract-method
|
||||
loss = None
|
||||
if 'targets' in features:
|
||||
with tf.variable_scope('code_embedding'):
|
||||
code_embedding = self.encode(features, 'targets')
|
||||
|
||||
"""This class defines the problem of finding similarity between Python function
|
||||
and docstring"""
|
||||
cosine_dist = tf.losses.cosine_distance(
|
||||
tf.nn.l2_normalize(string_embedding, axis=1),
|
||||
tf.nn.l2_normalize(code_embedding, axis=1),
|
||||
axis=1, reduction=tf.losses.Reduction.NONE)
|
||||
|
||||
@property
|
||||
def is_generate_per_split(self):
|
||||
return False
|
||||
# TODO(sanyamkapoor): need negative sampling, won't be all ones anymore.
|
||||
labels = tf.one_hot(tf.ones(
|
||||
tf.shape(features['targets'])[0], tf.int32), 2)
|
||||
logits = tf.concat([cosine_dist, 1 - cosine_dist], axis=1)
|
||||
|
||||
def generate_samples(self, data_dir, _tmp_dir, dataset_split): #pylint: disable=no-self-use
|
||||
"""This method returns the generator to return {"inputs": [text], "targets": [text]} dict"""
|
||||
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
|
||||
logits=logits)
|
||||
|
||||
functions_file_path = os.path.join(data_dir, '{}.function'.format(dataset_split))
|
||||
docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split))
|
||||
if loss is not None:
|
||||
return string_embedding, loss
|
||||
|
||||
return text_problems.text2text_txt_iterator(functions_file_path, docstrings_file_path)
|
||||
return string_embedding
|
||||
|
||||
def encode(self, features, input_key):
|
||||
hparams = self._hparams
|
||||
inputs = common_layers.flatten4d3d(features[input_key])
|
||||
|
||||
(encoder_input, encoder_self_attention_bias, _) = (
|
||||
transformer.transformer_prepare_encoder(inputs, problem.SpaceID.EN_TOK,
|
||||
self._hparams))
|
||||
|
||||
encoder_input = tf.nn.dropout(encoder_input,
|
||||
1.0 - hparams.layer_prepostprocess_dropout)
|
||||
encoder_output = transformer.transformer_encoder(
|
||||
encoder_input,
|
||||
encoder_self_attention_bias,
|
||||
self._hparams,
|
||||
nonpadding=transformer.features_to_nonpadding(features, input_key))
|
||||
encoder_output = tf.expand_dims(encoder_output, 2)
|
||||
|
||||
encoder_output = tf.reduce_mean(tf.squeeze(encoder_output, axis=2), axis=1)
|
||||
|
||||
return encoder_output
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import os
|
||||
import logging
|
||||
import time
|
||||
import csv
|
||||
import io
|
||||
import apache_beam as beam
|
||||
import apache_beam.io as io
|
||||
from apache_beam import pvalue
|
||||
from apache_beam.metrics import Metrics
|
||||
from apache_beam.options.pipeline_options import StandardOptions, PipelineOptions, \
|
||||
|
|
@ -117,11 +118,11 @@ class ProcessGithubFiles(beam.PTransform):
|
|||
'function_tokens', 'docstring_tokens']
|
||||
self.data_types = ['STRING', 'STRING', 'STRING', 'INTEGER', 'STRING', 'STRING', 'STRING']
|
||||
|
||||
self.num_shards = 1
|
||||
self.num_shards = 10
|
||||
|
||||
def expand(self, input_or_inputs):
|
||||
tokenize_result = (input_or_inputs
|
||||
| "Read Github Dataset" >> io.Read(io.BigQuerySource(query=self.query_string,
|
||||
| "Read Github Dataset" >> beam.io.Read(beam.io.BigQuerySource(query=self.query_string,
|
||||
use_standard_sql=True))
|
||||
| "Split 'repo_path'" >> beam.ParDo(SplitRepoPath())
|
||||
| "Tokenize Code/Docstring Pairs" >> beam.ParDo(TokenizeCodeDocstring())
|
||||
|
|
@ -130,7 +131,7 @@ class ProcessGithubFiles(beam.PTransform):
|
|||
|
||||
#pylint: disable=expression-not-assigned
|
||||
(tokenize_result.err_rows
|
||||
| "Failed Row Tokenization" >> io.WriteToBigQuery(project=self.project,
|
||||
| "Failed Row Tokenization" >> beam.io.WriteToBigQuery(project=self.project,
|
||||
dataset=self.output_dataset,
|
||||
table=self.output_table + '_failed',
|
||||
schema=self.create_failed_output_schema())
|
||||
|
|
@ -145,7 +146,7 @@ class ProcessGithubFiles(beam.PTransform):
|
|||
|
||||
#pylint: disable=expression-not-assigned
|
||||
(info_result.err_rows
|
||||
| "Failed Function Info" >> io.WriteToBigQuery(project=self.project,
|
||||
| "Failed Function Info" >> beam.io.WriteToBigQuery(project=self.project,
|
||||
dataset=self.output_dataset,
|
||||
table=self.output_table + '_failed',
|
||||
schema=self.create_failed_output_schema())
|
||||
|
|
@ -156,24 +157,43 @@ class ProcessGithubFiles(beam.PTransform):
|
|||
|
||||
# pylint: disable=expression-not-assigned
|
||||
(processed_rows
|
||||
| "Filter Function tokens" >> beam.Map(lambda x: x['function_tokens'])
|
||||
| "Write Function tokens" >> io.WriteToText('{}/raw_data/data'.format(self.storage_bucket),
|
||||
file_name_suffix='.function',
|
||||
num_shards=self.num_shards))
|
||||
(processed_rows
|
||||
| "Filter Docstring tokens" >> beam.Map(lambda x: x['docstring_tokens'])
|
||||
| "Write Docstring tokens" >> io.WriteToText('{}/raw_data/data'.format(self.storage_bucket),
|
||||
file_name_suffix='.docstring',
|
||||
num_shards=self.num_shards))
|
||||
| "Filter Tiny Docstrings" >> beam.Filter(
|
||||
lambda row: len(row['docstring_tokens'].split(' ')) > 5)
|
||||
| "Format For Write" >> beam.Map(self.format_for_write)
|
||||
| "Write To File" >> beam.io.WriteToText('{}/data/pairs'.format(self.storage_bucket),
|
||||
file_name_suffix='.csv',
|
||||
num_shards=self.num_shards))
|
||||
# pylint: enable=expression-not-assigned
|
||||
|
||||
return (processed_rows
|
||||
| "Save Tokens" >> io.WriteToBigQuery(project=self.project,
|
||||
| "Save Tokens" >> beam.io.WriteToBigQuery(project=self.project,
|
||||
dataset=self.output_dataset,
|
||||
table=self.output_table,
|
||||
schema=self.create_output_schema())
|
||||
)
|
||||
|
||||
def format_for_write(self, row):
|
||||
"""This method filters keys that we don't need in the
|
||||
final CSV. It must ensure that there are no multi-line
|
||||
column fields. For instance, 'original_function' is a
|
||||
multi-line string and makes CSV parsing hard for any
|
||||
derived Dataflow steps. This uses the CSV Writer
|
||||
to handle all edge cases like quote escaping."""
|
||||
|
||||
filter_keys = [
|
||||
'original_function',
|
||||
'lineno',
|
||||
]
|
||||
target_keys = [col for col in self.data_columns if col not in filter_keys]
|
||||
target_values = [row[key].encode('utf-8') for key in target_keys]
|
||||
|
||||
with io.BytesIO() as fs:
|
||||
cw = csv.writer(fs)
|
||||
cw.writerow(target_values)
|
||||
result_str = fs.getvalue().strip('\r\n')
|
||||
|
||||
return result_str
|
||||
|
||||
def create_output_schema(self):
|
||||
table_schema = bigquery.TableSchema()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
astor~=0.6.0
|
||||
apache-beam[gcp]~=2.4.0
|
||||
apache-beam[gcp]~=2.5.0
|
||||
nltk~=3.3.0
|
||||
spacy~=2.0.0
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import print_function
|
||||
import argparse
|
||||
import os
|
||||
import apache_beam as beam
|
||||
|
||||
from preprocess.pipeline import create_pipeline_opts, ProcessGithubFiles
|
||||
|
|
@ -7,7 +8,10 @@ from preprocess.pipeline import create_pipeline_opts, ProcessGithubFiles
|
|||
|
||||
def parse_arguments(args):
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('-i', '--input', metavar='', type=str, help='Path to BigQuery SQL script')
|
||||
|
||||
default_script_file = os.path.abspath('{}/../../files/select_github_archive.sql'.format(__file__))
|
||||
parser.add_argument('-i', '--input', metavar='', type=str, default=default_script_file,
|
||||
help='Path to BigQuery SQL script')
|
||||
parser.add_argument('-o', '--output', metavar='', type=str,
|
||||
help='Output string of the format <dataset>:<table>')
|
||||
parser.add_argument('-p', '--project', metavar='', type=str, default='Project', help='Project ID')
|
||||
|
|
|
|||
Loading…
Reference in New Issue