mirror of https://github.com/kubeflow/examples.git
Add a new similarity transformer model, register new problem (#146)
* Add a new similarity transformer model, register new problem * Remove useless constructor
This commit is contained in:
parent
656e1e3e7c
commit
f20161167e
|
|
@ -1,26 +0,0 @@
|
|||
import os
|
||||
from tensor2tensor.utils import registry
|
||||
from tensor2tensor.data_generators import text_problems
|
||||
from tensor2tensor.models import transformer
|
||||
|
||||
|
||||
@registry.register_problem
|
||||
class GithubDocstringLanguageModel(text_problems.Text2SelfProblem):
|
||||
# pylint: disable=abstract-method
|
||||
|
||||
"""This class defines the Language Modeling problem for Github docstrings"""
|
||||
|
||||
@property
|
||||
def is_generate_per_split(self):
|
||||
return False
|
||||
|
||||
def generate_samples(self, data_dir, _tmp_dir, dataset_split): #pylint: disable=no-self-use
|
||||
docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split))
|
||||
|
||||
return text_problems.text2self_txt_iterator(docstrings_file_path)
|
||||
|
||||
@registry.register_hparams
|
||||
def transformer_gh_lm():
|
||||
hparams = transformer.transformer_base()
|
||||
# TODO(sanyamkapoor): change language model embedding size
|
||||
return hparams
|
||||
|
|
@ -1,13 +1,30 @@
|
|||
import os
|
||||
from tensor2tensor.utils import t2t_model
|
||||
from tensor2tensor.utils import registry
|
||||
from tensor2tensor.data_generators import text_problems
|
||||
|
||||
|
||||
@registry.register_problem
|
||||
class GithubFunctionSummarizer(text_problems.Text2TextProblem):
|
||||
@registry.register_model
|
||||
class SimilarityTransformer(t2t_model.T2TModel):
|
||||
# pylint: disable=abstract-method
|
||||
|
||||
"""This class defines the problem of converting Python function code to docstring"""
|
||||
"""
|
||||
This class defines the model to compute similarity scores between functions and
|
||||
docstrings
|
||||
"""
|
||||
|
||||
def body(self, features):
|
||||
# TODO: need to fill this with Transformer encoder/decoder
|
||||
# and loss calculation
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@registry.register_problem
|
||||
class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
||||
# pylint: disable=abstract-method
|
||||
|
||||
"""This class defines the problem of finding similarity between Python function
|
||||
and docstring"""
|
||||
|
||||
@property
|
||||
def is_generate_per_split(self):
|
||||
|
|
@ -16,7 +33,6 @@ class GithubFunctionSummarizer(text_problems.Text2TextProblem):
|
|||
def generate_samples(self, data_dir, _tmp_dir, dataset_split): #pylint: disable=no-self-use
|
||||
"""This method returns the generator to return {"inputs": [text], "targets": [text]} dict"""
|
||||
|
||||
# TODO(sanyamkapoor): Merge with validation set file "valid.{function|docstring}"
|
||||
functions_file_path = os.path.join(data_dir, '{}.function'.format(dataset_split))
|
||||
docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split))
|
||||
|
||||
Loading…
Reference in New Issue