mirror of https://github.com/kubeflow/examples.git
Add a new github function docstring extended problem (#225)
* Add a new github function docstring extended problem * Fix lint errors * Update images
This commit is contained in:
parent
8fce4a7799
commit
18829159b0
|
|
@ -9,8 +9,8 @@
|
|||
numPsGpu: 0,
|
||||
train_steps: 100,
|
||||
eval_steps: 10,
|
||||
image: 'gcr.io/kubeflow-dev/code-search:v20180802-c622aac',
|
||||
imageGpu: 'gcr.io/kubeflow-dev/code-search:v20180802-c622aac-gpu',
|
||||
image: 'gcr.io/kubeflow-dev/code-search:v20180814-66d27b9',
|
||||
imageGpu: 'gcr.io/kubeflow-dev/code-search:v20180814-66d27b9-gpu',
|
||||
imagePullSecrets: [],
|
||||
dataDir: 'null',
|
||||
outputDir: 'null',
|
||||
|
|
@ -19,7 +19,7 @@
|
|||
},
|
||||
"t2t-code-search": {
|
||||
workingDir: 'gs://example/prefix',
|
||||
problem: 'github_function_docstring',
|
||||
problem: 'github_function_docstring_extended',
|
||||
model: 'similarity_transformer',
|
||||
hparams_set: 'transformer_tiny',
|
||||
},
|
||||
|
|
|
|||
|
|
@ -5,3 +5,4 @@
|
|||
#
|
||||
from . import function_docstring
|
||||
from . import similarity_transformer
|
||||
from . import function_docstring_extended
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
|||
the docstring tokens and function tokens. The delimiter is
|
||||
",".
|
||||
"""
|
||||
|
||||
DATA_PATH_PREFIX = "gs://kubeflow-examples/t2t-code-search/raw_data"
|
||||
|
||||
@property
|
||||
def pair_files_list(self):
|
||||
"""Return URL and file names.
|
||||
|
|
@ -45,11 +48,9 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
|||
In this case, the tuple is of size 1 because the URL points
|
||||
to a file itself.
|
||||
"""
|
||||
base_url = "gs://kubeflow-examples/t2t-code-search/raw_data"
|
||||
|
||||
return [
|
||||
[
|
||||
"{}/func-doc-pairs-000{:02}-of-00100.csv".format(base_url, i),
|
||||
"{}/func-doc-pairs-000{:02}-of-00100.csv".format(self.DATA_PATH_PREFIX, i),
|
||||
("func-doc-pairs-000{:02}-of-00100.csv".format(i),)
|
||||
]
|
||||
for i in range(100)
|
||||
|
|
@ -68,7 +69,13 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
|||
# FIXME(sanyamkapoor): This exists to handle memory explosion.
|
||||
return int(3.5e5)
|
||||
|
||||
def generate_samples(self, _data_dir, tmp_dir, _dataset_split):
|
||||
def get_csv_files(self, _data_dir, tmp_dir, _dataset_split):
|
||||
return [
|
||||
generator_utils.maybe_download(tmp_dir, file_list[0], uri)
|
||||
for uri, file_list in self.pair_files_list
|
||||
]
|
||||
|
||||
def generate_samples(self, data_dir, tmp_dir, dataset_split):
|
||||
"""A generator to return data samples.Returns the data generator to return.
|
||||
|
||||
|
||||
|
|
@ -82,14 +89,11 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
|||
Each element yielded is of a Python dict of the form
|
||||
{"inputs": "STRING", "targets": "STRING"}
|
||||
"""
|
||||
csv_files = [
|
||||
generator_utils.maybe_download(tmp_dir, file_list[0], uri)
|
||||
for uri, file_list in self.pair_files_list
|
||||
]
|
||||
csv_files = self.get_csv_files(data_dir, tmp_dir, dataset_split)
|
||||
|
||||
for pairs_file in csv_files:
|
||||
tf.logging.debug("Reading {}".format(pairs_file))
|
||||
with open(pairs_file, "r") as csv_file:
|
||||
with tf.gfile.Open(pairs_file) as csv_file:
|
||||
for line in csv_file:
|
||||
reader = csv.reader(StringIO(line))
|
||||
for docstring_tokens, function_tokens in reader:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
from tensor2tensor.utils import registry
|
||||
from .function_docstring import GithubFunctionDocstring
|
||||
|
||||
|
||||
@registry.register_problem
|
||||
class GithubFunctionDocstringExtended(GithubFunctionDocstring):
|
||||
"""Function Docstring problem with extended semantics.
|
||||
|
||||
This problem keeps all the properties of the original,
|
||||
with one change - the semantics of `data_dir` now dictate that
|
||||
the raw CSV files containing the function docstring pairs should
|
||||
already be available in the `data_dir`. This allows for the user
|
||||
to modify the `data_dir` to point to a new set of data points when
|
||||
needed and train an updated model.
|
||||
|
||||
As a reminder, in the standard setting, `data_dir` is only meant
|
||||
to be the output directory for TFRecords of an immutable dataset
|
||||
elsewhere (more particularly at `self.DATA_PATH_PREFIX`).
|
||||
"""
|
||||
|
||||
def get_csv_files(self, _data_dir, tmp_dir, _dataset_split):
|
||||
self.DATA_PATH_PREFIX = _data_dir
|
||||
return super(GithubFunctionDocstringExtended,
|
||||
self).get_csv_files(_data_dir, tmp_dir, _dataset_split)
|
||||
Loading…
Reference in New Issue