mirror of https://github.com/kubeflow/examples.git
				
				
				
			Language modeling using Transformer Networks (#129)
* Add Github language modeling problem * Rename folders, update README with datagen and train scripts * Fix linting
This commit is contained in:
		
							parent
							
								
									f4c8b7f80d
								
							
						
					
					
						commit
						d3c781772c
					
				|  | @ -80,16 +80,37 @@ step. It uses `tensor2tensor`. | ||||||
| 
 | 
 | ||||||
| * Generate `TFRecords` for training | * Generate `TFRecords` for training | ||||||
| ``` | ``` | ||||||
| (venv3) $ t2t-datagen --t2t_usr_dir=summarizer/gh_function_summarizer --problem=github_function_summarizer \ | (venv3) $ t2t-datagen --t2t_usr_dir=language_task/t2t_problems --problem=github_function_summarizer \ | ||||||
|                       --data_dir=~/data --tmp_dir=/tmp |                       --data_dir=~/data --tmp_dir=/tmp | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| * Train transduction model using `Tranformer Networks` and a base hyper-parameters set | * Train transduction model using `Tranformer Networks` and a base hyper-parameters set | ||||||
| ``` | ``` | ||||||
| (venv3) $ t2t-trainer --t2t_usr_dir=summarizer/gh_function_summarizer --problem=github_function_summarizer \ | (venv3) $ t2t-trainer --t2t_usr_dir=language_task/t2t_problems --problem=github_function_summarizer \ | ||||||
|                       --data_dir=~/data --model=transformer --hparams_set=transformer_base --output_dir=~/train |                       --data_dir=~/data --model=transformer --hparams_set=transformer_base --output_dir=~/train | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
|  | ## 3. Docstrings Language Model | ||||||
|  | 
 | ||||||
|  | This part trains a language model based on the docstrings in the dataset and uses `tensor2tensor` | ||||||
|  | 
 | ||||||
|  | * Install dependencies | ||||||
|  | ``` | ||||||
|  | (venv3) $ pip install -r summarizer/requirements.txt | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | * Generate `TFRecords` for training | ||||||
|  | ``` | ||||||
|  | (venv3) $ t2t-datagen --t2t_usr_dir=language_task/t2t_problems --problem=github_docstring_language_model \ | ||||||
|  |                       --data_dir=~/data --tmp_dir=/tmp | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | * Train language model using `Tranformer Networks` and a custom hyper-parameters set | ||||||
|  | ``` | ||||||
|  | (venv3) $ t2t-trainer --t2t_usr_dir=language_task/t2t_problems --problem=github_docstring_language_model \ | ||||||
|  |                       --data_dir=~/data --model=transformer --hparams_set=transformer_gh_lm --output_dir=~/train | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
| # Acknowledgements | # Acknowledgements | ||||||
| 
 | 
 | ||||||
| This project derives from [hamelsmu/code_search](https://github.com/hamelsmu/code_search). | This project derives from [hamelsmu/code_search](https://github.com/hamelsmu/code_search). | ||||||
|  | @ -1 +1,2 @@ | ||||||
| from . import function_summarizer | from . import function_summarizer | ||||||
|  | from . import docstring_lm | ||||||
|  | @ -0,0 +1,26 @@ | ||||||
|  | import os | ||||||
|  | from tensor2tensor.utils import registry | ||||||
|  | from tensor2tensor.data_generators import text_problems | ||||||
|  | from tensor2tensor.models import transformer | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @registry.register_problem | ||||||
|  | class GithubDocstringLanguageModel(text_problems.Text2SelfProblem): | ||||||
|  |   # pylint: disable=abstract-method | ||||||
|  | 
 | ||||||
|  |   """This class defines the Language Modeling problem for Github docstrings""" | ||||||
|  | 
 | ||||||
|  |   @property | ||||||
|  |   def is_generate_per_split(self): | ||||||
|  |     return False | ||||||
|  | 
 | ||||||
|  |   def generate_samples(self, data_dir, _tmp_dir, dataset_split):  #pylint: disable=no-self-use | ||||||
|  |     docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split)) | ||||||
|  | 
 | ||||||
|  |     return text_problems.text2self_txt_iterator(docstrings_file_path) | ||||||
|  | 
 | ||||||
|  | @registry.register_hparams | ||||||
|  | def transformer_gh_lm(): | ||||||
|  |   hparams = transformer.transformer_base() | ||||||
|  |   # TODO(sanyamkapoor): change language model embedding size | ||||||
|  |   return hparams | ||||||
|  | @ -1,17 +1,19 @@ | ||||||
| import os | import os | ||||||
| 
 |  | ||||||
| from tensor2tensor.utils import registry | from tensor2tensor.utils import registry | ||||||
| from tensor2tensor.data_generators import text_problems | from tensor2tensor.data_generators import text_problems | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @registry.register_problem | @registry.register_problem | ||||||
| class GithubFunctionSummarizer(text_problems.Text2TextProblem): | class GithubFunctionSummarizer(text_problems.Text2TextProblem): | ||||||
|  |   # pylint: disable=abstract-method | ||||||
|  | 
 | ||||||
|   """This class defines the problem of converting Python function code to docstring""" |   """This class defines the problem of converting Python function code to docstring""" | ||||||
| 
 | 
 | ||||||
|   @property |   @property | ||||||
|   def is_generate_per_split(self): |   def is_generate_per_split(self): | ||||||
|     return False |     return False | ||||||
| 
 | 
 | ||||||
|   def generate_samples(self, data_dir, _tmp_dir, dataset_split): # pylint: disable=no-self-use |   def generate_samples(self, data_dir, _tmp_dir, dataset_split):  #pylint: disable=no-self-use | ||||||
|     """This method returns the generator to return {"inputs": [text], "targets": [text]} dict""" |     """This method returns the generator to return {"inputs": [text], "targets": [text]} dict""" | ||||||
| 
 | 
 | ||||||
|     # TODO(sanyamkapoor): Merge with validation set file "valid.{function|docstring}" |     # TODO(sanyamkapoor): Merge with validation set file "valid.{function|docstring}" | ||||||
		Loading…
	
		Reference in New Issue