examples/github_issue_summarization/notebooks/train_test.py

83 lines
2.5 KiB
Python

import os
import tempfile
import unittest
import train
class TrainTest(unittest.TestCase):
def test_keras(self):
"""Test training_with_keras."""
output_dir = tempfile.mkdtemp()
output_model = os.path.join(output_dir, "model.h5")
body_pp_dpkl = os.path.join(output_dir, "body_pp.dpkl")
title_pp_dpkl = os.path.join(output_dir, "body_pp.dpkl")
title_vecs = os.path.join(output_dir, "title.npy")
body_vecs = os.path.join(output_dir, "body.npy")
this_dir = os.path.dirname(__file__)
args = [
"--sample_size=100",
"--num_epochs=1",
"--input_data=" + os.path.join(this_dir, "test_data",
"github_issues_sample.csv"),
"--output_model=" + output_model,
"--output_body_preprocessor_dpkl="+ body_pp_dpkl,
"--output_title_preprocessor_dpkl="+ title_pp_dpkl,
"--output_train_title_vecs_npy=" + title_vecs,
"--output_train_body_vecs_npy=" + body_vecs,
]
train.main(args)
output_files = [
output_model, body_pp_dpkl, title_pp_dpkl, title_vecs, body_vecs
]
for f in output_files:
self.assertTrue(os.path.exists(f))
# TODO(https://github.com/kubeflow/examples/issues/280)
# TODO(https://github.com/kubeflow/examples/issues/196)
# This test won't work until we fix the code to work using the estimator
# API.
@unittest.skip("skip estimator test")
def test_estimator(self):
"""Test training_with_keras."""
output_dir = tempfile.mkdtemp()
output_model = os.path.join(output_dir, "model")
body_pp_dpkl = os.path.join(output_dir, "body_pp.dpkl")
title_pp_dpkl = os.path.join(output_dir, "body_pp.dpkl")
title_vecs = os.path.join(output_dir, "title.npy")
body_vecs = os.path.join(output_dir, "body.npy")
this_dir = os.path.dirname(__file__)
args = [
"--sample_size=100",
"--num_epochs=1",
"--input_data=" + os.path.join(this_dir, "test_data",
"github_issues_sample.csv"),
"--output_model=" + output_model,
"--output_body_preprocessor_dpkl="+ body_pp_dpkl,
"--output_title_preprocessor_dpkl="+ title_pp_dpkl,
"--output_train_title_vecs_npy=" + title_vecs,
"--output_train_body_vecs_npy=" + body_vecs,
"--mode=estimator",
]
train.main(args)
output_files = [
output_model, body_pp_dpkl, title_pp_dpkl, title_vecs, body_vecs
]
for f in output_files:
self.assertTrue(os.path.exists(f))
if __name__ == "__main__":
unittest.main()