update task dispatcher (#10298)
This commit is contained in:
parent
227eab1c68
commit
d41efc3e96
|
|
@ -16,6 +16,7 @@
|
|||
import abc
|
||||
from typing import List
|
||||
|
||||
from kfp.dsl import pipeline_context
|
||||
from kfp.dsl import pipeline_task
|
||||
from kfp.dsl import structures
|
||||
from kfp.dsl.types import type_utils
|
||||
|
|
@ -100,6 +101,8 @@ class BaseComponent(abc.ABC):
|
|||
return pipeline_task.PipelineTask(
|
||||
component_spec=self.component_spec,
|
||||
args=task_inputs,
|
||||
execute_locally=pipeline_context.Pipeline.get_default_pipeline() is
|
||||
None,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -14,10 +14,8 @@
|
|||
"""Tests for kfp.dsl.base_component."""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from kfp import dsl
|
||||
from kfp.dsl import pipeline_task
|
||||
from kfp.dsl import placeholders
|
||||
from kfp.dsl import python_component
|
||||
from kfp.dsl import structures
|
||||
|
|
@ -59,34 +57,6 @@ component_op = python_component.PythonComponent(
|
|||
|
||||
class BaseComponentTest(unittest.TestCase):
|
||||
|
||||
@patch.object(pipeline_task, 'PipelineTask', autospec=True)
|
||||
def test_instantiate_component_with_keyword_arguments(
|
||||
self, mock_PipelineTask):
|
||||
|
||||
component_op(input1='hello', input2=100, input3=1.23, input4=3.21)
|
||||
|
||||
mock_PipelineTask.assert_called_once_with(
|
||||
component_spec=component_op.component_spec,
|
||||
args={
|
||||
'input1': 'hello',
|
||||
'input2': 100,
|
||||
'input3': 1.23,
|
||||
'input4': 3.21,
|
||||
})
|
||||
|
||||
@patch.object(pipeline_task, 'PipelineTask', autospec=True)
|
||||
def test_instantiate_component_omitting_arguments_with_default(
|
||||
self, mock_PipelineTask):
|
||||
|
||||
component_op(input1='hello', input2=100)
|
||||
|
||||
mock_PipelineTask.assert_called_once_with(
|
||||
component_spec=component_op.component_spec,
|
||||
args={
|
||||
'input1': 'hello',
|
||||
'input2': 100,
|
||||
})
|
||||
|
||||
def test_instantiate_component_with_positional_arugment(self):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ from kfp.dsl import placeholders
|
|||
from kfp.dsl import structures
|
||||
from kfp.dsl import utils
|
||||
from kfp.dsl.types import type_utils
|
||||
from kfp.local import task_dispatcher
|
||||
from kfp.pipeline_spec import pipeline_spec_pb2
|
||||
|
||||
TEMPORARILY_BLOCK_LOCAL_EXECUTION = True
|
||||
|
|
@ -99,7 +98,8 @@ class PipelineTask:
|
|||
self,
|
||||
component_spec: structures.ComponentSpec,
|
||||
args: Dict[str, Any],
|
||||
):
|
||||
execute_locally: bool = False,
|
||||
) -> None:
|
||||
"""Initilizes a PipelineTask instance."""
|
||||
# import within __init__ to avoid circular import
|
||||
from kfp.dsl.tasks_group import TasksGroup
|
||||
|
|
@ -181,21 +181,27 @@ class PipelineTask:
|
|||
if not isinstance(value, pipeline_channel.PipelineChannel)
|
||||
])
|
||||
|
||||
from kfp.dsl import pipeline_context
|
||||
if execute_locally:
|
||||
self._execute_locally(args=args)
|
||||
|
||||
# TODO: remove feature flag
|
||||
if not TEMPORARILY_BLOCK_LOCAL_EXECUTION and pipeline_context.Pipeline.get_default_pipeline(
|
||||
) is None:
|
||||
self._execute_locally()
|
||||
|
||||
def _execute_locally(self) -> None:
|
||||
def _execute_locally(self, args: Dict[str, Any]) -> None:
|
||||
"""Execute the pipeline task locally.
|
||||
|
||||
Set the task state to FINAL and update the outputs.
|
||||
"""
|
||||
from kfp.local import task_dispatcher
|
||||
|
||||
if self.pipeline_spec is not None:
|
||||
raise NotImplementedError(
|
||||
'Local pipeline execution is not currently supported.')
|
||||
|
||||
# TODO: remove feature flag
|
||||
if TEMPORARILY_BLOCK_LOCAL_EXECUTION:
|
||||
return
|
||||
|
||||
self._outputs = task_dispatcher.run_single_component(
|
||||
pipeline_spec=self.pipeline_spec,
|
||||
arguments=self.args,
|
||||
pipeline_spec=self.component_spec.to_pipeline_spec(),
|
||||
arguments=args,
|
||||
)
|
||||
self.state = TaskState.FINAL
|
||||
|
||||
|
|
|
|||
|
|
@ -227,6 +227,10 @@ def _get_type_string_from_component_argument(
|
|||
if argument_type in _TYPE_TO_TYPE_NAME:
|
||||
return _TYPE_TO_TYPE_NAME[argument_type]
|
||||
|
||||
if isinstance(argument_value, artifact_types.Artifact):
|
||||
raise ValueError(
|
||||
f'Input artifacts are not supported. Got input artifact of type {argument_value.__class__.__name__!r}.'
|
||||
)
|
||||
raise ValueError(
|
||||
f'Constant argument inputs must be one of type {list(_TYPE_TO_TYPE_NAME.values())} Got: {argument_value!r} of type {type(argument_value)!r}.'
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ import sys
|
|||
import types
|
||||
from typing import List
|
||||
|
||||
_COMPONENT_NAME_PREFIX = 'comp-'
|
||||
COMPONENT_NAME_PREFIX = 'comp-'
|
||||
_EXECUTOR_LABEL_PREFIX = 'exec-'
|
||||
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ def sanitize_input_name(name: str) -> str:
|
|||
|
||||
def sanitize_component_name(name: str) -> str:
|
||||
"""Sanitizes component name."""
|
||||
return _COMPONENT_NAME_PREFIX + maybe_rename_for_k8s(name)
|
||||
return COMPONENT_NAME_PREFIX + maybe_rename_for_k8s(name)
|
||||
|
||||
|
||||
def sanitize_task_name(name: str) -> str:
|
||||
|
|
|
|||
|
|
@ -12,9 +12,25 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Objects for configuring local execution."""
|
||||
import abc
|
||||
import dataclasses
|
||||
|
||||
|
||||
class LocalRunnerType(abc.ABC):
|
||||
"""The ABC for user-facing Runner configurations.
|
||||
|
||||
Subclasses should be a dataclass.
|
||||
|
||||
They should implement a .validate() method.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def validate(self) -> None:
|
||||
"""Validates that the configuration arguments provided by the user are
|
||||
valid."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SubprocessRunner:
|
||||
"""Runner that indicates that local tasks should be run in a subprocess.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright 2023 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for constructing the ExecutorInput message."""
|
||||
import datetime
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from kfp.compiler import pipeline_spec_builder
|
||||
from kfp.dsl import utils
|
||||
from kfp.pipeline_spec import pipeline_spec_pb2
|
||||
|
||||
_EXECUTOR_OUTPUT_FILE = 'executor_output.json'
|
||||
|
||||
|
||||
def construct_executor_input(
|
||||
component_spec: pipeline_spec_pb2.ComponentSpec,
|
||||
arguments: Dict[str, Any],
|
||||
task_root: str,
|
||||
) -> pipeline_spec_pb2.ExecutorInput:
|
||||
"""Constructs the executor input message for a task execution."""
|
||||
input_parameter_keys = list(
|
||||
component_spec.input_definitions.parameters.keys())
|
||||
input_artifact_keys = list(
|
||||
component_spec.input_definitions.artifacts.keys())
|
||||
if input_artifact_keys:
|
||||
raise ValueError(
|
||||
'Input artifacts are not yet supported for local execution.')
|
||||
|
||||
output_parameter_keys = list(
|
||||
component_spec.output_definitions.parameters.keys())
|
||||
output_artifact_specs_dict = component_spec.output_definitions.artifacts
|
||||
|
||||
inputs = pipeline_spec_pb2.ExecutorInput.Inputs(
|
||||
parameter_values={
|
||||
param_name:
|
||||
pipeline_spec_builder.to_protobuf_value(arguments[param_name])
|
||||
if param_name in arguments else component_spec.input_definitions
|
||||
.parameters[param_name].default_value
|
||||
for param_name in input_parameter_keys
|
||||
},
|
||||
# input artifact constants are not supported yet
|
||||
artifacts={},
|
||||
)
|
||||
outputs = pipeline_spec_pb2.ExecutorInput.Outputs(
|
||||
parameters={
|
||||
param_name: pipeline_spec_pb2.ExecutorInput.OutputParameter(
|
||||
output_file=os.path.join(task_root, param_name))
|
||||
for param_name in output_parameter_keys
|
||||
},
|
||||
artifacts={
|
||||
artifact_name: make_artifact_list(
|
||||
name=artifact_name,
|
||||
artifact_type=artifact_spec.artifact_type,
|
||||
task_root=task_root,
|
||||
) for artifact_name, artifact_spec in
|
||||
output_artifact_specs_dict.items()
|
||||
},
|
||||
output_file=os.path.join(task_root, _EXECUTOR_OUTPUT_FILE),
|
||||
)
|
||||
return pipeline_spec_pb2.ExecutorInput(
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
|
||||
def get_local_pipeline_resource_name(pipeline_name: str) -> str:
|
||||
"""Gets the local pipeline resource name from the pipeline name in
|
||||
PipelineSpec.
|
||||
|
||||
Args:
|
||||
pipeline_name: The pipeline name provided by PipelineSpec.pipelineInfo.name.
|
||||
|
||||
Returns:
|
||||
The local pipeline resource name. Includes timestamp.
|
||||
"""
|
||||
timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')
|
||||
return f'{pipeline_name}-{timestamp}'
|
||||
|
||||
|
||||
def get_local_task_resource_name(component_name: str) -> str:
|
||||
"""Gets the local task resource name from the component name in
|
||||
PipelineSpec.
|
||||
|
||||
Args:
|
||||
component_name: The component name provided as the key for the component's ComponentSpec
|
||||
message. Takes the form comp-*.
|
||||
|
||||
Returns:
|
||||
The local task resource name.
|
||||
"""
|
||||
return component_name[len(utils.COMPONENT_NAME_PREFIX):]
|
||||
|
||||
|
||||
def construct_local_task_root(
|
||||
pipeline_root: str,
|
||||
pipeline_resource_name: str,
|
||||
task_resource_name: str,
|
||||
) -> str:
|
||||
"""Constructs the local task root directory for a task."""
|
||||
return os.path.join(
|
||||
pipeline_root,
|
||||
pipeline_resource_name,
|
||||
task_resource_name,
|
||||
)
|
||||
|
||||
|
||||
def make_artifact_list(
|
||||
name: str,
|
||||
artifact_type: pipeline_spec_pb2.ArtifactTypeSchema,
|
||||
task_root: str,
|
||||
) -> pipeline_spec_pb2.ArtifactList:
|
||||
"""Constructs an ArtifactList instance for an artifact in ExecutorInput."""
|
||||
return pipeline_spec_pb2.ArtifactList(artifacts=[
|
||||
pipeline_spec_pb2.RuntimeArtifact(
|
||||
name=name,
|
||||
type=artifact_type,
|
||||
uri=os.path.join(task_root, name),
|
||||
# metadata always starts empty for output artifacts
|
||||
metadata={},
|
||||
)
|
||||
])
|
||||
|
|
@ -0,0 +1,198 @@
|
|||
# Copyright 2023 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for executor_input_utils.py."""
|
||||
|
||||
import unittest
|
||||
|
||||
from google.protobuf import json_format
|
||||
from kfp.local import executor_input_utils
|
||||
from kfp.local import testing_utilities
|
||||
from kfp.pipeline_spec import pipeline_spec_pb2
|
||||
|
||||
|
||||
class GetLocalPipelineResourceName(testing_utilities.MockedDatetimeTestCase):
|
||||
|
||||
def test(self):
|
||||
actual = executor_input_utils.get_local_pipeline_resource_name(
|
||||
'my-pipeline')
|
||||
expected = 'my-pipeline-2023-10-10-13-32-59-420710'
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
||||
class GetLocalTaskResourceName(unittest.TestCase):
|
||||
|
||||
def test(self):
|
||||
actual = executor_input_utils.get_local_task_resource_name(
|
||||
'comp-my-comp')
|
||||
expected = 'my-comp'
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
||||
class TestConstructLocalTaskRoot(testing_utilities.MockedDatetimeTestCase):
|
||||
|
||||
def test(self):
|
||||
|
||||
task_root = executor_input_utils.construct_local_task_root(
|
||||
pipeline_root='/foo/bar',
|
||||
pipeline_resource_name='my-pipeline-2023-10-10-13-32-59-420710',
|
||||
task_resource_name='my-comp',
|
||||
)
|
||||
self.assertEqual(
|
||||
task_root,
|
||||
'/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/my-comp',
|
||||
)
|
||||
|
||||
|
||||
class TestConstructExecutorInput(unittest.TestCase):
|
||||
|
||||
def test_no_inputs(self):
|
||||
component_spec = pipeline_spec_pb2.ComponentSpec()
|
||||
json_format.ParseDict(
|
||||
{
|
||||
'outputDefinitions': {
|
||||
'parameters': {
|
||||
'Output': {
|
||||
'parameterType': 'STRING'
|
||||
}
|
||||
}
|
||||
},
|
||||
'executorLabel': 'exec-comp'
|
||||
}, component_spec)
|
||||
arguments = {}
|
||||
task_root = '/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp'
|
||||
|
||||
actual = executor_input_utils.construct_executor_input(
|
||||
component_spec=component_spec,
|
||||
arguments=arguments,
|
||||
task_root=task_root,
|
||||
)
|
||||
expected = pipeline_spec_pb2.ExecutorInput()
|
||||
json_format.ParseDict(
|
||||
{
|
||||
'inputs': {},
|
||||
'outputs': {
|
||||
'parameters': {
|
||||
'Output': {
|
||||
'outputFile':
|
||||
'/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp/Output'
|
||||
}
|
||||
},
|
||||
'outputFile':
|
||||
'/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp/executor_output.json'
|
||||
}
|
||||
}, expected)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_various_io_types(self):
|
||||
component_spec = pipeline_spec_pb2.ComponentSpec()
|
||||
json_format.ParseDict(
|
||||
{
|
||||
'inputDefinitions': {
|
||||
'parameters': {
|
||||
'boolean': {
|
||||
'parameterType': 'BOOLEAN'
|
||||
}
|
||||
}
|
||||
},
|
||||
'outputDefinitions': {
|
||||
'artifacts': {
|
||||
'out_a': {
|
||||
'artifactType': {
|
||||
'schemaTitle': 'system.Dataset',
|
||||
'schemaVersion': '0.0.1'
|
||||
}
|
||||
}
|
||||
},
|
||||
'parameters': {
|
||||
'Output': {
|
||||
'parameterType': 'NUMBER_INTEGER'
|
||||
}
|
||||
}
|
||||
},
|
||||
'executorLabel': 'exec-comp'
|
||||
}, component_spec)
|
||||
arguments = {'boolean': False}
|
||||
task_root = '/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp'
|
||||
|
||||
actual = executor_input_utils.construct_executor_input(
|
||||
component_spec=component_spec,
|
||||
arguments=arguments,
|
||||
task_root=task_root,
|
||||
)
|
||||
expected = pipeline_spec_pb2.ExecutorInput()
|
||||
json_format.ParseDict(
|
||||
{
|
||||
'inputs': {
|
||||
'parameterValues': {
|
||||
'boolean': False
|
||||
}
|
||||
},
|
||||
'outputs': {
|
||||
'parameters': {
|
||||
'Output': {
|
||||
'outputFile':
|
||||
'/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp/Output'
|
||||
}
|
||||
},
|
||||
'artifacts': {
|
||||
'out_a': {
|
||||
'artifacts': [{
|
||||
'name':
|
||||
'out_a',
|
||||
'type': {
|
||||
'schemaTitle': 'system.Dataset',
|
||||
'schemaVersion': '0.0.1'
|
||||
},
|
||||
'uri':
|
||||
'/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp/out_a',
|
||||
'metadata': {}
|
||||
}]
|
||||
}
|
||||
},
|
||||
'outputFile':
|
||||
'/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp/executor_output.json'
|
||||
}
|
||||
}, expected)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_input_artifacts_not_yet_supported(self):
|
||||
component_spec = pipeline_spec_pb2.ComponentSpec()
|
||||
json_format.ParseDict(
|
||||
{
|
||||
'inputDefinitions': {
|
||||
'artifacts': {
|
||||
'in_artifact': {
|
||||
'artifactType': {
|
||||
'schemaTitle': 'system.Artifact',
|
||||
'schemaVersion': '0.0.1'
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
'executorLabel': 'exec-comp'
|
||||
}, component_spec)
|
||||
arguments = {}
|
||||
task_root = '/foo/bar/my-pipeline-2023-10-10-13-32-59-420710/comp'
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Input artifacts are not yet supported for local execution.'):
|
||||
executor_input_utils.construct_executor_input(
|
||||
component_spec=component_spec,
|
||||
arguments=arguments,
|
||||
task_root=task_root,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright 2023 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for working with placeholders."""
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from google.protobuf import json_format
|
||||
from kfp import dsl
|
||||
from kfp.pipeline_spec import pipeline_spec_pb2
|
||||
|
||||
|
||||
def make_random_id():
|
||||
"""Makes a random 8 digit integer."""
|
||||
return str(random.randint(0, 99999999))
|
||||
|
||||
|
||||
def replace_placeholders(
|
||||
full_command: List[str],
|
||||
executor_input: str,
|
||||
pipeline_resource_name: str,
|
||||
task_resource_name: str,
|
||||
pipeline_root: str,
|
||||
) -> List[str]:
|
||||
"""Iterates over each element in the command and replaces placeholders."""
|
||||
unique_pipeline_id = make_random_id()
|
||||
unique_task_id = make_random_id()
|
||||
return [
|
||||
replace_placeholder_for_element(
|
||||
element=el,
|
||||
executor_input=executor_input,
|
||||
pipeline_resource_name=pipeline_resource_name,
|
||||
task_resource_name=task_resource_name,
|
||||
pipeline_root=pipeline_root,
|
||||
pipeline_job_id=unique_pipeline_id,
|
||||
pipeline_task_id=unique_task_id,
|
||||
) for el in full_command
|
||||
]
|
||||
|
||||
|
||||
def replace_placeholder_for_element(
|
||||
element: str,
|
||||
executor_input: pipeline_spec_pb2.ExecutorInput,
|
||||
pipeline_resource_name: str,
|
||||
task_resource_name: str,
|
||||
pipeline_root: str,
|
||||
pipeline_job_id: str,
|
||||
pipeline_task_id: str,
|
||||
) -> str:
|
||||
"""Replaces placeholders for a single element."""
|
||||
PLACEHOLDERS = {
|
||||
r'{{$.outputs.output_file}}': executor_input.outputs.output_file,
|
||||
r'{{$.outputMetadataUri}}': executor_input.outputs.output_file,
|
||||
r'{{$}}': json_format.MessageToJson(executor_input),
|
||||
dsl.PIPELINE_JOB_NAME_PLACEHOLDER: pipeline_resource_name,
|
||||
dsl.PIPELINE_JOB_ID_PLACEHOLDER: pipeline_job_id,
|
||||
dsl.PIPELINE_TASK_NAME_PLACEHOLDER: task_resource_name,
|
||||
dsl.PIPELINE_TASK_ID_PLACEHOLDER: pipeline_task_id,
|
||||
dsl.PIPELINE_ROOT_PLACEHOLDER: pipeline_root,
|
||||
}
|
||||
for placeholder, value in PLACEHOLDERS.items():
|
||||
element = element.replace(placeholder, value)
|
||||
|
||||
return element
|
||||
|
|
@ -14,6 +14,10 @@
|
|||
"""Code for dispatching a local task execution."""
|
||||
from typing import Any, Dict
|
||||
|
||||
from kfp import local
|
||||
from kfp.local import config
|
||||
from kfp.local import executor_input_utils
|
||||
from kfp.local import placeholder_utils
|
||||
from kfp.pipeline_spec import pipeline_spec_pb2
|
||||
|
||||
|
||||
|
|
@ -30,5 +34,61 @@ def run_single_component(
|
|||
Returns:
|
||||
A LocalTask instance.
|
||||
"""
|
||||
# TODO: implement and return outputs
|
||||
if config.LocalExecutionConfig.instance is None:
|
||||
raise RuntimeError(
|
||||
f"Local environment not initialized. Please run '{local.__name__}.{local.init.__name__}()' before executing tasks locally."
|
||||
)
|
||||
|
||||
return _run_single_component_implementation(
|
||||
pipeline_spec=pipeline_spec,
|
||||
arguments=arguments,
|
||||
pipeline_root=config.LocalExecutionConfig.instance.pipeline_root,
|
||||
runner=config.LocalExecutionConfig.instance.runner,
|
||||
)
|
||||
|
||||
|
||||
def _run_single_component_implementation(
|
||||
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
|
||||
arguments: Dict[str, Any],
|
||||
pipeline_root: str,
|
||||
runner: config.LocalRunnerType,
|
||||
) -> Dict[str, Any]:
|
||||
"""The implementation of a single component runner."""
|
||||
|
||||
component_name, component_spec = list(pipeline_spec.components.items())[0]
|
||||
|
||||
pipeline_resource_name = executor_input_utils.get_local_pipeline_resource_name(
|
||||
pipeline_spec.pipeline_info.name)
|
||||
task_resource_name = executor_input_utils.get_local_task_resource_name(
|
||||
component_name)
|
||||
task_root = executor_input_utils.construct_local_task_root(
|
||||
pipeline_root=pipeline_root,
|
||||
pipeline_resource_name=pipeline_resource_name,
|
||||
task_resource_name=task_resource_name,
|
||||
)
|
||||
executor_input = executor_input_utils.construct_executor_input(
|
||||
component_spec=component_spec,
|
||||
arguments=arguments,
|
||||
task_root=task_root,
|
||||
)
|
||||
|
||||
executor_spec = pipeline_spec.deployment_spec['executors'][
|
||||
component_spec.executor_label]
|
||||
|
||||
container = executor_spec['container']
|
||||
full_command = list(container['command']) + list(container['args'])
|
||||
|
||||
# image + full_command are "inputs" to local execution
|
||||
image = container['image']
|
||||
# TODO: handler container component placeholders when
|
||||
# ContainerRunner is implemented
|
||||
full_command = placeholder_utils.replace_placeholders(
|
||||
full_command=full_command,
|
||||
executor_input=executor_input,
|
||||
pipeline_resource_name=pipeline_resource_name,
|
||||
task_resource_name=task_resource_name,
|
||||
pipeline_root=pipeline_root,
|
||||
)
|
||||
|
||||
# TODO: call task handler and return outputs
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,163 @@
|
|||
# Copyright 2023 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for task_dispatcher.py."""
|
||||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
from kfp import dsl
|
||||
from kfp import local
|
||||
from kfp.dsl import Artifact
|
||||
from kfp.local import testing_utilities
|
||||
|
||||
|
||||
class TestLocalExecutionValidation(
|
||||
testing_utilities.LocalRunnerEnvironmentTestCase):
|
||||
|
||||
def test_env_not_initialized(self):
|
||||
|
||||
@dsl.component
|
||||
def identity(x: str) -> str:
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Local environment not initialized\. Please run 'kfp\.local\.init\(\)' before executing tasks locally\."
|
||||
):
|
||||
identity(x='foo')
|
||||
|
||||
|
||||
@parameterized.parameters([
|
||||
(local.SubprocessRunner(use_venv=False),),
|
||||
(local.SubprocessRunner(use_venv=True),),
|
||||
])
|
||||
class TestArgumentValidation(parameterized.TestCase):
|
||||
|
||||
def test_no_argument_no_default(self, runner):
|
||||
local.init(runner=runner)
|
||||
|
||||
@dsl.component
|
||||
def identity(x: str) -> str:
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, r'identity\(\) missing 1 required argument: x'):
|
||||
identity()
|
||||
|
||||
def test_default_wrong_type(self, runner):
|
||||
local.init(runner=runner)
|
||||
|
||||
@dsl.component
|
||||
def identity(x: str) -> str:
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
dsl.types.type_utils.InconsistentTypeException,
|
||||
r"Incompatible argument passed to the input 'x' of component 'identity': Argument type 'NUMBER_INTEGER' is incompatible with the input type 'STRING'"
|
||||
):
|
||||
identity(x=1)
|
||||
|
||||
def test_extra_argument(self, runner):
|
||||
local.init(runner=runner)
|
||||
|
||||
@dsl.component
|
||||
def identity(x: str) -> str:
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r'identity\(\) got an unexpected keyword argument "y"\.'):
|
||||
identity(x='foo', y='bar')
|
||||
|
||||
def test_input_artifact_provided(self, runner):
|
||||
local.init(runner=runner)
|
||||
|
||||
@dsl.component
|
||||
def identity(a: Artifact) -> Artifact:
|
||||
return a
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Input artifacts are not supported. Got input artifact of type 'Artifact'."
|
||||
):
|
||||
identity(a=Artifact(name='a', uri='gs://bucket/foo'))
|
||||
|
||||
|
||||
@parameterized.parameters([
|
||||
(local.SubprocessRunner(use_venv=False),),
|
||||
(local.SubprocessRunner(use_venv=True),),
|
||||
])
|
||||
class TestLocalPipelineBlocked(testing_utilities.LocalRunnerEnvironmentTestCase
|
||||
):
|
||||
|
||||
def test_local_pipeline_unsupported_two_tasks(self, runner):
|
||||
local.init(runner=runner)
|
||||
|
||||
@dsl.component
|
||||
def identity(string: str) -> str:
|
||||
return string
|
||||
|
||||
@dsl.pipeline
|
||||
def my_pipeline():
|
||||
identity(string='foo')
|
||||
identity(string='bar')
|
||||
|
||||
# compile and load into a YamlComponent to ensure the NotImplementedError isn't simply being thrown because this is a GraphComponent
|
||||
my_pipeline = testing_utilities.compile_and_load_component(my_pipeline)
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Local pipeline execution is not currently supported\.',
|
||||
):
|
||||
my_pipeline()
|
||||
|
||||
def test_local_pipeline_unsupported_one_task_different_interface(
|
||||
self, runner):
|
||||
local.init(runner=runner)
|
||||
|
||||
@dsl.component
|
||||
def identity(string: str) -> str:
|
||||
return string
|
||||
|
||||
@dsl.pipeline
|
||||
def my_pipeline():
|
||||
identity(string='foo')
|
||||
|
||||
# compile and load into a YamlComponent to ensure the NotImplementedError isn't simply being thrown because this is a GraphComponent
|
||||
my_pipeline = testing_utilities.compile_and_load_component(my_pipeline)
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Local pipeline execution is not currently supported\.',
|
||||
):
|
||||
my_pipeline()
|
||||
|
||||
def test_local_pipeline_unsupported_if_is_graph_component(self, runner):
|
||||
local.init(runner=runner)
|
||||
|
||||
@dsl.component
|
||||
def identity(string: str) -> str:
|
||||
return string
|
||||
|
||||
# even if there is one task with the same interface as the pipeline, the code should catch that the pipeline is a GraphComponent and throw the NotImplementedError
|
||||
@dsl.pipeline
|
||||
def my_pipeline(string: str) -> str:
|
||||
return identity(string=string).output
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Local pipeline execution is not currently supported\.',
|
||||
):
|
||||
my_pipeline(string='foo')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2023 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for testing local execution."""
|
||||
|
||||
import datetime
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import parameterized
|
||||
from google.protobuf import json_format
|
||||
from kfp import components
|
||||
from kfp import dsl
|
||||
from kfp.local import config as local_config
|
||||
|
||||
|
||||
class LocalRunnerEnvironmentTestCase(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
from kfp.dsl import pipeline_task
|
||||
pipeline_task.TEMPORARILY_BLOCK_LOCAL_EXECUTION = False
|
||||
# start each test case without an uninitialized environment
|
||||
local_config.LocalExecutionConfig.instance = None
|
||||
|
||||
def tearDown(self) -> None:
|
||||
from kfp.dsl import pipeline_task
|
||||
pipeline_task.TEMPORARILY_BLOCK_LOCAL_EXECUTION = True
|
||||
|
||||
|
||||
class MockedDatetimeTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# set up patch, cleanup, and start
|
||||
patcher = mock.patch('kfp.local.executor_input_utils.datetime.datetime')
|
||||
self.addCleanup(patcher.stop)
|
||||
self.mock_datetime = patcher.start()
|
||||
|
||||
# set mock return values
|
||||
mock_now = mock.MagicMock(
|
||||
wraps=datetime.datetime(2023, 10, 10, 13, 32, 59, 420710))
|
||||
self.mock_datetime.now.return_value = mock_now
|
||||
mock_now.strftime.return_value = '2023-10-10-13-32-59-420710'
|
||||
|
||||
|
||||
def compile_and_load_component(
|
||||
base_component: dsl.base_component.BaseComponent,
|
||||
) -> dsl.yaml_component.YamlComponent:
|
||||
"""Compiles a component to PipelineSpec and reloads it as a
|
||||
YamlComponent."""
|
||||
return components.load_component_from_text(
|
||||
json_format.MessageToJson(base_component.pipeline_spec))
|
||||
Loading…
Reference in New Issue