423 lines
14 KiB
Python
423 lines
14 KiB
Python
# Copyright 2023 The Kubeflow Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tests for subprocess_local_task_handler.py."""
|
|
import contextlib
|
|
import io
|
|
from typing import NamedTuple
|
|
import unittest
|
|
from unittest import mock
|
|
|
|
from absl.testing import parameterized
|
|
from kfp import dsl
|
|
from kfp import local
|
|
from kfp.dsl import Artifact
|
|
from kfp.dsl import Dataset
|
|
from kfp.dsl import Output
|
|
from kfp.local import subprocess_task_handler
|
|
from kfp.local import testing_utilities
|
|
|
|
# NOTE: When testing SubprocessRunner, use_venv=True throughout to avoid
|
|
# modifying current code under test.
|
|
# If the dsl.component mocks are modified in a way that makes them not work,
|
|
# the code may install kfp from PyPI rather from source. To mitigate the
|
|
# impact of such an error we should not install into the main test process'
|
|
# environment.
|
|
|
|
|
|
class TestSubprocessRunner(testing_utilities.LocalRunnerEnvironmentTestCase):
|
|
|
|
@mock.patch('sys.stdout', new_callable=io.StringIO)
|
|
def test_basic(self, mock_stdout):
|
|
local.init(runner=local.SubprocessRunner(use_venv=True))
|
|
|
|
@dsl.component
|
|
def comp():
|
|
print('foobar!')
|
|
|
|
comp()
|
|
|
|
output = mock_stdout.getvalue().strip()
|
|
|
|
self.assertContainsSubsequence(output, 'foobar!')
|
|
|
|
def test_image_warning(self):
|
|
with self.assertWarnsRegex(
|
|
RuntimeWarning,
|
|
r"You may be attemping to run a task that uses custom or non-Python base image 'my_custom_image' in a Python environment\. This may result in incorrect dependencies and/or incorrect behavior\."
|
|
):
|
|
subprocess_task_handler.SubprocessTaskHandler(
|
|
image='my_custom_image',
|
|
# avoid catching the Container Component and
|
|
# Containerized Python Component validation errors
|
|
full_command=['kfp.dsl.executor_main'],
|
|
pipeline_root='pipeline_root',
|
|
runner=local.SubprocessRunner(use_venv=True),
|
|
)
|
|
|
|
def test_cannot_run_container_component(self):
|
|
local.init(runner=local.SubprocessRunner(use_venv=True))
|
|
|
|
@dsl.container_component
|
|
def comp():
|
|
return dsl.ContainerSpec(
|
|
image='alpine',
|
|
command=['echo'],
|
|
args=['foo'],
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r'The SubprocessRunner only supports running Lightweight Python Components\. You are attempting to run a Container Component\.',
|
|
):
|
|
comp()
|
|
|
|
def test_cannot_run_containerized_python_component(self):
|
|
local.init(runner=local.SubprocessRunner(use_venv=True))
|
|
|
|
@dsl.component(target_image='foo')
|
|
def comp():
|
|
pass
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r'The SubprocessRunner only supports running Lightweight Python Components\. You are attempting to run a Containerized Python Component\.',
|
|
):
|
|
comp()
|
|
|
|
|
|
class TestRunLocalSubproces(unittest.TestCase):
|
|
|
|
def test_simple_program(self):
|
|
buffer = io.StringIO()
|
|
|
|
with contextlib.redirect_stdout(buffer):
|
|
subprocess_task_handler.run_local_subprocess([
|
|
'echo',
|
|
'foo!',
|
|
])
|
|
|
|
output = buffer.getvalue().strip()
|
|
|
|
self.assertEqual(output, 'foo!')
|
|
|
|
|
|
class TestUseCurrentPythonExecutable(
|
|
testing_utilities.LocalRunnerEnvironmentTestCase):
|
|
|
|
def test(self):
|
|
full_command = ['python3 -c "from kfp import dsl"']
|
|
actual = subprocess_task_handler.replace_python_executable(
|
|
full_command=full_command,
|
|
new_executable='/foo/bar/python3',
|
|
)
|
|
expected = ['/foo/bar/python3 -c "from kfp import dsl"']
|
|
self.assertEqual(actual, expected)
|
|
|
|
|
|
class TestUseVenv(testing_utilities.LocalRunnerEnvironmentTestCase):
|
|
|
|
@parameterized.parameters([
|
|
({
|
|
'runner': local.SubprocessRunner(use_venv=True),
|
|
}),
|
|
({
|
|
'runner': local.SubprocessRunner(use_venv=True),
|
|
}),
|
|
])
|
|
def test_use_venv_true(self, **kwargs):
|
|
local.init(**kwargs)
|
|
|
|
@dsl.component(packages_to_install=['cloudpickle'])
|
|
def installer_component():
|
|
import cloudpickle
|
|
print('Cloudpickle is installed:', cloudpickle)
|
|
|
|
installer_component()
|
|
|
|
# since the module was installed in the virtual environment, it should not exist in the current environment
|
|
with self.assertRaisesRegex(ModuleNotFoundError,
|
|
r"No module named 'cloudpickle'"):
|
|
import cloudpickle
|
|
|
|
|
|
class TestLightweightPythonComponentLogic(
|
|
testing_utilities.LocalRunnerEnvironmentTestCase):
|
|
|
|
def test_single_output_simple_case(self):
|
|
local.init(runner=local.SubprocessRunner(use_venv=True))
|
|
|
|
@dsl.component
|
|
def identity(x: str) -> str:
|
|
return x
|
|
|
|
actual = identity(x='hello').output
|
|
expected = 'hello'
|
|
self.assertIsInstance(actual, str)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_many_primitives_in_and_out(self):
|
|
local.init(runner=local.SubprocessRunner(use_venv=True))
|
|
|
|
@dsl.component
|
|
def identity(
|
|
string: str,
|
|
integer: int,
|
|
decimal: float,
|
|
boolean: bool,
|
|
l: list,
|
|
d: dict,
|
|
) -> NamedTuple(
|
|
'Outputs',
|
|
string=str,
|
|
integer=int,
|
|
decimal=float,
|
|
boolean=bool,
|
|
l=list,
|
|
d=dict):
|
|
Outputs = NamedTuple(
|
|
'Outputs',
|
|
string=str,
|
|
integer=int,
|
|
decimal=float,
|
|
boolean=bool,
|
|
l=list,
|
|
d=dict)
|
|
return Outputs(
|
|
string=string,
|
|
integer=integer,
|
|
decimal=decimal,
|
|
boolean=boolean,
|
|
l=l,
|
|
d=d,
|
|
)
|
|
|
|
task = identity(
|
|
string='foo',
|
|
integer=1,
|
|
decimal=3.14,
|
|
boolean=True,
|
|
l=['a', 'b'],
|
|
d={'x': 'y'})
|
|
self.assertIsInstance(task.outputs['string'], str)
|
|
self.assertEqual(task.outputs['string'], 'foo')
|
|
|
|
self.assertIsInstance(task.outputs['integer'], int)
|
|
self.assertEqual(task.outputs['integer'], 1)
|
|
|
|
self.assertIsInstance(task.outputs['decimal'], float)
|
|
self.assertEqual(task.outputs['decimal'], 3.14)
|
|
|
|
self.assertIsInstance(task.outputs['boolean'], bool)
|
|
self.assertTrue(task.outputs['boolean'])
|
|
|
|
self.assertIsInstance(task.outputs['l'], list)
|
|
self.assertEqual(task.outputs['l'], ['a', 'b'])
|
|
|
|
self.assertIsInstance(task.outputs['d'], dict)
|
|
self.assertEqual(task.outputs['d'], {'x': 'y'})
|
|
|
|
def test_single_output_not_available(self):
|
|
local.init(runner=local.SubprocessRunner(use_venv=True))
|
|
from typing import NamedTuple
|
|
|
|
@dsl.component
|
|
def return_twice(x: str) -> NamedTuple('Outputs', x=str, y=str):
|
|
Outputs = NamedTuple('Output', x=str, y=str)
|
|
return Outputs(x=x, y=x)
|
|
|
|
local_task = return_twice(x='foo')
|
|
with self.assertRaisesRegex(
|
|
AttributeError,
|
|
r'The task has multiple outputs\. Please reference the output by its name\.'
|
|
):
|
|
local_task.output
|
|
|
|
def test_single_artifact_output_traditional(self):
|
|
local.init(runner=local.SubprocessRunner(use_venv=True))
|
|
|
|
@dsl.component
|
|
def artifact_maker(x: str, a: Output[Artifact]):
|
|
with open(a.path, 'w') as f:
|
|
f.write(x)
|
|
|
|
a.metadata['foo'] = 'bar'
|
|
|
|
actual = artifact_maker(x='hello').output
|
|
self.assertIsInstance(actual, Artifact)
|
|
self.assertEqual(actual.name, 'a')
|
|
self.assertTrue(actual.uri.endswith('/a'))
|
|
self.assertEqual(actual.metadata, {'foo': 'bar'})
|
|
with open(actual.path) as f:
|
|
contents = f.read()
|
|
self.assertEqual(contents, 'hello')
|
|
|
|
def test_single_artifact_output_pythonic(self):
|
|
local.init(runner=local.SubprocessRunner(use_venv=True))
|
|
|
|
@dsl.component
|
|
def artifact_maker(x: str) -> Artifact:
|
|
artifact = Artifact(
|
|
name='a', uri=dsl.get_uri('a'), metadata={'foo': 'bar'})
|
|
with open(artifact.path, 'w') as f:
|
|
f.write(x)
|
|
|
|
return artifact
|
|
|
|
actual = artifact_maker(x='hello').output
|
|
self.assertIsInstance(actual, Artifact)
|
|
self.assertEqual(actual.name, 'a')
|
|
self.assertTrue(actual.uri.endswith('/a'))
|
|
self.assertEqual(actual.metadata, {'foo': 'bar'})
|
|
with open(actual.path) as f:
|
|
contents = f.read()
|
|
self.assertEqual(contents, 'hello')
|
|
|
|
def test_multiple_artifact_outputs_traditional(self):
|
|
local.init(runner=local.SubprocessRunner(use_venv=True))
|
|
|
|
@dsl.component
|
|
def double_artifact_maker(
|
|
x: str,
|
|
y: str,
|
|
a: Output[Artifact],
|
|
b: Output[Dataset],
|
|
):
|
|
with open(a.path, 'w') as f:
|
|
f.write(x)
|
|
|
|
with open(b.path, 'w') as f:
|
|
f.write(y)
|
|
|
|
a.metadata['foo'] = 'bar'
|
|
b.metadata['baz'] = 'bat'
|
|
|
|
local_task = double_artifact_maker(x='hello', y='goodbye')
|
|
|
|
actual_a = local_task.outputs['a']
|
|
actual_b = local_task.outputs['b']
|
|
|
|
self.assertIsInstance(actual_a, Artifact)
|
|
self.assertEqual(actual_a.name, 'a')
|
|
self.assertTrue(actual_a.uri.endswith('/a'))
|
|
with open(actual_a.path) as f:
|
|
contents = f.read()
|
|
self.assertEqual(contents, 'hello')
|
|
self.assertEqual(actual_a.metadata, {'foo': 'bar'})
|
|
|
|
self.assertIsInstance(actual_b, Dataset)
|
|
self.assertEqual(actual_b.name, 'b')
|
|
self.assertTrue(actual_b.uri.endswith('/b'))
|
|
self.assertEqual(actual_b.metadata, {'baz': 'bat'})
|
|
with open(actual_b.path) as f:
|
|
contents = f.read()
|
|
self.assertEqual(contents, 'goodbye')
|
|
|
|
def test_multiple_artifact_outputs_pythonic(self):
|
|
local.init(runner=local.SubprocessRunner(use_venv=True))
|
|
|
|
@dsl.component
|
|
def double_artifact_maker(
|
|
x: str,
|
|
y: str,
|
|
) -> NamedTuple(
|
|
'Outputs', a=Artifact, b=Dataset):
|
|
a = Artifact(
|
|
name='a', uri=dsl.get_uri('a'), metadata={'foo': 'bar'})
|
|
b = Dataset(name='b', uri=dsl.get_uri('b'), metadata={'baz': 'bat'})
|
|
|
|
with open(a.path, 'w') as f:
|
|
f.write(x)
|
|
|
|
with open(b.path, 'w') as f:
|
|
f.write(y)
|
|
|
|
Outputs = NamedTuple('Outputs', a=Artifact, b=Dataset)
|
|
return Outputs(a=a, b=b)
|
|
|
|
local_task = double_artifact_maker(x='hello', y='goodbye')
|
|
|
|
actual_a = local_task.outputs['a']
|
|
actual_b = local_task.outputs['b']
|
|
|
|
self.assertIsInstance(actual_a, Artifact)
|
|
self.assertEqual(actual_a.name, 'a')
|
|
self.assertTrue(actual_a.uri.endswith('/a'))
|
|
with open(actual_a.path) as f:
|
|
contents = f.read()
|
|
self.assertEqual(contents, 'hello')
|
|
self.assertEqual(actual_a.metadata, {'foo': 'bar'})
|
|
|
|
self.assertIsInstance(actual_b, Dataset)
|
|
self.assertEqual(actual_b.name, 'b')
|
|
self.assertTrue(actual_b.uri.endswith('/b'))
|
|
with open(actual_b.path) as f:
|
|
contents = f.read()
|
|
self.assertEqual(contents, 'goodbye')
|
|
self.assertEqual(actual_b.metadata, {'baz': 'bat'})
|
|
|
|
def test_str_input_uses_default(self):
|
|
local.init(runner=local.SubprocessRunner(use_venv=True))
|
|
|
|
@dsl.component
|
|
def identity(x: str = 'hi') -> str:
|
|
return x
|
|
|
|
actual = identity().output
|
|
expected = 'hi'
|
|
self.assertIsInstance(actual, str)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_placeholder_default_resolved(self):
|
|
local.init(runner=local.SubprocessRunner(use_venv=True))
|
|
|
|
@dsl.component
|
|
def identity(x: str = dsl.PIPELINE_TASK_NAME_PLACEHOLDER) -> str:
|
|
return x
|
|
|
|
actual = identity().output
|
|
expected = 'identity'
|
|
self.assertIsInstance(actual, str)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_outputpath(self):
|
|
local.init(runner=local.SubprocessRunner(use_venv=True))
|
|
|
|
@dsl.component
|
|
def my_comp(out_param: dsl.OutputPath(str),) -> int:
|
|
with open(out_param, 'w') as f:
|
|
f.write('Hello' * 2)
|
|
return 1
|
|
|
|
task = my_comp()
|
|
|
|
self.assertEqual(task.outputs['out_param'], 'HelloHello')
|
|
self.assertEqual(task.outputs['Output'], 1)
|
|
|
|
def test_outputpath_result_not_written(self):
|
|
local.init(runner=local.SubprocessRunner(use_venv=True))
|
|
|
|
# use dsl.OutputPath(int) for more thorough testing
|
|
# want to ensure that the code that converts protobuf number to
|
|
# Python int permits unwritten outputs
|
|
@dsl.component
|
|
def my_comp(out_param: dsl.OutputPath(int)):
|
|
pass
|
|
|
|
task = my_comp()
|
|
self.assertEmpty(task.outputs)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|