feat(sdk): add runtime logic for custom artifact types (support for custom artifact types pt. 3) (#8233)

* add runtime artifact instance creation logic

* refactor executor

* add executor tests

* add custom artifact type import handling and tests

* fix artifact class construction

* fix custom artifact type in tests

* add typing extensions dependency for all python versions

* use mock google namespace artifact for tests

* remove print statement

* update google artifact golden snapshot

* resolve some review feedback

* remove handling for OutputPath and InputPath custom artifact types; update function names and tests

* clarify named tuple tests

* update executor tests

* add artifact return and named tuple support; refactor; clean tests

* implement review feedback; clean up artifact names

* move test method
This commit is contained in:
Connor McCarthy 2022-09-14 19:00:40 -06:00 committed by GitHub
parent 916777e62f
commit 166d6bb917
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1364 additions and 368 deletions

View File

@ -28,6 +28,7 @@ from kfp.components import python_component
from kfp.components import structures
from kfp.components.container_component_artifact_channel import \
ContainerComponentArtifactChannel
from kfp.components.types import custom_artifact_types
from kfp.components.types import type_annotations
from kfp.components.types import type_utils
@ -171,7 +172,7 @@ def extract_component_interface(
# parameter_type is type_annotations.Artifact or one of its subclasses.
parameter_type = type_annotations.get_io_artifact_class(
parameter_type)
if not type_annotations.is_artifact(parameter_type):
if not type_annotations.is_artifact_class(parameter_type):
raise ValueError(
'Input[T] and Output[T] are only supported when T is a '
'subclass of Artifact. Found `{} with type {}`'.format(
@ -203,7 +204,7 @@ def extract_component_interface(
]:
io_name = _maybe_make_unique(io_name, output_names)
output_names.add(io_name)
if type_annotations.is_artifact(parameter_type):
if type_annotations.is_artifact_class(parameter_type):
schema_version = parameter_type.schema_version
output_spec = structures.OutputSpec(
type=type_utils.create_bundled_artifact_type(
@ -214,7 +215,7 @@ def extract_component_interface(
else:
io_name = _maybe_make_unique(io_name, input_names)
input_names.add(io_name)
if type_annotations.is_artifact(parameter_type):
if type_annotations.is_artifact_class(parameter_type):
schema_version = parameter_type.schema_version
input_spec = structures.InputSpec(
type=type_utils.create_bundled_artifact_type(
@ -277,7 +278,7 @@ def extract_component_interface(
# `def func(output_path: OutputPath()) -> str: ...`
output_names.add(output_name)
return_ann = signature.return_annotation
if type_annotations.is_artifact(signature.return_annotation):
if type_annotations.is_artifact_class(signature.return_annotation):
output_spec = structures.OutputSpec(
type=type_utils.create_bundled_artifact_type(
return_ann.schema_title, return_ann.schema_version))
@ -322,7 +323,7 @@ def _get_command_and_args_for_lightweight_component(
'from kfp import dsl',
'from kfp.dsl import *',
'from typing import *',
]
] + custom_artifact_types.get_custom_artifact_type_import_statements(func)
func_source = _get_function_source_definition(func)
source = textwrap.dedent('''

View File

@ -44,3 +44,7 @@ class TestGetPackagesToInstallCommand(unittest.TestCase):
concat_command = ' '.join(command)
for package in packages_to_install + pip_index_urls:
self.assertTrue(package in concat_command)
if __name__ == '__main__':
unittest.main()

View File

@ -13,6 +13,7 @@
# limitations under the License.
import inspect
import json
import os
from typing import Any, Callable, Dict, List, Optional, Union
from kfp.components import task_final_status
@ -37,30 +38,40 @@ class Executor():
{}).get('artifacts', {}).items():
artifacts_list = artifacts.get('artifacts')
if artifacts_list:
self._input_artifacts[name] = self._make_input_artifact(
artifacts_list[0])
self._input_artifacts[name] = self.make_artifact(
artifacts_list[0],
name,
self._func,
)
for name, artifacts in self._input.get('outputs',
{}).get('artifacts', {}).items():
artifacts_list = artifacts.get('artifacts')
if artifacts_list:
self._output_artifacts[name] = self._make_output_artifact(
artifacts_list[0])
output_artifact = self.make_artifact(
artifacts_list[0],
name,
self._func,
)
self._output_artifacts[name] = output_artifact
self.makedirs_recursively(output_artifact.path)
self._return_annotation = inspect.signature(
self._func).return_annotation
self._executor_output = {}
@classmethod
def _make_input_artifact(cls, runtime_artifact: Dict):
return artifact_types.create_runtime_artifact(runtime_artifact)
def make_artifact(
self,
runtime_artifact: Dict,
name: str,
func: Callable,
) -> Any:
artifact_cls = func.__annotations__.get(name)
return create_artifact_instance(
runtime_artifact, artifact_cls=artifact_cls)
@classmethod
def _make_output_artifact(cls, runtime_artifact: Dict):
import os
artifact = artifact_types.create_runtime_artifact(runtime_artifact)
os.makedirs(os.path.dirname(artifact.path), exist_ok=True)
return artifact
def makedirs_recursively(self, path: str) -> None:
os.makedirs(os.path.dirname(path), exist_ok=True)
def _get_input_artifact(self, name: str):
return self._input_artifacts.get(name)
@ -170,7 +181,7 @@ class Executor():
@classmethod
def _is_artifact(cls, annotation: Any) -> bool:
if type(annotation) == type:
return type_annotations.is_artifact(annotation)
return type_annotations.is_artifact_class(annotation)
return False
@classmethod
@ -297,3 +308,20 @@ class Executor():
result = self._func(**func_kwargs)
self._write_executor_output(result)
def create_artifact_instance(
runtime_artifact: Dict,
artifact_cls=artifact_types.Artifact,
) -> type:
"""Creates an artifact class instances from a runtime artifact
dictionary."""
schema_title = runtime_artifact.get('type', {}).get('schemaTitle', '')
artifact_cls = artifact_types._SCHEMA_TITLE_TO_TYPE.get(
schema_title, artifact_cls)
return artifact_cls(
uri=runtime_artifact.get('uri', ''),
name=runtime_artifact.get('name', ''),
metadata=runtime_artifact.get('metadata', {}),
)

File diff suppressed because it is too large Load Diff

View File

@ -629,7 +629,6 @@ class ComponentSpec(base_model.BaseModel):
inputs = {}
for spec in component_dict.get('inputs', []):
type_ = spec.get('type')
print('TYPE', type_)
if isinstance(type_, str) and type_ == 'PipelineTaskFinalStatus':
inputs[utils.sanitize_input_name(

View File

@ -510,21 +510,3 @@ _SCHEMA_TITLE_TO_TYPE: Dict[str, Type[Artifact]] = {
Markdown,
]
}
def create_runtime_artifact(runtime_artifact: Dict) -> Artifact:
"""Creates an Artifact instance from the specified RuntimeArtifact.
Args:
runtime_artifact: Dictionary representing JSON-encoded RuntimeArtifact.
"""
schema_title = runtime_artifact.get('type', {}).get('schemaTitle', '')
artifact_type = _SCHEMA_TITLE_TO_TYPE.get(schema_title)
if not artifact_type:
artifact_type = Artifact
return artifact_type(
uri=runtime_artifact.get('uri', ''),
name=runtime_artifact.get('name', ''),
metadata=runtime_artifact.get('metadata', {}),
)

View File

@ -56,105 +56,6 @@ class ArtifactsTest(parameterized.TestCase):
expected_json = json.load(json_file)
self.assertEqual(expected_json, metrics.metadata)
@parameterized.parameters(
{
'runtime_artifact': {
'metadata': {},
'name': 'input_artifact_one',
'type': {
'schemaTitle': 'system.Artifact'
},
'uri': 'gs://some-bucket/input_artifact_one'
},
'expected_type': artifact_types.Artifact,
},
{
'runtime_artifact': {
'metadata': {},
'name': 'input_artifact_one',
'type': {
'schemaTitle': 'system.Model'
},
'uri': 'gs://some-bucket/input_artifact_one'
},
'expected_type': artifact_types.Model,
},
{
'runtime_artifact': {
'metadata': {},
'name': 'input_artifact_one',
'type': {
'schemaTitle': 'system.Dataset'
},
'uri': 'gs://some-bucket/input_artifact_one'
},
'expected_type': artifact_types.Dataset,
},
{
'runtime_artifact': {
'metadata': {},
'name': 'input_artifact_one',
'type': {
'schemaTitle': 'system.Metrics'
},
'uri': 'gs://some-bucket/input_artifact_one'
},
'expected_type': artifact_types.Metrics,
},
{
'runtime_artifact': {
'metadata': {},
'name': 'input_artifact_one',
'type': {
'schemaTitle': 'system.ClassificationMetrics'
},
'uri': 'gs://some-bucket/input_artifact_one'
},
'expected_type': artifact_types.ClassificationMetrics,
},
{
'runtime_artifact': {
'metadata': {},
'name': 'input_artifact_one',
'type': {
'schemaTitle': 'system.SlicedClassificationMetrics'
},
'uri': 'gs://some-bucket/input_artifact_one'
},
'expected_type': artifact_types.SlicedClassificationMetrics,
},
{
'runtime_artifact': {
'metadata': {},
'name': 'input_artifact_one',
'type': {
'schemaTitle': 'system.HTML'
},
'uri': 'gs://some-bucket/input_artifact_one'
},
'expected_type': artifact_types.HTML,
},
{
'runtime_artifact': {
'metadata': {},
'name': 'input_artifact_one',
'type': {
'schemaTitle': 'system.Markdown'
},
'uri': 'gs://some-bucket/input_artifact_one'
},
'expected_type': artifact_types.Markdown,
},
)
def test_create_runtime_artifact(
self,
runtime_artifact,
expected_type,
):
self.assertIsInstance(
artifact_types.create_runtime_artifact(runtime_artifact),
expected_type)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,191 @@
# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import inspect
from typing import Callable, Dict, List, Union
from kfp.components import component_factory
from kfp.components.types import type_annotations
from kfp.components.types import type_utils
RETURN_PREFIX = 'return-'
def get_custom_artifact_type_import_statements(func: Callable) -> List[str]:
"""Gets a list of custom artifact type import statements from a lightweight
Python component function."""
artifact_imports = get_custom_artifact_import_items_from_function(func)
imports_source = []
for obj_str in artifact_imports:
if '.' in obj_str:
path, name = obj_str.rsplit('.', 1)
imports_source.append(f'from {path} import {name}')
else:
imports_source.append(f'import {obj_str}')
return imports_source
def get_param_to_custom_artifact_class(func: Callable) -> Dict[str, type]:
"""Gets a map of parameter names to custom artifact classes.
Return key is 'return-' for normal returns and 'return-<field>' for
typing.NamedTuple returns.
"""
param_to_artifact_cls: Dict[str, type] = {}
kfp_artifact_classes = set(type_utils._ARTIFACT_CLASSES_MAPPING.values())
signature = inspect.signature(func)
for name, param in signature.parameters.items():
annotation = param.annotation
if type_annotations.is_artifact_annotation(annotation):
artifact_class = type_annotations.get_io_artifact_class(annotation)
if artifact_class not in kfp_artifact_classes:
param_to_artifact_cls[name] = artifact_class
elif type_annotations.is_artifact_class(annotation):
param_to_artifact_cls[name] = annotation
if artifact_class not in kfp_artifact_classes:
param_to_artifact_cls[name] = artifact_class
return_annotation = signature.return_annotation
if return_annotation is inspect.Signature.empty:
pass
elif type_utils.is_typed_named_tuple_annotation(return_annotation):
for name, annotation in return_annotation.__annotations__.items():
if type_annotations.is_artifact_class(
annotation) and annotation not in kfp_artifact_classes:
param_to_artifact_cls[f'{RETURN_PREFIX}{name}'] = annotation
elif type_annotations.is_artifact_class(
return_annotation
) and return_annotation not in kfp_artifact_classes:
param_to_artifact_cls[RETURN_PREFIX] = return_annotation
return param_to_artifact_cls
def get_full_qualname_for_artifact(obj: type) -> str:
"""Gets the fully qualified name for an object. For example, for class Foo
in module bar.baz, this function returns bar.baz.Foo.
Note: typing.get_type_hints purports to do the same thing, but it behaves
differently when executed within the scope of a test, so preferring this
approach instead.
Args:
obj: The class or module for which to get the fully qualified name.
Returns:
The fully qualified name for the class.
"""
module = obj.__module__
name = obj.__qualname__
if module is not None:
name = module + '.' + name
return name
def get_symbol_import_path(artifact_class_base_symbol: str,
qualname: str) -> str:
"""Gets the fully qualified name of the symbol that must be imported for
the custom artifact type annotation to be referenced successfully.
Args:
artifact_class_base_symbol: The base symbol from which the artifact class is referenced (e.g., aiplatform for aiplatform.VertexDataset).
qualname: The fully qualified type annotation name as a string.
Returns:
The fully qualified names of the module or type to import.
"""
split_qualname = qualname.split('.')
if artifact_class_base_symbol in split_qualname:
name_to_import = '.'.join(
split_qualname[:split_qualname.index(artifact_class_base_symbol) +
1])
else:
raise TypeError(
f"Module or type name aliases are not supported. You appear to be using an alias in your type annotation: '{qualname}'. This may be due to use of an 'as' statement in an import statement or a reassignment of a module or type to a new name. Reference the module and/or type using the name as defined in the source from which the module or type is imported."
)
return name_to_import
def traverse_ast_node_values_to_get_id(obj: Union[ast.Slice, None]) -> str:
while not hasattr(obj, 'id'):
obj = getattr(obj, 'value')
return obj.id
def get_custom_artifact_base_symbol_for_parameter(func: Callable,
arg_name: str) -> str:
"""Gets the symbol required for the custom artifact type annotation to be
referenced correctly."""
module_node = ast.parse(
component_factory._get_function_source_definition(func))
args = module_node.body[0].args.args
args = {arg.arg: arg for arg in args}
annotation = args[arg_name].annotation
return traverse_ast_node_values_to_get_id(annotation.slice)
def get_custom_artifact_base_symbol_for_return(func: Callable,
return_name: str) -> str:
"""Gets the symbol required for the custom artifact type return annotation
to be referenced correctly."""
module_node = ast.parse(
component_factory._get_function_source_definition(func))
return_ann = module_node.body[0].returns
if return_name == RETURN_PREFIX:
if isinstance(return_ann, (ast.Name, ast.Attribute)):
return traverse_ast_node_values_to_get_id(return_ann)
elif isinstance(return_ann, ast.Call):
func = return_ann.func
# handles NamedTuple and typing.NamedTuple
if (isinstance(func, ast.Attribute) and func.value.id == 'typing' and
func.attr == 'NamedTuple') or (isinstance(func, ast.Name) and
func.id == 'NamedTuple'):
nt_field_list = return_ann.args[1].elts
for el in nt_field_list:
if f'{RETURN_PREFIX}{el.elts[0].s}' == return_name:
return traverse_ast_node_values_to_get_id(el.elts[1])
raise TypeError(f"Unexpected type annotation '{return_ann}' for {func}.")
def get_custom_artifact_import_items_from_function(func: Callable) -> List[str]:
"""Gets the fully qualified name of the symbol that must be imported for
the custom artifact type annotation to be referenced successfully from a
component function."""
param_to_ann_obj = get_param_to_custom_artifact_class(func)
import_items = []
for param_name, artifact_class in param_to_ann_obj.items():
base_symbol = get_custom_artifact_base_symbol_for_return(
func, param_name
) if param_name.startswith(
RETURN_PREFIX) else get_custom_artifact_base_symbol_for_parameter(
func, param_name)
artifact_qualname = get_full_qualname_for_artifact(artifact_class)
symbol_import_path = get_symbol_import_path(base_symbol,
artifact_qualname)
# could use set here, but want to be have deterministic import ordering
# in compilation
if symbol_import_path not in import_items:
import_items.append(symbol_import_path)
return import_items

View File

@ -0,0 +1,527 @@
# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import sys
import tempfile
import textwrap
import typing
from typing import Any
import unittest
from absl.testing import parameterized
import kfp
from kfp import dsl
from kfp.components.types import artifact_types
from kfp.components.types import custom_artifact_types
from kfp.components.types.artifact_types import Artifact
from kfp.components.types.artifact_types import Dataset
from kfp.components.types.type_annotations import Input
from kfp.components.types.type_annotations import InputPath
from kfp.components.types.type_annotations import Output
from kfp.components.types.type_annotations import OutputPath
Alias = Artifact
artifact_types_alias = artifact_types
class _TestCaseWithThirdPartyPackage(parameterized.TestCase):
@classmethod
def setUpClass(cls):
class VertexDataset:
schema_title = 'google.VertexDataset'
schema_version = '0.0.0'
class_source = textwrap.dedent(inspect.getsource(VertexDataset))
tmp_dir = tempfile.TemporaryDirectory()
with open(os.path.join(tmp_dir.name, 'aiplatform.py'), 'w') as f:
f.write(class_source)
sys.path.append(tmp_dir.name)
cls.tmp_dir = tmp_dir
@classmethod
def teardownClass(cls):
sys.path.pop()
cls.tmp_dir.cleanup()
class TestGetParamToCustomArtifactClass(_TestCaseWithThirdPartyPackage):
def test_no_ann(self):
def func():
pass
actual = custom_artifact_types.get_param_to_custom_artifact_class(func)
self.assertEqual(actual, {})
def test_primitives(self):
def func(a: str, b: int) -> str:
pass
actual = custom_artifact_types.get_param_to_custom_artifact_class(func)
self.assertEqual(actual, {})
def test_input_path(self):
def func(a: InputPath(str), b: InputPath('Dataset')) -> str:
pass
actual = custom_artifact_types.get_param_to_custom_artifact_class(func)
self.assertEqual(actual, {})
def test_output_path(self):
def func(a: OutputPath(str), b: OutputPath('Dataset')) -> str:
pass
actual = custom_artifact_types.get_param_to_custom_artifact_class(func)
self.assertEqual(actual, {})
def test_input_kfp_artifact(self):
def func(a: Input[Artifact]):
pass
actual = custom_artifact_types.get_param_to_custom_artifact_class(func)
self.assertEqual(actual, {})
def test_output_kfp_artifact(self):
def func(a: Output[Artifact]):
pass
actual = custom_artifact_types.get_param_to_custom_artifact_class(func)
self.assertEqual(actual, {})
def test_return_kfp_artifact1(self):
def func() -> Artifact:
pass
actual = custom_artifact_types.get_param_to_custom_artifact_class(func)
self.assertEqual(actual, {})
def test_return_kfp_artifact2(self):
def func() -> dsl.Artifact:
pass
actual = custom_artifact_types.get_param_to_custom_artifact_class(func)
self.assertEqual(actual, {})
def test_named_tuple_primitives(self):
def func() -> typing.NamedTuple('Outputs', [
('a', str),
('b', int),
]):
pass
actual = custom_artifact_types.get_param_to_custom_artifact_class(func)
self.assertEqual(actual, {})
def test_input_google_artifact(self):
import aiplatform
from aiplatform import VertexDataset
def func(
a: Input[aiplatform.VertexDataset],
b: Input[VertexDataset],
c: dsl.Input[aiplatform.VertexDataset],
d: kfp.dsl.Input[VertexDataset],
):
pass
actual = custom_artifact_types.get_param_to_custom_artifact_class(func)
self.assertEqual(
actual, {
'a': aiplatform.VertexDataset,
'b': aiplatform.VertexDataset,
'c': aiplatform.VertexDataset,
'd': aiplatform.VertexDataset,
})
def test_output_google_artifact(self):
import aiplatform
from aiplatform import VertexDataset
def func(
a: Output[aiplatform.VertexDataset],
b: Output[VertexDataset],
c: dsl.Output[aiplatform.VertexDataset],
d: kfp.dsl.Output[VertexDataset],
):
pass
actual = custom_artifact_types.get_param_to_custom_artifact_class(func)
self.assertEqual(
actual, {
'a': aiplatform.VertexDataset,
'b': aiplatform.VertexDataset,
'c': aiplatform.VertexDataset,
'd': aiplatform.VertexDataset,
})
def test_return_google_artifact1(self):
import aiplatform
from aiplatform import VertexDataset
def func() -> VertexDataset:
pass
actual = custom_artifact_types.get_param_to_custom_artifact_class(func)
self.assertEqual(actual, {'return-': aiplatform.VertexDataset})
def test_return_google_artifact2(self):
import aiplatform
def func() -> aiplatform.VertexDataset:
pass
actual = custom_artifact_types.get_param_to_custom_artifact_class(func)
self.assertEqual(actual, {'return-': aiplatform.VertexDataset})
def test_named_tuple_google_artifact(self):
import aiplatform
from aiplatform import VertexDataset
def func() -> typing.NamedTuple('Outputs', [
('a', aiplatform.VertexDataset),
('b', VertexDataset),
]):
pass
actual = custom_artifact_types.get_param_to_custom_artifact_class(func)
self.assertEqual(
actual, {
'return-a': aiplatform.VertexDataset,
'return-b': aiplatform.VertexDataset,
})
class TestGetFullQualnameForArtifact(_TestCaseWithThirdPartyPackage):
# only gets called on artifacts, so don't need to test on all types
@parameterized.parameters([
(Alias, 'kfp.components.types.artifact_types.Artifact'),
(Artifact, 'kfp.components.types.artifact_types.Artifact'),
(Dataset, 'kfp.components.types.artifact_types.Dataset'),
])
def test(self, obj: Any, expected_qualname: str):
self.assertEqual(
custom_artifact_types.get_full_qualname_for_artifact(obj),
expected_qualname)
def test_aiplatform_artifact(self):
import aiplatform
self.assertEqual(
custom_artifact_types.get_full_qualname_for_artifact(
aiplatform.VertexDataset), 'aiplatform.VertexDataset')
class TestGetSymbolImportPath(parameterized.TestCase):
@parameterized.parameters([
{
'artifact_class_base_symbol': 'aiplatform',
'qualname': 'aiplatform.VertexDataset',
'expected': 'aiplatform'
},
{
'artifact_class_base_symbol': 'VertexDataset',
'qualname': 'aiplatform.VertexDataset',
'expected': 'aiplatform.VertexDataset'
},
{
'artifact_class_base_symbol': 'e',
'qualname': 'a.b.c.d.e',
'expected': 'a.b.c.d.e'
},
{
'artifact_class_base_symbol': 'c',
'qualname': 'a.b.c.d.e',
'expected': 'a.b.c'
},
])
def test(self, artifact_class_base_symbol: str, qualname: str,
expected: str):
actual = custom_artifact_types.get_symbol_import_path(
artifact_class_base_symbol, qualname)
self.assertEqual(actual, expected)
class TestGetCustomArtifactBaseSymbolForParameter(_TestCaseWithThirdPartyPackage
):
def test_input_google_artifact(self):
import aiplatform
from aiplatform import VertexDataset
def func(
a: Input[aiplatform.VertexDataset],
b: Input[VertexDataset],
c: dsl.Input[aiplatform.VertexDataset],
d: kfp.dsl.Input[VertexDataset],
):
pass
actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter(
func, 'a')
self.assertEqual(actual, 'aiplatform')
actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter(
func, 'b')
self.assertEqual(actual, 'VertexDataset')
actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter(
func, 'c')
self.assertEqual(actual, 'aiplatform')
actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter(
func, 'd')
self.assertEqual(actual, 'VertexDataset')
def test_output_google_artifact(self):
import aiplatform
from aiplatform import VertexDataset
def func(
a: Output[aiplatform.VertexDataset],
b: Output[VertexDataset],
c: dsl.Output[aiplatform.VertexDataset],
d: kfp.dsl.Output[VertexDataset],
):
pass
actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter(
func, 'a')
self.assertEqual(actual, 'aiplatform')
actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter(
func, 'b')
self.assertEqual(actual, 'VertexDataset')
actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter(
func, 'c')
self.assertEqual(actual, 'aiplatform')
actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter(
func, 'd')
self.assertEqual(actual, 'VertexDataset')
class TestGetCustomArtifactBaseSymbolForReturn(_TestCaseWithThirdPartyPackage):
def test_return_google_artifact1(self):
from aiplatform import VertexDataset
def func() -> VertexDataset:
pass
actual = custom_artifact_types.get_custom_artifact_base_symbol_for_return(
func, 'return-')
self.assertEqual(actual, 'VertexDataset')
def test_return_google_artifact2(self):
import aiplatform
def func() -> aiplatform.VertexDataset:
pass
actual = custom_artifact_types.get_custom_artifact_base_symbol_for_return(
func, 'return-')
self.assertEqual(actual, 'aiplatform')
def test_named_tuple_google_artifact(self):
import aiplatform
from aiplatform import VertexDataset
def func() -> typing.NamedTuple('Outputs', [
('a', aiplatform.VertexDataset),
('b', VertexDataset),
]):
pass
actual = custom_artifact_types.get_custom_artifact_base_symbol_for_return(
func, 'return-a')
self.assertEqual(actual, 'aiplatform')
actual = custom_artifact_types.get_custom_artifact_base_symbol_for_return(
func, 'return-b')
self.assertEqual(actual, 'VertexDataset')
class TestGetCustomArtifactImportItemsFromFunction(
_TestCaseWithThirdPartyPackage):
def test_no_ann(self):
def func():
pass
actual = custom_artifact_types.get_custom_artifact_import_items_from_function(
func)
self.assertEqual(actual, [])
def test_primitives(self):
def func(a: str, b: int) -> str:
pass
actual = custom_artifact_types.get_custom_artifact_import_items_from_function(
func)
self.assertEqual(actual, [])
def test_input_path(self):
def func(a: InputPath(str), b: InputPath('Dataset')) -> str:
pass
actual = custom_artifact_types.get_custom_artifact_import_items_from_function(
func)
self.assertEqual(actual, [])
def test_output_path(self):
def func(a: OutputPath(str), b: OutputPath('Dataset')) -> str:
pass
actual = custom_artifact_types.get_custom_artifact_import_items_from_function(
func)
self.assertEqual(actual, [])
def test_input_kfp_artifact(self):
def func(a: Input[Artifact]):
pass
actual = custom_artifact_types.get_custom_artifact_import_items_from_function(
func)
self.assertEqual(actual, [])
def test_output_kfp_artifact(self):
def func(a: Output[Artifact]):
pass
actual = custom_artifact_types.get_custom_artifact_import_items_from_function(
func)
self.assertEqual(actual, [])
def test_return_kfp_artifact1(self):
def func() -> Artifact:
pass
actual = custom_artifact_types.get_custom_artifact_import_items_from_function(
func)
self.assertEqual(actual, [])
def test_return_kfp_artifact2(self):
def func() -> dsl.Artifact:
pass
actual = custom_artifact_types.get_custom_artifact_import_items_from_function(
func)
self.assertEqual(actual, [])
def test_named_tuple_primitives(self):
def func() -> typing.NamedTuple('Outputs', [
('a', str),
('b', int),
]):
pass
actual = custom_artifact_types.get_custom_artifact_import_items_from_function(
func)
self.assertEqual(actual, [])
def test_input_google_artifact(self):
import aiplatform
from aiplatform import VertexDataset
def func(
a: Input[aiplatform.VertexDataset],
b: Input[VertexDataset],
c: dsl.Input[aiplatform.VertexDataset],
d: kfp.dsl.Input[VertexDataset],
):
pass
actual = custom_artifact_types.get_custom_artifact_import_items_from_function(
func)
self.assertEqual(actual, ['aiplatform', 'aiplatform.VertexDataset'])
def test_output_google_artifact(self):
import aiplatform
from aiplatform import VertexDataset
def func(
a: Output[aiplatform.VertexDataset],
b: Output[VertexDataset],
c: dsl.Output[aiplatform.VertexDataset],
d: kfp.dsl.Output[VertexDataset],
):
pass
actual = custom_artifact_types.get_custom_artifact_import_items_from_function(
func)
self.assertEqual(actual, ['aiplatform', 'aiplatform.VertexDataset'])
def test_return_google_artifact1(self):
import aiplatform
from aiplatform import VertexDataset
def func() -> VertexDataset:
pass
actual = custom_artifact_types.get_custom_artifact_import_items_from_function(
func)
self.assertEqual(actual, ['aiplatform.VertexDataset'])
def test_return_google_artifact2(self):
import aiplatform
def func() -> aiplatform.VertexDataset:
pass
actual = custom_artifact_types.get_custom_artifact_import_items_from_function(
func)
self.assertEqual(actual, ['aiplatform'])
def test_named_tuple_google_artifact(self):
import aiplatform
from aiplatform import VertexDataset
def func() -> typing.NamedTuple('Outputs', [
('a', aiplatform.VertexDataset),
('b', VertexDataset),
]):
pass
actual = custom_artifact_types.get_custom_artifact_import_items_from_function(
func)
self.assertEqual(actual, ['aiplatform', 'aiplatform.VertexDataset'])
if __name__ == '__main__':
unittest.main()

View File

@ -106,7 +106,7 @@ class InputPath:
def construct_type_for_inputpath_or_outputpath(
type_: Union[str, Type, None]) -> Union[str, None]:
if type_annotations.is_artifact(type_):
if type_annotations.is_artifact_class(type_):
return type_utils.create_bundled_artifact_type(type_.schema_title,
type_.schema_version)
elif isinstance(
@ -274,7 +274,7 @@ def get_short_type_name(type_name: str) -> str:
return type_name
def is_artifact(artifact_class_or_instance: Type) -> bool:
def is_artifact_class(artifact_class_or_instance: Type) -> bool:
# we do not yet support non-pre-registered custom artifact types with instance_schema attribute
return hasattr(artifact_class_or_instance, 'schema_title') and hasattr(
artifact_class_or_instance, 'schema_version')

View File

@ -161,31 +161,31 @@ class TestIsArtifact(parameterized.TestCase):
'obj': obj
} for obj in artifact_types._SCHEMA_TITLE_TO_TYPE.values()])
def test_true_class(self, obj):
self.assertTrue(type_annotations.is_artifact(obj))
self.assertTrue(type_annotations.is_artifact_class(obj))
@parameterized.parameters([{
'obj': obj(name='name', uri='uri', metadata={})
} for obj in artifact_types._SCHEMA_TITLE_TO_TYPE.values()])
def test_true_instance(self, obj):
self.assertTrue(type_annotations.is_artifact(obj))
self.assertTrue(type_annotations.is_artifact_class(obj))
@parameterized.parameters([{'obj': 'string'}, {'obj': 1}, {'obj': int}])
def test_false(self, obj):
self.assertFalse(type_annotations.is_artifact(obj))
self.assertFalse(type_annotations.is_artifact_class(obj))
def test_false_no_schema_title(self):
class NotArtifact:
schema_version = ''
self.assertFalse(type_annotations.is_artifact(NotArtifact))
self.assertFalse(type_annotations.is_artifact_class(NotArtifact))
def test_false_no_schema_version(self):
class NotArtifact:
schema_title = ''
self.assertFalse(type_annotations.is_artifact(NotArtifact))
self.assertFalse(type_annotations.is_artifact_class(NotArtifact))
if __name__ == '__main__':

View File

@ -28,6 +28,7 @@ PARAMETER_TYPES = Union[str, int, float, bool, dict, list]
# ComponentSpec I/O types to DSL ontology artifact classes mapping.
_ARTIFACT_CLASSES_MAPPING = {
'artifact': artifact_types.Artifact,
'model': artifact_types.Model,
'dataset': artifact_types.Dataset,
'metrics': artifact_types.Metrics,
@ -413,7 +414,7 @@ def _annotation_to_type_struct(annotation):
type_struct = get_canonical_type_name_for_type(annotation)
if type_struct:
return type_struct
elif type_annotations.is_artifact(annotation):
elif type_annotations.is_artifact_class(annotation):
schema_title = annotation.schema_title
else:
schema_title = str(annotation.__name__)
@ -423,3 +424,8 @@ def _annotation_to_type_struct(annotation):
schema_title = str(annotation)
type_struct = get_canonical_type_name_for_type(schema_title)
return type_struct or schema_title
def is_typed_named_tuple_annotation(annotation: Any) -> bool:
return hasattr(annotation, '_fields') and hasattr(annotation,
'__annotations__')

View File

@ -1,8 +1,8 @@
#
# This file is autogenerated by pip-compile with python 3.9
# This file is autogenerated by pip-compile with python 3.7
# To update, run:
#
# pip-compile
# pip-compile --no-emit-index-url requirements.in
#
absl-py==1.2.0
# via -r requirements.in
@ -34,7 +34,7 @@ google-api-core==2.8.2
# -r requirements.in
# google-cloud-core
# google-cloud-storage
google-auth==2.10.0
google-auth==2.11.0
# via
# -r requirements.in
# google-api-core
@ -45,7 +45,7 @@ google-cloud-core==2.3.2
# via google-cloud-storage
google-cloud-storage==2.5.0
# via -r requirements.in
google-crc32c==1.3.0
google-crc32c==1.5.0
# via google-resumable-media
google-resumable-media==2.3.3
# via google-cloud-storage
@ -53,11 +53,15 @@ googleapis-common-protos==1.56.4
# via google-api-core
idna==3.3
# via requests
importlib-metadata==4.12.0
# via
# click
# jsonschema
jsonschema==3.2.0
# via -r requirements.in
kfp-pipeline-spec==0.1.16
# via -r requirements.in
kfp-server-api==2.0.0a3
kfp-server-api==2.0.0a4
# via -r requirements.in
kubernetes==23.6.0
# via -r requirements.in
@ -114,19 +118,25 @@ termcolor==1.1.0
# via fire
typer==0.6.1
# via -r requirements.in
typing-extensions==4.3.0 ; python_version < "3.9"
# via
# -r requirements.in
# importlib-metadata
uritemplate==3.0.1
# via -r requirements.in
urllib3==1.26.11
urllib3==1.26.12
# via
# kfp-server-api
# kubernetes
# requests
websocket-client==1.3.3
websocket-client==1.4.0
# via kubernetes
wheel==0.37.1
# via strip-hints
wrapt==1.14.1
# via deprecated
zipp==3.8.1
# via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools

View File

@ -66,8 +66,9 @@ deploymentSpec:
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef model_consumer(model: Input[VertexModel],\n \
\ dataset: Input[VertexDataset]):\n print('Model')\n print('artifact.type:\
\ *\nfrom aiplatform import VertexModel\nfrom aiplatform import VertexDataset\n\
\ndef model_consumer(model: Input[VertexModel],\n dataset:\
\ Input[VertexDataset]):\n print('Model')\n print('artifact.type:\
\ ', type(model))\n print('artifact.name: ', model.name)\n print('artifact.uri:\
\ ', model.uri)\n print('artifact.metadata: ', model.metadata)\n\n \
\ print('Dataset')\n print('artifact.type: ', type(dataset))\n print('artifact.name:\
@ -98,9 +99,9 @@ deploymentSpec:
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef model_producer(model: Output[aiplatform.VertexModel]):\n\n \
\ assert isinstance(model, aiplatform.VertexModel), type(model)\n with\
\ open(model.path, 'w') as f:\n f.write('my model')\n\n"
\ *\nimport aiplatform\n\ndef model_producer(model: Output[aiplatform.VertexModel]):\n\
\n assert isinstance(model, aiplatform.VertexModel), type(model)\n \
\ with open(model.path, 'w') as f:\n f.write('my model')\n\n"
image: python:3.7
pipelineInfo:
name: pipeline-with-google-types