mirror of https://github.com/kubeflow/examples.git
Add similarity transformer body (#159)
* Add similarity transformer body * Update pipeline to Write a single CSV file * Fix lint errors * Use CSV writer to handle formatting rows * Use direct transformer encoding methods with variable scopes * Complete end-to-end training with new model and problem * Read from mutliple csv files
This commit is contained in:
parent
836ad70421
commit
5a9748bf8f
|
|
@ -63,10 +63,9 @@ Results are saved back into a BigQuery table.
|
||||||
|
|
||||||
* Execute the `Dataflow` job
|
* Execute the `Dataflow` job
|
||||||
```
|
```
|
||||||
$ python preprocess/scripts/process_github_archive.py -i files/select_github_archive.sql \
|
$ python preprocess/scripts/process_github_archive.py -p kubeflow-dev -j process-github-archive \
|
||||||
-o code_search:function_docstrings -p kubeflow-dev -j process-github-archive \
|
--storage-bucket gs://kubeflow-examples/t2t-code-search -o code_search:function_docstrings \
|
||||||
--storage-bucket gs://kubeflow-dev --machine-type n1-highcpu-32 --num-workers 16 \
|
--machine-type n1-highcpu-32 --num-workers 16 --max-num-workers 16
|
||||||
--max-num-workers 16
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## 2. Model Training
|
## 2. Model Training
|
||||||
|
|
@ -102,23 +101,13 @@ See [GCR Pushing and Pulling Images](https://cloud.google.com/container-registry
|
||||||
|
|
||||||
#### 2.2.1 Function Summarizer
|
#### 2.2.1 Function Summarizer
|
||||||
|
|
||||||
This part generates a model to summarize functions into docstrings using the data generated in previous
|
|
||||||
step. It uses `tensor2tensor`.
|
|
||||||
|
|
||||||
* Generate `TFRecords` for training
|
|
||||||
```
|
|
||||||
$ export MOUNT_DATA_DIR=/path/to/data/folder
|
|
||||||
$ docker run --rm -it -v ${MOUNT_DATA_DIR}:/data ${BUILD_IMAGE_TAG} \
|
|
||||||
t2t-datagen --problem=github_function_summarizer --data_dir=/data
|
|
||||||
```
|
|
||||||
|
|
||||||
* Train transduction model using `Tranformer Networks` and a base hyper-parameters set
|
* Train transduction model using `Tranformer Networks` and a base hyper-parameters set
|
||||||
```
|
```
|
||||||
$ export MOUNT_DATA_DIR=/path/to/data/folder
|
$ export MOUNT_DATA_DIR=/path/to/data/folder
|
||||||
$ export MOUNT_OUTPUT_DIR=/path/to/output/folder
|
$ export MOUNT_OUTPUT_DIR=/path/to/output/folder
|
||||||
$ docker run --rm -it -v ${MOUNT_DATA_DIR}:/data -v ${MOUNT_OUTPUT_DIR}:/output ${BUILD_IMAGE_TAG} \
|
$ docker run --rm -it -v ${MOUNT_DATA_DIR}:/data -v ${MOUNT_OUTPUT_DIR}:/output ${BUILD_IMAGE_TAG} \
|
||||||
t2t-trainer --problem=github_function_summarizer --data_dir=/data --output_dir=/output \
|
--generate_data --problem=github_function_docstring --data_dir=/data --output_dir=/output \
|
||||||
--model=transformer --hparams_set=transformer_base
|
--model=similarity_transformer --hparams_set=transformer_tiny
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2.2 Train on Kubeflow
|
### 2.2 Train on Kubeflow
|
||||||
|
|
|
||||||
|
|
@ -1 +1,2 @@
|
||||||
import code_search.t2t.similarity_transformer
|
from . import function_docstring
|
||||||
|
from . import similarity_transformer
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
"""Github function/text similatrity problems."""
|
||||||
|
import csv
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
from tensor2tensor.data_generators import text_problems
|
||||||
|
from tensor2tensor.utils import metrics
|
||||||
|
from tensor2tensor.utils import registry
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_problem
|
||||||
|
class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
||||||
|
# pylint: disable=abstract-method
|
||||||
|
|
||||||
|
"""This class defines the problem of finding similarity between Python
|
||||||
|
function and docstring"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_generate_per_split(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def approx_vocab_size(self):
|
||||||
|
return 2**13
|
||||||
|
|
||||||
|
def generate_samples(self, data_dir, tmp_dir, dataset_split): # pylint: disable=no-self-use,unused-argument
|
||||||
|
"""Returns a generator to return {"inputs": [text], "targets": [text]}."""
|
||||||
|
|
||||||
|
# TODO(sanyamkapoor): separate train/eval data set.
|
||||||
|
pair_files_glob = os.path.join(data_dir, 'pairs-*.csv')
|
||||||
|
for pairs_file_path in glob.glob(pair_files_glob):
|
||||||
|
with open(pairs_file_path, 'r') as csv_file:
|
||||||
|
pairs_reader = csv.reader(csv_file)
|
||||||
|
for row in pairs_reader:
|
||||||
|
function_tokens, docstring_tokens = row[-2:]
|
||||||
|
yield {'inputs': docstring_tokens, 'targets': function_tokens}
|
||||||
|
|
||||||
|
def eval_metrics(self): # pylint: disable=no-self-use
|
||||||
|
return [
|
||||||
|
metrics.Metrics.ACC
|
||||||
|
]
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
import os
|
"""Using Transformer Networks for String similarities."""
|
||||||
from tensor2tensor.utils import t2t_model
|
from tensor2tensor.data_generators import problem
|
||||||
|
from tensor2tensor.layers import common_layers
|
||||||
|
from tensor2tensor.models import transformer
|
||||||
from tensor2tensor.utils import registry
|
from tensor2tensor.utils import registry
|
||||||
from tensor2tensor.data_generators import text_problems
|
from tensor2tensor.utils import t2t_model
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
@registry.register_model
|
@registry.register_model
|
||||||
|
|
@ -9,31 +12,56 @@ class SimilarityTransformer(t2t_model.T2TModel):
|
||||||
# pylint: disable=abstract-method
|
# pylint: disable=abstract-method
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This class defines the model to compute similarity scores between functions and
|
This class defines the model to compute similarity scores between functions
|
||||||
docstrings
|
and docstrings
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def body(self, features):
|
def body(self, features):
|
||||||
# TODO: need to fill this with Transformer encoder/decoder
|
"""Body of the Similarity Transformer Network."""
|
||||||
# and loss calculation
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
with tf.variable_scope('string_embedding'):
|
||||||
|
string_embedding = self.encode(features, 'inputs')
|
||||||
|
|
||||||
@registry.register_problem
|
loss = None
|
||||||
class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
if 'targets' in features:
|
||||||
# pylint: disable=abstract-method
|
with tf.variable_scope('code_embedding'):
|
||||||
|
code_embedding = self.encode(features, 'targets')
|
||||||
|
|
||||||
"""This class defines the problem of finding similarity between Python function
|
cosine_dist = tf.losses.cosine_distance(
|
||||||
and docstring"""
|
tf.nn.l2_normalize(string_embedding, axis=1),
|
||||||
|
tf.nn.l2_normalize(code_embedding, axis=1),
|
||||||
|
axis=1, reduction=tf.losses.Reduction.NONE)
|
||||||
|
|
||||||
@property
|
# TODO(sanyamkapoor): need negative sampling, won't be all ones anymore.
|
||||||
def is_generate_per_split(self):
|
labels = tf.one_hot(tf.ones(
|
||||||
return False
|
tf.shape(features['targets'])[0], tf.int32), 2)
|
||||||
|
logits = tf.concat([cosine_dist, 1 - cosine_dist], axis=1)
|
||||||
|
|
||||||
def generate_samples(self, data_dir, _tmp_dir, dataset_split): #pylint: disable=no-self-use
|
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
|
||||||
"""This method returns the generator to return {"inputs": [text], "targets": [text]} dict"""
|
logits=logits)
|
||||||
|
|
||||||
functions_file_path = os.path.join(data_dir, '{}.function'.format(dataset_split))
|
if loss is not None:
|
||||||
docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split))
|
return string_embedding, loss
|
||||||
|
|
||||||
return text_problems.text2text_txt_iterator(functions_file_path, docstrings_file_path)
|
return string_embedding
|
||||||
|
|
||||||
|
def encode(self, features, input_key):
|
||||||
|
hparams = self._hparams
|
||||||
|
inputs = common_layers.flatten4d3d(features[input_key])
|
||||||
|
|
||||||
|
(encoder_input, encoder_self_attention_bias, _) = (
|
||||||
|
transformer.transformer_prepare_encoder(inputs, problem.SpaceID.EN_TOK,
|
||||||
|
self._hparams))
|
||||||
|
|
||||||
|
encoder_input = tf.nn.dropout(encoder_input,
|
||||||
|
1.0 - hparams.layer_prepostprocess_dropout)
|
||||||
|
encoder_output = transformer.transformer_encoder(
|
||||||
|
encoder_input,
|
||||||
|
encoder_self_attention_bias,
|
||||||
|
self._hparams,
|
||||||
|
nonpadding=transformer.features_to_nonpadding(features, input_key))
|
||||||
|
encoder_output = tf.expand_dims(encoder_output, 2)
|
||||||
|
|
||||||
|
encoder_output = tf.reduce_mean(tf.squeeze(encoder_output, axis=2), axis=1)
|
||||||
|
|
||||||
|
return encoder_output
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import csv
|
||||||
|
import io
|
||||||
import apache_beam as beam
|
import apache_beam as beam
|
||||||
import apache_beam.io as io
|
|
||||||
from apache_beam import pvalue
|
from apache_beam import pvalue
|
||||||
from apache_beam.metrics import Metrics
|
from apache_beam.metrics import Metrics
|
||||||
from apache_beam.options.pipeline_options import StandardOptions, PipelineOptions, \
|
from apache_beam.options.pipeline_options import StandardOptions, PipelineOptions, \
|
||||||
|
|
@ -117,11 +118,11 @@ class ProcessGithubFiles(beam.PTransform):
|
||||||
'function_tokens', 'docstring_tokens']
|
'function_tokens', 'docstring_tokens']
|
||||||
self.data_types = ['STRING', 'STRING', 'STRING', 'INTEGER', 'STRING', 'STRING', 'STRING']
|
self.data_types = ['STRING', 'STRING', 'STRING', 'INTEGER', 'STRING', 'STRING', 'STRING']
|
||||||
|
|
||||||
self.num_shards = 1
|
self.num_shards = 10
|
||||||
|
|
||||||
def expand(self, input_or_inputs):
|
def expand(self, input_or_inputs):
|
||||||
tokenize_result = (input_or_inputs
|
tokenize_result = (input_or_inputs
|
||||||
| "Read Github Dataset" >> io.Read(io.BigQuerySource(query=self.query_string,
|
| "Read Github Dataset" >> beam.io.Read(beam.io.BigQuerySource(query=self.query_string,
|
||||||
use_standard_sql=True))
|
use_standard_sql=True))
|
||||||
| "Split 'repo_path'" >> beam.ParDo(SplitRepoPath())
|
| "Split 'repo_path'" >> beam.ParDo(SplitRepoPath())
|
||||||
| "Tokenize Code/Docstring Pairs" >> beam.ParDo(TokenizeCodeDocstring())
|
| "Tokenize Code/Docstring Pairs" >> beam.ParDo(TokenizeCodeDocstring())
|
||||||
|
|
@ -130,7 +131,7 @@ class ProcessGithubFiles(beam.PTransform):
|
||||||
|
|
||||||
#pylint: disable=expression-not-assigned
|
#pylint: disable=expression-not-assigned
|
||||||
(tokenize_result.err_rows
|
(tokenize_result.err_rows
|
||||||
| "Failed Row Tokenization" >> io.WriteToBigQuery(project=self.project,
|
| "Failed Row Tokenization" >> beam.io.WriteToBigQuery(project=self.project,
|
||||||
dataset=self.output_dataset,
|
dataset=self.output_dataset,
|
||||||
table=self.output_table + '_failed',
|
table=self.output_table + '_failed',
|
||||||
schema=self.create_failed_output_schema())
|
schema=self.create_failed_output_schema())
|
||||||
|
|
@ -145,7 +146,7 @@ class ProcessGithubFiles(beam.PTransform):
|
||||||
|
|
||||||
#pylint: disable=expression-not-assigned
|
#pylint: disable=expression-not-assigned
|
||||||
(info_result.err_rows
|
(info_result.err_rows
|
||||||
| "Failed Function Info" >> io.WriteToBigQuery(project=self.project,
|
| "Failed Function Info" >> beam.io.WriteToBigQuery(project=self.project,
|
||||||
dataset=self.output_dataset,
|
dataset=self.output_dataset,
|
||||||
table=self.output_table + '_failed',
|
table=self.output_table + '_failed',
|
||||||
schema=self.create_failed_output_schema())
|
schema=self.create_failed_output_schema())
|
||||||
|
|
@ -156,24 +157,43 @@ class ProcessGithubFiles(beam.PTransform):
|
||||||
|
|
||||||
# pylint: disable=expression-not-assigned
|
# pylint: disable=expression-not-assigned
|
||||||
(processed_rows
|
(processed_rows
|
||||||
| "Filter Function tokens" >> beam.Map(lambda x: x['function_tokens'])
|
| "Filter Tiny Docstrings" >> beam.Filter(
|
||||||
| "Write Function tokens" >> io.WriteToText('{}/raw_data/data'.format(self.storage_bucket),
|
lambda row: len(row['docstring_tokens'].split(' ')) > 5)
|
||||||
file_name_suffix='.function',
|
| "Format For Write" >> beam.Map(self.format_for_write)
|
||||||
num_shards=self.num_shards))
|
| "Write To File" >> beam.io.WriteToText('{}/data/pairs'.format(self.storage_bucket),
|
||||||
(processed_rows
|
file_name_suffix='.csv',
|
||||||
| "Filter Docstring tokens" >> beam.Map(lambda x: x['docstring_tokens'])
|
num_shards=self.num_shards))
|
||||||
| "Write Docstring tokens" >> io.WriteToText('{}/raw_data/data'.format(self.storage_bucket),
|
|
||||||
file_name_suffix='.docstring',
|
|
||||||
num_shards=self.num_shards))
|
|
||||||
# pylint: enable=expression-not-assigned
|
# pylint: enable=expression-not-assigned
|
||||||
|
|
||||||
return (processed_rows
|
return (processed_rows
|
||||||
| "Save Tokens" >> io.WriteToBigQuery(project=self.project,
|
| "Save Tokens" >> beam.io.WriteToBigQuery(project=self.project,
|
||||||
dataset=self.output_dataset,
|
dataset=self.output_dataset,
|
||||||
table=self.output_table,
|
table=self.output_table,
|
||||||
schema=self.create_output_schema())
|
schema=self.create_output_schema())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def format_for_write(self, row):
|
||||||
|
"""This method filters keys that we don't need in the
|
||||||
|
final CSV. It must ensure that there are no multi-line
|
||||||
|
column fields. For instance, 'original_function' is a
|
||||||
|
multi-line string and makes CSV parsing hard for any
|
||||||
|
derived Dataflow steps. This uses the CSV Writer
|
||||||
|
to handle all edge cases like quote escaping."""
|
||||||
|
|
||||||
|
filter_keys = [
|
||||||
|
'original_function',
|
||||||
|
'lineno',
|
||||||
|
]
|
||||||
|
target_keys = [col for col in self.data_columns if col not in filter_keys]
|
||||||
|
target_values = [row[key].encode('utf-8') for key in target_keys]
|
||||||
|
|
||||||
|
with io.BytesIO() as fs:
|
||||||
|
cw = csv.writer(fs)
|
||||||
|
cw.writerow(target_values)
|
||||||
|
result_str = fs.getvalue().strip('\r\n')
|
||||||
|
|
||||||
|
return result_str
|
||||||
|
|
||||||
def create_output_schema(self):
|
def create_output_schema(self):
|
||||||
table_schema = bigquery.TableSchema()
|
table_schema = bigquery.TableSchema()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
astor~=0.6.0
|
astor~=0.6.0
|
||||||
apache-beam[gcp]~=2.4.0
|
apache-beam[gcp]~=2.5.0
|
||||||
nltk~=3.3.0
|
nltk~=3.3.0
|
||||||
spacy~=2.0.0
|
spacy~=2.0.0
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
import apache_beam as beam
|
import apache_beam as beam
|
||||||
|
|
||||||
from preprocess.pipeline import create_pipeline_opts, ProcessGithubFiles
|
from preprocess.pipeline import create_pipeline_opts, ProcessGithubFiles
|
||||||
|
|
@ -7,7 +8,10 @@ from preprocess.pipeline import create_pipeline_opts, ProcessGithubFiles
|
||||||
|
|
||||||
def parse_arguments(args):
|
def parse_arguments(args):
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument('-i', '--input', metavar='', type=str, help='Path to BigQuery SQL script')
|
|
||||||
|
default_script_file = os.path.abspath('{}/../../files/select_github_archive.sql'.format(__file__))
|
||||||
|
parser.add_argument('-i', '--input', metavar='', type=str, default=default_script_file,
|
||||||
|
help='Path to BigQuery SQL script')
|
||||||
parser.add_argument('-o', '--output', metavar='', type=str,
|
parser.add_argument('-o', '--output', metavar='', type=str,
|
||||||
help='Output string of the format <dataset>:<table>')
|
help='Output string of the format <dataset>:<table>')
|
||||||
parser.add_argument('-p', '--project', metavar='', type=str, default='Project', help='Project ID')
|
parser.add_argument('-p', '--project', metavar='', type=str, default='Project', help='Project ID')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue