mirror of https://github.com/kubeflow/examples.git
Language modeling using Transformer Networks (#129)
* Add Github language modeling problem * Rename folders, update README with datagen and train scripts * Fix linting
This commit is contained in:
parent
f4c8b7f80d
commit
d3c781772c
|
@ -80,16 +80,37 @@ step. It uses `tensor2tensor`.
|
|||
|
||||
* Generate `TFRecords` for training
|
||||
```
|
||||
(venv3) $ t2t-datagen --t2t_usr_dir=summarizer/gh_function_summarizer --problem=github_function_summarizer \
|
||||
(venv3) $ t2t-datagen --t2t_usr_dir=language_task/t2t_problems --problem=github_function_summarizer \
|
||||
--data_dir=~/data --tmp_dir=/tmp
|
||||
```
|
||||
|
||||
* Train transduction model using `Tranformer Networks` and a base hyper-parameters set
|
||||
```
|
||||
(venv3) $ t2t-trainer --t2t_usr_dir=summarizer/gh_function_summarizer --problem=github_function_summarizer \
|
||||
(venv3) $ t2t-trainer --t2t_usr_dir=language_task/t2t_problems --problem=github_function_summarizer \
|
||||
--data_dir=~/data --model=transformer --hparams_set=transformer_base --output_dir=~/train
|
||||
```
|
||||
|
||||
## 3. Docstrings Language Model
|
||||
|
||||
This part trains a language model based on the docstrings in the dataset and uses `tensor2tensor`
|
||||
|
||||
* Install dependencies
|
||||
```
|
||||
(venv3) $ pip install -r summarizer/requirements.txt
|
||||
```
|
||||
|
||||
* Generate `TFRecords` for training
|
||||
```
|
||||
(venv3) $ t2t-datagen --t2t_usr_dir=language_task/t2t_problems --problem=github_docstring_language_model \
|
||||
--data_dir=~/data --tmp_dir=/tmp
|
||||
```
|
||||
|
||||
* Train language model using `Tranformer Networks` and a custom hyper-parameters set
|
||||
```
|
||||
(venv3) $ t2t-trainer --t2t_usr_dir=language_task/t2t_problems --problem=github_docstring_language_model \
|
||||
--data_dir=~/data --model=transformer --hparams_set=transformer_gh_lm --output_dir=~/train
|
||||
```
|
||||
|
||||
# Acknowledgements
|
||||
|
||||
This project derives from [hamelsmu/code_search](https://github.com/hamelsmu/code_search).
|
||||
This project derives from [hamelsmu/code_search](https://github.com/hamelsmu/code_search).
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
from . import function_summarizer
|
||||
from . import docstring_lm
|
|
@ -0,0 +1,26 @@
|
|||
import os
|
||||
from tensor2tensor.utils import registry
|
||||
from tensor2tensor.data_generators import text_problems
|
||||
from tensor2tensor.models import transformer
|
||||
|
||||
|
||||
@registry.register_problem
|
||||
class GithubDocstringLanguageModel(text_problems.Text2SelfProblem):
|
||||
# pylint: disable=abstract-method
|
||||
|
||||
"""This class defines the Language Modeling problem for Github docstrings"""
|
||||
|
||||
@property
|
||||
def is_generate_per_split(self):
|
||||
return False
|
||||
|
||||
def generate_samples(self, data_dir, _tmp_dir, dataset_split): #pylint: disable=no-self-use
|
||||
docstrings_file_path = os.path.join(data_dir, '{}.docstring'.format(dataset_split))
|
||||
|
||||
return text_problems.text2self_txt_iterator(docstrings_file_path)
|
||||
|
||||
@registry.register_hparams
|
||||
def transformer_gh_lm():
|
||||
hparams = transformer.transformer_base()
|
||||
# TODO(sanyamkapoor): change language model embedding size
|
||||
return hparams
|
|
@ -1,17 +1,19 @@
|
|||
import os
|
||||
|
||||
from tensor2tensor.utils import registry
|
||||
from tensor2tensor.data_generators import text_problems
|
||||
|
||||
|
||||
@registry.register_problem
|
||||
class GithubFunctionSummarizer(text_problems.Text2TextProblem):
|
||||
# pylint: disable=abstract-method
|
||||
|
||||
"""This class defines the problem of converting Python function code to docstring"""
|
||||
|
||||
@property
|
||||
def is_generate_per_split(self):
|
||||
return False
|
||||
|
||||
def generate_samples(self, data_dir, _tmp_dir, dataset_split): # pylint: disable=no-self-use
|
||||
def generate_samples(self, data_dir, _tmp_dir, dataset_split): #pylint: disable=no-self-use
|
||||
"""This method returns the generator to return {"inputs": [text], "targets": [text]} dict"""
|
||||
|
||||
# TODO(sanyamkapoor): Merge with validation set file "valid.{function|docstring}"
|
Loading…
Reference in New Issue