1051 lines
45 KiB
Python
1051 lines
45 KiB
Python
# Copyright 2021 The Kubeflow Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""KFP DSL compiler.
|
|
|
|
This is an experimental implementation of KFP compiler that compiles KFP
|
|
pipeline into Pipeline IR:
|
|
https://docs.google.com/document/d/1PUDuSQ8vmeKSBloli53mp7GIvzekaY7sggg6ywy35Dk/
|
|
"""
|
|
|
|
import collections
|
|
import inspect
|
|
import re
|
|
import uuid
|
|
from typing import (Any, Callable, Dict, List, Mapping, Optional, Set, Tuple,
|
|
Union)
|
|
|
|
import kfp
|
|
from google.protobuf import json_format
|
|
from kfp.pipeline_spec import pipeline_spec_pb2
|
|
from kfp import dsl
|
|
from kfp.compiler import pipeline_spec_builder as builder
|
|
from kfp.components import utils as component_utils
|
|
from kfp.components import component_factory
|
|
from kfp.components import for_loop
|
|
from kfp.components import pipeline_context
|
|
from kfp.components import pipeline_task
|
|
from kfp.components import tasks_group
|
|
from kfp.components.types import type_utils
|
|
|
|
_GroupOrTask = Union[tasks_group.TasksGroup, pipeline_task.PipelineTask]
|
|
|
|
|
|
class Compiler:
|
|
"""Experimental DSL compiler that targets the PipelineSpec IR.
|
|
|
|
It compiles pipeline function into PipelineSpec json string.
|
|
PipelineSpec is the IR protobuf message that defines a pipeline:
|
|
https://github.com/kubeflow/pipelines/blob/237795539f7b85bac77435e2464367226ee19391/api/v2alpha1/pipeline_spec.proto#L8
|
|
In this initial implementation, we only support components authored through
|
|
Component yaml spec. And we don't support advanced features like conditions,
|
|
static and dynamic loops, etc.
|
|
|
|
Example::
|
|
|
|
@dsl.pipeline(
|
|
name='name',
|
|
description='description',
|
|
)
|
|
def my_pipeline(a: int = 1, b: str = "default value"):
|
|
...
|
|
|
|
kfp.compiler.Compiler().compile(
|
|
pipeline_func=my_pipeline,
|
|
package_path='path/to/pipeline.json',
|
|
)
|
|
"""
|
|
|
|
def compile(
|
|
self,
|
|
pipeline_func: Callable[..., Any],
|
|
package_path: str,
|
|
pipeline_name: Optional[str] = None,
|
|
pipeline_parameters: Optional[Mapping[str, Any]] = None,
|
|
type_check: bool = True,
|
|
) -> None:
|
|
"""Compile the given pipeline function into pipeline job json.
|
|
|
|
Args:
|
|
pipeline_func: Pipeline function with @dsl.pipeline decorator.
|
|
package_path: The output pipeline spec .json file path. For example,
|
|
"~/pipeline_spec.json".
|
|
pipeline_name: Optional; the name of the pipeline.
|
|
pipeline_parameters: Optional; the mapping from parameter names to
|
|
values.
|
|
type_check: Optional; whether to enable the type check or not.
|
|
Default is True.
|
|
"""
|
|
type_check_old_value = kfp.TYPE_CHECK
|
|
try:
|
|
kfp.TYPE_CHECK = type_check
|
|
pipeline_spec = self._create_pipeline(
|
|
pipeline_func=pipeline_func,
|
|
pipeline_name=pipeline_name,
|
|
pipeline_parameters_override=pipeline_parameters,
|
|
)
|
|
self._write_pipeline_spec_json(
|
|
pipeline_spec=pipeline_spec,
|
|
output_path=package_path,
|
|
)
|
|
finally:
|
|
kfp.TYPE_CHECK = type_check_old_value
|
|
|
|
def _create_pipeline(
|
|
self,
|
|
pipeline_func: Callable[..., Any],
|
|
pipeline_name: Optional[str] = None,
|
|
pipeline_parameters_override: Optional[Mapping[str, Any]] = None,
|
|
) -> pipeline_spec_pb2.PipelineSpec:
|
|
"""Creates a pipeline instance and constructs the pipeline spec from
|
|
it.
|
|
|
|
Args:
|
|
pipeline_func: The pipeline function with @dsl.pipeline decorator.
|
|
pipeline_name: Optional; the name of the pipeline.
|
|
pipeline_parameters_override: Optional; the mapping from parameter
|
|
names to values.
|
|
|
|
Returns:
|
|
A PipelineSpec proto representing the compiled pipeline.
|
|
"""
|
|
|
|
# Create the arg list with no default values and call pipeline function.
|
|
# Assign type information to the PipelineChannel
|
|
pipeline_meta = component_factory.extract_component_interface(
|
|
pipeline_func)
|
|
pipeline_name = pipeline_name or pipeline_meta.name
|
|
|
|
pipeline_root = getattr(pipeline_func, 'pipeline_root', None)
|
|
|
|
args_list = []
|
|
signature = inspect.signature(pipeline_func)
|
|
|
|
for arg_name in signature.parameters:
|
|
arg_type = pipeline_meta.inputs[arg_name].type
|
|
if not type_utils.is_parameter_type(arg_type):
|
|
raise TypeError(
|
|
'The pipeline argument "{arg_name}" is viewed as an artifact'
|
|
' due to its type "{arg_type}". And we currently do not '
|
|
'support passing artifacts as pipeline inputs. Consider type'
|
|
' annotating the argument with a primitive type, such as '
|
|
'"str", "int", "float", "bool", "dict", and "list".'.format(
|
|
arg_name=arg_name, arg_type=arg_type))
|
|
args_list.append(
|
|
dsl.PipelineParameterChannel(
|
|
name=arg_name, channel_type=arg_type))
|
|
|
|
with pipeline_context.Pipeline(pipeline_name) as dsl_pipeline:
|
|
pipeline_func(*args_list)
|
|
|
|
if not dsl_pipeline.tasks:
|
|
raise ValueError('Task is missing from pipeline.')
|
|
|
|
self._validate_exit_handler(dsl_pipeline)
|
|
|
|
pipeline_inputs = pipeline_meta.inputs or {}
|
|
|
|
# Verify that pipeline_parameters_override contains only input names
|
|
# that match the pipeline inputs definition.
|
|
pipeline_parameters_override = pipeline_parameters_override or {}
|
|
for input_name in pipeline_parameters_override:
|
|
if input_name not in pipeline_inputs:
|
|
raise ValueError(
|
|
'Pipeline parameter {} does not match any known '
|
|
'pipeline argument.'.format(input_name))
|
|
|
|
# Fill in the default values.
|
|
args_list_with_defaults = [
|
|
dsl.PipelineParameterChannel(
|
|
name=input_name,
|
|
channel_type=input_spec.type,
|
|
value=pipeline_parameters_override.get(input_name) or
|
|
input_spec.default,
|
|
) for input_name, input_spec in pipeline_inputs.items()
|
|
]
|
|
|
|
# Making the pipeline group name unique to prevent name clashes with
|
|
# templates
|
|
pipeline_group = dsl_pipeline.groups[0]
|
|
pipeline_group.name = uuid.uuid4().hex
|
|
|
|
pipeline_spec = self._create_pipeline_spec(
|
|
pipeline_args=args_list_with_defaults,
|
|
pipeline=dsl_pipeline,
|
|
)
|
|
|
|
if pipeline_root:
|
|
pipeline_spec.default_pipeline_root = pipeline_root
|
|
|
|
return pipeline_spec
|
|
|
|
def _write_pipeline_spec_json(
|
|
self,
|
|
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
|
|
output_path: str,
|
|
) -> None:
|
|
"""Writes pipeline spec into a json file.
|
|
|
|
Args:
|
|
pipeline_spec: IR pipeline spec.
|
|
ouput_path: The file path to be written.
|
|
|
|
Raises:
|
|
ValueError: if the specified output path doesn't end with the
|
|
acceptable extention.
|
|
"""
|
|
json_text = json_format.MessageToJson(pipeline_spec, sort_keys=True)
|
|
|
|
if output_path.endswith('.json'):
|
|
with open(output_path, 'w') as json_file:
|
|
json_file.write(json_text)
|
|
else:
|
|
raise ValueError(
|
|
'The output path {} should ends with ".json".'.format(
|
|
output_path))
|
|
|
|
def _validate_exit_handler(self,
|
|
pipeline: pipeline_context.Pipeline) -> None:
|
|
"""Makes sure there is only one global exit handler.
|
|
|
|
This is temporary to be compatible with KFP v1.
|
|
|
|
Raises:
|
|
ValueError if there are more than one exit handler.
|
|
"""
|
|
|
|
def _validate_exit_handler_helper(
|
|
group: tasks_group.TasksGroup,
|
|
exiting_task_names: List[str],
|
|
handler_exists: bool,
|
|
) -> None:
|
|
|
|
if isinstance(group, dsl.ExitHandler):
|
|
if handler_exists or len(exiting_task_names) > 1:
|
|
raise ValueError(
|
|
'Only one global exit_handler is allowed and all ops need to be included.'
|
|
)
|
|
handler_exists = True
|
|
|
|
if group.tasks:
|
|
exiting_task_names.extend([x.name for x in group.tasks])
|
|
|
|
for group in group.groups:
|
|
_validate_exit_handler_helper(
|
|
group=group,
|
|
exiting_task_names=exiting_task_names,
|
|
handler_exists=handler_exists,
|
|
)
|
|
|
|
_validate_exit_handler_helper(
|
|
group=pipeline.groups[0],
|
|
exiting_task_names=[],
|
|
handler_exists=False,
|
|
)
|
|
|
|
def _create_pipeline_spec(
|
|
self,
|
|
pipeline_args: List[dsl.PipelineChannel],
|
|
pipeline: pipeline_context.Pipeline,
|
|
) -> pipeline_spec_pb2.PipelineSpec:
|
|
"""Creates a pipeline spec object.
|
|
|
|
Args:
|
|
pipeline_args: The list of pipeline input parameters.
|
|
pipeline: The instantiated pipeline object.
|
|
|
|
Returns:
|
|
A PipelineSpec proto representing the compiled pipeline.
|
|
|
|
Raises:
|
|
ValueError if the argument is of unsupported types.
|
|
"""
|
|
self._validate_pipeline_name(pipeline.name)
|
|
|
|
deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()
|
|
pipeline_spec = pipeline_spec_pb2.PipelineSpec()
|
|
|
|
pipeline_spec.pipeline_info.name = pipeline.name
|
|
pipeline_spec.sdk_version = 'kfp-{}'.format(kfp.__version__)
|
|
# Schema version 2.1.0 is required for kfp-pipeline-spec>0.1.13
|
|
pipeline_spec.schema_version = '2.1.0'
|
|
|
|
pipeline_spec.root.CopyFrom(
|
|
builder.build_component_spec_for_group(
|
|
pipeline_channels=pipeline_args,
|
|
is_root_group=True,
|
|
))
|
|
|
|
root_group = pipeline.groups[0]
|
|
|
|
all_groups = self._get_all_groups(root_group)
|
|
group_name_to_group = {group.name: group for group in all_groups}
|
|
task_name_to_parent_groups, group_name_to_parent_groups = (
|
|
self._get_parent_groups(root_group))
|
|
condition_channels = self._get_condition_channels_for_tasks(root_group)
|
|
name_to_for_loop_group = {
|
|
group_name: group
|
|
for group_name, group in group_name_to_group.items()
|
|
if isinstance(group, dsl.ParallelFor)
|
|
}
|
|
inputs = self._get_inputs_for_all_groups(
|
|
pipeline=pipeline,
|
|
pipeline_args=pipeline_args,
|
|
root_group=root_group,
|
|
task_name_to_parent_groups=task_name_to_parent_groups,
|
|
group_name_to_parent_groups=group_name_to_parent_groups,
|
|
condition_channels=condition_channels,
|
|
name_to_for_loop_group=name_to_for_loop_group,
|
|
)
|
|
dependencies = self._get_dependencies(
|
|
pipeline=pipeline,
|
|
root_group=root_group,
|
|
task_name_to_parent_groups=task_name_to_parent_groups,
|
|
group_name_to_parent_groups=group_name_to_parent_groups,
|
|
group_name_to_group=group_name_to_group,
|
|
condition_channels=condition_channels,
|
|
)
|
|
|
|
for group in all_groups:
|
|
self._build_spec_by_group(
|
|
pipeline_spec=pipeline_spec,
|
|
deployment_config=deployment_config,
|
|
group=group,
|
|
inputs=inputs,
|
|
dependencies=dependencies,
|
|
rootgroup_name=root_group.name,
|
|
task_name_to_parent_groups=task_name_to_parent_groups,
|
|
group_name_to_parent_groups=group_name_to_parent_groups,
|
|
name_to_for_loop_group=name_to_for_loop_group,
|
|
)
|
|
|
|
# TODO: refactor to support multiple exit handler per pipeline.
|
|
if pipeline.groups[0].groups:
|
|
first_group = pipeline.groups[0].groups[0]
|
|
if isinstance(first_group, dsl.ExitHandler):
|
|
exit_task = first_group.exit_task
|
|
exit_task_name = component_utils.sanitize_task_name(
|
|
exit_task.name)
|
|
exit_handler_group_task_name = component_utils.sanitize_task_name(
|
|
first_group.name)
|
|
input_parameters_in_current_dag = [
|
|
input_name for input_name in
|
|
pipeline_spec.root.input_definitions.parameters
|
|
]
|
|
exit_task_task_spec = builder.build_task_spec_for_exit_task(
|
|
task=exit_task,
|
|
dependent_task=exit_handler_group_task_name,
|
|
pipeline_inputs=pipeline_spec.root.input_definitions,
|
|
)
|
|
|
|
exit_task_component_spec = builder.build_component_spec_for_exit_task(
|
|
task=exit_task)
|
|
|
|
exit_task_container_spec = builder.build_container_spec_for_task(
|
|
task=exit_task)
|
|
|
|
# Add exit task task spec
|
|
pipeline_spec.root.dag.tasks[exit_task_name].CopyFrom(
|
|
exit_task_task_spec)
|
|
|
|
# Add exit task component spec if it does not exist.
|
|
component_name = exit_task_task_spec.component_ref.name
|
|
if component_name not in pipeline_spec.components:
|
|
pipeline_spec.components[component_name].CopyFrom(
|
|
exit_task_component_spec)
|
|
|
|
# Add exit task container spec if it does not exist.
|
|
executor_label = exit_task_component_spec.executor_label
|
|
if executor_label not in deployment_config.executors:
|
|
deployment_config.executors[
|
|
executor_label].container.CopyFrom(
|
|
exit_task_container_spec)
|
|
pipeline_spec.deployment_spec.update(
|
|
json_format.MessageToDict(deployment_config))
|
|
|
|
return pipeline_spec
|
|
|
|
def _validate_pipeline_name(self, name: str) -> None:
|
|
"""Validate pipeline name.
|
|
|
|
A valid pipeline name should match ^[a-z0-9][a-z0-9-]{0,127}$.
|
|
|
|
Args:
|
|
name: The pipeline name.
|
|
|
|
Raises:
|
|
ValueError if the pipeline name doesn't conform to the regular expression.
|
|
"""
|
|
pattern = re.compile(r'^[a-z0-9][a-z0-9-]{0,127}$')
|
|
if not pattern.match(name):
|
|
raise ValueError(
|
|
'Invalid pipeline name: %s.\n'
|
|
'Please specify a pipeline name that matches the regular '
|
|
'expression "^[a-z0-9][a-z0-9-]{0,127}$" using '
|
|
'`dsl.pipeline(name=...)` decorator.' % name)
|
|
|
|
def _get_all_groups(
|
|
self,
|
|
root_group: tasks_group.TasksGroup,
|
|
) -> List[tasks_group.TasksGroup]:
|
|
"""Gets all groups (not including tasks) in a pipeline.
|
|
|
|
Args:
|
|
root_group: The root group of a pipeline.
|
|
|
|
Returns:
|
|
A list of all groups in topological order (parent first).
|
|
"""
|
|
all_groups = []
|
|
|
|
def _get_all_groups_helper(
|
|
group: tasks_group.TasksGroup,
|
|
all_groups: List[tasks_group.TasksGroup],
|
|
):
|
|
all_groups.append(group)
|
|
for group in group.groups:
|
|
_get_all_groups_helper(group, all_groups)
|
|
|
|
_get_all_groups_helper(root_group, all_groups)
|
|
return all_groups
|
|
|
|
def _get_parent_groups(
|
|
self,
|
|
root_group: tasks_group.TasksGroup,
|
|
) -> Tuple[Mapping[str, List[_GroupOrTask]], Mapping[str,
|
|
List[_GroupOrTask]]]:
|
|
"""Get parent groups that contain the specified tasks.
|
|
|
|
Each pipeline has a root group. Each group has a list of tasks (leaf)
|
|
and groups.
|
|
This function traverse the tree and get ancestor groups for all tasks.
|
|
|
|
Args:
|
|
root_group: The root group of a pipeline.
|
|
|
|
Returns:
|
|
A tuple. The first item is a mapping of task names to parent groups,
|
|
and second item is a mapping of group names to parent groups.
|
|
A list of parent groups is a list of ancestor groups including the
|
|
task/group itself. The list is sorted in a way that the farthest
|
|
parent group is the first and task/group itself is the last.
|
|
"""
|
|
|
|
def _get_parent_groups_helper(
|
|
current_groups: List[tasks_group.TasksGroup],
|
|
tasks_to_groups: Dict[str, List[_GroupOrTask]],
|
|
groups_to_groups: Dict[str, List[_GroupOrTask]],
|
|
) -> None:
|
|
root_group = current_groups[-1]
|
|
for group in root_group.groups:
|
|
|
|
groups_to_groups[group.name] = [x.name for x in current_groups
|
|
] + [group.name]
|
|
current_groups.append(group)
|
|
|
|
_get_parent_groups_helper(
|
|
current_groups=current_groups,
|
|
tasks_to_groups=tasks_to_groups,
|
|
groups_to_groups=groups_to_groups,
|
|
)
|
|
del current_groups[-1]
|
|
|
|
for task in root_group.tasks:
|
|
tasks_to_groups[task.name] = [x.name for x in current_groups
|
|
] + [task.name]
|
|
|
|
tasks_to_groups = {}
|
|
groups_to_groups = {}
|
|
current_groups = [root_group]
|
|
|
|
_get_parent_groups_helper(
|
|
current_groups=current_groups,
|
|
tasks_to_groups=tasks_to_groups,
|
|
groups_to_groups=groups_to_groups,
|
|
)
|
|
return (tasks_to_groups, groups_to_groups)
|
|
|
|
# TODO: do we really need this?
|
|
def _get_condition_channels_for_tasks(
|
|
self,
|
|
root_group: tasks_group.TasksGroup,
|
|
) -> Mapping[str, Set[dsl.PipelineChannel]]:
|
|
"""Gets channels referenced in conditions of tasks' parents.
|
|
|
|
Args:
|
|
root_group: The root group of a pipeline.
|
|
|
|
Returns:
|
|
A mapping of task name to a set of pipeline channels appeared in its
|
|
parent dsl.Condition groups.
|
|
"""
|
|
conditions = collections.defaultdict(set)
|
|
|
|
def _get_condition_channels_for_tasks_helper(
|
|
group,
|
|
current_conditions_channels,
|
|
):
|
|
new_current_conditions_channels = current_conditions_channels
|
|
if isinstance(group, dsl.Condition):
|
|
new_current_conditions_channels = list(
|
|
current_conditions_channels)
|
|
if isinstance(group.condition.left_operand,
|
|
dsl.PipelineChannel):
|
|
new_current_conditions_channels.append(
|
|
group.condition.left_operand)
|
|
if isinstance(group.condition.right_operand,
|
|
dsl.PipelineChannel):
|
|
new_current_conditions_channels.append(
|
|
group.condition.right_operand)
|
|
for task in group.tasks:
|
|
for channel in new_current_conditions_channels:
|
|
conditions[task.name].add(channel)
|
|
for group in group.groups:
|
|
_get_condition_channels_for_tasks_helper(
|
|
group, new_current_conditions_channels)
|
|
|
|
_get_condition_channels_for_tasks_helper(root_group, [])
|
|
return conditions
|
|
|
|
def _get_inputs_for_all_groups(
|
|
self,
|
|
pipeline: pipeline_context.Pipeline,
|
|
pipeline_args: List[dsl.PipelineChannel],
|
|
root_group: tasks_group.TasksGroup,
|
|
task_name_to_parent_groups: Mapping[str, List[_GroupOrTask]],
|
|
group_name_to_parent_groups: Mapping[str, List[tasks_group.TasksGroup]],
|
|
condition_channels: Mapping[str, Set[dsl.PipelineParameterChannel]],
|
|
name_to_for_loop_group: Mapping[str, dsl.ParallelFor],
|
|
) -> Mapping[str, List[Tuple[dsl.PipelineChannel, str]]]:
|
|
"""Get inputs and outputs of each group and op.
|
|
|
|
Args:
|
|
pipeline: The instantiated pipeline object.
|
|
pipeline_args: The list of pipeline function arguments as
|
|
PipelineChannel.
|
|
root_group: The root group of the pipeline.
|
|
task_name_to_parent_groups: The dict of task name to list of parent
|
|
groups.
|
|
group_name_to_parent_groups: The dict of group name to list of
|
|
parent groups.
|
|
condition_channels: The dict of task name to a set of pipeline
|
|
channels referenced by its parent condition groups.
|
|
name_to_for_loop_group: The dict of for loop group name to loop
|
|
group.
|
|
|
|
Returns:
|
|
A mapping with key being the group/task names and values being list
|
|
of tuples (channel, producing_task_name).
|
|
producing_task_name is the name of the task that produces the
|
|
channel. If the channel is a pipeline argument (no producer task),
|
|
then producing_task_name is None.
|
|
"""
|
|
inputs = collections.defaultdict(set)
|
|
|
|
for task in pipeline.tasks.values():
|
|
# task's inputs and all channels used in conditions for that task are
|
|
# considered.
|
|
task_condition_inputs = list(condition_channels[task.name])
|
|
|
|
for channel in task.channel_inputs + task_condition_inputs:
|
|
|
|
# If the value is already provided (immediate value), then no
|
|
# need to expose it as input for its parent groups.
|
|
if getattr(channel, 'value', None):
|
|
continue
|
|
|
|
# channels_to_add could be a list of PipelineChannels when loop
|
|
# args are involved. Given a nested loops example as follows:
|
|
#
|
|
# def my_pipeline(loop_parameter: list):
|
|
# with dsl.ParallelFor(loop_parameter) as item:
|
|
# with dsl.ParallelFor(item.p_a) as item_p_a:
|
|
# print_op(item_p_a.q_a)
|
|
#
|
|
# The print_op takes an input of
|
|
# {{channel:task=;name=loop_parameter-loop-item-subvar-p_a-loop-item-subvar-q_a;}}.
|
|
# Given this, we calculate the list of PipelineChannels potentially
|
|
# needed by across DAG levels as follows:
|
|
#
|
|
# [{{channel:task=;name=loop_parameter-loop-item-subvar-p_a-loop-item-subvar-q_a}},
|
|
# {{channel:task=;name=loop_parameter-loop-item-subvar-p_a-loop-item}},
|
|
# {{channel:task=;name=loop_parameter-loop-item-subvar-p_a}},
|
|
# {{channel:task=;name=loop_parameter-loop-item}},
|
|
# {{chaenel:task=;name=loop_parameter}}]
|
|
#
|
|
# For the above example, the first loop needs the input of
|
|
# {{channel:task=;name=loop_parameter}},
|
|
# the second loop needs the input of
|
|
# {{channel:task=;name=loop_parameter-loop-item}}
|
|
# and the print_op needs the input of
|
|
# {{channel:task=;name=loop_parameter-loop-item-subvar-p_a-loop-item}}
|
|
#
|
|
# When we traverse a DAG in a top-down direction, we add channels
|
|
# from the end, and pop it out when it's no longer needed by the
|
|
# sub-DAG.
|
|
# When we traverse a DAG in a bottom-up direction, we add
|
|
# channels from the front, and pop it out when it's no longer
|
|
# needed by the parent DAG.
|
|
channels_to_add = collections.deque()
|
|
channel_to_add = channel
|
|
|
|
while isinstance(channel_to_add, (
|
|
for_loop.LoopArgument,
|
|
for_loop.LoopArgumentVariable,
|
|
)):
|
|
channels_to_add.append(channel_to_add)
|
|
if isinstance(channel_to_add,
|
|
for_loop.LoopArgumentVariable):
|
|
channel_to_add = channel_to_add.loop_argument
|
|
else:
|
|
channel_to_add = channel_to_add.items_or_pipeline_channel
|
|
|
|
if isinstance(channel_to_add, dsl.PipelineChannel):
|
|
channels_to_add.append(channel_to_add)
|
|
|
|
if channel.task_name:
|
|
# The PipelineChannel is produced by a task.
|
|
|
|
upstream_task = pipeline.tasks[channel.task_name]
|
|
upstream_groups, downstream_groups = (
|
|
self._get_uncommon_ancestors(
|
|
task_name_to_parent_groups=task_name_to_parent_groups,
|
|
group_name_to_parent_groups=group_name_to_parent_groups,
|
|
task1=upstream_task,
|
|
task2=task,
|
|
))
|
|
|
|
for i, group_name in enumerate(downstream_groups):
|
|
if i == 0:
|
|
# If it is the first uncommon downstream group, then
|
|
# the input comes from the first uncommon upstream
|
|
# group.
|
|
producer_task = upstream_groups[0]
|
|
else:
|
|
# If not the first downstream group, then the input
|
|
# is passed down from its ancestor groups so the
|
|
# upstream group is None.
|
|
producer_task = None
|
|
|
|
inputs[group_name].add(
|
|
(channels_to_add[-1], producer_task))
|
|
|
|
if group_name in name_to_for_loop_group:
|
|
loop_group = name_to_for_loop_group[group_name]
|
|
|
|
# Pop out the last elements from channels_to_add if it
|
|
# is found in the current (loop) DAG. Downstreams
|
|
# would only need the more specific versions for it.
|
|
if channels_to_add[
|
|
-1].full_name in loop_group.loop_argument.full_name:
|
|
channels_to_add.pop()
|
|
if not channels_to_add:
|
|
break
|
|
|
|
else:
|
|
# The PipelineChannel is not produced by a task. It's either
|
|
# a top-level pipeline input, or a constant value to loop
|
|
# items.
|
|
|
|
# TODO: revisit if this is correct.
|
|
if getattr(task, 'is_exit_handler', False):
|
|
continue
|
|
|
|
# For PipelineChannel as a result of constant value used as
|
|
# loop items, we have to go from bottom-up because the
|
|
# PipelineChannel can be originated from the middle a DAG,
|
|
# which is not needed and visible to its parent DAG.
|
|
if isinstance(
|
|
channel,
|
|
(for_loop.LoopArgument, for_loop.LoopArgumentVariable
|
|
)) and channel.is_with_items_loop_argument:
|
|
for group_name in task_name_to_parent_groups[
|
|
task.name][::-1]:
|
|
|
|
inputs[group_name].add((channels_to_add[0], None))
|
|
if group_name in name_to_for_loop_group:
|
|
# for example:
|
|
# loop_group.loop_argument.name = 'loop-item-param-1'
|
|
# channel.name = 'loop-item-param-1-subvar-a'
|
|
loop_group = name_to_for_loop_group[group_name]
|
|
|
|
if channels_to_add[
|
|
0].full_name in loop_group.loop_argument.full_name:
|
|
channels_to_add.popleft()
|
|
if not channels_to_add:
|
|
break
|
|
else:
|
|
# For PipelineChannel from pipeline input, go top-down
|
|
# just like we do for PipelineChannel produced by a task.
|
|
for group_name in task_name_to_parent_groups[task.name]:
|
|
|
|
inputs[group_name].add((channels_to_add[-1], None))
|
|
if group_name in name_to_for_loop_group:
|
|
loop_group = name_to_for_loop_group[group_name]
|
|
|
|
if channels_to_add[
|
|
-1].full_name in loop_group.loop_argument.full_name:
|
|
channels_to_add.pop()
|
|
if not channels_to_add:
|
|
break
|
|
|
|
return inputs
|
|
|
|
def _get_uncommon_ancestors(
|
|
self,
|
|
task_name_to_parent_groups: Mapping[str, List[_GroupOrTask]],
|
|
group_name_to_parent_groups: Mapping[str, List[tasks_group.TasksGroup]],
|
|
task1: _GroupOrTask,
|
|
task2: _GroupOrTask,
|
|
) -> Tuple[List[_GroupOrTask], List[_GroupOrTask]]:
|
|
"""Gets the unique ancestors between two tasks.
|
|
|
|
For example, task1's ancestor groups are [root, G1, G2, G3, task1],
|
|
task2's ancestor groups are [root, G1, G4, task2], then it returns a
|
|
tuple ([G2, G3, task1], [G4, task2]).
|
|
|
|
Args:
|
|
task_name_to_parent_groups: The dict of task name to list of parent
|
|
groups.
|
|
group_name_tor_parent_groups: The dict of group name to list of
|
|
parent groups.
|
|
task1: One of the two tasks.
|
|
task2: The other task.
|
|
|
|
Returns:
|
|
A tuple which are lists of uncommon ancestors for each task.
|
|
"""
|
|
if task1.name in task_name_to_parent_groups:
|
|
task1_groups = task_name_to_parent_groups[task1.name]
|
|
elif task1.name in group_name_to_parent_groups:
|
|
task1_groups = group_name_to_parent_groups[task1.name]
|
|
else:
|
|
raise ValueError(task1.name + ' does not exist.')
|
|
|
|
if task2.name in task_name_to_parent_groups:
|
|
task2_groups = task_name_to_parent_groups[task2.name]
|
|
elif task2.name in group_name_to_parent_groups:
|
|
task2_groups = group_name_to_parent_groups[task2.name]
|
|
else:
|
|
raise ValueError(task2.name + ' does not exist.')
|
|
|
|
both_groups = [task1_groups, task2_groups]
|
|
common_groups_len = sum(
|
|
1 for x in zip(*both_groups) if x == (x[0],) * len(x))
|
|
group1 = task1_groups[common_groups_len:]
|
|
group2 = task2_groups[common_groups_len:]
|
|
return (group1, group2)
|
|
|
|
def _get_dependencies(
|
|
self,
|
|
pipeline: pipeline_context.Pipeline,
|
|
root_group: tasks_group.TasksGroup,
|
|
task_name_to_parent_groups: Mapping[str, List[_GroupOrTask]],
|
|
group_name_to_parent_groups: Mapping[str, List[tasks_group.TasksGroup]],
|
|
group_name_to_group: Mapping[str, tasks_group.TasksGroup],
|
|
condition_channels: Dict[str, dsl.PipelineChannel],
|
|
) -> Mapping[str, List[_GroupOrTask]]:
|
|
"""Gets dependent groups and tasks for all tasks and groups.
|
|
|
|
Args:
|
|
pipeline: The instantiated pipeline object.
|
|
root_group: The root group of the pipeline.
|
|
task_name_to_parent_groups: The dict of task name to list of parent
|
|
groups.
|
|
group_name_to_parent_groups: The dict of group name to list of
|
|
parent groups.
|
|
group_name_to_group: The dict of group name to group.
|
|
condition_channels: The dict of task name to a set of pipeline
|
|
channels referenced by its parent condition groups.
|
|
|
|
Returns:
|
|
A Mapping where key is group/task name, value is a list of dependent
|
|
groups/tasks. The dependencies are calculated in the following way:
|
|
if task2 depends on task1, and their ancestors are
|
|
[root, G1, G2, task1] and [root, G1, G3, G4, task2], then G3 is
|
|
dependent on G2. Basically dependency only exists in the first
|
|
uncommon ancesters in their ancesters chain. Only sibling
|
|
groups/tasks can have dependencies.
|
|
|
|
Raises:
|
|
RuntimeError: if a task depends on a task inside a condition or loop
|
|
group.
|
|
"""
|
|
dependencies = collections.defaultdict(set)
|
|
for task in pipeline.tasks.values():
|
|
upstream_task_names = set()
|
|
task_condition_inputs = list(condition_channels[task.name])
|
|
for channel in task.channel_inputs + task_condition_inputs:
|
|
if channel.task_name:
|
|
upstream_task_names.add(channel.task_name)
|
|
upstream_task_names |= set(task.dependent_tasks)
|
|
|
|
for upstream_task_name in upstream_task_names:
|
|
# the dependent op could be either a BaseOp or an opsgroup
|
|
if upstream_task_name in pipeline.tasks:
|
|
upstream_task = pipeline.tasks[upstream_task_name]
|
|
elif upstream_task_name in group_name_to_group:
|
|
upstream_task = group_name_to_group[upstream_task_name]
|
|
else:
|
|
raise ValueError(
|
|
f'Compiler cannot find task: {upstream_task_name}.')
|
|
|
|
upstream_groups, downstream_groups = self._get_uncommon_ancestors(
|
|
task_name_to_parent_groups=task_name_to_parent_groups,
|
|
group_name_to_parent_groups=group_name_to_parent_groups,
|
|
task1=upstream_task,
|
|
task2=task,
|
|
)
|
|
|
|
# If a task depends on a condition group or a loop group, it
|
|
# must explicitly dependent on a task inside the group. This
|
|
# should not be allowed, because it leads to ambiguous
|
|
# expectations for runtime behaviors.
|
|
dependent_group = group_name_to_group.get(
|
|
upstream_groups[0], None)
|
|
if isinstance(dependent_group,
|
|
(tasks_group.Condition, tasks_group.ParallelFor)):
|
|
raise RuntimeError(
|
|
f'Task {task.name} cannot dependent on any task inside'
|
|
f' the group: {upstream_groups[0]}.')
|
|
|
|
dependencies[downstream_groups[0]].add(upstream_groups[0])
|
|
|
|
return dependencies
|
|
|
|
def _build_spec_by_group(
|
|
self,
|
|
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
|
|
deployment_config: pipeline_spec_pb2.PipelineDeploymentConfig,
|
|
group: tasks_group.TasksGroup,
|
|
inputs: Mapping[str, List[Tuple[dsl.PipelineChannel, str]]],
|
|
dependencies: Dict[str, List[_GroupOrTask]],
|
|
rootgroup_name: str,
|
|
task_name_to_parent_groups: Mapping[str, List[_GroupOrTask]],
|
|
group_name_to_parent_groups: Mapping[str, List[tasks_group.TasksGroup]],
|
|
name_to_for_loop_group: Mapping[str, dsl.ParallelFor],
|
|
) -> None:
|
|
"""Generates IR spec given a TasksGroup.
|
|
|
|
Args:
|
|
pipeline_spec: The pipeline_spec to update in place.
|
|
deployment_config: The deployment_config to hold all executors. The
|
|
spec is updated in place.
|
|
group: The TasksGroup to generate spec for.
|
|
inputs: The inputs dictionary. The keys are group/task names and the
|
|
values are lists of tuples (channel, producing_task_name).
|
|
dependencies: The group dependencies dictionary. The keys are group
|
|
or task names, and the values are lists of dependent groups or
|
|
tasks.
|
|
rootgroup_name: The name of the group root. Used to determine whether
|
|
the component spec for the current group should be the root dag.
|
|
task_name_to_parent_groups: The dict of task name to parent groups.
|
|
Key is task name. Value is a list of ancestor groups including
|
|
the task itself. The list of a given task is sorted in a way that
|
|
the farthest group is the first and the task itself is the last.
|
|
group_name_to_parent_groups: The dict of group name to parent groups.
|
|
Key is the group name. Value is a list of ancestor groups
|
|
including the group itself. The list of a given group is sorted
|
|
in a way that the farthest group is the first and the group
|
|
itself is the last.
|
|
name_to_for_loop_group: The dict of for loop group name to loop
|
|
group.
|
|
"""
|
|
group_component_name = component_utils.sanitize_component_name(
|
|
group.name)
|
|
|
|
if group.name == rootgroup_name:
|
|
group_component_spec = pipeline_spec.root
|
|
else:
|
|
group_component_spec = pipeline_spec.components[
|
|
group_component_name]
|
|
|
|
task_name_to_task_spec = {}
|
|
task_name_to_component_spec = {}
|
|
|
|
# Generate task specs and component specs for the dag.
|
|
subgroups = group.groups + group.tasks
|
|
for subgroup in subgroups:
|
|
|
|
subgroup_inputs = inputs.get(subgroup.name, [])
|
|
subgroup_channels = [channel for channel, _ in subgroup_inputs]
|
|
|
|
subgroup_component_name = (
|
|
component_utils.sanitize_component_name(subgroup.name))
|
|
|
|
tasks_in_current_dag = [
|
|
component_utils.sanitize_task_name(subgroup.name)
|
|
for subgroup in subgroups
|
|
]
|
|
input_parameters_in_current_dag = [
|
|
input_name for input_name in
|
|
group_component_spec.input_definitions.parameters
|
|
]
|
|
input_artifacts_in_current_dag = [
|
|
input_name for input_name in
|
|
group_component_spec.input_definitions.artifacts
|
|
]
|
|
is_parent_component_root = (
|
|
group_component_spec == pipeline_spec.root)
|
|
|
|
if isinstance(subgroup, pipeline_task.PipelineTask):
|
|
|
|
subgroup_task_spec = builder.build_task_spec_for_task(
|
|
task=subgroup,
|
|
parent_component_inputs=group_component_spec
|
|
.input_definitions,
|
|
tasks_in_current_dag=tasks_in_current_dag,
|
|
input_parameters_in_current_dag=input_parameters_in_current_dag,
|
|
input_artifacts_in_current_dag=input_artifacts_in_current_dag,
|
|
)
|
|
task_name_to_task_spec[subgroup.name] = subgroup_task_spec
|
|
|
|
subgroup_component_spec = builder.build_component_spec_for_task(
|
|
task=subgroup)
|
|
task_name_to_component_spec[
|
|
subgroup.name] = subgroup_component_spec
|
|
|
|
executor_label = subgroup_component_spec.executor_label
|
|
|
|
if executor_label not in deployment_config.executors:
|
|
if subgroup.container_spec is not None:
|
|
subgroup_container_spec = builder.build_container_spec_for_task(
|
|
task=subgroup)
|
|
deployment_config.executors[
|
|
executor_label].container.CopyFrom(
|
|
subgroup_container_spec)
|
|
elif subgroup.importer_spec is not None:
|
|
subgroup_importer_spec = builder.build_importer_spec_for_task(
|
|
task=subgroup)
|
|
deployment_config.executors[
|
|
executor_label].importer.CopyFrom(
|
|
subgroup_importer_spec)
|
|
elif isinstance(subgroup, dsl.ParallelFor):
|
|
|
|
# "Punch the hole", adding additional inputs (other than loop
|
|
# arguments which will be handled separately) needed by its
|
|
# subgroups or tasks.
|
|
loop_subgroup_channels = []
|
|
|
|
for channel in subgroup_channels:
|
|
# Skip 'withItems' loop arguments if it's from an inner loop.
|
|
if isinstance(
|
|
channel,
|
|
(for_loop.LoopArgument, for_loop.LoopArgumentVariable
|
|
)) and channel.is_with_items_loop_argument:
|
|
withitems_loop_arg_found_in_self_or_upstream = False
|
|
for group_name in group_name_to_parent_groups[
|
|
subgroup.name][::-1]:
|
|
if group_name in name_to_for_loop_group:
|
|
loop_group = name_to_for_loop_group[group_name]
|
|
if channel.name in loop_group.loop_argument.name:
|
|
withitems_loop_arg_found_in_self_or_upstream = True
|
|
break
|
|
if not withitems_loop_arg_found_in_self_or_upstream:
|
|
continue
|
|
loop_subgroup_channels.append(channel)
|
|
|
|
if subgroup.items_is_pipeline_channel:
|
|
# This loop_argument is based on a pipeline channel, i.e.,
|
|
# rather than a static list, it is either the output of
|
|
# another task or an input as global pipeline parameters.
|
|
loop_subgroup_channels.append(
|
|
subgroup.loop_argument.items_or_pipeline_channel)
|
|
|
|
loop_subgroup_channels.append(subgroup.loop_argument)
|
|
|
|
subgroup_component_spec = builder.build_component_spec_for_group(
|
|
pipeline_channels=loop_subgroup_channels,
|
|
is_root_group=False,
|
|
)
|
|
|
|
subgroup_task_spec = builder.build_task_spec_for_group(
|
|
group=subgroup,
|
|
pipeline_channels=loop_subgroup_channels,
|
|
tasks_in_current_dag=tasks_in_current_dag,
|
|
is_parent_component_root=is_parent_component_root,
|
|
)
|
|
|
|
elif isinstance(subgroup, dsl.Condition):
|
|
|
|
# "Punch the hole", adding inputs needed by its subgroups or
|
|
# tasks.
|
|
condition_subgroup_channels = list(subgroup_channels)
|
|
for operand in [
|
|
subgroup.condition.left_operand,
|
|
subgroup.condition.right_operand,
|
|
]:
|
|
if isinstance(operand, dsl.PipelineChannel):
|
|
condition_subgroup_channels.append(operand)
|
|
|
|
subgroup_component_spec = builder.build_component_spec_for_group(
|
|
pipeline_channels=condition_subgroup_channels,
|
|
is_root_group=False,
|
|
)
|
|
|
|
subgroup_task_spec = builder.build_task_spec_for_group(
|
|
group=subgroup,
|
|
pipeline_channels=condition_subgroup_channels,
|
|
tasks_in_current_dag=tasks_in_current_dag,
|
|
is_parent_component_root=is_parent_component_root,
|
|
)
|
|
|
|
elif isinstance(subgroup, dsl.ExitHandler):
|
|
|
|
subgroup_component_spec = builder.build_component_spec_for_group(
|
|
pipeline_channels=subgroup_channels,
|
|
is_root_group=False,
|
|
)
|
|
|
|
subgroup_task_spec = builder.build_task_spec_for_group(
|
|
group=subgroup,
|
|
pipeline_channels=subgroup_channels,
|
|
tasks_in_current_dag=tasks_in_current_dag,
|
|
is_parent_component_root=is_parent_component_root,
|
|
)
|
|
|
|
else:
|
|
raise RuntimeError(
|
|
f'Unexpected task/group type: Got {subgroup} of type '
|
|
f'{type(subgroup)}.')
|
|
|
|
# Generate dependencies section for this task.
|
|
if dependencies.get(subgroup.name, None):
|
|
group_dependencies = list(dependencies[subgroup.name])
|
|
group_dependencies.sort()
|
|
subgroup_task_spec.dependent_tasks.extend([
|
|
component_utils.sanitize_task_name(dep)
|
|
for dep in group_dependencies
|
|
])
|
|
|
|
# Add component spec if not exists
|
|
if subgroup_component_name not in pipeline_spec.components:
|
|
pipeline_spec.components[subgroup_component_name].CopyFrom(
|
|
subgroup_component_spec)
|
|
|
|
# Add task spec
|
|
group_component_spec.dag.tasks[subgroup.name].CopyFrom(
|
|
subgroup_task_spec)
|
|
|
|
pipeline_spec.deployment_spec.update(
|
|
json_format.MessageToDict(deployment_config))
|
|
|
|
# Surface metrics outputs to the top.
|
|
builder.populate_metrics_in_dag_outputs(
|
|
tasks=group.tasks,
|
|
task_name_to_parent_groups=task_name_to_parent_groups,
|
|
task_name_to_task_spec=task_name_to_task_spec,
|
|
task_name_to_component_spec=task_name_to_component_spec,
|
|
pipeline_spec=pipeline_spec,
|
|
)
|