pipelines/sdk/python/kfp/dsl/component_factory_test.py

194 lines
6.5 KiB
Python

# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import unittest
from kfp import dsl
from kfp.dsl import component_factory
from kfp.dsl import Input
from kfp.dsl import Output
from kfp.dsl import structures
from kfp.dsl.component_decorator import component
from kfp.dsl.types.artifact_types import Artifact
from kfp.dsl.types.artifact_types import Model
from kfp.dsl.types.type_annotations import OutputPath
class TestGetPackagesToInstallCommand(unittest.TestCase):
def test_with_no_packages_to_install(self):
packages_to_install = []
command = component_factory._get_packages_to_install_command(
packages_to_install)
self.assertEqual(command, [])
def test_with_packages_to_install_and_no_pip_index_url(self):
packages_to_install = ['package1', 'package2']
command = component_factory._get_packages_to_install_command(
packages_to_install)
concat_command = ' '.join(command)
for package in packages_to_install:
self.assertTrue(package in concat_command)
def test_with_packages_to_install_with_pip_index_url(self):
packages_to_install = ['package1', 'package2']
pip_index_urls = ['https://myurl.org/simple']
command = component_factory._get_packages_to_install_command(
packages_to_install, pip_index_urls)
concat_command = ' '.join(command)
for package in packages_to_install + pip_index_urls:
self.assertTrue(package in concat_command)
class TestInvalidParameterName(unittest.TestCase):
def test_output_named_Output(self):
with self.assertRaisesRegex(ValueError,
r'"Output" is an invalid parameter name.'):
@component
def comp(Output: OutputPath(str)):
pass
def test_output_named_Output_with_string_output(self):
with self.assertRaisesRegex(ValueError,
r'"Output" is an invalid parameter name.'):
@component
def comp(Output: OutputPath(str), text: str) -> str:
pass
class TestExtractComponentInterfaceListofArtifacts(unittest.TestCase):
def test_python_component_input(self):
def comp(i: Input[List[Model]]):
...
component_spec = component_factory.extract_component_interface(comp)
self.assertEqual(component_spec.name, 'comp')
self.assertEqual(component_spec.description, None)
self.assertEqual(
component_spec.inputs, {
'i':
structures.InputSpec(
type='system.Model@0.0.1',
default=None,
is_artifact_list=True)
})
def test_custom_container_component_input(self):
def comp(i: Input[List[Artifact]]):
...
component_spec = component_factory.extract_component_interface(
comp, containerized=True)
self.assertEqual(component_spec.name, 'comp')
self.assertEqual(component_spec.description, None)
self.assertEqual(
component_spec.inputs, {
'i':
structures.InputSpec(
type='system.Artifact@0.0.1',
default=None,
is_artifact_list=True)
})
def test_pipeline_input(self):
def comp(i: Input[List[Model]]):
...
component_spec = component_factory.extract_component_interface(comp)
self.assertEqual(component_spec.name, 'comp')
self.assertEqual(component_spec.description, None)
self.assertEqual(
component_spec.inputs, {
'i':
structures.InputSpec(
type='system.Model@0.0.1',
default=None,
is_artifact_list=True)
})
class TestArtifactStringInInputpathOutputpath(unittest.TestCase):
def test_unknown(self):
@dsl.component
def comp(
i: dsl.InputPath('MyCustomType'),
o: dsl.OutputPath('MyCustomType'),
):
...
self.assertEqual(comp.component_spec.outputs['o'].type,
'system.Artifact@0.0.1')
self.assertFalse(comp.component_spec.outputs['o'].is_artifact_list)
self.assertEqual(comp.component_spec.inputs['i'].type,
'system.Artifact@0.0.1')
self.assertFalse(comp.component_spec.inputs['i'].is_artifact_list)
def test_known_v1_back_compat(self):
@dsl.component
def comp(
i: dsl.InputPath('Dataset'),
o: dsl.OutputPath('Dataset'),
):
...
self.assertEqual(comp.component_spec.outputs['o'].type,
'system.Dataset@0.0.1')
self.assertFalse(comp.component_spec.outputs['o'].is_artifact_list)
self.assertEqual(comp.component_spec.inputs['i'].type,
'system.Dataset@0.0.1')
self.assertFalse(comp.component_spec.inputs['i'].is_artifact_list)
class TestOutputListsOfArtifactsTemporarilyBlocked(unittest.TestCase):
def test_python_component(self):
with self.assertRaisesRegex(
ValueError,
r"Output lists of artifacts are only supported for pipelines\. Got output list of artifacts for output parameter 'output_list' of component 'comp'\."
):
@dsl.component
def comp(output_list: Output[List[Artifact]]):
...
def test_container_component(self):
with self.assertRaisesRegex(
ValueError,
r"Output lists of artifacts are only supported for pipelines\. Got output list of artifacts for output parameter 'output_list' of component 'comp'\."
):
@dsl.container_component
def comp(output_list: Output[List[Artifact]]):
return dsl.ContainerSpec(image='alpine')
if __name__ == '__main__':
unittest.main()