# Copyright 2022 The Kubeflow Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List import unittest from kfp import dsl from kfp.dsl import component_factory from kfp.dsl import Input from kfp.dsl import Output from kfp.dsl import structures from kfp.dsl.component_decorator import component from kfp.dsl.types.artifact_types import Artifact from kfp.dsl.types.artifact_types import Model from kfp.dsl.types.type_annotations import OutputPath class TestGetPackagesToInstallCommand(unittest.TestCase): def test_with_no_packages_to_install(self): packages_to_install = [] command = component_factory._get_packages_to_install_command( packages_to_install) self.assertEqual(command, []) def test_with_packages_to_install_and_no_pip_index_url(self): packages_to_install = ['package1', 'package2'] command = component_factory._get_packages_to_install_command( packages_to_install) concat_command = ' '.join(command) for package in packages_to_install: self.assertTrue(package in concat_command) def test_with_packages_to_install_with_pip_index_url(self): packages_to_install = ['package1', 'package2'] pip_index_urls = ['https://myurl.org/simple'] command = component_factory._get_packages_to_install_command( packages_to_install, pip_index_urls) concat_command = ' '.join(command) for package in packages_to_install + pip_index_urls: self.assertTrue(package in concat_command) class TestInvalidParameterName(unittest.TestCase): def test_output_named_Output(self): with self.assertRaisesRegex(ValueError, r'"Output" is an invalid parameter name.'): @component def comp(Output: OutputPath(str)): pass def test_output_named_Output_with_string_output(self): with self.assertRaisesRegex(ValueError, r'"Output" is an invalid parameter name.'): @component def comp(Output: OutputPath(str), text: str) -> str: pass class TestExtractComponentInterfaceListofArtifacts(unittest.TestCase): def test_python_component_input(self): def comp(i: Input[List[Model]]): ... component_spec = component_factory.extract_component_interface(comp) self.assertEqual(component_spec.name, 'comp') self.assertEqual(component_spec.description, None) self.assertEqual( component_spec.inputs, { 'i': structures.InputSpec( type='system.Model@0.0.1', default=None, is_artifact_list=True) }) def test_custom_container_component_input(self): def comp(i: Input[List[Artifact]]): ... component_spec = component_factory.extract_component_interface( comp, containerized=True) self.assertEqual(component_spec.name, 'comp') self.assertEqual(component_spec.description, None) self.assertEqual( component_spec.inputs, { 'i': structures.InputSpec( type='system.Artifact@0.0.1', default=None, is_artifact_list=True) }) def test_pipeline_input(self): def comp(i: Input[List[Model]]): ... component_spec = component_factory.extract_component_interface(comp) self.assertEqual(component_spec.name, 'comp') self.assertEqual(component_spec.description, None) self.assertEqual( component_spec.inputs, { 'i': structures.InputSpec( type='system.Model@0.0.1', default=None, is_artifact_list=True) }) class TestArtifactStringInInputpathOutputpath(unittest.TestCase): def test_unknown(self): @dsl.component def comp( i: dsl.InputPath('MyCustomType'), o: dsl.OutputPath('MyCustomType'), ): ... self.assertEqual(comp.component_spec.outputs['o'].type, 'system.Artifact@0.0.1') self.assertFalse(comp.component_spec.outputs['o'].is_artifact_list) self.assertEqual(comp.component_spec.inputs['i'].type, 'system.Artifact@0.0.1') self.assertFalse(comp.component_spec.inputs['i'].is_artifact_list) def test_known_v1_back_compat(self): @dsl.component def comp( i: dsl.InputPath('Dataset'), o: dsl.OutputPath('Dataset'), ): ... self.assertEqual(comp.component_spec.outputs['o'].type, 'system.Dataset@0.0.1') self.assertFalse(comp.component_spec.outputs['o'].is_artifact_list) self.assertEqual(comp.component_spec.inputs['i'].type, 'system.Dataset@0.0.1') self.assertFalse(comp.component_spec.inputs['i'].is_artifact_list) class TestOutputListsOfArtifactsTemporarilyBlocked(unittest.TestCase): def test_python_component(self): with self.assertRaisesRegex( ValueError, r"Output lists of artifacts are only supported for pipelines\. Got output list of artifacts for output parameter 'output_list' of component 'comp'\." ): @dsl.component def comp(output_list: Output[List[Artifact]]): ... def test_container_component(self): with self.assertRaisesRegex( ValueError, r"Output lists of artifacts are only supported for pipelines\. Got output list of artifacts for output parameter 'output_list' of component 'comp'\." ): @dsl.container_component def comp(output_list: Output[List[Artifact]]): return dsl.ContainerSpec(image='alpine') if __name__ == '__main__': unittest.main()