From 18829159b07e2c417474dd734e319f2b70e6d8a2 Mon Sep 17 00:00:00 2001 From: Sanyam Kapoor Date: Tue, 14 Aug 2018 15:41:47 -0700 Subject: [PATCH] Add a new github function docstring extended problem (#225) * Add a new github function docstring extended problem * Fix lint errors * Update images --- .../kubeflow/components/params.libsonnet | 6 ++--- code_search/src/code_search/t2t/__init__.py | 1 + .../src/code_search/t2t/function_docstring.py | 22 ++++++++++------- .../t2t/function_docstring_extended.py | 24 +++++++++++++++++++ 4 files changed, 41 insertions(+), 12 deletions(-) create mode 100644 code_search/src/code_search/t2t/function_docstring_extended.py diff --git a/code_search/kubeflow/components/params.libsonnet b/code_search/kubeflow/components/params.libsonnet index 3e46de4b..41e39b34 100644 --- a/code_search/kubeflow/components/params.libsonnet +++ b/code_search/kubeflow/components/params.libsonnet @@ -9,8 +9,8 @@ numPsGpu: 0, train_steps: 100, eval_steps: 10, - image: 'gcr.io/kubeflow-dev/code-search:v20180802-c622aac', - imageGpu: 'gcr.io/kubeflow-dev/code-search:v20180802-c622aac-gpu', + image: 'gcr.io/kubeflow-dev/code-search:v20180814-66d27b9', + imageGpu: 'gcr.io/kubeflow-dev/code-search:v20180814-66d27b9-gpu', imagePullSecrets: [], dataDir: 'null', outputDir: 'null', @@ -19,7 +19,7 @@ }, "t2t-code-search": { workingDir: 'gs://example/prefix', - problem: 'github_function_docstring', + problem: 'github_function_docstring_extended', model: 'similarity_transformer', hparams_set: 'transformer_tiny', }, diff --git a/code_search/src/code_search/t2t/__init__.py b/code_search/src/code_search/t2t/__init__.py index a633cb8c..7a9f81ca 100644 --- a/code_search/src/code_search/t2t/__init__.py +++ b/code_search/src/code_search/t2t/__init__.py @@ -5,3 +5,4 @@ # from . import function_docstring from . import similarity_transformer +from . import function_docstring_extended diff --git a/code_search/src/code_search/t2t/function_docstring.py b/code_search/src/code_search/t2t/function_docstring.py index d3df4072..a9c3a175 100644 --- a/code_search/src/code_search/t2t/function_docstring.py +++ b/code_search/src/code_search/t2t/function_docstring.py @@ -18,6 +18,9 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem): the docstring tokens and function tokens. The delimiter is ",". """ + + DATA_PATH_PREFIX = "gs://kubeflow-examples/t2t-code-search/raw_data" + @property def pair_files_list(self): """Return URL and file names. @@ -45,11 +48,9 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem): In this case, the tuple is of size 1 because the URL points to a file itself. """ - base_url = "gs://kubeflow-examples/t2t-code-search/raw_data" - return [ [ - "{}/func-doc-pairs-000{:02}-of-00100.csv".format(base_url, i), + "{}/func-doc-pairs-000{:02}-of-00100.csv".format(self.DATA_PATH_PREFIX, i), ("func-doc-pairs-000{:02}-of-00100.csv".format(i),) ] for i in range(100) @@ -68,7 +69,13 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem): # FIXME(sanyamkapoor): This exists to handle memory explosion. return int(3.5e5) - def generate_samples(self, _data_dir, tmp_dir, _dataset_split): + def get_csv_files(self, _data_dir, tmp_dir, _dataset_split): + return [ + generator_utils.maybe_download(tmp_dir, file_list[0], uri) + for uri, file_list in self.pair_files_list + ] + + def generate_samples(self, data_dir, tmp_dir, dataset_split): """A generator to return data samples.Returns the data generator to return. @@ -82,14 +89,11 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem): Each element yielded is of a Python dict of the form {"inputs": "STRING", "targets": "STRING"} """ - csv_files = [ - generator_utils.maybe_download(tmp_dir, file_list[0], uri) - for uri, file_list in self.pair_files_list - ] + csv_files = self.get_csv_files(data_dir, tmp_dir, dataset_split) for pairs_file in csv_files: tf.logging.debug("Reading {}".format(pairs_file)) - with open(pairs_file, "r") as csv_file: + with tf.gfile.Open(pairs_file) as csv_file: for line in csv_file: reader = csv.reader(StringIO(line)) for docstring_tokens, function_tokens in reader: diff --git a/code_search/src/code_search/t2t/function_docstring_extended.py b/code_search/src/code_search/t2t/function_docstring_extended.py new file mode 100644 index 00000000..460647ca --- /dev/null +++ b/code_search/src/code_search/t2t/function_docstring_extended.py @@ -0,0 +1,24 @@ +from tensor2tensor.utils import registry +from .function_docstring import GithubFunctionDocstring + + +@registry.register_problem +class GithubFunctionDocstringExtended(GithubFunctionDocstring): + """Function Docstring problem with extended semantics. + + This problem keeps all the properties of the original, + with one change - the semantics of `data_dir` now dictate that + the raw CSV files containing the function docstring pairs should + already be available in the `data_dir`. This allows for the user + to modify the `data_dir` to point to a new set of data points when + needed and train an updated model. + + As a reminder, in the standard setting, `data_dir` is only meant + to be the output directory for TFRecords of an immutable dataset + elsewhere (more particularly at `self.DATA_PATH_PREFIX`). + """ + + def get_csv_files(self, _data_dir, tmp_dir, _dataset_split): + self.DATA_PATH_PREFIX = _data_dir + return super(GithubFunctionDocstringExtended, + self).get_csv_files(_data_dir, tmp_dir, _dataset_split)