update task dispatcher (#10298)

This commit is contained in:
Connor McCarthy 2023-12-12 14:23:49 -05:00 committed by GitHub
parent 227eab1c68
commit d41efc3e96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 731 additions and 44 deletions

View File

@ -16,6 +16,7 @@
import abc
from typing import List
from kfp.dsl import pipeline_context
from kfp.dsl import pipeline_task
from kfp.dsl import structures
from kfp.dsl.types import type_utils
@ -100,6 +101,8 @@ class BaseComponent(abc.ABC):
return pipeline_task.PipelineTask(
component_spec=self.component_spec,
args=task_inputs,
execute_locally=pipeline_context.Pipeline.get_default_pipeline() is
None,
)
@property

View File

@ -14,10 +14,8 @@
"""Tests for kfp.dsl.base_component."""
import unittest
from unittest.mock import patch
from kfp import dsl
from kfp.dsl import pipeline_task
from kfp.dsl import placeholders
from kfp.dsl import python_component
from kfp.dsl import structures
@ -59,34 +57,6 @@ component_op = python_component.PythonComponent(
class BaseComponentTest(unittest.TestCase):
@patch.object(pipeline_task, 'PipelineTask', autospec=True)
def test_instantiate_component_with_keyword_arguments(
self, mock_PipelineTask):
component_op(input1='hello', input2=100, input3=1.23, input4=3.21)
mock_PipelineTask.assert_called_once_with(
component_spec=component_op.component_spec,
args={
'input1': 'hello',
'input2': 100,
'input3': 1.23,
'input4': 3.21,
})
@patch.object(pipeline_task, 'PipelineTask', autospec=True)
def test_instantiate_component_omitting_arguments_with_default(
self, mock_PipelineTask):
component_op(input1='hello', input2=100)
mock_PipelineTask.assert_called_once_with(
component_spec=component_op.component_spec,
args={
'input1': 'hello',
'input2': 100,
})
def test_instantiate_component_with_positional_arugment(self):
with self.assertRaisesRegex(
TypeError,

View File

@ -28,7 +28,6 @@ from kfp.dsl import placeholders
from kfp.dsl import structures
from kfp.dsl import utils
from kfp.dsl.types import type_utils
from kfp.local import task_dispatcher
from kfp.pipeline_spec import pipeline_spec_pb2
TEMPORARILY_BLOCK_LOCAL_EXECUTION = True
@ -99,7 +98,8 @@ class PipelineTask:
self,
component_spec: structures.ComponentSpec,
args: Dict[str, Any],
):
execute_locally: bool = False,
) -> None:
"""Initilizes a PipelineTask instance."""
# import within __init__ to avoid circular import
from kfp.dsl.tasks_group import TasksGroup
@ -181,21 +181,27 @@ class PipelineTask:
if not isinstance(value, pipeline_channel.PipelineChannel)
])
from kfp.dsl import pipeline_context
if execute_locally:
self._execute_locally(args=args)
# TODO: remove feature flag
if not TEMPORARILY_BLOCK_LOCAL_EXECUTION and pipeline_context.Pipeline.get_default_pipeline(
) is None:
self._execute_locally()
def _execute_locally(self) -> None:
def _execute_locally(self, args: Dict[str, Any]) -> None:
"""Execute the pipeline task locally.
Set the task state to FINAL and update the outputs.
"""
from kfp.local import task_dispatcher
if self.pipeline_spec is not None:
raise NotImplementedError(
'Local pipeline execution is not currently supported.')
# TODO: remove feature flag
if TEMPORARILY_BLOCK_LOCAL_EXECUTION:
return
self._outputs = task_dispatcher.run_single_component(
pipeline_spec=self.pipeline_spec,
arguments=self.args,
pipeline_spec=self.component_spec.to_pipeline_spec(),
arguments=args,
)
self.state = TaskState.FINAL

View File

@ -227,6 +227,10 @@ def _get_type_string_from_component_argument(
if argument_type in _TYPE_TO_TYPE_NAME:
return _TYPE_TO_TYPE_NAME[argument_type]
if isinstance(argument_value, artifact_types.Artifact):
raise ValueError(
f'Input artifacts are not supported. Got input artifact of type {argument_value.__class__.__name__!r}.'
)
raise ValueError(
f'Constant argument inputs must be one of type {list(_TYPE_TO_TYPE_NAME.values())} Got: {argument_value!r} of type {type(argument_value)!r}.'
)

View File

@ -20,7 +20,7 @@ import sys
import types
from typing import List
_COMPONENT_NAME_PREFIX = 'comp-'
COMPONENT_NAME_PREFIX = 'comp-'
_EXECUTOR_LABEL_PREFIX = 'exec-'
@ -69,7 +69,7 @@ def sanitize_input_name(name: str) -> str:
def sanitize_component_name(name: str) -> str:
"""Sanitizes component name."""
return _COMPONENT_NAME_PREFIX + maybe_rename_for_k8s(name)
return COMPONENT_NAME_PREFIX + maybe_rename_for_k8s(name)
def sanitize_task_name(name: str) -> str:

View File

@ -12,9 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Objects for configuring local execution."""
import abc
import dataclasses
class LocalRunnerType(abc.ABC):
"""The ABC for user-facing Runner configurations.
Subclasses should be a dataclass.
They should implement a .validate() method.
"""
@abc.abstractmethod
def validate(self) -> None:
"""Validates that the configuration arguments provided by the user are
valid."""
raise NotImplementedError
@dataclasses.dataclass
class SubprocessRunner:
"""Runner that indicates that local tasks should be run in a subprocess.

View File

@ -0,0 +1,132 @@
# Copyright 2023 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for constructing the ExecutorInput message."""
import datetime
import os
from typing import Any, Dict
from kfp.compiler import pipeline_spec_builder
from kfp.dsl import utils
from kfp.pipeline_spec import pipeline_spec_pb2
_EXECUTOR_OUTPUT_FILE = 'executor_output.json'
def construct_executor_input(
component_spec: pipeline_spec_pb2.ComponentSpec,
arguments: Dict[str, Any],
task_root: str,
) -> pipeline_spec_pb2.ExecutorInput:
"""Constructs the executor input message for a task execution."""
input_parameter_keys = list(
component_spec.input_definitions.parameters.keys())
input_artifact_keys = list(
component_spec.input_definitions.artifacts.keys())
if input_artifact_keys:
raise ValueError(
'Input artifacts are not yet supported for local execution.')
output_parameter_keys = list(
component_spec.output_definitions.parameters.keys())
output_artifact_specs_dict = component_spec.output_definitions.artifacts
inputs = pipeline_spec_pb2.ExecutorInput.Inputs(
parameter_values={
param_name:
pipeline_spec_builder.to_protobuf_value(arguments[param_name])
if param_name in arguments else component_spec.input_definitions
.parameters[param_name].default_value
for param_name in input_parameter_keys
},
# input artifact constants are not supported yet
artifacts={},
)
outputs = pipeline_spec_pb2.ExecutorInput.Outputs(
parameters={
param_name: pipeline_spec_pb2.ExecutorInput.OutputParameter(
output_file=os.path.join(task_root, param_name))
for param_name in output_parameter_keys
},
artifacts={
artifact_name: make_artifact_list(
name=artifact_name,
artifact_type=artifact_spec.artifact_type,
task_root=task_root,
) for artifact_name, artifact_spec in
output_artifact_specs_dict.items()
},
output_file=os.path.join(task_root, _EXECUTOR_OUTPUT_FILE),
)
return pipeline_spec_pb2.ExecutorInput(
inputs=inputs,
outputs=outputs,
)
def get_local_pipeline_resource_name(pipeline_name: str) -> str:
"""Gets the local pipeline resource name from the pipeline name in
PipelineSpec.
Args:
pipeline_name: The pipeline name provided by PipelineSpec.pipelineInfo.name.
Returns:
The local pipeline resource name. Includes timestamp.
"""
timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')
return f'{pipeline_name}-{timestamp}'
def get_local_task_resource_name(component_name: str) -> str:
"""Gets the local task resource name from the component name in
PipelineSpec.
Args:
component_name: The component name provided as the key for the component's ComponentSpec
message. Takes the form comp-*.
Returns:
The local task resource name.
"""
return component_name[len(utils.COMPONENT_NAME_PREFIX):]
def construct_local_task_root(
pipeline_root: str,
pipeline_resource_name: str,
task_resource_name: str,
) -> str:
"""Constructs the local task root directory for a task."""
return os.path.join(
pipeline_root,
pipeline_resource_name,
task_resource_name,
)
def make_artifact_list(
name: str,
artifact_type: pipeline_spec_pb2.ArtifactTypeSchema,
task_root: str,
) -> pipeline_spec_pb2.ArtifactList:
"""Constructs an ArtifactList instance for an artifact in ExecutorInput."""
return pipeline_spec_pb2.ArtifactList(artifacts=[
pipeline_spec_pb2.RuntimeArtifact(
name=name,
type=artifact_type,
uri=os.path.join(task_root, name),
# metadata always starts empty for output artifacts
metadata={},
)
])

View File

@ -0,0 +1,198 @@
# Copyright 2023 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for executor_input_utils.py."""
import unittest
from google.protobuf import json_format
from kfp.local import executor_input_utils
from kfp.local import testing_utilities
from kfp.pipeline_spec import pipeline_spec_pb2
class GetLocalPipelineResourceName(testing_utilities.MockedDatetimeTestCase):
def test(self):
actual = executor_input_utils.get_local_pipeline_resource_name(
'my-pipeline')
expected = 'my-pipeline-2023-10-10-13-32-59-420710'
self.assertEqual(actual, expected)
class GetLocalTaskResourceName(unittest.TestCase):
def test(self):
actual = executor_input_utils.get_local_task_resource_name(
'comp-my-comp')
expected = 'my-comp'
self.assertEqual(actual, expected)
class TestConstructLocalTaskRoot(testing_utilities.MockedDatetimeTestCase):
def test(self):
task_root = executor_input_utils.construct_local_task_root(
pipeline_root='/foo/bar',
pipeline_resource_name='my-pipeline-2023-10-10-13-32-59-420710',
task_resource_name='my-comp',
)
self.assertEqual(
task_root,
'/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/my-comp',
)
class TestConstructExecutorInput(unittest.TestCase):
def test_no_inputs(self):
component_spec = pipeline_spec_pb2.ComponentSpec()
json_format.ParseDict(
{
'outputDefinitions': {
'parameters': {
'Output': {
'parameterType': 'STRING'
}
}
},
'executorLabel': 'exec-comp'
}, component_spec)
arguments = {}
task_root = '/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp'
actual = executor_input_utils.construct_executor_input(
component_spec=component_spec,
arguments=arguments,
task_root=task_root,
)
expected = pipeline_spec_pb2.ExecutorInput()
json_format.ParseDict(
{
'inputs': {},
'outputs': {
'parameters': {
'Output': {
'outputFile':
'/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp/Output'
}
},
'outputFile':
'/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp/executor_output.json'
}
}, expected)
self.assertEqual(actual, expected)
def test_various_io_types(self):
component_spec = pipeline_spec_pb2.ComponentSpec()
json_format.ParseDict(
{
'inputDefinitions': {
'parameters': {
'boolean': {
'parameterType': 'BOOLEAN'
}
}
},
'outputDefinitions': {
'artifacts': {
'out_a': {
'artifactType': {
'schemaTitle': 'system.Dataset',
'schemaVersion': '0.0.1'
}
}
},
'parameters': {
'Output': {
'parameterType': 'NUMBER_INTEGER'
}
}
},
'executorLabel': 'exec-comp'
}, component_spec)
arguments = {'boolean': False}
task_root = '/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp'
actual = executor_input_utils.construct_executor_input(
component_spec=component_spec,
arguments=arguments,
task_root=task_root,
)
expected = pipeline_spec_pb2.ExecutorInput()
json_format.ParseDict(
{
'inputs': {
'parameterValues': {
'boolean': False
}
},
'outputs': {
'parameters': {
'Output': {
'outputFile':
'/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp/Output'
}
},
'artifacts': {
'out_a': {
'artifacts': [{
'name':
'out_a',
'type': {
'schemaTitle': 'system.Dataset',
'schemaVersion': '0.0.1'
},
'uri':
'/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp/out_a',
'metadata': {}
}]
}
},
'outputFile':
'/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp/executor_output.json'
}
}, expected)
self.assertEqual(actual, expected)
def test_input_artifacts_not_yet_supported(self):
component_spec = pipeline_spec_pb2.ComponentSpec()
json_format.ParseDict(
{
'inputDefinitions': {
'artifacts': {
'in_artifact': {
'artifactType': {
'schemaTitle': 'system.Artifact',
'schemaVersion': '0.0.1'
}
}
}
},
'executorLabel': 'exec-comp'
}, component_spec)
arguments = {}
task_root = '/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp'
with self.assertRaisesRegex(
ValueError,
'Input artifacts are not yet supported for local execution.'):
executor_input_utils.construct_executor_input(
component_spec=component_spec,
arguments=arguments,
task_root=task_root,
)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,74 @@
# Copyright 2023 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for working with placeholders."""
import random
from typing import List
from google.protobuf import json_format
from kfp import dsl
from kfp.pipeline_spec import pipeline_spec_pb2
def make_random_id():
"""Makes a random 8 digit integer."""
return str(random.randint(0, 99999999))
def replace_placeholders(
full_command: List[str],
executor_input: str,
pipeline_resource_name: str,
task_resource_name: str,
pipeline_root: str,
) -> List[str]:
"""Iterates over each element in the command and replaces placeholders."""
unique_pipeline_id = make_random_id()
unique_task_id = make_random_id()
return [
replace_placeholder_for_element(
element=el,
executor_input=executor_input,
pipeline_resource_name=pipeline_resource_name,
task_resource_name=task_resource_name,
pipeline_root=pipeline_root,
pipeline_job_id=unique_pipeline_id,
pipeline_task_id=unique_task_id,
) for el in full_command
]
def replace_placeholder_for_element(
element: str,
executor_input: pipeline_spec_pb2.ExecutorInput,
pipeline_resource_name: str,
task_resource_name: str,
pipeline_root: str,
pipeline_job_id: str,
pipeline_task_id: str,
) -> str:
"""Replaces placeholders for a single element."""
PLACEHOLDERS = {
r'{{$.outputs.output_file}}': executor_input.outputs.output_file,
r'{{$.outputMetadataUri}}': executor_input.outputs.output_file,
r'{{$}}': json_format.MessageToJson(executor_input),
dsl.PIPELINE_JOB_NAME_PLACEHOLDER: pipeline_resource_name,
dsl.PIPELINE_JOB_ID_PLACEHOLDER: pipeline_job_id,
dsl.PIPELINE_TASK_NAME_PLACEHOLDER: task_resource_name,
dsl.PIPELINE_TASK_ID_PLACEHOLDER: pipeline_task_id,
dsl.PIPELINE_ROOT_PLACEHOLDER: pipeline_root,
}
for placeholder, value in PLACEHOLDERS.items():
element = element.replace(placeholder, value)
return element

View File

@ -14,6 +14,10 @@
"""Code for dispatching a local task execution."""
from typing import Any, Dict
from kfp import local
from kfp.local import config
from kfp.local import executor_input_utils
from kfp.local import placeholder_utils
from kfp.pipeline_spec import pipeline_spec_pb2
@ -30,5 +34,61 @@ def run_single_component(
Returns:
A LocalTask instance.
"""
# TODO: implement and return outputs
if config.LocalExecutionConfig.instance is None:
raise RuntimeError(
f"Local environment not initialized. Please run '{local.__name__}.{local.init.__name__}()' before executing tasks locally."
)
return _run_single_component_implementation(
pipeline_spec=pipeline_spec,
arguments=arguments,
pipeline_root=config.LocalExecutionConfig.instance.pipeline_root,
runner=config.LocalExecutionConfig.instance.runner,
)
def _run_single_component_implementation(
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
arguments: Dict[str, Any],
pipeline_root: str,
runner: config.LocalRunnerType,
) -> Dict[str, Any]:
"""The implementation of a single component runner."""
component_name, component_spec = list(pipeline_spec.components.items())[0]
pipeline_resource_name = executor_input_utils.get_local_pipeline_resource_name(
pipeline_spec.pipeline_info.name)
task_resource_name = executor_input_utils.get_local_task_resource_name(
component_name)
task_root = executor_input_utils.construct_local_task_root(
pipeline_root=pipeline_root,
pipeline_resource_name=pipeline_resource_name,
task_resource_name=task_resource_name,
)
executor_input = executor_input_utils.construct_executor_input(
component_spec=component_spec,
arguments=arguments,
task_root=task_root,
)
executor_spec = pipeline_spec.deployment_spec['executors'][
component_spec.executor_label]
container = executor_spec['container']
full_command = list(container['command']) + list(container['args'])
# image + full_command are "inputs" to local execution
image = container['image']
# TODO: handler container component placeholders when
# ContainerRunner is implemented
full_command = placeholder_utils.replace_placeholders(
full_command=full_command,
executor_input=executor_input,
pipeline_resource_name=pipeline_resource_name,
task_resource_name=task_resource_name,
pipeline_root=pipeline_root,
)
# TODO: call task handler and return outputs
return {}

View File

@ -0,0 +1,163 @@
# Copyright 2023 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for task_dispatcher.py."""
import unittest
from absl.testing import parameterized
from kfp import dsl
from kfp import local
from kfp.dsl import Artifact
from kfp.local import testing_utilities
class TestLocalExecutionValidation(
testing_utilities.LocalRunnerEnvironmentTestCase):
def test_env_not_initialized(self):
@dsl.component
def identity(x: str) -> str:
return x
with self.assertRaisesRegex(
RuntimeError,
r"Local environment not initialized\. Please run 'kfp\.local\.init\(\)' before executing tasks locally\."
):
identity(x='foo')
@parameterized.parameters([
(local.SubprocessRunner(use_venv=False),),
(local.SubprocessRunner(use_venv=True),),
])
class TestArgumentValidation(parameterized.TestCase):
def test_no_argument_no_default(self, runner):
local.init(runner=runner)
@dsl.component
def identity(x: str) -> str:
return x
with self.assertRaisesRegex(
TypeError, r'identity\(\) missing 1 required argument: x'):
identity()
def test_default_wrong_type(self, runner):
local.init(runner=runner)
@dsl.component
def identity(x: str) -> str:
return x
with self.assertRaisesRegex(
dsl.types.type_utils.InconsistentTypeException,
r"Incompatible argument passed to the input 'x' of component 'identity': Argument type 'NUMBER_INTEGER' is incompatible with the input type 'STRING'"
):
identity(x=1)
def test_extra_argument(self, runner):
local.init(runner=runner)
@dsl.component
def identity(x: str) -> str:
return x
with self.assertRaisesRegex(
TypeError,
r'identity\(\) got an unexpected keyword argument "y"\.'):
identity(x='foo', y='bar')
def test_input_artifact_provided(self, runner):
local.init(runner=runner)
@dsl.component
def identity(a: Artifact) -> Artifact:
return a
with self.assertRaisesRegex(
ValueError,
r"Input artifacts are not supported. Got input artifact of type 'Artifact'."
):
identity(a=Artifact(name='a', uri='gs://bucket/foo'))
@parameterized.parameters([
(local.SubprocessRunner(use_venv=False),),
(local.SubprocessRunner(use_venv=True),),
])
class TestLocalPipelineBlocked(testing_utilities.LocalRunnerEnvironmentTestCase
):
def test_local_pipeline_unsupported_two_tasks(self, runner):
local.init(runner=runner)
@dsl.component
def identity(string: str) -> str:
return string
@dsl.pipeline
def my_pipeline():
identity(string='foo')
identity(string='bar')
# compile and load into a YamlComponent to ensure the NotImplementedError isn't simply being thrown because this is a GraphComponent
my_pipeline = testing_utilities.compile_and_load_component(my_pipeline)
with self.assertRaisesRegex(
NotImplementedError,
'Local pipeline execution is not currently supported\.',
):
my_pipeline()
def test_local_pipeline_unsupported_one_task_different_interface(
self, runner):
local.init(runner=runner)
@dsl.component
def identity(string: str) -> str:
return string
@dsl.pipeline
def my_pipeline():
identity(string='foo')
# compile and load into a YamlComponent to ensure the NotImplementedError isn't simply being thrown because this is a GraphComponent
my_pipeline = testing_utilities.compile_and_load_component(my_pipeline)
with self.assertRaisesRegex(
NotImplementedError,
'Local pipeline execution is not currently supported\.',
):
my_pipeline()
def test_local_pipeline_unsupported_if_is_graph_component(self, runner):
local.init(runner=runner)
@dsl.component
def identity(string: str) -> str:
return string
# even if there is one task with the same interface as the pipeline, the code should catch that the pipeline is a GraphComponent and throw the NotImplementedError
@dsl.pipeline
def my_pipeline(string: str) -> str:
return identity(string=string).output
with self.assertRaisesRegex(
NotImplementedError,
'Local pipeline execution is not currently supported\.',
):
my_pipeline(string='foo')
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,61 @@
# Copyright 2023 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for testing local execution."""
import datetime
import unittest
from unittest import mock
from absl.testing import parameterized
from google.protobuf import json_format
from kfp import components
from kfp import dsl
from kfp.local import config as local_config
class LocalRunnerEnvironmentTestCase(parameterized.TestCase):
def setUp(self):
from kfp.dsl import pipeline_task
pipeline_task.TEMPORARILY_BLOCK_LOCAL_EXECUTION = False
# start each test case without an uninitialized environment
local_config.LocalExecutionConfig.instance = None
def tearDown(self) -> None:
from kfp.dsl import pipeline_task
pipeline_task.TEMPORARILY_BLOCK_LOCAL_EXECUTION = True
class MockedDatetimeTestCase(unittest.TestCase):
def setUp(self):
# set up patch, cleanup, and start
patcher = mock.patch('kfp.local.executor_input_utils.datetime.datetime')
self.addCleanup(patcher.stop)
self.mock_datetime = patcher.start()
# set mock return values
mock_now = mock.MagicMock(
wraps=datetime.datetime(2023, 10, 10, 13, 32, 59, 420710))
self.mock_datetime.now.return_value = mock_now
mock_now.strftime.return_value = '2023-10-10-13-32-59-420710'
def compile_and_load_component(
base_component: dsl.base_component.BaseComponent,
) -> dsl.yaml_component.YamlComponent:
"""Compiles a component to PipelineSpec and reloads it as a
YamlComponent."""
return components.load_component_from_text(
json_format.MessageToJson(base_component.pipeline_spec))