mirror of https://github.com/kubeflow/examples.git
Add a new similarity transformer model, register new problem (#146)
* Add a new similarity transformer model, register new problem * Remove useless constructor
This commit is contained in:
parent
656e1e3e7c
commit
f20161167e
|
|
@ -1,26 +0,0 @@
|
||||||
import os
|
|
||||||
from tensor2tensor.utils import registry
|
|
||||||
from tensor2tensor.data_generators import text_problems
|
|
||||||
from tensor2tensor.models import transformer
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register_problem
|
|
||||||
class GithubDocstringLanguageModel(text_problems.Text2SelfProblem):
|
|
||||||
# pylint: disable=abstract-method
|
|
||||||
|
|
||||||
"""This class defines the Language Modeling problem for Github docstrings"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_generate_per_split(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def generate_samples(self, data_dir, _tmp_dir, dataset_split): #pylint: disable=no-self-use
|
|
||||||
docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split))
|
|
||||||
|
|
||||||
return text_problems.text2self_txt_iterator(docstrings_file_path)
|
|
||||||
|
|
||||||
@registry.register_hparams
|
|
||||||
def transformer_gh_lm():
|
|
||||||
hparams = transformer.transformer_base()
|
|
||||||
# TODO(sanyamkapoor): change language model embedding size
|
|
||||||
return hparams
|
|
||||||
|
|
@ -1,13 +1,30 @@
|
||||||
import os
|
import os
|
||||||
|
from tensor2tensor.utils import t2t_model
|
||||||
from tensor2tensor.utils import registry
|
from tensor2tensor.utils import registry
|
||||||
from tensor2tensor.data_generators import text_problems
|
from tensor2tensor.data_generators import text_problems
|
||||||
|
|
||||||
|
|
||||||
@registry.register_problem
|
@registry.register_model
|
||||||
class GithubFunctionSummarizer(text_problems.Text2TextProblem):
|
class SimilarityTransformer(t2t_model.T2TModel):
|
||||||
# pylint: disable=abstract-method
|
# pylint: disable=abstract-method
|
||||||
|
|
||||||
"""This class defines the problem of converting Python function code to docstring"""
|
"""
|
||||||
|
This class defines the model to compute similarity scores between functions and
|
||||||
|
docstrings
|
||||||
|
"""
|
||||||
|
|
||||||
|
def body(self, features):
|
||||||
|
# TODO: need to fill this with Transformer encoder/decoder
|
||||||
|
# and loss calculation
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_problem
|
||||||
|
class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
||||||
|
# pylint: disable=abstract-method
|
||||||
|
|
||||||
|
"""This class defines the problem of finding similarity between Python function
|
||||||
|
and docstring"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_generate_per_split(self):
|
def is_generate_per_split(self):
|
||||||
|
|
@ -16,7 +33,6 @@ class GithubFunctionSummarizer(text_problems.Text2TextProblem):
|
||||||
def generate_samples(self, data_dir, _tmp_dir, dataset_split): #pylint: disable=no-self-use
|
def generate_samples(self, data_dir, _tmp_dir, dataset_split): #pylint: disable=no-self-use
|
||||||
"""This method returns the generator to return {"inputs": [text], "targets": [text]} dict"""
|
"""This method returns the generator to return {"inputs": [text], "targets": [text]} dict"""
|
||||||
|
|
||||||
# TODO(sanyamkapoor): Merge with validation set file "valid.{function|docstring}"
|
|
||||||
functions_file_path = os.path.join(data_dir, '{}.function'.format(dataset_split))
|
functions_file_path = os.path.join(data_dir, '{}.function'.format(dataset_split))
|
||||||
docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split))
|
docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split))
|
||||||
|
|
||||||
Loading…
Reference in New Issue