Fix T2T memory problem (#205)

* Update T2T problems to workaround memory limitations

* Add max_samples_for_vocab to prevent memory overflow

* Fix a base URL to download data from, sweet spot for max samples

* Convert class variables to class properties

* Fix lint errors

* Use Python2/3 compatible code for StringIO

* Fix lint errors

* Fix source data files format

* Move to Text2TextProblem instead of TranslateProblem

* Update details for num_shards and T2T problem dataset
This commit is contained in:
Sanyam Kapoor 2018-08-01 13:37:41 -07:00 committed by k8s-ci-robot
parent 767c90ff20
commit fd2e750990
4 changed files with 94 additions and 52 deletions

View File

@ -17,6 +17,9 @@ def create_function_embeddings(argv=None):
- See `transforms.github_dataset.GithubBatchPredict` for details of tables created
- Additionally, store CSV of docstring, original functions and other metadata for
reverse index lookup during search engine queries.
NOTE: The number of output file shards have been fixed (at 100) to avoid a large
number of output files, making it manageable.
"""
pipeline_opts = arguments.prepare_pipeline_opts(argv)
args = pipeline_opts._visible_options # pylint: disable=protected-access
@ -37,7 +40,8 @@ def create_function_embeddings(argv=None):
| "Format for CSV Write" >> beam.ParDo(dict_to_csv.DictToCSVString(
['nwo', 'path', 'function_name', 'lineno', 'original_function', 'function_embedding']))
| "Write Embeddings to CSV" >> beam.io.WriteToText('{}/func-index'.format(args.data_dir),
file_name_suffix='.csv')
file_name_suffix='.csv',
num_shards=100)
)
result = pipeline.run()

View File

@ -18,6 +18,9 @@ def preprocess_github_dataset(argv=None):
- See `transforms.github_dataset.TransformGithubDataset` for details of tables created
- Additionally, store pairs of docstring and function tokens in a CSV file
for training
NOTE: The number of output file shards have been fixed (at 100) to avoid a large
number of output files, making it manageable.
"""
pipeline_opts = arguments.prepare_pipeline_opts(argv)
args = pipeline_opts._visible_options # pylint: disable=protected-access
@ -40,7 +43,8 @@ def preprocess_github_dataset(argv=None):
| "Format for CSV Write" >> beam.ParDo(dict_to_csv.DictToCSVString(
['docstring_tokens', 'function_tokens']))
| "Write CSV" >> beam.io.WriteToText('{}/func-doc-pairs'.format(args.data_dir),
file_name_suffix='.csv')
file_name_suffix='.csv',
num_shards=100)
)
result = pipeline.run()

View File

@ -1,32 +1,59 @@
"""Github function/text similatrity problems."""
import csv
import os
from cStringIO import StringIO
from six import StringIO
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import translate
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import metrics
from tensor2tensor.utils import registry
##
# These URLs are only for fallback purposes in case the specified
# `data_dir` does not contain the data. However, note that the data
# files must have the same naming pattern.
# TODO: The memory is exploding, need to fix this.
#
_DATA_BASE_URL = 'gs://kubeflow-examples/t2t-code-search/data'
_GITHUB_FUNCTION_DOCSTRING_FILES = [
'pairs-0000{}-of-00010.csv'.format(i)
for i in range(1)
]
import tensorflow as tf
@registry.register_problem
class GithubFunctionDocstring(translate.TranslateProblem):
# pylint: disable=abstract-method
class GithubFunctionDocstring(text_problems.Text2TextProblem):
"""Function and Docstring similarity Problem.
"""This class defines the problem of finding similarity between Python
function and docstring"""
This problem contains the data consisting of function
and docstring pairs as CSV files. The files are structured
such that they contain two columns without headers containing
the docstring tokens and function tokens. The delimiter is
",".
"""
@property
def pair_files_list(self):
"""Return URL and file names.
This format is a convention across the Tensor2Tensor (T2T)
codebase. It should be noted that the file names are currently
hardcoded. This is to preserve the semantics of a T2T problem.
In case a change of these values is desired, one must subclass
and override this property.
# TODO(sanyamkapoor): Manually separate train/eval data set.
Returns:
A list of the format,
[
[
"STRING",
("STRING", "STRING", ...)
],
...
]
Each element is a list of size 2 where the first represents
the source URL and the next is an n-tuple of file names.
In this case, the tuple is of size 1 because the URL points
to a file itself.
"""
base_url = "gs://kubeflow-examples/t2t-code-search/raw_data"
return [
[
"{}/func-doc-pairs-000{:02}-of-00100.csv".format(base_url, i),
("func-doc-pairs-000{:02}-of-00100.csv".format(i),)
]
for i in range(100)
]
@property
def is_generate_per_split(self):
@ -36,30 +63,37 @@ class GithubFunctionDocstring(translate.TranslateProblem):
def approx_vocab_size(self):
return 2**13
def source_data_files(self, dataset_split): # pylint: disable=no-self-use,unused-argument
# TODO(sanyamkapoor): separate train/eval data set.
return _GITHUB_FUNCTION_DOCSTRING_FILES
@property
def max_samples_for_vocab(self):
# FIXME(sanyamkapoor): This exists to handle memory explosion.
return int(3.5e5)
def generate_samples(self, data_dir, tmp_dir, dataset_split): # pylint: disable=no-self-use,unused-argument
"""Returns a generator to return {"inputs": [text], "targets": [text]}.
def generate_samples(self, _data_dir, tmp_dir, _dataset_split):
"""A generator to return data samples.Returns the data generator to return.
If the `data_dir` is a GCS path, all data is downloaded to the
`tmp_dir`.
Args:
data_dir: A string representing the data directory.
tmp_dir: A string representing the temporary directory and is
used to download files if not already available.
dataset_split: Train, Test or Eval.
Yields:
Each element yielded is of a Python dict of the form
{"inputs": "STRING", "targets": "STRING"}
"""
download_dir = tmp_dir if data_dir.startswith('gs://') else data_dir
uri_base = data_dir if data_dir.startswith('gs://') else _DATA_BASE_URL
pair_csv_files = [
generator_utils.maybe_download(download_dir, filename, os.path.join(uri_base, filename))
for filename in self.source_data_files(dataset_split)
csv_files = [
generator_utils.maybe_download(tmp_dir, file_list[0], uri)
for uri, file_list in self.pair_files_list
]
for pairs_file in pair_csv_files:
with open(pairs_file, 'r') as csv_file:
for pairs_file in csv_files:
tf.logging.debug("Reading {}".format(pairs_file))
with open(pairs_file, "r") as csv_file:
for line in csv_file:
reader = csv.reader(StringIO(line), delimiter=',')
function_tokens, docstring_tokens = next(reader)[-2:] # pylint: disable=stop-iteration-return
yield {'inputs': docstring_tokens, 'targets': function_tokens}
reader = csv.reader(StringIO(line))
for docstring_tokens, function_tokens in reader:
yield {"inputs": docstring_tokens, "targets": function_tokens}
def eval_metrics(self): # pylint: disable=no-self-use
return [

View File

@ -9,19 +9,19 @@ import tensorflow as tf
@registry.register_model
class SimilarityTransformer(t2t_model.T2TModel):
# pylint: disable=abstract-method
"""Transformer Model for Similarity between two strings.
"""
This class defines the model to compute similarity scores between functions
and docstrings
This model defines the architecture using two transformer
networks, each of which embed a string and the loss is
calculated as a Binary Cross-Entropy loss. Normalized
Dot Product is used as the distance measure between two
string embeddings.
"""
def top(self, body_output, features): # pylint: disable=no-self-use,unused-argument
def top(self, body_output, _): # pylint: disable=no-self-use
return body_output
def body(self, features):
"""Body of the Similarity Transformer Network."""
with tf.variable_scope('string_embedding'):
string_embedding = self.encode(features, 'inputs')
@ -57,21 +57,21 @@ class SimilarityTransformer(t2t_model.T2TModel):
(encoder_input, encoder_self_attention_bias, _) = (
transformer.transformer_prepare_encoder(inputs, problem.SpaceID.EN_TOK,
self._hparams))
hparams))
encoder_input = tf.nn.dropout(encoder_input,
1.0 - hparams.layer_prepostprocess_dropout)
encoder_output = transformer.transformer_encoder(
encoder_input,
encoder_self_attention_bias,
self._hparams,
hparams,
nonpadding=transformer.features_to_nonpadding(features, input_key))
encoder_output = tf.expand_dims(encoder_output, 2)
encoder_output = tf.reduce_mean(tf.squeeze(encoder_output, axis=2), axis=1)
encoder_output = tf.reduce_mean(encoder_output, axis=1)
return encoder_output
def infer(self, features=None, **kwargs): # pylint: disable=no-self-use,unused-argument
def infer(self, features=None, **kwargs):
del kwargs
predictions, _ = self(features)
return predictions