diff --git a/code_search/language_task/t2t_problems/docstring_lm.py b/code_search/language_task/t2t_problems/docstring_lm.py deleted file mode 100644 index 24489f3f..00000000 --- a/code_search/language_task/t2t_problems/docstring_lm.py +++ /dev/null @@ -1,26 +0,0 @@ -import os -from tensor2tensor.utils import registry -from tensor2tensor.data_generators import text_problems -from tensor2tensor.models import transformer - - -@registry.register_problem -class GithubDocstringLanguageModel(text_problems.Text2SelfProblem): - # pylint: disable=abstract-method - - """This class defines the Language Modeling problem for Github docstrings""" - - @property - def is_generate_per_split(self): - return False - - def generate_samples(self, data_dir, _tmp_dir, dataset_split): #pylint: disable=no-self-use - docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split)) - - return text_problems.text2self_txt_iterator(docstrings_file_path) - -@registry.register_hparams -def transformer_gh_lm(): - hparams = transformer.transformer_base() - # TODO(sanyamkapoor): change language model embedding size - return hparams diff --git a/code_search/language_task/t2t_problems/function_summarizer.py b/code_search/language_task/t2t_problems/similarity_transformer.py similarity index 54% rename from code_search/language_task/t2t_problems/function_summarizer.py rename to code_search/language_task/t2t_problems/similarity_transformer.py index fbe6d196..9a7c5c58 100644 --- a/code_search/language_task/t2t_problems/function_summarizer.py +++ b/code_search/language_task/t2t_problems/similarity_transformer.py @@ -1,13 +1,30 @@ import os +from tensor2tensor.utils import t2t_model from tensor2tensor.utils import registry from tensor2tensor.data_generators import text_problems -@registry.register_problem -class GithubFunctionSummarizer(text_problems.Text2TextProblem): +@registry.register_model +class SimilarityTransformer(t2t_model.T2TModel): # pylint: disable=abstract-method - """This class defines the problem of converting Python function code to docstring""" + """ + This class defines the model to compute similarity scores between functions and + docstrings + """ + + def body(self, features): + # TODO: need to fill this with Transformer encoder/decoder + # and loss calculation + raise NotImplementedError + + +@registry.register_problem +class GithubFunctionDocstring(text_problems.Text2TextProblem): + # pylint: disable=abstract-method + + """This class defines the problem of finding similarity between Python function + and docstring""" @property def is_generate_per_split(self): @@ -16,7 +33,6 @@ class GithubFunctionSummarizer(text_problems.Text2TextProblem): def generate_samples(self, data_dir, _tmp_dir, dataset_split): #pylint: disable=no-self-use """This method returns the generator to return {"inputs": [text], "targets": [text]} dict""" - # TODO(sanyamkapoor): Merge with validation set file "valid.{function|docstring}" functions_file_path = os.path.join(data_dir, '{}.function'.format(dataset_split)) docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split))