mirror of https://github.com/kubeflow/examples.git
				
				
				
			Fix T2T memory problem (#205)
* Update T2T problems to workaround memory limitations * Add max_samples_for_vocab to prevent memory overflow * Fix a base URL to download data from, sweet spot for max samples * Convert class variables to class properties * Fix lint errors * Use Python2/3 compatible code for StringIO * Fix lint errors * Fix source data files format * Move to Text2TextProblem instead of TranslateProblem * Update details for num_shards and T2T problem dataset
This commit is contained in:
		
							parent
							
								
									767c90ff20
								
							
						
					
					
						commit
						fd2e750990
					
				| 
						 | 
				
			
			@ -17,6 +17,9 @@ def create_function_embeddings(argv=None):
 | 
			
		|||
    - See `transforms.github_dataset.GithubBatchPredict` for details of tables created
 | 
			
		||||
    - Additionally, store CSV of docstring, original functions and other metadata for
 | 
			
		||||
      reverse index lookup during search engine queries.
 | 
			
		||||
 | 
			
		||||
  NOTE: The number of output file shards have been fixed (at 100) to avoid a large
 | 
			
		||||
  number of output files, making it manageable.
 | 
			
		||||
  """
 | 
			
		||||
  pipeline_opts = arguments.prepare_pipeline_opts(argv)
 | 
			
		||||
  args = pipeline_opts._visible_options  # pylint: disable=protected-access
 | 
			
		||||
| 
						 | 
				
			
			@ -37,7 +40,8 @@ def create_function_embeddings(argv=None):
 | 
			
		|||
    | "Format for CSV Write" >> beam.ParDo(dict_to_csv.DictToCSVString(
 | 
			
		||||
        ['nwo', 'path', 'function_name', 'lineno', 'original_function', 'function_embedding']))
 | 
			
		||||
    | "Write Embeddings to CSV" >> beam.io.WriteToText('{}/func-index'.format(args.data_dir),
 | 
			
		||||
                                                       file_name_suffix='.csv')
 | 
			
		||||
                                                       file_name_suffix='.csv',
 | 
			
		||||
                                                       num_shards=100)
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
  result = pipeline.run()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,6 +18,9 @@ def preprocess_github_dataset(argv=None):
 | 
			
		|||
    - See `transforms.github_dataset.TransformGithubDataset` for details of tables created
 | 
			
		||||
    - Additionally, store pairs of docstring and function tokens in a CSV file
 | 
			
		||||
      for training
 | 
			
		||||
 | 
			
		||||
  NOTE: The number of output file shards have been fixed (at 100) to avoid a large
 | 
			
		||||
  number of output files, making it manageable.
 | 
			
		||||
  """
 | 
			
		||||
  pipeline_opts = arguments.prepare_pipeline_opts(argv)
 | 
			
		||||
  args = pipeline_opts._visible_options  # pylint: disable=protected-access
 | 
			
		||||
| 
						 | 
				
			
			@ -40,7 +43,8 @@ def preprocess_github_dataset(argv=None):
 | 
			
		|||
    | "Format for CSV Write" >> beam.ParDo(dict_to_csv.DictToCSVString(
 | 
			
		||||
      ['docstring_tokens', 'function_tokens']))
 | 
			
		||||
    | "Write CSV" >> beam.io.WriteToText('{}/func-doc-pairs'.format(args.data_dir),
 | 
			
		||||
                                         file_name_suffix='.csv')
 | 
			
		||||
                                         file_name_suffix='.csv',
 | 
			
		||||
                                         num_shards=100)
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
  result = pipeline.run()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,32 +1,59 @@
 | 
			
		|||
"""Github function/text similatrity problems."""
 | 
			
		||||
import csv
 | 
			
		||||
import os
 | 
			
		||||
from cStringIO import StringIO
 | 
			
		||||
from six import StringIO
 | 
			
		||||
from tensor2tensor.data_generators import generator_utils
 | 
			
		||||
from tensor2tensor.data_generators import translate
 | 
			
		||||
from tensor2tensor.data_generators import text_problems
 | 
			
		||||
from tensor2tensor.utils import metrics
 | 
			
		||||
from tensor2tensor.utils import registry
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
##
 | 
			
		||||
# These URLs are only for fallback purposes in case the specified
 | 
			
		||||
# `data_dir` does not contain the data. However, note that the data
 | 
			
		||||
# files must have the same naming pattern.
 | 
			
		||||
# TODO: The memory is exploding, need to fix this.
 | 
			
		||||
#
 | 
			
		||||
_DATA_BASE_URL = 'gs://kubeflow-examples/t2t-code-search/data'
 | 
			
		||||
_GITHUB_FUNCTION_DOCSTRING_FILES = [
 | 
			
		||||
    'pairs-0000{}-of-00010.csv'.format(i)
 | 
			
		||||
    for i in range(1)
 | 
			
		||||
]
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register_problem
 | 
			
		||||
class GithubFunctionDocstring(translate.TranslateProblem):
 | 
			
		||||
  # pylint: disable=abstract-method
 | 
			
		||||
class GithubFunctionDocstring(text_problems.Text2TextProblem):
 | 
			
		||||
  """Function and Docstring similarity Problem.
 | 
			
		||||
 | 
			
		||||
  """This class defines the problem of finding similarity between Python
 | 
			
		||||
  function and docstring"""
 | 
			
		||||
  This problem contains the data consisting of function
 | 
			
		||||
  and docstring pairs as CSV files. The files are structured
 | 
			
		||||
  such that they contain two columns without headers containing
 | 
			
		||||
  the docstring tokens and function tokens. The delimiter is
 | 
			
		||||
  ",".
 | 
			
		||||
  """
 | 
			
		||||
  @property
 | 
			
		||||
  def pair_files_list(self):
 | 
			
		||||
    """Return URL and file names.
 | 
			
		||||
 | 
			
		||||
    This format is a convention across the Tensor2Tensor (T2T)
 | 
			
		||||
    codebase. It should be noted that the file names are currently
 | 
			
		||||
    hardcoded. This is to preserve the semantics of a T2T problem.
 | 
			
		||||
    In case a change of these values is desired, one must subclass
 | 
			
		||||
    and override this property.
 | 
			
		||||
 | 
			
		||||
    # TODO(sanyamkapoor): Manually separate train/eval data set.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      A list of the format,
 | 
			
		||||
        [
 | 
			
		||||
          [
 | 
			
		||||
            "STRING",
 | 
			
		||||
            ("STRING", "STRING", ...)
 | 
			
		||||
          ],
 | 
			
		||||
          ...
 | 
			
		||||
        ]
 | 
			
		||||
      Each element is a list of size 2 where the first represents
 | 
			
		||||
      the source URL and the next is an n-tuple of file names.
 | 
			
		||||
 | 
			
		||||
      In this case, the tuple is of size 1 because the URL points
 | 
			
		||||
      to a file itself.
 | 
			
		||||
    """
 | 
			
		||||
    base_url = "gs://kubeflow-examples/t2t-code-search/raw_data"
 | 
			
		||||
 | 
			
		||||
    return [
 | 
			
		||||
        [
 | 
			
		||||
            "{}/func-doc-pairs-000{:02}-of-00100.csv".format(base_url, i),
 | 
			
		||||
            ("func-doc-pairs-000{:02}-of-00100.csv".format(i),)
 | 
			
		||||
        ]
 | 
			
		||||
        for i in range(100)
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def is_generate_per_split(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -36,30 +63,37 @@ class GithubFunctionDocstring(translate.TranslateProblem):
 | 
			
		|||
  def approx_vocab_size(self):
 | 
			
		||||
    return 2**13
 | 
			
		||||
 | 
			
		||||
  def source_data_files(self, dataset_split):  # pylint: disable=no-self-use,unused-argument
 | 
			
		||||
    # TODO(sanyamkapoor): separate train/eval data set.
 | 
			
		||||
    return _GITHUB_FUNCTION_DOCSTRING_FILES
 | 
			
		||||
  @property
 | 
			
		||||
  def max_samples_for_vocab(self):
 | 
			
		||||
    # FIXME(sanyamkapoor): This exists to handle memory explosion.
 | 
			
		||||
    return int(3.5e5)
 | 
			
		||||
 | 
			
		||||
  def generate_samples(self, data_dir, tmp_dir, dataset_split):  # pylint: disable=no-self-use,unused-argument
 | 
			
		||||
    """Returns a generator to return {"inputs": [text], "targets": [text]}.
 | 
			
		||||
  def generate_samples(self, _data_dir, tmp_dir, _dataset_split):
 | 
			
		||||
    """A generator to return data samples.Returns the data generator to return.
 | 
			
		||||
 | 
			
		||||
    If the `data_dir` is a GCS path, all data is downloaded to the
 | 
			
		||||
    `tmp_dir`.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      data_dir: A string representing the data directory.
 | 
			
		||||
      tmp_dir: A string representing the temporary directory and is
 | 
			
		||||
              used to download files if not already available.
 | 
			
		||||
      dataset_split: Train, Test or Eval.
 | 
			
		||||
 | 
			
		||||
    Yields:
 | 
			
		||||
      Each element yielded is of a Python dict of the form
 | 
			
		||||
        {"inputs": "STRING", "targets": "STRING"}
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    download_dir = tmp_dir if data_dir.startswith('gs://') else data_dir
 | 
			
		||||
    uri_base = data_dir if data_dir.startswith('gs://') else _DATA_BASE_URL
 | 
			
		||||
    pair_csv_files = [
 | 
			
		||||
        generator_utils.maybe_download(download_dir, filename, os.path.join(uri_base, filename))
 | 
			
		||||
        for filename in self.source_data_files(dataset_split)
 | 
			
		||||
    csv_files = [
 | 
			
		||||
        generator_utils.maybe_download(tmp_dir, file_list[0], uri)
 | 
			
		||||
        for uri, file_list in self.pair_files_list
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    for pairs_file in pair_csv_files:
 | 
			
		||||
      with open(pairs_file, 'r') as csv_file:
 | 
			
		||||
    for pairs_file in csv_files:
 | 
			
		||||
      tf.logging.debug("Reading {}".format(pairs_file))
 | 
			
		||||
      with open(pairs_file, "r") as csv_file:
 | 
			
		||||
        for line in csv_file:
 | 
			
		||||
          reader = csv.reader(StringIO(line), delimiter=',')
 | 
			
		||||
          function_tokens, docstring_tokens = next(reader)[-2:]  # pylint: disable=stop-iteration-return
 | 
			
		||||
          yield {'inputs': docstring_tokens, 'targets': function_tokens}
 | 
			
		||||
          reader = csv.reader(StringIO(line))
 | 
			
		||||
          for docstring_tokens, function_tokens in reader:
 | 
			
		||||
            yield {"inputs": docstring_tokens, "targets": function_tokens}
 | 
			
		||||
 | 
			
		||||
  def eval_metrics(self):  # pylint: disable=no-self-use
 | 
			
		||||
    return [
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -9,19 +9,19 @@ import tensorflow as tf
 | 
			
		|||
 | 
			
		||||
@registry.register_model
 | 
			
		||||
class SimilarityTransformer(t2t_model.T2TModel):
 | 
			
		||||
  # pylint: disable=abstract-method
 | 
			
		||||
  """Transformer Model for Similarity between two strings.
 | 
			
		||||
 | 
			
		||||
  """
 | 
			
		||||
  This class defines the model to compute similarity scores between functions
 | 
			
		||||
  and docstrings
 | 
			
		||||
  This model defines the architecture using two transformer
 | 
			
		||||
  networks, each of which embed a string and the loss is
 | 
			
		||||
  calculated as a Binary Cross-Entropy loss. Normalized
 | 
			
		||||
  Dot Product is used as the distance measure between two
 | 
			
		||||
  string embeddings.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  def top(self, body_output, features):  # pylint: disable=no-self-use,unused-argument
 | 
			
		||||
  def top(self, body_output, _):  # pylint: disable=no-self-use
 | 
			
		||||
    return body_output
 | 
			
		||||
 | 
			
		||||
  def body(self, features):
 | 
			
		||||
    """Body of the Similarity Transformer Network."""
 | 
			
		||||
 | 
			
		||||
    with tf.variable_scope('string_embedding'):
 | 
			
		||||
      string_embedding = self.encode(features, 'inputs')
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -57,21 +57,21 @@ class SimilarityTransformer(t2t_model.T2TModel):
 | 
			
		|||
 | 
			
		||||
    (encoder_input, encoder_self_attention_bias, _) = (
 | 
			
		||||
        transformer.transformer_prepare_encoder(inputs, problem.SpaceID.EN_TOK,
 | 
			
		||||
                                                self._hparams))
 | 
			
		||||
                                                hparams))
 | 
			
		||||
 | 
			
		||||
    encoder_input = tf.nn.dropout(encoder_input,
 | 
			
		||||
                                  1.0 - hparams.layer_prepostprocess_dropout)
 | 
			
		||||
    encoder_output = transformer.transformer_encoder(
 | 
			
		||||
        encoder_input,
 | 
			
		||||
        encoder_self_attention_bias,
 | 
			
		||||
        self._hparams,
 | 
			
		||||
        hparams,
 | 
			
		||||
        nonpadding=transformer.features_to_nonpadding(features, input_key))
 | 
			
		||||
    encoder_output = tf.expand_dims(encoder_output, 2)
 | 
			
		||||
 | 
			
		||||
    encoder_output = tf.reduce_mean(tf.squeeze(encoder_output, axis=2), axis=1)
 | 
			
		||||
    encoder_output = tf.reduce_mean(encoder_output, axis=1)
 | 
			
		||||
 | 
			
		||||
    return encoder_output
 | 
			
		||||
 | 
			
		||||
  def infer(self, features=None, **kwargs):  # pylint: disable=no-self-use,unused-argument
 | 
			
		||||
  def infer(self, features=None, **kwargs):
 | 
			
		||||
    del kwargs
 | 
			
		||||
    predictions, _ = self(features)
 | 
			
		||||
    return predictions
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue