Add a new github function docstring extended problem (#225)

* Add a new github function docstring extended problem

* Fix lint errors

* Update images
This commit is contained in:
Sanyam Kapoor 2018-08-14 15:41:47 -07:00 committed by k8s-ci-robot
parent 8fce4a7799
commit 18829159b0
4 changed files with 41 additions and 12 deletions

View File

@ -9,8 +9,8 @@
numPsGpu: 0,
train_steps: 100,
eval_steps: 10,
image: 'gcr.io/kubeflow-dev/code-search:v20180802-c622aac',
imageGpu: 'gcr.io/kubeflow-dev/code-search:v20180802-c622aac-gpu',
image: 'gcr.io/kubeflow-dev/code-search:v20180814-66d27b9',
imageGpu: 'gcr.io/kubeflow-dev/code-search:v20180814-66d27b9-gpu',
imagePullSecrets: [],
dataDir: 'null',
outputDir: 'null',
@ -19,7 +19,7 @@
},
"t2t-code-search": {
workingDir: 'gs://example/prefix',
problem: 'github_function_docstring',
problem: 'github_function_docstring_extended',
model: 'similarity_transformer',
hparams_set: 'transformer_tiny',
},

View File

@ -5,3 +5,4 @@
#
from . import function_docstring
from . import similarity_transformer
from . import function_docstring_extended

View File

@ -18,6 +18,9 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem):
the docstring tokens and function tokens. The delimiter is
",".
"""
DATA_PATH_PREFIX = "gs://kubeflow-examples/t2t-code-search/raw_data"
@property
def pair_files_list(self):
"""Return URL and file names.
@ -45,11 +48,9 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem):
In this case, the tuple is of size 1 because the URL points
to a file itself.
"""
base_url = "gs://kubeflow-examples/t2t-code-search/raw_data"
return [
[
"{}/func-doc-pairs-000{:02}-of-00100.csv".format(base_url, i),
"{}/func-doc-pairs-000{:02}-of-00100.csv".format(self.DATA_PATH_PREFIX, i),
("func-doc-pairs-000{:02}-of-00100.csv".format(i),)
]
for i in range(100)
@ -68,7 +69,13 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem):
# FIXME(sanyamkapoor): This exists to handle memory explosion.
return int(3.5e5)
def generate_samples(self, _data_dir, tmp_dir, _dataset_split):
def get_csv_files(self, _data_dir, tmp_dir, _dataset_split):
return [
generator_utils.maybe_download(tmp_dir, file_list[0], uri)
for uri, file_list in self.pair_files_list
]
def generate_samples(self, data_dir, tmp_dir, dataset_split):
"""A generator to return data samples.Returns the data generator to return.
@ -82,14 +89,11 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem):
Each element yielded is of a Python dict of the form
{"inputs": "STRING", "targets": "STRING"}
"""
csv_files = [
generator_utils.maybe_download(tmp_dir, file_list[0], uri)
for uri, file_list in self.pair_files_list
]
csv_files = self.get_csv_files(data_dir, tmp_dir, dataset_split)
for pairs_file in csv_files:
tf.logging.debug("Reading {}".format(pairs_file))
with open(pairs_file, "r") as csv_file:
with tf.gfile.Open(pairs_file) as csv_file:
for line in csv_file:
reader = csv.reader(StringIO(line))
for docstring_tokens, function_tokens in reader:

View File

@ -0,0 +1,24 @@
from tensor2tensor.utils import registry
from .function_docstring import GithubFunctionDocstring
@registry.register_problem
class GithubFunctionDocstringExtended(GithubFunctionDocstring):
"""Function Docstring problem with extended semantics.
This problem keeps all the properties of the original,
with one change - the semantics of `data_dir` now dictate that
the raw CSV files containing the function docstring pairs should
already be available in the `data_dir`. This allows for the user
to modify the `data_dir` to point to a new set of data points when
needed and train an updated model.
As a reminder, in the standard setting, `data_dir` is only meant
to be the output directory for TFRecords of an immutable dataset
elsewhere (more particularly at `self.DATA_PATH_PREFIX`).
"""
def get_csv_files(self, _data_dir, tmp_dir, _dataset_split):
self.DATA_PATH_PREFIX = _data_dir
return super(GithubFunctionDocstringExtended,
self).get_csv_files(_data_dir, tmp_dir, _dataset_split)