mirror of https://github.com/kubeflow/examples.git
				
				
				
			Add a new similarity transformer model, register new problem (#146)
* Add a new similarity transformer model, register new problem * Remove useless constructor
This commit is contained in:
		
							parent
							
								
									656e1e3e7c
								
							
						
					
					
						commit
						f20161167e
					
				|  | @ -1,26 +0,0 @@ | |||
| import os | ||||
| from tensor2tensor.utils import registry | ||||
| from tensor2tensor.data_generators import text_problems | ||||
| from tensor2tensor.models import transformer | ||||
| 
 | ||||
| 
 | ||||
| @registry.register_problem | ||||
| class GithubDocstringLanguageModel(text_problems.Text2SelfProblem): | ||||
|   # pylint: disable=abstract-method | ||||
| 
 | ||||
|   """This class defines the Language Modeling problem for Github docstrings""" | ||||
| 
 | ||||
|   @property | ||||
|   def is_generate_per_split(self): | ||||
|     return False | ||||
| 
 | ||||
|   def generate_samples(self, data_dir, _tmp_dir, dataset_split):  #pylint: disable=no-self-use | ||||
|     docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split)) | ||||
| 
 | ||||
|     return text_problems.text2self_txt_iterator(docstrings_file_path) | ||||
| 
 | ||||
| @registry.register_hparams | ||||
| def transformer_gh_lm(): | ||||
|   hparams = transformer.transformer_base() | ||||
|   # TODO(sanyamkapoor): change language model embedding size | ||||
|   return hparams | ||||
|  | @ -1,13 +1,30 @@ | |||
| import os | ||||
| from tensor2tensor.utils import t2t_model | ||||
| from tensor2tensor.utils import registry | ||||
| from tensor2tensor.data_generators import text_problems | ||||
| 
 | ||||
| 
 | ||||
| @registry.register_problem | ||||
| class GithubFunctionSummarizer(text_problems.Text2TextProblem): | ||||
| @registry.register_model | ||||
| class SimilarityTransformer(t2t_model.T2TModel): | ||||
|   # pylint: disable=abstract-method | ||||
| 
 | ||||
|   """This class defines the problem of converting Python function code to docstring""" | ||||
|   """ | ||||
|   This class defines the model to compute similarity scores between functions and | ||||
|   docstrings | ||||
|   """ | ||||
| 
 | ||||
|   def body(self, features): | ||||
|     # TODO: need to fill this with Transformer encoder/decoder | ||||
|     # and loss calculation | ||||
|     raise NotImplementedError | ||||
| 
 | ||||
| 
 | ||||
| @registry.register_problem | ||||
| class GithubFunctionDocstring(text_problems.Text2TextProblem): | ||||
|   # pylint: disable=abstract-method | ||||
| 
 | ||||
|   """This class defines the problem of finding similarity between Python function | ||||
|    and docstring""" | ||||
| 
 | ||||
|   @property | ||||
|   def is_generate_per_split(self): | ||||
|  | @ -16,7 +33,6 @@ class GithubFunctionSummarizer(text_problems.Text2TextProblem): | |||
|   def generate_samples(self, data_dir, _tmp_dir, dataset_split):  #pylint: disable=no-self-use | ||||
|     """This method returns the generator to return {"inputs": [text], "targets": [text]} dict""" | ||||
| 
 | ||||
|     # TODO(sanyamkapoor): Merge with validation set file "valid.{function|docstring}" | ||||
|     functions_file_path = os.path.join(data_dir, '{}.function'.format(dataset_split)) | ||||
|     docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split)) | ||||
| 
 | ||||
		Loading…
	
		Reference in New Issue