# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Base class for MLMD artifact ontology in KFP SDK.""" from typing import Any, Dict, Optional from absl import logging import enum import importlib from google.protobuf import json_format import yaml from kfp.pipeline_spec import pipeline_spec_pb2 from kfp.dsl import serialization_utils _KFP_ARTIFACT_TITLE_PATTERN = 'kfp.{}' KFP_ARTIFACT_ONTOLOGY_MODULE = 'kfp.dsl.ontology_artifacts' DEFAULT_ARTIFACT_SCHEMA = 'title: kfp.Artifact\ntype: object\nproperties:\n' # Enum for property types. # This is introduced to decouple the MLMD ontology with Python built-in types. class PropertyType(enum.Enum): INT = 1 DOUBLE = 2 STRING = 3 class Property(object): """Property specified for an Artifact.""" # Mapping from Python enum to primitive type in the IR proto. _ALLOWED_PROPERTY_TYPES = { PropertyType.INT: pipeline_spec_pb2.PrimitiveType.INT, PropertyType.DOUBLE: pipeline_spec_pb2.PrimitiveType.DOUBLE, PropertyType.STRING: pipeline_spec_pb2.PrimitiveType.STRING, } def __init__(self, type: PropertyType, description: Optional[str] = None): if type not in Property._ALLOWED_PROPERTY_TYPES: raise ValueError('Property type must be one of %s.' % list(Property._ALLOWED_PROPERTY_TYPES.keys())) self.type = type self.description = description @classmethod def from_dict(cls, dict_data: Dict[str, Any]) -> 'Property': """Deserializes the Property object from YAML dict.""" if not dict_data.get('type'): raise TypeError('Missing type keyword in property dict.') if dict_data['type'] == 'string': kind = PropertyType.STRING elif dict_data['type'] == 'int': kind = PropertyType.INT elif dict_data['type'] == 'double': kind = PropertyType.DOUBLE else: raise TypeError('Got unknown type: %s' % dict_data['type']) return Property( type=kind, description=dict_data['description'] ) def get_ir_type(self): """Gets the IR primitive type.""" return Property._ALLOWED_PROPERTY_TYPES[self.type] def get_type_name(self): """Gets the type name used in YAML instance.""" if self.type == PropertyType.INT: return 'int' elif self.type == PropertyType.DOUBLE: return 'double' elif self.type == PropertyType.STRING: return 'string' else: raise TypeError('Unexpected property type: %s' % self.type) class Artifact(object): """KFP Artifact Python class. Artifact Python class/object mainly serves following purposes in different period of its lifecycle. 1. During compile time, users can use Artifact class to annotate I/O types of their components. 2. At runtime, Artifact objects provide helper function/utilities to access the underlying RuntimeArtifact pb message, and provide additional layers of validation to ensure type compatibility. """ # Name of the Artifact type. TYPE_NAME = 'Artifact' # Property schema. # Example usage: # # PROPERTIES = { # 'span': Property(type=PropertyType.INT), # # Comma separated of splits for an artifact. Empty string means artifact # # has no split. # 'split_names': Property(type=PropertyType.STRING), # } PROPERTIES = None # Initialization flag to support setattr / getattr behavior. _initialized = False def __init__(self, instance_schema: Optional[str] = None): """Constructs an instance of Artifact""" if self.__class__ == Artifact: if not instance_schema: raise ValueError( 'The "instance_schema" argument must be set.') schema_yaml = yaml.safe_load(instance_schema) if 'properties' not in schema_yaml: raise ValueError('Invalid instance_schema, properties must be present. ' 'Got %s' % instance_schema) schema = schema_yaml['properties'] or {} if 'title' not in schema_yaml: raise ValueError('Invalid instance_schema, title must be present. ' 'Got %s' % instance_schema) self.TYPE_NAME = schema_yaml['title'] self.PROPERTIES = {} for k, v in schema.items(): self.PROPERTIES[k] = Property.from_dict(v) else: if instance_schema: raise ValueError( 'The "instance_schema" argument must not be passed for ' 'Artifact subclass %s.' % self.__class__) instance_schema = self.get_artifact_type() # MLMD artifact type schema string. self._type_schema = instance_schema # Instantiate a RuntimeArtifact pb message as the POD data structure. self._artifact = pipeline_spec_pb2.RuntimeArtifact() self._artifact.type.CopyFrom(pipeline_spec_pb2.ArtifactTypeSchema( instance_schema=instance_schema )) # Initialization flag to prevent recursive getattr / setattr errors. self._initialized = True @classmethod def get_artifact_type(cls) -> str: """Gets the instance_schema according to the Python schema spec.""" title = _KFP_ARTIFACT_TITLE_PATTERN.format(cls.TYPE_NAME) schema_map = None if cls.PROPERTIES: schema_map = {} for k, v in cls.PROPERTIES.items(): schema_map[k] = { 'type': v.get_type_name(), 'description': v.description } result_map = { 'title': title, 'type': 'object', 'properties': schema_map } return serialization_utils.yaml_dump(result_map) @property def type_schema(self) -> str: """Gets the instance_schema specified for this Artifact object.""" return self._type_schema def __repr__(self) -> str: return 'Artifact(artifact: {}, type_schema: {})'.format( str(self._artifact), str(self.type_schema)) def __getattr__(self, name: str) -> Any: """Custom __getattr__ to allow access to artifact properties.""" if name == '_artifact_type': # Prevent infinite recursion when used with copy.deepcopy(). raise AttributeError() properties = self.PROPERTIES or {} if name not in properties: raise AttributeError( '%s artifact has no property %r.' % (self.TYPE_NAME, name)) property_type = properties[name].type if property_type == PropertyType.STRING: if name not in self._artifact.properties: # Avoid populating empty property protobuf with the [] operator. return '' return self._artifact.properties[name].string_value elif property_type == PropertyType.INT: if name not in self._artifact.properties: # Avoid populating empty property protobuf with the [] operator. return 0 return self._artifact.properties[name].int_value elif property_type == PropertyType.DOUBLE: if name not in self._artifact.properties: # Avoid populating empty property protobuf with the [] operator. return 0.0 return self._artifact.properties[name].double_value else: raise Exception('Unknown MLMD type %r for property %r.' % (property_type, name)) def __setattr__(self, name: str, value: Any): """Custom __setattr__ to allow access to artifact properties.""" if not self._initialized: object.__setattr__(self, name, value) return properties = self.PROPERTIES or {} if name not in properties: if (name in self.__dict__ or any(name in c.__dict__ for c in self.__class__.mro())): # Use any provided getter / setter if available. object.__setattr__(self, name, value) return # In the case where we do not handle this via an explicit getter / # setter, we assume that the user implied an artifact attribute store, # and we raise an exception since such an attribute was not explicitly # defined in the Artifact PROPERTIES dictionary. raise AttributeError('Cannot set unknown property %r on artifact %r.' % (name, self)) property_type = properties[name].type if property_type == PropertyType.STRING: if not isinstance(value, str): raise Exception( 'Expected string value for property %r; got %r instead.' % (name, value)) self._artifact.properties[name].string_value = value elif property_type == PropertyType.INT: if not isinstance(value, int): raise Exception( 'Expected integer value for property %r; got %r instead.' % (name, value)) self._artifact.properties[name].int_value = value elif property_type == PropertyType.DOUBLE: if not isinstance(value, float): raise Exception( 'Expected integer value for property %r; got %r instead.' % (name, value)) self._artifact.properties[name].double_value = value else: raise Exception('Unknown property type %r for property %r.' % (property_type, name)) @property def type(self): return self.__class__ @property def type_name(self): return self.TYPE_NAME @property def runtime_artifact(self) -> pipeline_spec_pb2.RuntimeArtifact: return self._artifact @runtime_artifact.setter def runtime_artifact(self, artifact: pipeline_spec_pb2.RuntimeArtifact): self._artifact = artifact @property def uri(self) -> str: return self._artifact.uri @uri.setter def uri(self, uri: str) -> None: self._artifact.uri = uri @property def name(self) -> str: return self._artifact.name @name.setter def name(self, name: str) -> None: self._artifact.name = name # Custom property accessors. def set_string_custom_property(self, key: str, value: str): """Sets a custom property of string type.""" self._artifact.custom_properties[key].string_value = value def set_int_custom_property(self, key: str, value: int): """Sets a custom property of int type.""" self._artifact.custom_properties[key].int_value = value def set_float_custom_property(self, key: str, value: float): """Sets a custom property of float type.""" self._artifact.custom_properties[key].double_value = value def has_custom_property(self, key: str) -> bool: return key in self._artifact.custom_properties def get_string_custom_property(self, key: str) -> str: """Gets a custom property of string type.""" if key not in self._artifact.custom_properties: return '' return self._artifact.custom_properties[key].string_value def get_int_custom_property(self, key: str) -> int: """Gets a custom property of int type.""" if key not in self._artifact.custom_properties: return 0 return self._artifact.custom_properties[key].int_value def get_float_custom_property(self, key: str) -> float: """Gets a custom property of float type.""" if key not in self._artifact.custom_properties: return 0.0 return self._artifact.custom_properties[key].double_value @classmethod def get_from_runtime_artifact( cls, artifact: pipeline_spec_pb2.RuntimeArtifact) -> Any: """Deserializes an Artifact object from RuntimeArtifact message.""" instance_schema = yaml.safe_load(artifact.type.instance_schema) type_name = instance_schema['title'][len('kfp.'):] result = None try: artifact_cls = getattr( importlib.import_module(KFP_ARTIFACT_ONTOLOGY_MODULE), type_name) # TODO(numerology): Add deserialization tests for first party classes. result = artifact_cls() except (AttributeError, ImportError, ValueError): logging.warning(( 'Could not load artifact class %s.%s; using fallback deserialization ' 'for the relevant artifact. Please make sure that any artifact ' 'classes can be imported within your container or environment.'), KFP_ARTIFACT_ONTOLOGY_MODULE, type_name) if not result: # Otherwise generate a generic Artifact object. result = Artifact(instance_schema=artifact.type.instance_schema) result.runtime_artifact = artifact return result @classmethod def deserialize(cls, data: str) -> Any: """Deserializes an Artifact object from JSON dict.""" artifact = pipeline_spec_pb2.RuntimeArtifact() json_format.Parse(data, artifact, ignore_unknown_fields=True) return cls.get_from_runtime_artifact(artifact) def serialize(self) -> str: """Serializes an Artifact to JSON dict format.""" return json_format.MessageToJson(self._artifact, sort_keys=True)