194 lines
6.5 KiB
Python
194 lines
6.5 KiB
Python
# Copyright 2022 The Kubeflow Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import List
|
|
import unittest
|
|
|
|
from kfp import dsl
|
|
from kfp.dsl import component_factory
|
|
from kfp.dsl import Input
|
|
from kfp.dsl import Output
|
|
from kfp.dsl import structures
|
|
from kfp.dsl.component_decorator import component
|
|
from kfp.dsl.types.artifact_types import Artifact
|
|
from kfp.dsl.types.artifact_types import Model
|
|
from kfp.dsl.types.type_annotations import OutputPath
|
|
|
|
|
|
class TestGetPackagesToInstallCommand(unittest.TestCase):
|
|
|
|
def test_with_no_packages_to_install(self):
|
|
packages_to_install = []
|
|
|
|
command = component_factory._get_packages_to_install_command(
|
|
packages_to_install)
|
|
self.assertEqual(command, [])
|
|
|
|
def test_with_packages_to_install_and_no_pip_index_url(self):
|
|
packages_to_install = ['package1', 'package2']
|
|
|
|
command = component_factory._get_packages_to_install_command(
|
|
packages_to_install)
|
|
concat_command = ' '.join(command)
|
|
for package in packages_to_install:
|
|
self.assertTrue(package in concat_command)
|
|
|
|
def test_with_packages_to_install_with_pip_index_url(self):
|
|
packages_to_install = ['package1', 'package2']
|
|
pip_index_urls = ['https://myurl.org/simple']
|
|
|
|
command = component_factory._get_packages_to_install_command(
|
|
packages_to_install, pip_index_urls)
|
|
concat_command = ' '.join(command)
|
|
for package in packages_to_install + pip_index_urls:
|
|
self.assertTrue(package in concat_command)
|
|
|
|
|
|
class TestInvalidParameterName(unittest.TestCase):
|
|
|
|
def test_output_named_Output(self):
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
r'"Output" is an invalid parameter name.'):
|
|
|
|
@component
|
|
def comp(Output: OutputPath(str)):
|
|
pass
|
|
|
|
def test_output_named_Output_with_string_output(self):
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
r'"Output" is an invalid parameter name.'):
|
|
|
|
@component
|
|
def comp(Output: OutputPath(str), text: str) -> str:
|
|
pass
|
|
|
|
|
|
class TestExtractComponentInterfaceListofArtifacts(unittest.TestCase):
|
|
|
|
def test_python_component_input(self):
|
|
|
|
def comp(i: Input[List[Model]]):
|
|
...
|
|
|
|
component_spec = component_factory.extract_component_interface(comp)
|
|
self.assertEqual(component_spec.name, 'comp')
|
|
self.assertEqual(component_spec.description, None)
|
|
self.assertEqual(
|
|
component_spec.inputs, {
|
|
'i':
|
|
structures.InputSpec(
|
|
type='system.Model@0.0.1',
|
|
default=None,
|
|
is_artifact_list=True)
|
|
})
|
|
|
|
def test_custom_container_component_input(self):
|
|
|
|
def comp(i: Input[List[Artifact]]):
|
|
...
|
|
|
|
component_spec = component_factory.extract_component_interface(
|
|
comp, containerized=True)
|
|
self.assertEqual(component_spec.name, 'comp')
|
|
self.assertEqual(component_spec.description, None)
|
|
self.assertEqual(
|
|
component_spec.inputs, {
|
|
'i':
|
|
structures.InputSpec(
|
|
type='system.Artifact@0.0.1',
|
|
default=None,
|
|
is_artifact_list=True)
|
|
})
|
|
|
|
def test_pipeline_input(self):
|
|
|
|
def comp(i: Input[List[Model]]):
|
|
...
|
|
|
|
component_spec = component_factory.extract_component_interface(comp)
|
|
self.assertEqual(component_spec.name, 'comp')
|
|
self.assertEqual(component_spec.description, None)
|
|
self.assertEqual(
|
|
component_spec.inputs, {
|
|
'i':
|
|
structures.InputSpec(
|
|
type='system.Model@0.0.1',
|
|
default=None,
|
|
is_artifact_list=True)
|
|
})
|
|
|
|
|
|
class TestArtifactStringInInputpathOutputpath(unittest.TestCase):
|
|
|
|
def test_unknown(self):
|
|
|
|
@dsl.component
|
|
def comp(
|
|
i: dsl.InputPath('MyCustomType'),
|
|
o: dsl.OutputPath('MyCustomType'),
|
|
):
|
|
...
|
|
|
|
self.assertEqual(comp.component_spec.outputs['o'].type,
|
|
'system.Artifact@0.0.1')
|
|
self.assertFalse(comp.component_spec.outputs['o'].is_artifact_list)
|
|
self.assertEqual(comp.component_spec.inputs['i'].type,
|
|
'system.Artifact@0.0.1')
|
|
self.assertFalse(comp.component_spec.inputs['i'].is_artifact_list)
|
|
|
|
def test_known_v1_back_compat(self):
|
|
|
|
@dsl.component
|
|
def comp(
|
|
i: dsl.InputPath('Dataset'),
|
|
o: dsl.OutputPath('Dataset'),
|
|
):
|
|
...
|
|
|
|
self.assertEqual(comp.component_spec.outputs['o'].type,
|
|
'system.Dataset@0.0.1')
|
|
self.assertFalse(comp.component_spec.outputs['o'].is_artifact_list)
|
|
self.assertEqual(comp.component_spec.inputs['i'].type,
|
|
'system.Dataset@0.0.1')
|
|
self.assertFalse(comp.component_spec.inputs['i'].is_artifact_list)
|
|
|
|
|
|
class TestOutputListsOfArtifactsTemporarilyBlocked(unittest.TestCase):
|
|
|
|
def test_python_component(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"Output lists of artifacts are only supported for pipelines\. Got output list of artifacts for output parameter 'output_list' of component 'comp'\."
|
|
):
|
|
|
|
@dsl.component
|
|
def comp(output_list: Output[List[Artifact]]):
|
|
...
|
|
|
|
def test_container_component(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"Output lists of artifacts are only supported for pipelines\. Got output list of artifacts for output parameter 'output_list' of component 'comp'\."
|
|
):
|
|
|
|
@dsl.container_component
|
|
def comp(output_list: Output[List[Artifact]]):
|
|
return dsl.ContainerSpec(image='alpine')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|