fix(sdk): Allow keyword-only arguments in pipeline function signature (#4544)
* add test for keyword-only arguments in pipeline func * fix: kwargs-only argument for pipeline func * test: kwargs generate same yaml as args * remove whole metadata * assert -> self.assertEqual * programmatic example --> fixed example * same name for both Co-authored-by: Alexey Volkov <alexey.volkov@ark-kun.com>
This commit is contained in:
parent
cad02dc283
commit
ce985bc287
|
|
@ -842,17 +842,22 @@ class Compiler(object):
|
|||
raise ValueError('Either specify pipeline params in the pipeline function, or in "params_list", but not both.')
|
||||
|
||||
args_list = []
|
||||
kwargs_dict = dict()
|
||||
signature = inspect.signature(pipeline_func)
|
||||
for arg_name in signature.parameters:
|
||||
for arg_name, arg in signature.parameters.items():
|
||||
arg_type = None
|
||||
for input in pipeline_meta.inputs or []:
|
||||
if arg_name == input.name:
|
||||
arg_type = input.type
|
||||
break
|
||||
args_list.append(dsl.PipelineParam(sanitize_k8s_name(arg_name, True), param_type=arg_type))
|
||||
param = dsl.PipelineParam(sanitize_k8s_name(arg_name, True), param_type=arg_type)
|
||||
if arg.kind == inspect.Parameter.KEYWORD_ONLY:
|
||||
kwargs_dict[arg_name] = param
|
||||
else:
|
||||
args_list.append(param)
|
||||
|
||||
with dsl.Pipeline(pipeline_name) as dsl_pipeline:
|
||||
pipeline_func(*args_list)
|
||||
pipeline_func(*args_list, **kwargs_dict)
|
||||
|
||||
pipeline_conf = pipeline_conf or dsl_pipeline.conf # Configuration passed to the compiler is overriding. Unfortunately, it's not trivial to detect whether the dsl_pipeline.conf was ever modified.
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List
|
||||
|
||||
import kfp
|
||||
import kfp.compiler as compiler
|
||||
|
|
@ -1108,3 +1109,39 @@ implementation:
|
|||
|
||||
def test_uri_artifact_passing(self):
|
||||
self._test_py_compile_yaml('uri_artifacts')
|
||||
|
||||
def test_keyword_only_argument_for_pipeline_func(self):
|
||||
def some_pipeline(casual_argument: str, *, keyword_only_argument: str):
|
||||
pass
|
||||
kfp.compiler.Compiler()._create_workflow(some_pipeline)
|
||||
|
||||
def test_keyword_only_argument_for_pipeline_func_identity(self):
|
||||
test_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
|
||||
sys.path.append(test_data_dir)
|
||||
|
||||
# `@pipeline` is needed to make name the same for both functions
|
||||
|
||||
@pipeline(name="pipeline_func")
|
||||
def pipeline_func_arg(foo_arg: str, bar_arg: str):
|
||||
dsl.ContainerOp(
|
||||
name='foo',
|
||||
image='foo',
|
||||
command=['bar'],
|
||||
arguments=[foo_arg, ' and ', bar_arg]
|
||||
)
|
||||
|
||||
@pipeline(name="pipeline_func")
|
||||
def pipeline_func_kwarg(foo_arg: str, *, bar_arg: str):
|
||||
return pipeline_func_arg(foo_arg, bar_arg)
|
||||
|
||||
pipeline_yaml_arg = kfp.compiler.Compiler()._create_workflow(pipeline_func_arg)
|
||||
pipeline_yaml_kwarg = kfp.compiler.Compiler()._create_workflow(pipeline_func_kwarg)
|
||||
|
||||
# the yamls may differ in metadata
|
||||
def remove_metadata(yaml) -> None:
|
||||
del yaml['metadata']
|
||||
remove_metadata(pipeline_yaml_arg)
|
||||
remove_metadata(pipeline_yaml_kwarg)
|
||||
|
||||
# compare
|
||||
self.assertEqual(pipeline_yaml_arg, pipeline_yaml_kwarg)
|
||||
|
|
|
|||
Loading…
Reference in New Issue