# Copyright 2018 The Kubeflow Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from ._pipeline_param import PipelineParam from .types import check_types, InconsistentTypeException from ._ops_group import Graph import kfp.deprecated as kfp # @deprecated( # version='0.2.6', # reason='This decorator does not seem to be used, so we deprecate it. ' # 'If you need this decorator, please create an issue at ' # 'https://github.com/kubeflow/pipelines/issues', # ) # def python_component(name, # description=None, # base_image=None, # target_component_file: str = None): # """Decorator for Python component functions. # This decorator adds the metadata to the function object itself. # Args: # name: Human-readable name of the component # description: Optional. Description of the component # base_image: Optional. Docker container image to use as the base of the # component. Needs to have Python 3.5+ installed. # target_component_file: Optional. Local file to store the component # definition. The file can then be used for sharing. # Returns: # The same function (with some metadata fields set). # Example: # :: # @dsl.python_component( # name='my awesome component', # description='Come, Let\'s play', # base_image='tensorflow/tensorflow:1.11.0-py3', # ) # def my_component(a: str, b: int) -> str: # ... # """ # def _python_component(func): # func._component_human_name = name # if description: # func._component_description = description # if base_image: # func._component_base_image = base_image # if target_component_file: # func._component_target_component_file = target_component_file # return func # return _python_component def component(func): """Decorator for component functions that returns a ContainerOp. This is useful to enable type checking in the DSL compiler. Example: :: @dsl.component def foobar(model: TFModel(), step: MLStep()): return dsl.ContainerOp() """ from functools import wraps @wraps(func) def _component(*args, **kargs): from ..components._python_op import _extract_component_interface component_meta = _extract_component_interface(func) if kfp.TYPE_CHECK: arg_index = 0 for arg in args: if isinstance(arg, PipelineParam) and not check_types( arg.param_type, component_meta.inputs[arg_index].type): raise InconsistentTypeException( 'Component "' + component_meta.name + '" is expecting ' + component_meta.inputs[arg_index].name + ' to be type(' + str(component_meta.inputs[arg_index].type) + '), but the passed argument is type(' + str(arg.param_type) + ')') arg_index += 1 if kargs is not None: for key in kargs: if isinstance(kargs[key], PipelineParam): for input_spec in component_meta.inputs: if input_spec.name == key and not check_types( kargs[key].param_type, input_spec.type): raise InconsistentTypeException( 'Component "' + component_meta.name + '" is expecting ' + input_spec.name + ' to be type(' + str(input_spec.type) + '), but the passed argument is type(' + str(kargs[key].param_type) + ')') container_op = func(*args, **kargs) container_op._set_metadata(component_meta) return container_op return _component #TODO: combine the component and graph_component decorators into one def graph_component(func): """Decorator for graph component functions. This decorator returns an ops_group. Example: :: # Warning: caching is tricky when recursion is involved. Please be careful # and set proper max_cache_staleness in case of infinite loop. import kfp.dsl as dsl @dsl.graph_component def flip_component(flip_result): print_flip = PrintOp(flip_result) flipA = FlipCoinOp().after(print_flip) flipA.execution_options.caching_strategy.max_cache_staleness = "P0D" with dsl.Condition(flipA.output == 'heads'): flip_component(flipA.output) return {'flip_result': flipA.output} """ from functools import wraps @wraps(func) def _graph_component(*args, **kargs): # We need to make sure that the arguments are correctly mapped to inputs # regardless of the passing order signature = inspect.signature(func) bound_arguments = signature.bind(*args, **kargs) graph_ops_group = Graph(func.__name__) graph_ops_group.inputs = list(bound_arguments.arguments.values()) graph_ops_group.arguments = bound_arguments.arguments for input in graph_ops_group.inputs: if not isinstance(input, PipelineParam): raise ValueError('arguments to ' + func.__name__ + ' should be PipelineParams.') # Entering the Graph Context with graph_ops_group: # Call the function if not graph_ops_group.recursive_ref: func(*args, **kargs) return graph_ops_group return _graph_component