765 lines
26 KiB
Python
765 lines
26 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
__all__ = [
|
|
'InputSpec',
|
|
'OutputSpec',
|
|
|
|
'InputValuePlaceholder',
|
|
'InputPathPlaceholder',
|
|
'OutputPathPlaceholder',
|
|
'InputUriPlaceholder',
|
|
'OutputUriPlaceholder',
|
|
'InputMetadataPlaceholder',
|
|
'InputOutputPortNamePlaceholder',
|
|
'OutputMetadataPlaceholder',
|
|
'ConcatPlaceholder',
|
|
'IsPresentPlaceholder',
|
|
'IfPlaceholderStructure',
|
|
'IfPlaceholder',
|
|
|
|
'ContainerSpec',
|
|
'ContainerImplementation',
|
|
|
|
'ComponentSpec',
|
|
|
|
'ComponentReference',
|
|
|
|
'GraphInputReference',
|
|
'GraphInputArgument',
|
|
'TaskOutputReference',
|
|
'TaskOutputArgument',
|
|
|
|
'EqualsPredicate',
|
|
'NotEqualsPredicate',
|
|
'GreaterThanPredicate',
|
|
'GreaterThanOrEqualPredicate',
|
|
'LessThenPredicate',
|
|
'LessThenOrEqualPredicate',
|
|
'NotPredicate',
|
|
'AndPredicate',
|
|
'OrPredicate',
|
|
|
|
'RetryStrategySpec',
|
|
'CachingStrategySpec',
|
|
'ExecutionOptionsSpec',
|
|
'TaskSpec',
|
|
|
|
'GraphSpec',
|
|
'GraphImplementation',
|
|
|
|
'PipelineRunSpec',
|
|
]
|
|
|
|
from collections import OrderedDict
|
|
|
|
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
|
|
|
|
from .modelbase import ModelBase
|
|
|
|
|
|
PrimitiveTypes = Union[str, int, float, bool]
|
|
PrimitiveTypesIncludingNone = Optional[PrimitiveTypes]
|
|
|
|
TypeSpecType = Union[str, Dict, List]
|
|
|
|
|
|
class InputSpec(ModelBase):
|
|
'''Describes the component input specification'''
|
|
def __init__(self,
|
|
name: str,
|
|
type: Optional[TypeSpecType] = None,
|
|
description: Optional[str] = None,
|
|
default: Optional[PrimitiveTypes] = None,
|
|
optional: Optional[bool] = False,
|
|
annotations: Optional[Dict[str, Any]] = None,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class OutputSpec(ModelBase):
|
|
'''Describes the component output specification'''
|
|
def __init__(self,
|
|
name: str,
|
|
type: Optional[TypeSpecType] = None,
|
|
description: Optional[str] = None,
|
|
annotations: Optional[Dict[str, Any]] = None,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class InputValuePlaceholder(ModelBase): #Non-standard attr names
|
|
'''Represents the command-line argument placeholder that will be replaced at run-time by the input argument value.'''
|
|
_serialized_names = {
|
|
'input_name': 'inputValue',
|
|
}
|
|
|
|
def __init__(self,
|
|
input_name: str,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class InputPathPlaceholder(ModelBase): #Non-standard attr names
|
|
'''Represents the command-line argument placeholder that will be replaced at run-time by a local file path pointing to a file containing the input argument value.'''
|
|
_serialized_names = {
|
|
'input_name': 'inputPath',
|
|
}
|
|
|
|
def __init__(self,
|
|
input_name: str,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class OutputPathPlaceholder(ModelBase): #Non-standard attr names
|
|
'''Represents the command-line argument placeholder that will be replaced at run-time by a local file path pointing to a file where the program should write its output data.'''
|
|
_serialized_names = {
|
|
'output_name': 'outputPath',
|
|
}
|
|
|
|
def __init__(self,
|
|
output_name: str,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class InputUriPlaceholder(ModelBase): # Non-standard attr names
|
|
"""Represents a placeholder for the URI of an input artifact.
|
|
|
|
Represents the command-line argument placeholder that will be replaced at
|
|
run-time by the URI of the input artifact argument.
|
|
"""
|
|
_serialized_names = {
|
|
'input_name': 'inputUri',
|
|
}
|
|
|
|
def __init__(self,
|
|
input_name: str,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class OutputUriPlaceholder(ModelBase): # Non-standard attr names
|
|
"""Represents a placeholder for the URI of an output artifact.
|
|
|
|
Represents the command-line argument placeholder that will be replaced at
|
|
run-time by a URI of the output artifac where the program should write its
|
|
output data.
|
|
"""
|
|
_serialized_names = {
|
|
'output_name': 'outputUri',
|
|
}
|
|
|
|
def __init__(self,
|
|
output_name: str,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class InputMetadataPlaceholder(ModelBase): # Non-standard attr names
|
|
"""Represents the file path to an input artifact metadata.
|
|
|
|
During runtime, this command-line argument placeholder will be replaced
|
|
by the path where the metadata file associated with this artifact has been
|
|
written to. Currently only supported in v2 components.
|
|
"""
|
|
_serialized_names = {
|
|
'input_name': 'inputMetadata',
|
|
}
|
|
|
|
def __init__(self, input_name: str):
|
|
super().__init__(locals())
|
|
|
|
|
|
class InputOutputPortNamePlaceholder(ModelBase): # Non-standard attr names
|
|
"""Represents the output port name of an input artifact.
|
|
|
|
During compile time, this command-line argument placeholder will be replaced
|
|
by the actual output port name used by the producer task. Currently only
|
|
supported in v2 components.
|
|
"""
|
|
_serialized_names = {
|
|
'input_name': 'inputOutputPortName',
|
|
}
|
|
|
|
def __init__(self, input_name: str):
|
|
super().__init__(locals())
|
|
|
|
|
|
class OutputMetadataPlaceholder(ModelBase): # Non-standard attr names
|
|
"""Represents the output metadata JSON file location of this task.
|
|
|
|
This file will encode the metadata information produced by this task:
|
|
- Artifacts metadata, but not the content of the artifact, and
|
|
- output parameters.
|
|
|
|
Only supported in v2 components.
|
|
"""
|
|
_serialized_names = {
|
|
'output_name': 'outputMetadata',
|
|
}
|
|
|
|
def __init__(self, output_name):
|
|
super().__init__(locals())
|
|
|
|
|
|
CommandlineArgumentType = Union[
|
|
str,
|
|
InputValuePlaceholder,
|
|
InputPathPlaceholder,
|
|
OutputPathPlaceholder,
|
|
InputUriPlaceholder,
|
|
OutputUriPlaceholder,
|
|
InputMetadataPlaceholder,
|
|
InputOutputPortNamePlaceholder,
|
|
OutputMetadataPlaceholder,
|
|
'ConcatPlaceholder',
|
|
'IfPlaceholder',
|
|
]
|
|
|
|
|
|
class ConcatPlaceholder(ModelBase): #Non-standard attr names
|
|
'''Represents the command-line argument placeholder that will be replaced at run-time by the concatenated values of its items.'''
|
|
_serialized_names = {
|
|
'items': 'concat',
|
|
}
|
|
|
|
def __init__(self,
|
|
items: List[CommandlineArgumentType],
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class IsPresentPlaceholder(ModelBase): #Non-standard attr names
|
|
'''Represents the command-line argument placeholder that will be replaced at run-time by a boolean value specifying whether the caller has passed an argument for the specified optional input.'''
|
|
_serialized_names = {
|
|
'input_name': 'isPresent',
|
|
}
|
|
|
|
def __init__(self,
|
|
input_name: str,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
IfConditionArgumentType = Union[bool, str, IsPresentPlaceholder, InputValuePlaceholder]
|
|
|
|
|
|
class IfPlaceholderStructure(ModelBase): #Non-standard attr names
|
|
'''Used in by the IfPlaceholder - the command-line argument placeholder that will be replaced at run-time by the expanded value of either "then_value" or "else_value" depending on the submissio-time resolved value of the "cond" predicate.'''
|
|
_serialized_names = {
|
|
'condition': 'cond',
|
|
'then_value': 'then',
|
|
'else_value': 'else',
|
|
}
|
|
|
|
def __init__(self,
|
|
condition: IfConditionArgumentType,
|
|
then_value: Union[CommandlineArgumentType, List[CommandlineArgumentType]],
|
|
else_value: Optional[Union[CommandlineArgumentType, List[CommandlineArgumentType]]] = None,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class IfPlaceholder(ModelBase): #Non-standard attr names
|
|
'''Represents the command-line argument placeholder that will be replaced at run-time by the expanded value of either "then_value" or "else_value" depending on the submissio-time resolved value of the "cond" predicate.'''
|
|
_serialized_names = {
|
|
'if_structure': 'if',
|
|
}
|
|
|
|
def __init__(self,
|
|
if_structure: IfPlaceholderStructure,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class ContainerSpec(ModelBase):
|
|
'''Describes the container component implementation.'''
|
|
_serialized_names = {
|
|
'file_outputs': 'fileOutputs', #TODO: rename to something like legacy_unconfigurable_output_paths
|
|
}
|
|
|
|
def __init__(self,
|
|
image: str,
|
|
command: Optional[List[CommandlineArgumentType]] = None,
|
|
args: Optional[List[CommandlineArgumentType]] = None,
|
|
env: Optional[Mapping[str, str]] = None,
|
|
file_outputs: Optional[Mapping[str, str]] = None, #TODO: rename to something like legacy_unconfigurable_output_paths
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class ContainerImplementation(ModelBase):
|
|
'''Represents the container component implementation.'''
|
|
def __init__(self,
|
|
container: ContainerSpec,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
ImplementationType = Union[ContainerImplementation, 'GraphImplementation']
|
|
|
|
|
|
class MetadataSpec(ModelBase):
|
|
def __init__(self,
|
|
annotations: Optional[Dict[str, str]] = None,
|
|
labels: Optional[Dict[str, str]] = None,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class ComponentSpec(ModelBase):
|
|
'''Component specification. Describes the metadata (name, description, annotations and labels), the interface (inputs and outputs) and the implementation of the component.'''
|
|
def __init__(
|
|
self,
|
|
name: Optional[str] = None, #? Move to metadata?
|
|
description: Optional[str] = None, #? Move to metadata?
|
|
metadata: Optional[MetadataSpec] = None,
|
|
inputs: Optional[List[InputSpec]] = None,
|
|
outputs: Optional[List[OutputSpec]] = None,
|
|
implementation: Optional[ImplementationType] = None,
|
|
version: Optional[str] = 'google.com/cloud/pipelines/component/v1',
|
|
#tags: Optional[Set[str]] = None,
|
|
):
|
|
super().__init__(locals())
|
|
self._post_init()
|
|
|
|
def _post_init(self):
|
|
#Checking input names for uniqueness
|
|
self._inputs_dict = {}
|
|
if self.inputs:
|
|
for input in self.inputs:
|
|
if input.name in self._inputs_dict:
|
|
raise ValueError('Non-unique input name "{}"'.format(input.name))
|
|
self._inputs_dict[input.name] = input
|
|
|
|
#Checking output names for uniqueness
|
|
self._outputs_dict = {}
|
|
if self.outputs:
|
|
for output in self.outputs:
|
|
if output.name in self._outputs_dict:
|
|
raise ValueError('Non-unique output name "{}"'.format(output.name))
|
|
self._outputs_dict[output.name] = output
|
|
|
|
if isinstance(self.implementation, ContainerImplementation):
|
|
container = self.implementation.container
|
|
|
|
if container.file_outputs:
|
|
for output_name, path in container.file_outputs.items():
|
|
if output_name not in self._outputs_dict:
|
|
raise TypeError('Unconfigurable output entry "{}" references non-existing output.'.format({output_name: path}))
|
|
|
|
def verify_arg(arg):
|
|
if arg is None:
|
|
pass
|
|
elif isinstance(
|
|
arg, (str, int, float, bool, OutputMetadataPlaceholder)):
|
|
pass
|
|
elif isinstance(arg, list):
|
|
for arg2 in arg:
|
|
verify_arg(arg2)
|
|
elif isinstance(
|
|
arg, (InputUriPlaceholder, InputValuePlaceholder,
|
|
InputPathPlaceholder, IsPresentPlaceholder,
|
|
InputMetadataPlaceholder,
|
|
InputOutputPortNamePlaceholder)):
|
|
if arg.input_name not in self._inputs_dict:
|
|
raise TypeError(
|
|
'Argument "{}" references non-existing input.'.format(arg))
|
|
elif isinstance(arg, (OutputUriPlaceholder, OutputPathPlaceholder)):
|
|
if arg.output_name not in self._outputs_dict:
|
|
raise TypeError(
|
|
'Argument "{}" references non-existing output.'.format(arg))
|
|
elif isinstance(arg, ConcatPlaceholder):
|
|
for arg2 in arg.items:
|
|
verify_arg(arg2)
|
|
elif isinstance(arg, IfPlaceholder):
|
|
verify_arg(arg.if_structure.condition)
|
|
verify_arg(arg.if_structure.then_value)
|
|
verify_arg(arg.if_structure.else_value)
|
|
else:
|
|
raise TypeError('Unexpected argument "{}"'.format(arg))
|
|
|
|
verify_arg(container.command)
|
|
verify_arg(container.args)
|
|
|
|
if isinstance(self.implementation, GraphImplementation):
|
|
graph = self.implementation.graph
|
|
|
|
if graph.output_values is not None:
|
|
for output_name, argument in graph.output_values.items():
|
|
if output_name not in self._outputs_dict:
|
|
raise TypeError('Graph output argument entry "{}" references non-existing output.'.format({output_name: argument}))
|
|
|
|
if graph.tasks is not None:
|
|
for task in graph.tasks.values():
|
|
if task.arguments is not None:
|
|
for argument in task.arguments.values():
|
|
if isinstance(argument, GraphInputArgument) and argument.graph_input.input_name not in self._inputs_dict:
|
|
raise TypeError('Argument "{}" references non-existing input.'.format(argument))
|
|
|
|
def save(self, file_path: str):
|
|
'''Saves the component definition to file. It can be shared online and later loaded using the load_component function.'''
|
|
from ._yaml_utils import dump_yaml
|
|
component_yaml = dump_yaml(self.to_dict())
|
|
with open(file_path, 'w') as f:
|
|
f.write(component_yaml)
|
|
|
|
|
|
class ComponentReference(ModelBase):
|
|
'''Component reference. Contains information that can be used to locate and load a component by name, digest or URL'''
|
|
def __init__(self,
|
|
name: Optional[str] = None,
|
|
digest: Optional[str] = None,
|
|
tag: Optional[str] = None,
|
|
url: Optional[str] = None,
|
|
spec: Optional[ComponentSpec] = None,
|
|
):
|
|
super().__init__(locals())
|
|
self._post_init()
|
|
|
|
def _post_init(self) -> None:
|
|
if not any([self.name, self.digest, self.tag, self.url, self.spec]):
|
|
raise TypeError('Need at least one argument.')
|
|
|
|
|
|
class GraphInputReference(ModelBase):
|
|
'''References the input of the graph (the scope is a single graph).'''
|
|
_serialized_names = {
|
|
'input_name': 'inputName',
|
|
}
|
|
|
|
def __init__(self,
|
|
input_name: str,
|
|
type: Optional[TypeSpecType] = None, # Can be used to override the reference data type
|
|
):
|
|
super().__init__(locals())
|
|
|
|
def as_argument(self) -> 'GraphInputArgument':
|
|
return GraphInputArgument(graph_input=self)
|
|
|
|
def with_type(self, type_spec: TypeSpecType) -> 'GraphInputReference':
|
|
return GraphInputReference(
|
|
input_name=self.input_name,
|
|
type=type_spec,
|
|
)
|
|
|
|
def without_type(self) -> 'GraphInputReference':
|
|
return self.with_type(None)
|
|
|
|
class GraphInputArgument(ModelBase):
|
|
'''Represents the component argument value that comes from the graph component input.'''
|
|
_serialized_names = {
|
|
'graph_input': 'graphInput',
|
|
}
|
|
|
|
def __init__(self,
|
|
graph_input: GraphInputReference,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class TaskOutputReference(ModelBase):
|
|
'''References the output of some task (the scope is a single graph).'''
|
|
_serialized_names = {
|
|
'task_id': 'taskId',
|
|
'output_name': 'outputName',
|
|
}
|
|
|
|
def __init__(self,
|
|
output_name: str,
|
|
task_id: Optional[str] = None, # Used for linking to the upstream task in serialized component file.
|
|
task: Optional['TaskSpec'] = None, # Used for linking to the upstream task in runtime since Task does not have an ID until inserted into a graph.
|
|
type: Optional[TypeSpecType] = None, # Can be used to override the reference data type
|
|
):
|
|
super().__init__(locals())
|
|
if self.task_id is None and self.task is None:
|
|
raise TypeError('task_id and task cannot be None at the same time.')
|
|
|
|
def with_type(self, type_spec: TypeSpecType) -> 'TaskOutputReference':
|
|
return TaskOutputReference(
|
|
output_name=self.output_name,
|
|
task_id=self.task_id,
|
|
task=self.task,
|
|
type=type_spec,
|
|
)
|
|
|
|
def without_type(self) -> 'TaskOutputReference':
|
|
return self.with_type(None)
|
|
|
|
|
|
class TaskOutputArgument(ModelBase): #Has additional constructor for convenience
|
|
'''Represents the component argument value that comes from the output of another task.'''
|
|
_serialized_names = {
|
|
'task_output': 'taskOutput',
|
|
}
|
|
|
|
def __init__(self,
|
|
task_output: TaskOutputReference,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
@staticmethod
|
|
def construct(
|
|
task_id: str,
|
|
output_name: str,
|
|
) -> 'TaskOutputArgument':
|
|
return TaskOutputArgument(TaskOutputReference(
|
|
task_id=task_id,
|
|
output_name=output_name,
|
|
))
|
|
|
|
def with_type(self, type_spec: TypeSpecType) -> 'TaskOutputArgument':
|
|
return TaskOutputArgument(
|
|
task_output=self.task_output.with_type(type_spec),
|
|
)
|
|
|
|
def without_type(self) -> 'TaskOutputArgument':
|
|
return self.with_type(None)
|
|
|
|
ArgumentType = Union[PrimitiveTypes, GraphInputArgument, TaskOutputArgument]
|
|
|
|
|
|
class TwoOperands(ModelBase):
|
|
def __init__(self,
|
|
op1: ArgumentType,
|
|
op2: ArgumentType,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class BinaryPredicate(ModelBase): #abstract base type
|
|
def __init__(self,
|
|
operands: TwoOperands
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class EqualsPredicate(BinaryPredicate):
|
|
'''Represents the "equals" comparison predicate.'''
|
|
_serialized_names = {'operands': '=='}
|
|
|
|
|
|
class NotEqualsPredicate(BinaryPredicate):
|
|
'''Represents the "not equals" comparison predicate.'''
|
|
_serialized_names = {'operands': '!='}
|
|
|
|
|
|
class GreaterThanPredicate(BinaryPredicate):
|
|
'''Represents the "greater than" comparison predicate.'''
|
|
_serialized_names = {'operands': '>'}
|
|
|
|
|
|
class GreaterThanOrEqualPredicate(BinaryPredicate):
|
|
'''Represents the "greater than or equal" comparison predicate.'''
|
|
_serialized_names = {'operands': '>='}
|
|
|
|
|
|
class LessThenPredicate(BinaryPredicate):
|
|
'''Represents the "less than" comparison predicate.'''
|
|
_serialized_names = {'operands': '<'}
|
|
|
|
|
|
class LessThenOrEqualPredicate(BinaryPredicate):
|
|
'''Represents the "less than or equal" comparison predicate.'''
|
|
_serialized_names = { 'operands': '<='}
|
|
|
|
|
|
PredicateType = Union[
|
|
ArgumentType,
|
|
EqualsPredicate, NotEqualsPredicate, GreaterThanPredicate, GreaterThanOrEqualPredicate, LessThenPredicate, LessThenOrEqualPredicate,
|
|
'NotPredicate', 'AndPredicate', 'OrPredicate',
|
|
]
|
|
|
|
|
|
class TwoBooleanOperands(ModelBase):
|
|
def __init__(self,
|
|
op1: PredicateType,
|
|
op2: PredicateType,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class NotPredicate(ModelBase):
|
|
'''Represents the "not" logical operation.'''
|
|
_serialized_names = {'operand': 'not'}
|
|
|
|
def __init__(self,
|
|
operand: PredicateType
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class AndPredicate(ModelBase):
|
|
'''Represents the "and" logical operation.'''
|
|
_serialized_names = {'operands': 'and'}
|
|
|
|
def __init__(self,
|
|
operands: TwoBooleanOperands
|
|
) :
|
|
super().__init__(locals())
|
|
|
|
class OrPredicate(ModelBase):
|
|
'''Represents the "or" logical operation.'''
|
|
_serialized_names = {'operands': 'or'}
|
|
|
|
def __init__(self,
|
|
operands: TwoBooleanOperands
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class RetryStrategySpec(ModelBase):
|
|
_serialized_names = {
|
|
'max_retries': 'maxRetries',
|
|
}
|
|
|
|
def __init__(self,
|
|
max_retries: int,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class CachingStrategySpec(ModelBase):
|
|
_serialized_names = {
|
|
'max_cache_staleness': 'maxCacheStaleness',
|
|
}
|
|
|
|
def __init__(self,
|
|
max_cache_staleness: Optional[str] = None, # RFC3339 compliant duration: P30DT1H22M3S
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class ExecutionOptionsSpec(ModelBase):
|
|
_serialized_names = {
|
|
'retry_strategy': 'retryStrategy',
|
|
'caching_strategy': 'cachingStrategy',
|
|
}
|
|
|
|
def __init__(self,
|
|
retry_strategy: Optional[RetryStrategySpec] = None,
|
|
caching_strategy: Optional[CachingStrategySpec] = None,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class TaskSpec(ModelBase):
|
|
'''Task specification. Task is a "configured" component - a component supplied with arguments and other applied configuration changes.'''
|
|
_serialized_names = {
|
|
'component_ref': 'componentRef',
|
|
'is_enabled': 'isEnabled',
|
|
'execution_options': 'executionOptions'
|
|
}
|
|
|
|
def __init__(self,
|
|
component_ref: ComponentReference,
|
|
arguments: Optional[Mapping[str, ArgumentType]] = None,
|
|
is_enabled: Optional[PredicateType] = None,
|
|
execution_options: Optional[ExecutionOptionsSpec] = None,
|
|
annotations: Optional[Dict[str, Any]] = None,
|
|
):
|
|
super().__init__(locals())
|
|
#TODO: If component_ref is resolved to component spec, then check that the arguments correspond to the inputs
|
|
|
|
def _init_outputs(self):
|
|
#Adding output references to the task
|
|
if self.component_ref.spec is None:
|
|
return
|
|
task_outputs = OrderedDict()
|
|
for output in self.component_ref.spec.outputs or []:
|
|
task_output_ref = TaskOutputReference(
|
|
output_name=output.name,
|
|
task=self,
|
|
type=output.type, # TODO: Resolve type expressions. E.g. type: {TypeOf: Input 1}
|
|
)
|
|
task_output_arg = TaskOutputArgument(task_output=task_output_ref)
|
|
task_outputs[output.name] = task_output_arg
|
|
|
|
self.outputs = task_outputs
|
|
if len(task_outputs) == 1:
|
|
self.output = list(task_outputs.values())[0]
|
|
|
|
|
|
class GraphSpec(ModelBase):
|
|
'''Describes the graph component implementation. It represents a graph of component tasks connected to the upstream sources of data using the argument specifications. It also describes the sources of graph output values.'''
|
|
_serialized_names = {
|
|
'output_values': 'outputValues',
|
|
}
|
|
|
|
def __init__(self,
|
|
tasks: Mapping[str, TaskSpec],
|
|
output_values: Mapping[str, ArgumentType] = None,
|
|
):
|
|
super().__init__(locals())
|
|
self._post_init()
|
|
|
|
def _post_init(self):
|
|
#Checking task output references and preparing the dependency table
|
|
task_dependencies = {}
|
|
for task_id, task in self.tasks.items():
|
|
dependencies = set()
|
|
task_dependencies[task_id] = dependencies
|
|
if task.arguments is not None:
|
|
for argument in task.arguments.values():
|
|
if isinstance(argument, TaskOutputArgument):
|
|
dependencies.add(argument.task_output.task_id)
|
|
if argument.task_output.task_id not in self.tasks:
|
|
raise TypeError('Argument "{}" references non-existing task.'.format(argument))
|
|
|
|
#Topologically sorting tasks to detect cycles
|
|
task_dependents = {k: set() for k in task_dependencies.keys()}
|
|
for task_id, dependencies in task_dependencies.items():
|
|
for dependency in dependencies:
|
|
task_dependents[dependency].add(task_id)
|
|
task_number_of_remaining_dependencies = {k: len(v) for k, v in task_dependencies.items()}
|
|
sorted_tasks = OrderedDict()
|
|
def process_task(task_id):
|
|
if task_number_of_remaining_dependencies[task_id] == 0 and task_id not in sorted_tasks:
|
|
sorted_tasks[task_id] = self.tasks[task_id]
|
|
for dependent_task in task_dependents[task_id]:
|
|
task_number_of_remaining_dependencies[dependent_task] = task_number_of_remaining_dependencies[dependent_task] - 1
|
|
process_task(dependent_task)
|
|
for task_id in task_dependencies.keys():
|
|
process_task(task_id)
|
|
if len(sorted_tasks) != len(task_dependencies):
|
|
tasks_with_unsatisfied_dependencies = {k: v for k, v in task_number_of_remaining_dependencies.items() if v > 0}
|
|
task_wth_minimal_number_of_unsatisfied_dependencies = min(tasks_with_unsatisfied_dependencies.keys(), key=lambda task_id: tasks_with_unsatisfied_dependencies[task_id])
|
|
raise ValueError('Task "{}" has cyclical dependency.'.format(task_wth_minimal_number_of_unsatisfied_dependencies))
|
|
|
|
self._toposorted_tasks = sorted_tasks
|
|
|
|
|
|
class GraphImplementation(ModelBase):
|
|
'''Represents the graph component implementation.'''
|
|
def __init__(self,
|
|
graph: GraphSpec,
|
|
):
|
|
super().__init__(locals())
|
|
|
|
|
|
class PipelineRunSpec(ModelBase):
|
|
'''The object that can be sent to the backend to start a new Run.'''
|
|
_serialized_names = {
|
|
'root_task': 'rootTask',
|
|
#'on_exit_task': 'onExitTask',
|
|
}
|
|
|
|
def __init__(self,
|
|
root_task: TaskSpec,
|
|
#on_exit_task: Optional[TaskSpec] = None,
|
|
):
|
|
super().__init__(locals())
|