65 lines
2.5 KiB
Python
65 lines
2.5 KiB
Python
# Copyright 2018 The Kubeflow Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import subprocess
|
|
import tempfile
|
|
import unittest
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
|
|
import kfp.components as comp
|
|
|
|
@contextmanager
|
|
def components_local_output_dir_context(output_dir: str):
|
|
old_dir = comp._components._outputs_dir
|
|
try:
|
|
comp._components._outputs_dir = output_dir
|
|
yield output_dir
|
|
finally:
|
|
comp._components._outputs_dir = old_dir
|
|
|
|
class KerasTrainClassifierTestCase(unittest.TestCase):
|
|
def test_handle_training_xor(self):
|
|
tests_root = os.path.abspath(os.path.dirname(__file__))
|
|
component_root = os.path.abspath(os.path.join(tests_root, '..'))
|
|
testdata_root = os.path.abspath(os.path.join(tests_root, 'testdata'))
|
|
|
|
train_op = comp.load_component(os.path.join(component_root, 'component.yaml'))
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir_name:
|
|
with components_local_output_dir_context(temp_dir_name):
|
|
train_task = train_op(
|
|
training_set_features_path=os.path.join(testdata_root, 'training_set_features.tsv'),
|
|
training_set_labels_path=os.path.join(testdata_root, 'training_set_labels.tsv'),
|
|
output_model_uri=os.path.join(temp_dir_name, 'outputs/output_model/data'),
|
|
model_config=Path(testdata_root).joinpath('model_config.json').read_text(),
|
|
number_of_classes=2,
|
|
number_of_epochs=10,
|
|
batch_size=32,
|
|
)
|
|
|
|
full_command = train_task.command + train_task.arguments
|
|
full_command[0] = 'python'
|
|
full_command[1] = os.path.join(component_root, 'src', 'train.py')
|
|
|
|
process = subprocess.run(full_command)
|
|
|
|
(output_model_uri_file, ) = (train_task.file_outputs['output-model-uri'], )
|
|
output_model_uri = Path(output_model_uri_file).read_text()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|