pipelines/components/contrib/sample/keras/train_classifier/tests/test_component.py

65 lines
2.5 KiB
Python

# Copyright 2018 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import tempfile
import unittest
from contextlib import contextmanager
from pathlib import Path
import kfp.components as comp
@contextmanager
def components_local_output_dir_context(output_dir: str):
old_dir = comp._components._outputs_dir
try:
comp._components._outputs_dir = output_dir
yield output_dir
finally:
comp._components._outputs_dir = old_dir
class KerasTrainClassifierTestCase(unittest.TestCase):
def test_handle_training_xor(self):
tests_root = os.path.abspath(os.path.dirname(__file__))
component_root = os.path.abspath(os.path.join(tests_root, '..'))
testdata_root = os.path.abspath(os.path.join(tests_root, 'testdata'))
train_op = comp.load_component(os.path.join(component_root, 'component.yaml'))
with tempfile.TemporaryDirectory() as temp_dir_name:
with components_local_output_dir_context(temp_dir_name):
train_task = train_op(
training_set_features_path=os.path.join(testdata_root, 'training_set_features.tsv'),
training_set_labels_path=os.path.join(testdata_root, 'training_set_labels.tsv'),
output_model_uri=os.path.join(temp_dir_name, 'outputs/output_model/data'),
model_config=Path(testdata_root).joinpath('model_config.json').read_text(),
number_of_classes=2,
number_of_epochs=10,
batch_size=32,
)
full_command = train_task.command + train_task.arguments
full_command[0] = 'python'
full_command[1] = os.path.join(component_root, 'src', 'train.py')
process = subprocess.run(full_command)
(output_model_uri_file, ) = (train_task.file_outputs['output-model-uri'], )
output_model_uri = Path(output_model_uri_file).read_text()
if __name__ == '__main__':
unittest.main()