fix(sdk): Allow keyword-only arguments in pipeline function signature (#4544)

* add test for keyword-only arguments in pipeline func

* fix: kwargs-only argument for pipeline func

* test: kwargs generate same yaml as args

* remove whole metadata

* assert -> self.assertEqual

* programmatic example --> fixed example

* same name for both

Co-authored-by: Alexey Volkov <alexey.volkov@ark-kun.com>
This commit is contained in:
Michalina Kotwica 2021-01-30 03:31:02 +01:00 committed by GitHub
parent cad02dc283
commit ce985bc287
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 3 deletions

View File

@ -842,17 +842,22 @@ class Compiler(object):
raise ValueError('Either specify pipeline params in the pipeline function, or in "params_list", but not both.')
args_list = []
kwargs_dict = dict()
signature = inspect.signature(pipeline_func)
for arg_name in signature.parameters:
for arg_name, arg in signature.parameters.items():
arg_type = None
for input in pipeline_meta.inputs or []:
if arg_name == input.name:
arg_type = input.type
break
args_list.append(dsl.PipelineParam(sanitize_k8s_name(arg_name, True), param_type=arg_type))
param = dsl.PipelineParam(sanitize_k8s_name(arg_name, True), param_type=arg_type)
if arg.kind == inspect.Parameter.KEYWORD_ONLY:
kwargs_dict[arg_name] = param
else:
args_list.append(param)
with dsl.Pipeline(pipeline_name) as dsl_pipeline:
pipeline_func(*args_list)
pipeline_func(*args_list, **kwargs_dict)
pipeline_conf = pipeline_conf or dsl_pipeline.conf # Configuration passed to the compiler is overriding. Unfortunately, it's not trivial to detect whether the dsl_pipeline.conf was ever modified.

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import kfp
import kfp.compiler as compiler
@ -1108,3 +1109,39 @@ implementation:
def test_uri_artifact_passing(self):
self._test_py_compile_yaml('uri_artifacts')
def test_keyword_only_argument_for_pipeline_func(self):
def some_pipeline(casual_argument: str, *, keyword_only_argument: str):
pass
kfp.compiler.Compiler()._create_workflow(some_pipeline)
def test_keyword_only_argument_for_pipeline_func_identity(self):
test_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
sys.path.append(test_data_dir)
# `@pipeline` is needed to make name the same for both functions
@pipeline(name="pipeline_func")
def pipeline_func_arg(foo_arg: str, bar_arg: str):
dsl.ContainerOp(
name='foo',
image='foo',
command=['bar'],
arguments=[foo_arg, ' and ', bar_arg]
)
@pipeline(name="pipeline_func")
def pipeline_func_kwarg(foo_arg: str, *, bar_arg: str):
return pipeline_func_arg(foo_arg, bar_arg)
pipeline_yaml_arg = kfp.compiler.Compiler()._create_workflow(pipeline_func_arg)
pipeline_yaml_kwarg = kfp.compiler.Compiler()._create_workflow(pipeline_func_kwarg)
# the yamls may differ in metadata
def remove_metadata(yaml) -> None:
del yaml['metadata']
remove_metadata(pipeline_yaml_arg)
remove_metadata(pipeline_yaml_kwarg)
# compare
self.assertEqual(pipeline_yaml_arg, pipeline_yaml_kwarg)