pipelines/sdk/python/kfp/local/task_dispatcher_test.py

319 lines
11 KiB
Python
Executable File

# Copyright 2023 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for task_dispatcher.py.
The difference between these tests and the E2E test are that E2E tests
focus on how the runner should behave to be local execution conformant,
whereas these tests focus on how the task dispatcher should behave,
irrespective of the runner. While there will inevitably some overlap, we
should seek to minimize it.
"""
import io
import os
import re
import unittest
from unittest import mock
from absl.testing import parameterized
from kfp import dsl
from kfp import local
from kfp.dsl import Artifact
from kfp.dsl import Model
from kfp.dsl import Output
from kfp.local import testing_utilities
# NOTE: uses SubprocessRunner throughout to test the taks dispatcher behavior
# NOTE: When testing SubprocessRunner, use_venv=True throughout to avoid
# modifying current code under test.
# If the dsl.component mocks are modified in a way that makes them not work,
# the code may install kfp from PyPI rather from source. To mitigate the
# impact of such an error we should not install into the main test process'
# environment.
class TestLocalExecutionValidation(
testing_utilities.LocalRunnerEnvironmentTestCase):
def test_env_not_initialized(self):
@dsl.component
def identity(x: str) -> str:
return x
with self.assertRaisesRegex(
RuntimeError,
r"Local environment not initialized\. Please run 'kfp\.local\.init\(\)' before executing tasks locally\."
):
identity(x='foo')
class TestArgumentValidation(parameterized.TestCase):
def test_no_argument_no_default(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def identity(x: str) -> str:
return x
with self.assertRaisesRegex(
TypeError, r'identity\(\) missing 1 required argument: x'):
identity()
def test_default_wrong_type(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def identity(x: str) -> str:
return x
with self.assertRaisesRegex(
dsl.types.type_utils.InconsistentTypeException,
r"Incompatible argument passed to the input 'x' of component 'identity': Argument type 'NUMBER_INTEGER' is incompatible with the input type 'STRING'"
):
identity(x=1)
def test_extra_argument(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def identity(x: str) -> str:
return x
with self.assertRaisesRegex(
TypeError,
r'identity\(\) got an unexpected keyword argument "y"\.'):
identity(x='foo', y='bar')
def test_input_artifact_provided(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def artifact_identity(a: Artifact) -> Artifact:
return a
with self.assertRaisesRegex(
ValueError,
r"Input artifacts are not supported. Got input artifact of type 'Artifact'."
):
artifact_identity(a=Artifact(name='a', uri='gs://bucket/foo'))
class TestSupportOfComponentTypes(
testing_utilities.LocalRunnerEnvironmentTestCase):
def test_local_pipeline_unsupported_two_tasks(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def identity(x: str) -> str:
return x
@dsl.pipeline
def my_pipeline():
identity(x='foo')
identity(x='bar')
# compile and load into a YamlComponent to ensure the NotImplementedError isn't simply being thrown because this is a GraphComponent
my_pipeline = testing_utilities.compile_and_load_component(my_pipeline)
with self.assertRaisesRegex(
NotImplementedError,
r'Local pipeline execution is not currently supported\.',
):
my_pipeline()
def test_local_pipeline_unsupported_one_task_different_interface(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def identity(x: str) -> str:
return x
@dsl.pipeline
def my_pipeline():
identity(x='foo')
# compile and load into a YamlComponent to ensure the NotImplementedError isn't simply being thrown because this is a GraphComponent
my_pipeline = testing_utilities.compile_and_load_component(my_pipeline)
with self.assertRaisesRegex(
NotImplementedError,
r'Local pipeline execution is not currently supported\.',
):
my_pipeline()
def test_local_pipeline_unsupported_if_is_graph_component(self):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def identity(x: str) -> str:
return x
# even if there is one task with the same interface as the pipeline, the code should catch that the pipeline is a GraphComponent and throw the NotImplementedError
@dsl.pipeline
def my_pipeline(string: str) -> str:
return identity(x=string).output
with self.assertRaisesRegex(
NotImplementedError,
r'Local pipeline execution is not currently supported\.',
):
my_pipeline(string='foo')
def test_can_run_loaded_component(self):
# use venv to avoid installing non-local KFP into test process
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def identity(x: str) -> str:
return x
loaded_identity = testing_utilities.compile_and_load_component(identity)
actual = loaded_identity(x='hello').output
expected = 'hello'
# since == is overloaded for dsl.Condition, if local execution is not
# "hit", then actual will be a channel and actual == expected evaluates
# to ConditionOperation. Since ConditionOperation is truthy,
# this may result in a false negative test result. For this reason,
# we perform an isinstance check first.
self.assertIsInstance(actual, str)
self.assertEqual(actual, expected)
class TestExceptionHandlingAndLogging(
testing_utilities.LocalRunnerEnvironmentTestCase):
@mock.patch('sys.stdout', new_callable=io.StringIO)
def test_user_code_throws_exception_if_raise_on_error(self, mock_stdout):
local.init(
runner=local.SubprocessRunner(use_venv=True),
raise_on_error=True,
)
@dsl.component
def fail_comp():
raise Exception('String to match on')
with self.assertRaisesRegex(
RuntimeError,
r"Task \x1b\[96m'fail-comp'\x1b\[0m finished with status \x1b\[91mFAILURE\x1b\[0m",
):
fail_comp()
self.assertIn(
'Exception: String to match on',
mock_stdout.getvalue(),
)
@mock.patch('sys.stdout', new_callable=io.StringIO)
def test_user_code_no_exception_if_not_raise_on_error(self, mock_stdout):
local.init(
runner=local.SubprocessRunner(use_venv=True),
raise_on_error=False,
)
@dsl.component
def fail_comp():
raise Exception('String to match on')
task = fail_comp()
self.assertDictEqual(task.outputs, {})
self.assertRegex(
mock_stdout.getvalue(),
r"\d+:\d+:\d+\.\d+ - ERROR - Task \x1b\[96m'fail-comp'\x1b\[0m finished with status \x1b\[91mFAILURE\x1b\[0m",
)
self.assertIn(
'Exception: String to match on',
mock_stdout.getvalue(),
)
@mock.patch('sys.stdout', new_callable=io.StringIO)
def test_all_logs(self, mock_stdout):
local.init(runner=local.SubprocessRunner(use_venv=True))
@dsl.component
def many_type_component(
num: int,
model: Output[Model],
) -> str:
print('Inside of my component!')
model.metadata['foo'] = 'bar'
return 'hello' * num
many_type_component(num=2)
# inner process logs correctly nested inside outer process logs
outer_log_regex_sections = [
r"\d+:\d+:\d+\.\d+ - INFO - Executing task \x1b\[96m'many-type-component'\x1b\[0m\n",
r'\d+:\d+:\d+\.\d+ - INFO - Streamed logs:\n\n',
r'.*',
r'Looking for component ',
r'.*',
r'Loading KFP component ',
r'.*',
r'Got executor_input:',
r'.*',
r'Inside of my component!',
r'.*',
r'Wrote executor output file to',
r'.*',
r"\d+:\d+:\d+\.\d+ - INFO - Task \x1b\[96m'many-type-component'\x1b\[0m finished with status \x1b\[92mSUCCESS\x1b\[0m\n",
r"\d+:\d+:\d+\.\d+ - INFO - Task \x1b\[96m'many-type-component'\x1b\[0m outputs:\n Output: 'hellohello'\n model: Model\( name='model',\n uri='[a-zA-Z0-9/_\.-]+/local_outputs/many-type-component-\d+-\d+-\d+-\d+-\d+-\d+-\d+/many-type-component/model',\n metadata={'foo': 'bar'} \)\n\n",
]
self.assertRegex(
mock_stdout.getvalue(),
# use dotall os that .* include newline characters
re.compile(''.join(outer_log_regex_sections), re.DOTALL),
)
class TestPipelineRootPaths(testing_utilities.LocalRunnerEnvironmentTestCase):
def test_relpath(self):
local.init(
runner=local.SubprocessRunner(use_venv=True),
pipeline_root='relpath_root')
# define in test to force install from source
@dsl.component
def identity(x: str) -> str:
return x
task = identity(x='foo')
self.assertIsInstance(task.output, str)
self.assertEqual(task.output, 'foo')
def test_abspath(self):
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
local.init(
runner=local.SubprocessRunner(use_venv=True),
pipeline_root=os.path.join(tmpdir, 'asbpath_root'))
# define in test to force install from source
@dsl.component
def identity(x: str) -> str:
return x
task = identity(x='foo')
self.assertIsInstance(task.output, str)
self.assertEqual(task.output, 'foo')
if __name__ == '__main__':
unittest.main()