fix(sdk): fix compilation of boolean constant passed to component (#9390)

This commit is contained in:
Connor McCarthy 2023-05-12 10:13:39 -07:00 committed by GitHub
parent c01288d673
commit 96947e6fb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 4 deletions

View File

@ -7,6 +7,8 @@
## Deprecations
## Bug fixes and other changes
* Fix compilation of boolean constant passed to component [\#9390](https://github.com/kubeflow/pipelines/pull/9390)
## Documentation updates

View File

@ -1436,7 +1436,7 @@ class TestMultipleExitHandlerCompilation(unittest.TestCase):
print_op(message='Inside second exit handler.')
class TestBoolInputParameterWithDefaultSerializesCorrectly(unittest.TestCase):
class TestBooleanInputCompiledCorrectly(unittest.TestCase):
# test with default = True, may have false test successes due to protocol buffer boolean default of False
def test_python_component(self):
@ -1569,6 +1569,27 @@ class TestBoolInputParameterWithDefaultSerializesCorrectly(unittest.TestCase):
pipeline_spec.root.input_definitions.parameters['boolean']
.default_value.bool_value, True)
def test_constant_passed_to_component(self):
@dsl.component
def comp(boolean1: bool, boolean2: bool) -> bool:
return boolean1
@dsl.pipeline
def my_pipeline():
comp(boolean1=True, boolean2=False)
with tempfile.TemporaryDirectory() as tmpdir:
pipeline_spec_path = os.path.join(tmpdir, 'output.yaml')
compiler.Compiler().compile(my_pipeline, pipeline_spec_path)
pipeline_spec = pipeline_spec_from_file(pipeline_spec_path)
self.assertTrue(
pipeline_spec.root.dag.tasks['comp'].inputs.parameters['boolean1']
.runtime_value.constant.bool_value)
self.assertFalse(
pipeline_spec.root.dag.tasks['comp'].inputs.parameters['boolean2']
.runtime_value.constant.bool_value)
# helper component defintions for the ValidLegalTopologies tests
@dsl.component

View File

@ -61,12 +61,13 @@ def to_protobuf_value(value: type_utils.PARAMETER_TYPES) -> struct_pb2.Value:
Raises:
ValueError if the given value is not one of the parameter types.
"""
if isinstance(value, str):
# bool check must be above (int, float) check because bool is a subclass of int so isinstance(True, int) == True
if isinstance(value, bool):
return struct_pb2.Value(bool_value=value)
elif isinstance(value, str):
return struct_pb2.Value(string_value=value)
elif isinstance(value, (int, float)):
return struct_pb2.Value(number_value=value)
elif isinstance(value, bool):
return struct_pb2.Value(bool_value=value)
elif isinstance(value, dict):
return struct_pb2.Value(
struct_value=struct_pb2.Struct(