feat(sdk): add support for metadata placeholders (#8151)

This commit is contained in:
Scott_Xu 2022-08-17 16:10:06 -07:00 committed by GitHub
parent 51bea09833
commit 88a1b314c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 102 additions and 22 deletions

View File

@ -32,6 +32,8 @@ def container_with_artifact_output(
num_epochs,
'--model_path',
model.uri,
'--model_metadata',
model.metadata,
'--model_config_path',
model_config_path,
])

View File

@ -23,6 +23,8 @@ deploymentSpec:
- '{{$.inputs.parameters[''num_epochs'']}}'
- --model_path
- '{{$.outputs.artifacts[''model''].uri}}'
- --model_metadata
- '{{$.outputs.artifacts[''model''].metadata}}'
- --model_config_path
- '{{$.outputs.parameters[''model_config_path''].output_file}}'
command:
@ -50,4 +52,4 @@ root:
num_epochs:
parameterType: NUMBER_INTEGER
schemaVersion: 2.1.0
sdkVersion: kfp-2.0.0-beta.1
sdkVersion: kfp-2.0.0-beta.2

View File

@ -29,16 +29,23 @@ class ContainerComponentArtifactChannel:
self, _name: str
) -> Union[placeholders.InputUriPlaceholder, placeholders
.InputPathPlaceholder, placeholders.OutputUriPlaceholder,
placeholders.OutputPathPlaceholder]:
if _name not in ['uri', 'path']:
placeholders.OutputPathPlaceholder,
placeholders.InputMetadataPlaceholder,
placeholders.OutputMetadataPlaceholder]:
attr_to_placeholder_dict = {
'uri': {
'input': placeholders.InputUriPlaceholder,
'output': placeholders.OutputUriPlaceholder,
},
'path': {
'input': placeholders.InputPathPlaceholder,
'output': placeholders.OutputPathPlaceholder,
},
'metadata': {
'input': placeholders.InputMetadataPlaceholder,
'output': placeholders.OutputMetadataPlaceholder
},
}
if _name not in ['uri', 'path', 'metadata']:
raise AttributeError(f'Cannot access artifact attribute "{_name}".')
if self._io_type == 'input':
if _name == 'uri':
return placeholders.InputUriPlaceholder(self._var_name)
elif _name == 'path':
return placeholders.InputPathPlaceholder(self._var_name)
elif self._io_type == 'output':
if _name == 'uri':
return placeholders.OutputUriPlaceholder(self._var_name)
elif _name == 'path':
return placeholders.OutputPathPlaceholder(self._var_name)
return attr_to_placeholder_dict[_name][self._io_type](self._var_name)

View File

@ -29,6 +29,8 @@ class TestContainerComponentArtifactChannel(unittest.TestCase):
placeholders.InputUriPlaceholder('my_dataset'))
self.assertEqual(out_channel.path,
placeholders.OutputPathPlaceholder('my_result'))
self.assertEqual(out_channel.metadata,
placeholders.OutputMetadataPlaceholder('my_result'))
self.assertRaisesRegex(AttributeError,
r'Cannot access artifact attribute "name"',
lambda: in_channel.name)

View File

@ -163,7 +163,7 @@ class InputValuePlaceholder(base_model.BaseModel,
"""Class that holds an input value placeholder.
Attributes:
output_name: Name of the input.
input_name: Name of the input.
"""
input_name: str
_aliases = {'input_name': 'inputValue'}
@ -178,7 +178,7 @@ class InputPathPlaceholder(base_model.BaseModel,
"""Class that holds an input path placeholder.
Attributes:
output_name: Name of the input.
input_name: Name of the input.
"""
input_name: str
_aliases = {'input_name': 'inputPath'}
@ -193,7 +193,7 @@ class InputUriPlaceholder(base_model.BaseModel,
"""Class that holds an input uri placeholder.
Attributes:
output_name: Name of the input.
input_name: Name of the input.
"""
input_name: str
_aliases = {'input_name': 'inputUri'}
@ -203,12 +203,27 @@ class InputUriPlaceholder(base_model.BaseModel,
)
class InputMetadataPlaceholder(base_model.BaseModel,
RegexPlaceholderSerializationMixin):
"""Class that holds an input metadata placeholder.
Attributes:
input_name: Name of the input.
"""
input_name: str
_aliases = {'input_name': 'inputMetadata'}
_TO_PLACEHOLDER = "{{{{$.inputs.artifacts['{input_name}'].metadata}}}}"
_FROM_PLACEHOLDER = re.compile(
r"^\{\{\$\.inputs\.artifacts\[(?:''|'|\")(?P<input_name>.+?)(?:''|'|\")]\.metadata\}\}$"
)
class OutputParameterPlaceholder(base_model.BaseModel,
RegexPlaceholderSerializationMixin):
"""Class that holds an output parameter placeholder.
Attributes:
output_name: Name of the input.
output_name: Name of the output.
"""
output_name: str
_aliases = {'output_name': 'outputPath'}
@ -223,7 +238,7 @@ class OutputPathPlaceholder(base_model.BaseModel,
"""Class that holds an output path placeholder.
Attributes:
output_name: Name of the input.
output_name: Name of the output.
"""
output_name: str
_aliases = {'output_name': 'outputPath'}
@ -248,6 +263,21 @@ class OutputUriPlaceholder(base_model.BaseModel,
)
class OutputMetadataPlaceholder(base_model.BaseModel,
RegexPlaceholderSerializationMixin):
"""Class that holds an output metadata placeholder.
Attributes:
output_name: Name of the output.
"""
output_name: str
_aliases = {'output_name': 'outputMetadata'}
_TO_PLACEHOLDER = "{{{{$.outputs.artifacts['{output_name}'].metadata}}}}"
_FROM_PLACEHOLDER = re.compile(
r"^\{\{\$\.outputs\.artifacts\[(?:''|'|\")(?P<output_name>.+?)(?:''|'|\")]\.metadata\}\}$"
)
CommandLineElement = Union[str, ExecutorInputPlaceholder, InputValuePlaceholder,
InputPathPlaceholder, InputUriPlaceholder,
OutputParameterPlaceholder, OutputPathPlaceholder,

View File

@ -95,6 +95,22 @@ class TestInputUriPlaceholder(parameterized.TestCase):
placeholder_string)
class TestInputMetadataPlaceholder(parameterized.TestCase):
@parameterized.parameters([
("{{$.inputs.artifacts['input1'].metadata}}",
placeholders.InputMetadataPlaceholder('input1')),
])
def test_to_from_placeholder(
self, placeholder_string: str,
placeholder_obj: placeholders.InputMetadataPlaceholder):
self.assertEqual(
placeholders.InputMetadataPlaceholder.from_placeholder_string(
placeholder_string), placeholder_obj)
self.assertEqual(placeholder_obj.to_placeholder_string(),
placeholder_string)
class TestOutputPathPlaceholder(parameterized.TestCase):
@parameterized.parameters([
@ -143,6 +159,22 @@ class TestOutputUriPlaceholder(parameterized.TestCase):
placeholder_string)
class TestOutputMetadataPlaceholder(parameterized.TestCase):
@parameterized.parameters([
("{{$.outputs.artifacts['output1'].metadata}}",
placeholders.OutputMetadataPlaceholder('output1')),
])
def test_to_from_placeholder(
self, placeholder_string: str,
placeholder_obj: placeholders.OutputMetadataPlaceholder):
self.assertEqual(
placeholders.OutputMetadataPlaceholder.from_placeholder_string(
placeholder_string), placeholder_obj)
self.assertEqual(placeholder_obj.to_placeholder_string(),
placeholder_string)
class TestIfPresentPlaceholderStructure(parameterized.TestCase):
def test_else_transform(self):

View File

@ -448,13 +448,16 @@ def _check_valid_placeholder_reference(
elif isinstance(
placeholder,
(placeholders.InputValuePlaceholder, placeholders.InputPathPlaceholder,
placeholders.InputUriPlaceholder)):
placeholders.InputUriPlaceholder,
placeholders.InputMetadataPlaceholder)):
if placeholder.input_name not in valid_inputs:
raise ValueError(
f'Argument "{placeholder}" references non-existing input.')
elif isinstance(placeholder, (placeholders.OutputParameterPlaceholder,
placeholders.OutputPathPlaceholder,
placeholders.OutputUriPlaceholder)):
elif isinstance(
placeholder,
(placeholders.OutputParameterPlaceholder,
placeholders.OutputPathPlaceholder, placeholders.OutputUriPlaceholder,
placeholders.OutputMetadataPlaceholder)):
if placeholder.output_name not in valid_outputs:
raise ValueError(
f'Argument "{placeholder}" references non-existing output.')
@ -481,6 +484,8 @@ ValidCommandArgTypes = (str, placeholders.InputValuePlaceholder,
placeholders.InputUriPlaceholder,
placeholders.OutputPathPlaceholder,
placeholders.OutputUriPlaceholder,
placeholders.InputMetadataPlaceholder,
placeholders.OutputMetadataPlaceholder,
placeholders.IfPresentPlaceholder,
placeholders.ConcatPlaceholder)