fix(sdk): fix compilation of boolean constant passed to component (#9390)
This commit is contained in:
parent
c01288d673
commit
96947e6fb9
|
|
@ -7,6 +7,8 @@
|
|||
## Deprecations
|
||||
|
||||
## Bug fixes and other changes
|
||||
* Fix compilation of boolean constant passed to component [\#9390](https://github.com/kubeflow/pipelines/pull/9390)
|
||||
|
||||
|
||||
## Documentation updates
|
||||
|
||||
|
|
|
|||
|
|
@ -1436,7 +1436,7 @@ class TestMultipleExitHandlerCompilation(unittest.TestCase):
|
|||
print_op(message='Inside second exit handler.')
|
||||
|
||||
|
||||
class TestBoolInputParameterWithDefaultSerializesCorrectly(unittest.TestCase):
|
||||
class TestBooleanInputCompiledCorrectly(unittest.TestCase):
|
||||
# test with default = True, may have false test successes due to protocol buffer boolean default of False
|
||||
def test_python_component(self):
|
||||
|
||||
|
|
@ -1569,6 +1569,27 @@ class TestBoolInputParameterWithDefaultSerializesCorrectly(unittest.TestCase):
|
|||
pipeline_spec.root.input_definitions.parameters['boolean']
|
||||
.default_value.bool_value, True)
|
||||
|
||||
def test_constant_passed_to_component(self):
|
||||
|
||||
@dsl.component
|
||||
def comp(boolean1: bool, boolean2: bool) -> bool:
|
||||
return boolean1
|
||||
|
||||
@dsl.pipeline
|
||||
def my_pipeline():
|
||||
comp(boolean1=True, boolean2=False)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipeline_spec_path = os.path.join(tmpdir, 'output.yaml')
|
||||
compiler.Compiler().compile(my_pipeline, pipeline_spec_path)
|
||||
pipeline_spec = pipeline_spec_from_file(pipeline_spec_path)
|
||||
self.assertTrue(
|
||||
pipeline_spec.root.dag.tasks['comp'].inputs.parameters['boolean1']
|
||||
.runtime_value.constant.bool_value)
|
||||
self.assertFalse(
|
||||
pipeline_spec.root.dag.tasks['comp'].inputs.parameters['boolean2']
|
||||
.runtime_value.constant.bool_value)
|
||||
|
||||
|
||||
# helper component defintions for the ValidLegalTopologies tests
|
||||
@dsl.component
|
||||
|
|
|
|||
|
|
@ -61,12 +61,13 @@ def to_protobuf_value(value: type_utils.PARAMETER_TYPES) -> struct_pb2.Value:
|
|||
Raises:
|
||||
ValueError if the given value is not one of the parameter types.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
# bool check must be above (int, float) check because bool is a subclass of int so isinstance(True, int) == True
|
||||
if isinstance(value, bool):
|
||||
return struct_pb2.Value(bool_value=value)
|
||||
elif isinstance(value, str):
|
||||
return struct_pb2.Value(string_value=value)
|
||||
elif isinstance(value, (int, float)):
|
||||
return struct_pb2.Value(number_value=value)
|
||||
elif isinstance(value, bool):
|
||||
return struct_pb2.Value(bool_value=value)
|
||||
elif isinstance(value, dict):
|
||||
return struct_pb2.Value(
|
||||
struct_value=struct_pb2.Struct(
|
||||
|
|
|
|||
Loading…
Reference in New Issue