Language modeling using Transformer Networks (#129)

* Add Github language modeling problem

* Rename folders, update README with datagen and train scripts

* Fix linting
This commit is contained in:
Sanyam Kapoor 2018-06-07 06:31:22 -07:00 committed by k8s-ci-robot
parent f4c8b7f80d
commit d3c781772c
5 changed files with 55 additions and 5 deletions

View File

@ -80,16 +80,37 @@ step. It uses `tensor2tensor`.
* Generate `TFRecords` for training
```
(venv3) $ t2t-datagen --t2t_usr_dir=summarizer/gh_function_summarizer --problem=github_function_summarizer \
(venv3) $ t2t-datagen --t2t_usr_dir=language_task/t2t_problems --problem=github_function_summarizer \
--data_dir=~/data --tmp_dir=/tmp
```
* Train transduction model using `Tranformer Networks` and a base hyper-parameters set
```
(venv3) $ t2t-trainer --t2t_usr_dir=summarizer/gh_function_summarizer --problem=github_function_summarizer \
(venv3) $ t2t-trainer --t2t_usr_dir=language_task/t2t_problems --problem=github_function_summarizer \
--data_dir=~/data --model=transformer --hparams_set=transformer_base --output_dir=~/train
```
## 3. Docstrings Language Model
This part trains a language model based on the docstrings in the dataset and uses `tensor2tensor`
* Install dependencies
```
(venv3) $ pip install -r summarizer/requirements.txt
```
* Generate `TFRecords` for training
```
(venv3) $ t2t-datagen --t2t_usr_dir=language_task/t2t_problems --problem=github_docstring_language_model \
--data_dir=~/data --tmp_dir=/tmp
```
* Train language model using `Tranformer Networks` and a custom hyper-parameters set
```
(venv3) $ t2t-trainer --t2t_usr_dir=language_task/t2t_problems --problem=github_docstring_language_model \
--data_dir=~/data --model=transformer --hparams_set=transformer_gh_lm --output_dir=~/train
```
# Acknowledgements
This project derives from [hamelsmu/code_search](https://github.com/hamelsmu/code_search).
This project derives from [hamelsmu/code_search](https://github.com/hamelsmu/code_search).

View File

@ -1 +1,2 @@
from . import function_summarizer
from . import docstring_lm

View File

@ -0,0 +1,26 @@
import os
from tensor2tensor.utils import registry
from tensor2tensor.data_generators import text_problems
from tensor2tensor.models import transformer
@registry.register_problem
class GithubDocstringLanguageModel(text_problems.Text2SelfProblem):
# pylint: disable=abstract-method
"""This class defines the Language Modeling problem for Github docstrings"""
@property
def is_generate_per_split(self):
return False
def generate_samples(self, data_dir, _tmp_dir, dataset_split): #pylint: disable=no-self-use
docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split))
return text_problems.text2self_txt_iterator(docstrings_file_path)
@registry.register_hparams
def transformer_gh_lm():
hparams = transformer.transformer_base()
# TODO(sanyamkapoor): change language model embedding size
return hparams

View File

@ -1,17 +1,19 @@
import os
from tensor2tensor.utils import registry
from tensor2tensor.data_generators import text_problems
@registry.register_problem
class GithubFunctionSummarizer(text_problems.Text2TextProblem):
# pylint: disable=abstract-method
"""This class defines the problem of converting Python function code to docstring"""
@property
def is_generate_per_split(self):
return False
def generate_samples(self, data_dir, _tmp_dir, dataset_split): # pylint: disable=no-self-use
def generate_samples(self, data_dir, _tmp_dir, dataset_split): #pylint: disable=no-self-use
"""This method returns the generator to return {"inputs": [text], "targets": [text]} dict"""
# TODO(sanyamkapoor): Merge with validation set file "valid.{function|docstring}"