pipelines/sdk/python/kfp/local/subprocess_task_handler_tes...

423 lines
14 KiB
Python

# Copyright 2023 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for subprocess_local_task_handler.py."""
import contextlib
import io
from typing import NamedTuple
import unittest
from unittest import mock
from absl.testing import parameterized
from kfp import dsl
from kfp import local
from kfp.dsl import Artifact
from kfp.dsl import Dataset
from kfp.dsl import Output
from kfp.local import subprocess_task_handler
from kfp.local import testing_utilities
# NOTE: When testing SubprocessRunner, use_venv=True throughout to avoid
# modifying current code under test.
# If the dsl.component mocks are modified in a way that makes them not work,
# the code may install kfp from PyPI rather from source. To mitigate the
# impact of such an error we should not install into the main test process'
# environment.
class TestSubprocessRunner(testing_utilities.LocalRunnerEnvironmentTestCase):
@mock.patch('sys.stdout', new_callable=io.StringIO)
def test_basic(self, mock_stdout):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def comp():
print('foobar!')
comp()
output = mock_stdout.getvalue().strip()
self.assertContainsSubsequence(output, 'foobar!')
def test_image_warning(self):
with self.assertWarnsRegex(
RuntimeWarning,
r"You may be attemping to run a task that uses custom or non-Python base image 'my_custom_image' in a Python environment\. This may result in incorrect dependencies and/or incorrect behavior\."
):
subprocess_task_handler.SubprocessTaskHandler(
image='my_custom_image',
# avoid catching the Container Component and
# Containerized Python Component validation errors
full_command=['kfp.dsl.executor_main'],
pipeline_root='pipeline_root',
runner=local.SubprocessRunner(use_venv=True),
)
def test_cannot_run_container_component(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.container_component
def comp():
return dsl.ContainerSpec(
image='alpine',
command=['echo'],
args=['foo'],
)
with self.assertRaisesRegex(
RuntimeError,
r'The SubprocessRunner only supports running Lightweight Python Components\. You are attempting to run a Container Component\.',
):
comp()
def test_cannot_run_containerized_python_component(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component(target_image='foo')
def comp():
pass
with self.assertRaisesRegex(
RuntimeError,
r'The SubprocessRunner only supports running Lightweight Python Components\. You are attempting to run a Containerized Python Component\.',
):
comp()
class TestRunLocalSubproces(unittest.TestCase):
def test_simple_program(self):
buffer = io.StringIO()
with contextlib.redirect_stdout(buffer):
subprocess_task_handler.run_local_subprocess([
'echo',
'foo!',
])
output = buffer.getvalue().strip()
self.assertEqual(output, 'foo!')
class TestUseCurrentPythonExecutable(
testing_utilities.LocalRunnerEnvironmentTestCase):
def test(self):
full_command = ['python3 -c "from kfp import dsl"']
actual = subprocess_task_handler.replace_python_executable(
full_command=full_command,
new_executable='/foo/bar/python3',
)
expected = ['/foo/bar/python3 -c "from kfp import dsl"']
self.assertEqual(actual, expected)
class TestUseVenv(testing_utilities.LocalRunnerEnvironmentTestCase):
@parameterized.parameters([
({
'runner': local.SubprocessRunner(use_venv=True),
}),
({
'runner': local.SubprocessRunner(use_venv=True),
}),
])
def test_use_venv_true(self, **kwargs):
local.init(**kwargs)
@dsl.component(packages_to_install=['cloudpickle'])
def installer_component():
import cloudpickle
print('Cloudpickle is installed:', cloudpickle)
installer_component()
# since the module was installed in the virtual environment, it should not exist in the current environment
with self.assertRaisesRegex(ModuleNotFoundError,
r"No module named 'cloudpickle'"):
import cloudpickle
class TestLightweightPythonComponentLogic(
testing_utilities.LocalRunnerEnvironmentTestCase):
def test_single_output_simple_case(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def identity(x: str) -> str:
return x
actual = identity(x='hello').output
expected = 'hello'
self.assertIsInstance(actual, str)
self.assertEqual(actual, expected)
def test_many_primitives_in_and_out(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def identity(
string: str,
integer: int,
decimal: float,
boolean: bool,
l: list,
d: dict,
) -> NamedTuple(
'Outputs',
string=str,
integer=int,
decimal=float,
boolean=bool,
l=list,
d=dict):
Outputs = NamedTuple(
'Outputs',
string=str,
integer=int,
decimal=float,
boolean=bool,
l=list,
d=dict)
return Outputs(
string=string,
integer=integer,
decimal=decimal,
boolean=boolean,
l=l,
d=d,
)
task = identity(
string='foo',
integer=1,
decimal=3.14,
boolean=True,
l=['a', 'b'],
d={'x': 'y'})
self.assertIsInstance(task.outputs['string'], str)
self.assertEqual(task.outputs['string'], 'foo')
self.assertIsInstance(task.outputs['integer'], int)
self.assertEqual(task.outputs['integer'], 1)
self.assertIsInstance(task.outputs['decimal'], float)
self.assertEqual(task.outputs['decimal'], 3.14)
self.assertIsInstance(task.outputs['boolean'], bool)
self.assertTrue(task.outputs['boolean'])
self.assertIsInstance(task.outputs['l'], list)
self.assertEqual(task.outputs['l'], ['a', 'b'])
self.assertIsInstance(task.outputs['d'], dict)
self.assertEqual(task.outputs['d'], {'x': 'y'})
def test_single_output_not_available(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
from typing import NamedTuple
@dsl.component
def return_twice(x: str) -> NamedTuple('Outputs', x=str, y=str):
Outputs = NamedTuple('Output', x=str, y=str)
return Outputs(x=x, y=x)
local_task = return_twice(x='foo')
with self.assertRaisesRegex(
AttributeError,
r'The task has multiple outputs\. Please reference the output by its name\.'
):
local_task.output
def test_single_artifact_output_traditional(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def artifact_maker(x: str, a: Output[Artifact]):
with open(a.path, 'w') as f:
f.write(x)
a.metadata['foo'] = 'bar'
actual = artifact_maker(x='hello').output
self.assertIsInstance(actual, Artifact)
self.assertEqual(actual.name, 'a')
self.assertTrue(actual.uri.endswith('/a'))
self.assertEqual(actual.metadata, {'foo': 'bar'})
with open(actual.path) as f:
contents = f.read()
self.assertEqual(contents, 'hello')
def test_single_artifact_output_pythonic(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def artifact_maker(x: str) -> Artifact:
artifact = Artifact(
name='a', uri=dsl.get_uri('a'), metadata={'foo': 'bar'})
with open(artifact.path, 'w') as f:
f.write(x)
return artifact
actual = artifact_maker(x='hello').output
self.assertIsInstance(actual, Artifact)
self.assertEqual(actual.name, 'a')
self.assertTrue(actual.uri.endswith('/a'))
self.assertEqual(actual.metadata, {'foo': 'bar'})
with open(actual.path) as f:
contents = f.read()
self.assertEqual(contents, 'hello')
def test_multiple_artifact_outputs_traditional(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def double_artifact_maker(
x: str,
y: str,
a: Output[Artifact],
b: Output[Dataset],
):
with open(a.path, 'w') as f:
f.write(x)
with open(b.path, 'w') as f:
f.write(y)
a.metadata['foo'] = 'bar'
b.metadata['baz'] = 'bat'
local_task = double_artifact_maker(x='hello', y='goodbye')
actual_a = local_task.outputs['a']
actual_b = local_task.outputs['b']
self.assertIsInstance(actual_a, Artifact)
self.assertEqual(actual_a.name, 'a')
self.assertTrue(actual_a.uri.endswith('/a'))
with open(actual_a.path) as f:
contents = f.read()
self.assertEqual(contents, 'hello')
self.assertEqual(actual_a.metadata, {'foo': 'bar'})
self.assertIsInstance(actual_b, Dataset)
self.assertEqual(actual_b.name, 'b')
self.assertTrue(actual_b.uri.endswith('/b'))
self.assertEqual(actual_b.metadata, {'baz': 'bat'})
with open(actual_b.path) as f:
contents = f.read()
self.assertEqual(contents, 'goodbye')
def test_multiple_artifact_outputs_pythonic(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def double_artifact_maker(
x: str,
y: str,
) -> NamedTuple(
'Outputs', a=Artifact, b=Dataset):
a = Artifact(
name='a', uri=dsl.get_uri('a'), metadata={'foo': 'bar'})
b = Dataset(name='b', uri=dsl.get_uri('b'), metadata={'baz': 'bat'})
with open(a.path, 'w') as f:
f.write(x)
with open(b.path, 'w') as f:
f.write(y)
Outputs = NamedTuple('Outputs', a=Artifact, b=Dataset)
return Outputs(a=a, b=b)
local_task = double_artifact_maker(x='hello', y='goodbye')
actual_a = local_task.outputs['a']
actual_b = local_task.outputs['b']
self.assertIsInstance(actual_a, Artifact)
self.assertEqual(actual_a.name, 'a')
self.assertTrue(actual_a.uri.endswith('/a'))
with open(actual_a.path) as f:
contents = f.read()
self.assertEqual(contents, 'hello')
self.assertEqual(actual_a.metadata, {'foo': 'bar'})
self.assertIsInstance(actual_b, Dataset)
self.assertEqual(actual_b.name, 'b')
self.assertTrue(actual_b.uri.endswith('/b'))
with open(actual_b.path) as f:
contents = f.read()
self.assertEqual(contents, 'goodbye')
self.assertEqual(actual_b.metadata, {'baz': 'bat'})
def test_str_input_uses_default(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def identity(x: str = 'hi') -> str:
return x
actual = identity().output
expected = 'hi'
self.assertIsInstance(actual, str)
self.assertEqual(actual, expected)
def test_placeholder_default_resolved(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def identity(x: str = dsl.PIPELINE_TASK_NAME_PLACEHOLDER) -> str:
return x
actual = identity().output
expected = 'identity'
self.assertIsInstance(actual, str)
self.assertEqual(actual, expected)
def test_outputpath(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def my_comp(out_param: dsl.OutputPath(str),) -> int:
with open(out_param, 'w') as f:
f.write('Hello' * 2)
return 1
task = my_comp()
self.assertEqual(task.outputs['out_param'], 'HelloHello')
self.assertEqual(task.outputs['Output'], 1)
def test_outputpath_result_not_written(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
# use dsl.OutputPath(int) for more thorough testing
# want to ensure that the code that converts protobuf number to
# Python int permits unwritten outputs
@dsl.component
def my_comp(out_param: dsl.OutputPath(int)):
pass
task = my_comp()
self.assertEmpty(task.outputs)
if __name__ == '__main__':
unittest.main()