# Copyright 2022 The Kubeflow Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility methods for compiler implementation that is IR-agnostic.""" import collections import copy from typing import DefaultDict, Dict, List, Mapping, Set, Tuple, Union from kfp import dsl from kfp.dsl import constants from kfp.dsl import for_loop from kfp.dsl import pipeline_channel from kfp.dsl import pipeline_context from kfp.dsl import pipeline_task from kfp.dsl import tasks_group GroupOrTaskType = Union[tasks_group.TasksGroup, pipeline_task.PipelineTask] ILLEGAL_CROSS_DAG_ERROR_PREFIX = 'Illegal task dependency across DSL context managers.' def additional_input_name_for_pipeline_channel( channel_or_name: Union[pipeline_channel.PipelineChannel, str]) -> str: """Gets the name for an additional (compiler-injected) input.""" # Adding a prefix to avoid (reduce chance of) name collision between the # original component inputs and the injected input. return 'pipelinechannel--' + ( channel_or_name.full_name if isinstance( channel_or_name, pipeline_channel.PipelineChannel) else channel_or_name) def get_all_groups( root_group: tasks_group.TasksGroup,) -> List[tasks_group.TasksGroup]: """Gets all groups (not including tasks) in a pipeline. Args: root_group: The root group of a pipeline. Returns: A list of all groups in topological order (parent first). """ all_groups = [] def _get_all_groups_helper( group: tasks_group.TasksGroup, all_groups: List[tasks_group.TasksGroup], ): all_groups.append(group) for group in group.groups: _get_all_groups_helper(group, all_groups) _get_all_groups_helper(root_group, all_groups) return all_groups def get_parent_groups( root_group: tasks_group.TasksGroup, ) -> Tuple[Mapping[str, List[str]], Mapping[str, List[str]]]: """Get parent groups that contain the specified tasks. Each pipeline has a root group. Each group has a list of tasks (leaf) and groups. This function traverse the tree and get ancestor groups for all tasks. Args: root_group: The root group of a pipeline. Returns: A tuple. The first item is a mapping of task names to parent groups, and second item is a mapping of group names to parent groups. A list of parent groups is a list of ancestor groups including the task/group itself. The list is sorted in a way that the farthest parent group is the first and task/group itself is the last. """ def _get_parent_groups_helper( current_groups: List[tasks_group.TasksGroup], tasks_to_groups: Dict[str, List[GroupOrTaskType]], groups_to_groups: Dict[str, List[GroupOrTaskType]], ) -> None: root_group = current_groups[-1] for group in root_group.groups: groups_to_groups[group.name] = [x.name for x in current_groups ] + [group.name] current_groups.append(group) _get_parent_groups_helper( current_groups=current_groups, tasks_to_groups=tasks_to_groups, groups_to_groups=groups_to_groups, ) del current_groups[-1] for task in root_group.tasks: tasks_to_groups[task.name] = [x.name for x in current_groups ] + [task.name] tasks_to_groups = {} groups_to_groups = {} current_groups = [root_group] _get_parent_groups_helper( current_groups=current_groups, tasks_to_groups=tasks_to_groups, groups_to_groups=groups_to_groups, ) return (tasks_to_groups, groups_to_groups) def get_channels_from_condition( operations: List[pipeline_channel.ConditionOperation], collected_channels: list, ) -> None: """Appends to collected_channels each pipeline channels used in each operand of each operation in operations.""" for operation in operations: for operand in [operation.left_operand, operation.right_operand]: if isinstance(operand, pipeline_channel.PipelineChannel): collected_channels.append(operand) def get_condition_channels_for_tasks( root_group: tasks_group.TasksGroup, ) -> Mapping[str, Set[pipeline_channel.PipelineChannel]]: """Gets channels referenced in conditions of tasks' parents. Args: root_group: The root group of a pipeline. Returns: A mapping of task name to a set of pipeline channels appeared in its parent dsl.Condition groups. """ conditions = collections.defaultdict(set) def _get_condition_channels_for_tasks_helper( group, current_conditions_channels, ): new_current_conditions_channels = current_conditions_channels if isinstance(group, tasks_group._ConditionBase): new_current_conditions_channels = list(current_conditions_channels) get_channels_from_condition( group.conditions, new_current_conditions_channels, ) for task in group.tasks: for channel in new_current_conditions_channels: conditions[task.name].add(channel) for group in group.groups: _get_condition_channels_for_tasks_helper( group, new_current_conditions_channels) _get_condition_channels_for_tasks_helper(root_group, []) return conditions def get_inputs_for_all_groups( pipeline: pipeline_context.Pipeline, task_name_to_parent_groups: Mapping[str, List[str]], group_name_to_parent_groups: Mapping[str, List[str]], condition_channels: Mapping[str, Set[pipeline_channel.PipelineParameterChannel]], name_to_for_loop_group: Mapping[str, tasks_group.ParallelFor], ) -> Mapping[str, List[Tuple[pipeline_channel.PipelineChannel, str]]]: """Get inputs and outputs of each group and op. Args: pipeline: The instantiated pipeline object. task_name_to_parent_groups: The dict of task name to list of parent groups. group_name_to_parent_groups: The dict of group name to list of parent groups. condition_channels: The dict of task name to a set of pipeline channels referenced by its parent condition groups. name_to_for_loop_group: The dict of for loop group name to loop group. Returns: A mapping with key being the group/task names and values being list of tuples (channel, producing_task_name). producing_task_name is the name of the task that produces the channel. If the channel is a pipeline argument (no producer task), then producing_task_name is None. """ inputs = collections.defaultdict(set) for task in pipeline.tasks.values(): # task's inputs and all channels used in conditions for that task are # considered. task_condition_inputs = list(condition_channels[task.name]) for channel in task.channel_inputs + task_condition_inputs: # If the value is already provided (immediate value), then no # need to expose it as input for its parent groups. if getattr(channel, 'value', None): continue # channels_to_add could be a list of PipelineChannels when loop # args are involved. Given a nested loops example as follows: # # def my_pipeline(loop_parameter: list): # with dsl.ParallelFor(loop_parameter) as item: # with dsl.ParallelFor(item.p_a) as item_p_a: # print_op(item_p_a.q_a) # # The print_op takes an input of # {{channel:task=;name=loop_parameter-loop-item-subvar-p_a-loop-item-subvar-q_a;}}. # Given this, we calculate the list of PipelineChannels potentially # needed by across DAG levels as follows: # # [{{channel:task=;name=loop_parameter-loop-item-subvar-p_a-loop-item-subvar-q_a}}, # {{channel:task=;name=loop_parameter-loop-item-subvar-p_a-loop-item}}, # {{channel:task=;name=loop_parameter-loop-item-subvar-p_a}}, # {{channel:task=;name=loop_parameter-loop-item}}, # {{chaenel:task=;name=loop_parameter}}] # # For the above example, the first loop needs the input of # {{channel:task=;name=loop_parameter}}, # the second loop needs the input of # {{channel:task=;name=loop_parameter-loop-item}} # and the print_op needs the input of # {{channel:task=;name=loop_parameter-loop-item-subvar-p_a-loop-item}} # # When we traverse a DAG in a top-down direction, we add channels # from the end, and pop it out when it's no longer needed by the # sub-DAG. # When we traverse a DAG in a bottom-up direction, we add # channels from the front, and pop it out when it's no longer # needed by the parent DAG. channels_to_add = collections.deque() channel_to_add = channel while isinstance(channel_to_add, ( for_loop.LoopParameterArgument, for_loop.LoopArtifactArgument, for_loop.LoopArgumentVariable, )): channels_to_add.append(channel_to_add) if isinstance(channel_to_add, for_loop.LoopArgumentVariable): channel_to_add = channel_to_add.loop_argument else: channel_to_add = channel_to_add.items_or_pipeline_channel if isinstance(channel_to_add, pipeline_channel.PipelineChannel): channels_to_add.append(channel_to_add) if channel.task: # The PipelineChannel is produced by a task. upstream_task = channel.task upstream_groups, downstream_groups = ( _get_uncommon_ancestors( task_name_to_parent_groups=task_name_to_parent_groups, group_name_to_parent_groups=group_name_to_parent_groups, task1=upstream_task, task2=task, )) for i, group_name in enumerate(downstream_groups): if i == 0: # If it is the first uncommon downstream group, then # the input comes from the first uncommon upstream # group. producer_task = upstream_groups[0] else: # If not the first downstream group, then the input # is passed down from its ancestor groups so the # upstream group is None. producer_task = None inputs[group_name].add((channels_to_add[-1], producer_task)) if group_name in name_to_for_loop_group: loop_group = name_to_for_loop_group[group_name] # Pop out the last elements from channels_to_add if it # is found in the current (loop) DAG. Downstreams # would only need the more specific versions for it. if channels_to_add[ -1].full_name in loop_group.loop_argument.full_name: channels_to_add.pop() if not channels_to_add: break else: # The PipelineChannel is not produced by a task. It's either # a top-level pipeline input, or a constant value to loop # items. # TODO: revisit if this is correct. if getattr(task, 'is_exit_handler', False): continue # For PipelineChannel as a result of constant value used as # loop items, we have to go from bottom-up because the # PipelineChannel can be originated from the middle a DAG, # which is not needed and visible to its parent DAG. if isinstance(channel, ( for_loop.LoopParameterArgument, for_loop.LoopArtifactArgument, for_loop.LoopArgumentVariable, )) and channel.is_with_items_loop_argument: for group_name in task_name_to_parent_groups[ task.name][::-1]: inputs[group_name].add((channels_to_add[0], None)) if group_name in name_to_for_loop_group: # for example: # loop_group.loop_argument.name = 'loop-item-param-1' # channel.name = 'loop-item-param-1-subvar-a' loop_group = name_to_for_loop_group[group_name] if channels_to_add[ 0].full_name in loop_group.loop_argument.full_name: channels_to_add.popleft() if not channels_to_add: break else: # For PipelineChannel from pipeline input, go top-down # just like we do for PipelineChannel produced by a task. for group_name in task_name_to_parent_groups[task.name]: inputs[group_name].add((channels_to_add[-1], None)) if group_name in name_to_for_loop_group: loop_group = name_to_for_loop_group[group_name] if channels_to_add[ -1].full_name in loop_group.loop_argument.full_name: channels_to_add.pop() if not channels_to_add: break return inputs class InvalidTopologyException(Exception): pass def validate_parallel_for_fan_in_consumption_legal( consumer_task_name: str, upstream_groups: List[str], group_name_to_group: Dict[str, tasks_group.TasksGroup], ) -> None: """Checks that a dsl.Collected object is being used results in an unambiguous pipeline topology and is therefore legal. Args: consumer_task_name: The name of the consumer task. upstream_groups: The names of the producer task's upstream groups, ordered from outermost group at beginning to producer task at end. This is produced by produced by _get_uncommon_ancestors. group_name_to_group: Map of group name to TasksGroup, for fast lookups. """ # handles cases like this: # @dsl.pipeline # def my_pipeline(): # with dsl.ParallelFor([1, 2, 3]) as x: # t = double(num=x) # x = add(dsl.Collected(t.output)) # # and this: # @dsl.pipeline # def my_pipeline(): # t = double(num=1) # x = add(dsl.Collected(t.output)) producer_task_idx = -1 producer_task_name = upstream_groups[producer_task_idx] if all(group_name_to_group[group_name].group_type != tasks_group.TasksGroupType.FOR_LOOP for group_name in upstream_groups[:producer_task_idx]): raise InvalidTopologyException( f'dsl.{for_loop.Collected.__name__} can only be used to fan-in outputs produced by a task within a dsl.{tasks_group.ParallelFor.__name__} context to a task outside of the dsl.{tasks_group.ParallelFor.__name__} context. Producer task {producer_task_name} is either not in a dsl.{tasks_group.ParallelFor.__name__} context or is only in a dsl.{tasks_group.ParallelFor.__name__} that also contains consumer task {consumer_task_name}.' ) # illegal if the producer has a parent conditional outside of its outermost for loop, since the for loop may or may not be executed # for example, what happens if text == 'b'? the resulting execution behavior is ambiguous. # # @dsl.pipeline # def my_pipeline(text: str = ''): # with dsl.Condition(text == 'a'): # with dsl.ParallelFor([1, 2, 3]) as x: # t = double(num=x) # x = add(nums=dsl.Collected(t.output)) outermost_uncommon_upstream_group = upstream_groups[0] group = group_name_to_group[outermost_uncommon_upstream_group] if group.group_type in [ tasks_group.TasksGroupType.CONDITION, tasks_group.TasksGroupType.EXIT_HANDLER, ]: raise InvalidTopologyException( f'{ILLEGAL_CROSS_DAG_ERROR_PREFIX} When using dsl.{for_loop.Collected.__name__} to fan-in outputs from a task within a dsl.{tasks_group.ParallelFor.__name__} context, the dsl.{tasks_group.ParallelFor.__name__} context manager cannot be nested within a dsl.{group.__class__.__name__} context manager unless the consumer task is too. Task {consumer_task_name} consumes from {producer_task_name} within a dsl.{group.__class__.__name__} context.' ) elif group.group_type != tasks_group.TasksGroupType.FOR_LOOP: raise ValueError( f'Got unexpected group type when validating fanning-in outputs from task in dsl.{tasks_group.ParallelFor.__name__}: {group.group_type}' ) def make_new_channel_for_collected_outputs( channel_name: str, starting_channel: pipeline_channel.PipelineChannel, task_name: str, ) -> pipeline_channel.PipelineChannel: """Creates a new PipelineParameterChannel/PipelineArtifactChannel list for a dsl.Collected channel from the original task output.""" if isinstance(starting_channel, pipeline_channel.PipelineParameterChannel): return pipeline_channel.PipelineParameterChannel( channel_name, channel_type='LIST', task_name=task_name) elif isinstance(starting_channel, pipeline_channel.PipelineArtifactChannel): return pipeline_channel.PipelineArtifactChannel( channel_name, channel_type=starting_channel.channel_type, task_name=task_name, is_artifact_list=True) else: ValueError( f'Got unknown PipelineChannel: {starting_channel!r}. Expected an instance of {pipeline_channel.PipelineArtifactChannel.__name__!r} or {pipeline_channel.PipelineParameterChannel.__name__!r}.' ) def get_outputs_for_all_groups( pipeline: pipeline_context.Pipeline, task_name_to_parent_groups: Mapping[str, List[str]], group_name_to_parent_groups: Mapping[str, List[str]], all_groups: List[tasks_group.TasksGroup], pipeline_outputs_dict: Dict[str, pipeline_channel.PipelineChannel] ) -> Tuple[DefaultDict[str, Dict[str, pipeline_channel.PipelineChannel]], Dict[ str, pipeline_channel.PipelineChannel]]: """Gets a dictionary of all TasksGroup names to an inner dictionary. The inner dictionary is TasksGroup output keys to channels corresponding to those keys. It constructs this dictionary from both data passing within the pipeline body, as well as the outputs returned from the pipeline (e.g., return dsl.Collected(...)). Also returns as the second item of tuple the updated pipeline_outputs_dict. This dict is modified so that the values (PipelineChannel) references the group that surfaces the task output, instead of the original task that produced it. """ # unlike inputs, which will be surfaced as component input parameters, # consumers of surfaced outputs need to have a reference to what the parent # component calls them when they surface them, which will be different than # the producer task name and channel name (the information contained in the # pipeline channel) # for this reason, we use additional_input_name_for_pipeline_channel here # to set the name of the surfaced output once group_name_to_group = {group.name: group for group in all_groups} group_name_to_children = { group.name: [group.name for group in group.groups] + [task.name for task in group.tasks] for group in all_groups } outputs = collections.defaultdict(dict) processed_oneofs: Set[pipeline_channel.OneOfMixin] = set() # handle dsl.Collected consumed by tasks for task in pipeline.tasks.values(): for channel in task.channel_inputs: # TODO: migrate Collected to OneOfMixin style implementation, # then simplify this logic to align with OneOfMixin logic if isinstance(channel, dsl.Collected): producer_task = pipeline.tasks[channel.task_name] consumer_task = task upstream_groups, downstream_groups = ( _get_uncommon_ancestors( task_name_to_parent_groups=task_name_to_parent_groups, group_name_to_parent_groups=group_name_to_parent_groups, task1=producer_task, task2=consumer_task, )) validate_parallel_for_fan_in_consumption_legal( consumer_task_name=consumer_task.name, upstream_groups=upstream_groups, group_name_to_group=group_name_to_group, ) # producer_task's immediate parent group and the name by which # to surface the channel surfaced_output_name = additional_input_name_for_pipeline_channel( channel) # the highest-level task group that "consumes" the # collected output parent_consumer = downstream_groups[0] producer_task_name = upstream_groups.pop() # process from the upstream groups from the inside out for upstream_name in reversed(upstream_groups): outputs[upstream_name][ surfaced_output_name] = make_new_channel_for_collected_outputs( channel_name=channel.name, starting_channel=channel.output, task_name=producer_task_name, ) # on each iteration, mutate the channel being consumed so # that it references the last parent group surfacer channel.name = surfaced_output_name channel.task_name = upstream_name # for the next iteration, set the consumer to the current # surfacer (parent group) producer_task_name = upstream_name parent_of_current_surfacer = group_name_to_parent_groups[ upstream_name][-2] if parent_consumer in group_name_to_children[ parent_of_current_surfacer]: break elif isinstance(channel, pipeline_channel.OneOfMixin): if channel in processed_oneofs: continue # we want to mutate the oneof's inner channels ONLY where they # are used in the oneof, not if they are used separately # for example: we should only modify the copy of # foo.output in dsl.OneOf(foo.output), not if foo.output is # passed to another downstream task channel.channels = [copy.copy(c) for c in channel.channels] for inner_channel in channel.channels: producer_task = pipeline.tasks[inner_channel.task_name] consumer_task = task upstream_groups, downstream_groups = ( _get_uncommon_ancestors( task_name_to_parent_groups=task_name_to_parent_groups, group_name_to_parent_groups=group_name_to_parent_groups, task1=producer_task, task2=consumer_task, )) surfaced_output_name = additional_input_name_for_pipeline_channel( inner_channel) # 1. get the oneof # 2. find the task group that surfaced it # 3. find the inner tasks reponsible for upstream_name in reversed(upstream_groups): # skip the first task processed, since we don't need to add new outputs for the innermost task if upstream_name == inner_channel.task.name: continue # # once we've hit the outermost condition-branches group, we're done if upstream_name == channel.condition_branches_group.name: outputs[upstream_name][channel.name] = channel break # copy as a mechanism for "freezing" the inner channel # before we make updates for the next iteration outputs[upstream_name][ surfaced_output_name] = copy.copy(inner_channel) inner_channel.name = surfaced_output_name inner_channel.task_name = upstream_name processed_oneofs.add(channel) # handle dsl.Collected returned from pipeline # TODO: consider migrating dsl.Collected returns to pattern used by dsl.OneOf, where the OneOf constructor returns a parameter/artifact channel, which fits in more cleanly into the existing compiler abtractions for output_key, channel in pipeline_outputs_dict.items(): if isinstance(channel, for_loop.Collected): surfaced_output_name = additional_input_name_for_pipeline_channel( channel) upstream_groups = task_name_to_parent_groups[channel.task_name][1:] producer_task_name = upstream_groups.pop() # process upstream groups from the inside out, until getting to the pipeline level for upstream_name in reversed(upstream_groups): new_channel = make_new_channel_for_collected_outputs( channel_name=channel.name, starting_channel=channel.output, task_name=producer_task_name, ) # on each iteration, mutate the channel being consumed so # that it references the last parent group surfacer channel.name = surfaced_output_name channel.task_name = upstream_name # for the next iteration, set the consumer to the current # surfacer (parent group) producer_task_name = upstream_name outputs[upstream_name][surfaced_output_name] = new_channel # after surfacing from all inner TasksGroup, change the PipelineChannel output to also return from the correct TasksGroup pipeline_outputs_dict[ output_key] = make_new_channel_for_collected_outputs( channel_name=surfaced_output_name, starting_channel=channel.output, task_name=upstream_name, ) elif isinstance(channel, pipeline_channel.OneOfMixin): # if the output has already been consumed by a task before it is returned, we don't need to reprocess it if channel in processed_oneofs: continue # we want to mutate the oneof's inner channels ONLY where they # are used in the oneof, not if they are used separately # for example: we should only modify the copy of # foo.output in dsl.OneOf(foo.output), not if foo.output is passed # to another downstream task channel.channels = [copy.copy(c) for c in channel.channels] for inner_channel in channel.channels: producer_task = pipeline.tasks[inner_channel.task_name] upstream_groups = task_name_to_parent_groups[ inner_channel.task_name][1:] surfaced_output_name = additional_input_name_for_pipeline_channel( inner_channel) # 1. get the oneof # 2. find the task group that surfaced it # 3. find the inner tasks reponsible for upstream_name in reversed(upstream_groups): # skip the first task processed, since we don't need to add new outputs for the innermost task if upstream_name == inner_channel.task.name: continue # # once we've hit the outermost condition-branches group, we're done if upstream_name == channel.condition_branches_group.name: outputs[upstream_name][channel.name] = channel break # copy as a mechanism for "freezing" the inner channel # before we make updates for the next iteration outputs[upstream_name][surfaced_output_name] = copy.copy( inner_channel) inner_channel.name = surfaced_output_name inner_channel.task_name = upstream_name return outputs, pipeline_outputs_dict def _get_uncommon_ancestors( task_name_to_parent_groups: Mapping[str, List[str]], group_name_to_parent_groups: Mapping[str, List[str]], task1: GroupOrTaskType, task2: GroupOrTaskType, ) -> Tuple[List[str], List[str]]: """Gets the unique ancestors between two tasks. For example, task1's ancestor groups are [root, G1, G2, G3, task1], task2's ancestor groups are [root, G1, G4, task2], then it returns a tuple ([G2, G3, task1], [G4, task2]). Args: task_name_to_parent_groups: The dict of task name to list of parent groups. group_name_tor_parent_groups: The dict of group name to list of parent groups. task1: One of the two tasks. task2: The other task. Returns: A tuple which are lists of uncommon ancestors for each task. """ if task1.name in task_name_to_parent_groups: task1_groups = task_name_to_parent_groups[task1.name] elif task1.name in group_name_to_parent_groups: task1_groups = group_name_to_parent_groups[task1.name] else: raise ValueError(task1.name + ' does not exist.') if task2.name in task_name_to_parent_groups: task2_groups = task_name_to_parent_groups[task2.name] elif task2.name in group_name_to_parent_groups: task2_groups = group_name_to_parent_groups[task2.name] else: raise ValueError(task2.name + ' does not exist.') both_groups = [task1_groups, task2_groups] common_groups_len = sum( 1 for x in zip(*both_groups) if x == (x[0],) * len(x)) group1 = task1_groups[common_groups_len:] group2 = task2_groups[common_groups_len:] return (group1, group2) def get_dependencies( pipeline: pipeline_context.Pipeline, task_name_to_parent_groups: Mapping[str, List[str]], group_name_to_parent_groups: Mapping[str, List[str]], group_name_to_group: Mapping[str, tasks_group.TasksGroup], condition_channels: Dict[str, pipeline_channel.PipelineChannel], ) -> Mapping[str, List[GroupOrTaskType]]: """Gets dependent groups and tasks for all tasks and groups. Args: pipeline: The instantiated pipeline object. task_name_to_parent_groups: The dict of task name to list of parent groups. group_name_to_parent_groups: The dict of group name to list of parent groups. group_name_to_group: The dict of group name to group. condition_channels: The dict of task name to a set of pipeline channels referenced by its parent condition groups. Returns: A Mapping where key is group/task name, value is a list of dependent groups/tasks. The dependencies are calculated in the following way: if task2 depends on task1, and their ancestors are [root, G1, G2, task1] and [root, G1, G3, G4, task2], then G3 is dependent on G2. Basically dependency only exists in the first uncommon ancesters in their ancesters chain. Only sibling groups/tasks can have dependencies. Raises: RuntimeError: if a task depends on a task inside a condition or loop group. """ dependencies = collections.defaultdict(set) for task in pipeline.tasks.values(): upstream_task_names: Set[Union[pipeline_task.PipelineTask, tasks_group.TasksGroup]] = set() task_condition_inputs = list(condition_channels[task.name]) all_channels = task.channel_inputs + task_condition_inputs upstream_task_names.update( {channel.task for channel in all_channels if channel.task}) # dependent tasks is tasks on which .after was called and can only be the names of PipelineTasks, not TasksGroups upstream_task_names.update( {pipeline.tasks[after_task] for after_task in task.dependent_tasks}) for upstream_task in upstream_task_names: upstream_names, downstream_names = _get_uncommon_ancestors( task_name_to_parent_groups=task_name_to_parent_groups, group_name_to_parent_groups=group_name_to_parent_groups, task1=upstream_task, task2=task, ) # uncommon upstream ancestor check uncommon_upstream_names = copy.deepcopy(upstream_names) # because a task's `upstream_groups` contains the upstream task's name, remove it so we can check it's parent TasksGroups uncommon_upstream_names.remove(upstream_task.name) if uncommon_upstream_names: upstream_parent_group = group_name_to_group.get( uncommon_upstream_names[0], None) # downstream tasks cannot depend on tasks in a dsl.ExitHandler or condition context if the downstream task is not also in the same context # this applies for .after() and data exchange if isinstance( upstream_parent_group, (tasks_group._ConditionBase, tasks_group.ExitHandler)): raise InvalidTopologyException( f'{ILLEGAL_CROSS_DAG_ERROR_PREFIX} A downstream task cannot depend on an upstream task within a dsl.{upstream_parent_group.__class__.__name__} context unless the downstream is within that context too. Found task {task.name} which depends on upstream task {upstream_task.name} within an uncommon dsl.{upstream_parent_group.__class__.__name__} context.' ) # same for dsl.ParallelFor, but only throw on data exchange (.after is allowed) # TODO: migrate Collected to OneOfMixin style implementation, # then make this validation dsl.Collected-aware elif isinstance(upstream_parent_group, tasks_group.ParallelFor): upstream_tasks_that_downstream_consumers_from = [ channel.task.name for channel in task._channel_inputs if channel.task is not None ] has_data_exchange = upstream_task.name in upstream_tasks_that_downstream_consumers_from # don't raise for .after if has_data_exchange: raise InvalidTopologyException( f'{ILLEGAL_CROSS_DAG_ERROR_PREFIX} A downstream task cannot depend on an upstream task within a dsl.{upstream_parent_group.__class__.__name__} context unless the downstream is within that context too or the outputs are begin fanned-in to a list using dsl.{for_loop.Collected.__name__}. Found task {task.name} which depends on upstream task {upstream_task.name} within an uncommon dsl.{upstream_parent_group.__class__.__name__} context.' ) dependencies[downstream_names[0]].add(upstream_names[0]) return dependencies def recursive_replace_placeholders(data: Union[Dict, List], old_value: str, new_value: str) -> Union[Dict, List]: """Recursively replaces values in a nested dict/list object. This method is used to replace PipelineChannel objects with input parameter placeholders in a nested object like worker_pool_specs for custom jobs. Args: data: A nested object that can contain dictionaries and/or lists. old_value: The value that will be replaced. new_value: The value to replace the old value with. Returns: A copy of data with all occurences of old_value replaced by new_value. """ if isinstance(data, dict): return { k: recursive_replace_placeholders(v, old_value, new_value) for k, v in data.items() } elif isinstance(data, list): return [ recursive_replace_placeholders(i, old_value, new_value) for i in data ] else: if isinstance(data, pipeline_channel.PipelineChannel): data = str(data) return new_value if data == old_value else data # Note that cpu_to_float assumes the string has already been validated by the _validate_cpu_request_limit method. def _cpu_to_float(cpu: str) -> float: """Converts the validated CPU request/limit string and to its numeric float value. Args: cpu: CPU requests or limits. This string should be a number or a number followed by an "m" to indicate millicores (1/1000). For more information, see `Specify a CPU Request and a CPU Limit Returns: The numeric value (float) of the cpu request/limit. """ return float(cpu[:-1]) / 1000 if cpu.endswith('m') else float(cpu) # Note that memory_to_float assumes the string has already been validated by the _validate_memory_request_limit method. def _memory_to_float(memory: str) -> float: """Converts the validated memory request/limit string to its numeric value. Args: memory: Memory requests or limits. This string should be a number or a number followed by one of "E", "Ei", "P", "Pi", "T", "Ti", "G", "Gi", "M", "Mi", "K", or "Ki". Returns: The numeric value (float) of the memory request/limit. """ if memory.endswith('E'): memory = float(memory[:-1]) * constants._E / constants._G elif memory.endswith('Ei'): memory = float(memory[:-2]) * constants._EI / constants._G elif memory.endswith('P'): memory = float(memory[:-1]) * constants._P / constants._G elif memory.endswith('Pi'): memory = float(memory[:-2]) * constants._PI / constants._G elif memory.endswith('T'): memory = float(memory[:-1]) * constants._T / constants._G elif memory.endswith('Ti'): memory = float(memory[:-2]) * constants._TI / constants._G elif memory.endswith('G'): memory = float(memory[:-1]) elif memory.endswith('Gi'): memory = float(memory[:-2]) * constants._GI / constants._G elif memory.endswith('M'): memory = float(memory[:-1]) * constants._M / constants._G elif memory.endswith('Mi'): memory = float(memory[:-2]) * constants._MI / constants._G elif memory.endswith('K'): memory = float(memory[:-1]) * constants._K / constants._G elif memory.endswith('Ki'): memory = float(memory[:-2]) * constants._KI / constants._G else: # By default interpret as a plain integer, in the unit of Bytes. memory = float(memory) / constants._G return memory