mirror of https://github.com/kubeflow/examples.git
Fix T2T memory problem (#205)
* Update T2T problems to workaround memory limitations * Add max_samples_for_vocab to prevent memory overflow * Fix a base URL to download data from, sweet spot for max samples * Convert class variables to class properties * Fix lint errors * Use Python2/3 compatible code for StringIO * Fix lint errors * Fix source data files format * Move to Text2TextProblem instead of TranslateProblem * Update details for num_shards and T2T problem dataset
This commit is contained in:
parent
767c90ff20
commit
fd2e750990
|
|
@ -17,6 +17,9 @@ def create_function_embeddings(argv=None):
|
|||
- See `transforms.github_dataset.GithubBatchPredict` for details of tables created
|
||||
- Additionally, store CSV of docstring, original functions and other metadata for
|
||||
reverse index lookup during search engine queries.
|
||||
|
||||
NOTE: The number of output file shards have been fixed (at 100) to avoid a large
|
||||
number of output files, making it manageable.
|
||||
"""
|
||||
pipeline_opts = arguments.prepare_pipeline_opts(argv)
|
||||
args = pipeline_opts._visible_options # pylint: disable=protected-access
|
||||
|
|
@ -37,7 +40,8 @@ def create_function_embeddings(argv=None):
|
|||
| "Format for CSV Write" >> beam.ParDo(dict_to_csv.DictToCSVString(
|
||||
['nwo', 'path', 'function_name', 'lineno', 'original_function', 'function_embedding']))
|
||||
| "Write Embeddings to CSV" >> beam.io.WriteToText('{}/func-index'.format(args.data_dir),
|
||||
file_name_suffix='.csv')
|
||||
file_name_suffix='.csv',
|
||||
num_shards=100)
|
||||
)
|
||||
|
||||
result = pipeline.run()
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ def preprocess_github_dataset(argv=None):
|
|||
- See `transforms.github_dataset.TransformGithubDataset` for details of tables created
|
||||
- Additionally, store pairs of docstring and function tokens in a CSV file
|
||||
for training
|
||||
|
||||
NOTE: The number of output file shards have been fixed (at 100) to avoid a large
|
||||
number of output files, making it manageable.
|
||||
"""
|
||||
pipeline_opts = arguments.prepare_pipeline_opts(argv)
|
||||
args = pipeline_opts._visible_options # pylint: disable=protected-access
|
||||
|
|
@ -40,7 +43,8 @@ def preprocess_github_dataset(argv=None):
|
|||
| "Format for CSV Write" >> beam.ParDo(dict_to_csv.DictToCSVString(
|
||||
['docstring_tokens', 'function_tokens']))
|
||||
| "Write CSV" >> beam.io.WriteToText('{}/func-doc-pairs'.format(args.data_dir),
|
||||
file_name_suffix='.csv')
|
||||
file_name_suffix='.csv',
|
||||
num_shards=100)
|
||||
)
|
||||
|
||||
result = pipeline.run()
|
||||
|
|
|
|||
|
|
@ -1,32 +1,59 @@
|
|||
"""Github function/text similatrity problems."""
|
||||
import csv
|
||||
import os
|
||||
from cStringIO import StringIO
|
||||
from six import StringIO
|
||||
from tensor2tensor.data_generators import generator_utils
|
||||
from tensor2tensor.data_generators import translate
|
||||
from tensor2tensor.data_generators import text_problems
|
||||
from tensor2tensor.utils import metrics
|
||||
from tensor2tensor.utils import registry
|
||||
|
||||
|
||||
##
|
||||
# These URLs are only for fallback purposes in case the specified
|
||||
# `data_dir` does not contain the data. However, note that the data
|
||||
# files must have the same naming pattern.
|
||||
# TODO: The memory is exploding, need to fix this.
|
||||
#
|
||||
_DATA_BASE_URL = 'gs://kubeflow-examples/t2t-code-search/data'
|
||||
_GITHUB_FUNCTION_DOCSTRING_FILES = [
|
||||
'pairs-0000{}-of-00010.csv'.format(i)
|
||||
for i in range(1)
|
||||
]
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@registry.register_problem
|
||||
class GithubFunctionDocstring(translate.TranslateProblem):
|
||||
# pylint: disable=abstract-method
|
||||
class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
||||
"""Function and Docstring similarity Problem.
|
||||
|
||||
"""This class defines the problem of finding similarity between Python
|
||||
function and docstring"""
|
||||
This problem contains the data consisting of function
|
||||
and docstring pairs as CSV files. The files are structured
|
||||
such that they contain two columns without headers containing
|
||||
the docstring tokens and function tokens. The delimiter is
|
||||
",".
|
||||
"""
|
||||
@property
|
||||
def pair_files_list(self):
|
||||
"""Return URL and file names.
|
||||
|
||||
This format is a convention across the Tensor2Tensor (T2T)
|
||||
codebase. It should be noted that the file names are currently
|
||||
hardcoded. This is to preserve the semantics of a T2T problem.
|
||||
In case a change of these values is desired, one must subclass
|
||||
and override this property.
|
||||
|
||||
# TODO(sanyamkapoor): Manually separate train/eval data set.
|
||||
|
||||
Returns:
|
||||
A list of the format,
|
||||
[
|
||||
[
|
||||
"STRING",
|
||||
("STRING", "STRING", ...)
|
||||
],
|
||||
...
|
||||
]
|
||||
Each element is a list of size 2 where the first represents
|
||||
the source URL and the next is an n-tuple of file names.
|
||||
|
||||
In this case, the tuple is of size 1 because the URL points
|
||||
to a file itself.
|
||||
"""
|
||||
base_url = "gs://kubeflow-examples/t2t-code-search/raw_data"
|
||||
|
||||
return [
|
||||
[
|
||||
"{}/func-doc-pairs-000{:02}-of-00100.csv".format(base_url, i),
|
||||
("func-doc-pairs-000{:02}-of-00100.csv".format(i),)
|
||||
]
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
@property
|
||||
def is_generate_per_split(self):
|
||||
|
|
@ -36,30 +63,37 @@ class GithubFunctionDocstring(translate.TranslateProblem):
|
|||
def approx_vocab_size(self):
|
||||
return 2**13
|
||||
|
||||
def source_data_files(self, dataset_split): # pylint: disable=no-self-use,unused-argument
|
||||
# TODO(sanyamkapoor): separate train/eval data set.
|
||||
return _GITHUB_FUNCTION_DOCSTRING_FILES
|
||||
@property
|
||||
def max_samples_for_vocab(self):
|
||||
# FIXME(sanyamkapoor): This exists to handle memory explosion.
|
||||
return int(3.5e5)
|
||||
|
||||
def generate_samples(self, data_dir, tmp_dir, dataset_split): # pylint: disable=no-self-use,unused-argument
|
||||
"""Returns a generator to return {"inputs": [text], "targets": [text]}.
|
||||
def generate_samples(self, _data_dir, tmp_dir, _dataset_split):
|
||||
"""A generator to return data samples.Returns the data generator to return.
|
||||
|
||||
If the `data_dir` is a GCS path, all data is downloaded to the
|
||||
`tmp_dir`.
|
||||
|
||||
Args:
|
||||
data_dir: A string representing the data directory.
|
||||
tmp_dir: A string representing the temporary directory and is
|
||||
used to download files if not already available.
|
||||
dataset_split: Train, Test or Eval.
|
||||
|
||||
Yields:
|
||||
Each element yielded is of a Python dict of the form
|
||||
{"inputs": "STRING", "targets": "STRING"}
|
||||
"""
|
||||
|
||||
download_dir = tmp_dir if data_dir.startswith('gs://') else data_dir
|
||||
uri_base = data_dir if data_dir.startswith('gs://') else _DATA_BASE_URL
|
||||
pair_csv_files = [
|
||||
generator_utils.maybe_download(download_dir, filename, os.path.join(uri_base, filename))
|
||||
for filename in self.source_data_files(dataset_split)
|
||||
csv_files = [
|
||||
generator_utils.maybe_download(tmp_dir, file_list[0], uri)
|
||||
for uri, file_list in self.pair_files_list
|
||||
]
|
||||
|
||||
for pairs_file in pair_csv_files:
|
||||
with open(pairs_file, 'r') as csv_file:
|
||||
for pairs_file in csv_files:
|
||||
tf.logging.debug("Reading {}".format(pairs_file))
|
||||
with open(pairs_file, "r") as csv_file:
|
||||
for line in csv_file:
|
||||
reader = csv.reader(StringIO(line), delimiter=',')
|
||||
function_tokens, docstring_tokens = next(reader)[-2:] # pylint: disable=stop-iteration-return
|
||||
yield {'inputs': docstring_tokens, 'targets': function_tokens}
|
||||
reader = csv.reader(StringIO(line))
|
||||
for docstring_tokens, function_tokens in reader:
|
||||
yield {"inputs": docstring_tokens, "targets": function_tokens}
|
||||
|
||||
def eval_metrics(self): # pylint: disable=no-self-use
|
||||
return [
|
||||
|
|
|
|||
|
|
@ -9,19 +9,19 @@ import tensorflow as tf
|
|||
|
||||
@registry.register_model
|
||||
class SimilarityTransformer(t2t_model.T2TModel):
|
||||
# pylint: disable=abstract-method
|
||||
"""Transformer Model for Similarity between two strings.
|
||||
|
||||
"""
|
||||
This class defines the model to compute similarity scores between functions
|
||||
and docstrings
|
||||
This model defines the architecture using two transformer
|
||||
networks, each of which embed a string and the loss is
|
||||
calculated as a Binary Cross-Entropy loss. Normalized
|
||||
Dot Product is used as the distance measure between two
|
||||
string embeddings.
|
||||
"""
|
||||
|
||||
def top(self, body_output, features): # pylint: disable=no-self-use,unused-argument
|
||||
def top(self, body_output, _): # pylint: disable=no-self-use
|
||||
return body_output
|
||||
|
||||
def body(self, features):
|
||||
"""Body of the Similarity Transformer Network."""
|
||||
|
||||
with tf.variable_scope('string_embedding'):
|
||||
string_embedding = self.encode(features, 'inputs')
|
||||
|
||||
|
|
@ -57,21 +57,21 @@ class SimilarityTransformer(t2t_model.T2TModel):
|
|||
|
||||
(encoder_input, encoder_self_attention_bias, _) = (
|
||||
transformer.transformer_prepare_encoder(inputs, problem.SpaceID.EN_TOK,
|
||||
self._hparams))
|
||||
hparams))
|
||||
|
||||
encoder_input = tf.nn.dropout(encoder_input,
|
||||
1.0 - hparams.layer_prepostprocess_dropout)
|
||||
encoder_output = transformer.transformer_encoder(
|
||||
encoder_input,
|
||||
encoder_self_attention_bias,
|
||||
self._hparams,
|
||||
hparams,
|
||||
nonpadding=transformer.features_to_nonpadding(features, input_key))
|
||||
encoder_output = tf.expand_dims(encoder_output, 2)
|
||||
|
||||
encoder_output = tf.reduce_mean(tf.squeeze(encoder_output, axis=2), axis=1)
|
||||
encoder_output = tf.reduce_mean(encoder_output, axis=1)
|
||||
|
||||
return encoder_output
|
||||
|
||||
def infer(self, features=None, **kwargs): # pylint: disable=no-self-use,unused-argument
|
||||
def infer(self, features=None, **kwargs):
|
||||
del kwargs
|
||||
predictions, _ = self(features)
|
||||
return predictions
|
||||
|
|
|
|||
Loading…
Reference in New Issue