154 lines
4.7 KiB
Python
154 lines
4.7 KiB
Python
"""Defines common spec input and output classes."""
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import (
|
|
Callable,
|
|
List,
|
|
NewType,
|
|
Optional,
|
|
Union,
|
|
)
|
|
from .spec_input_parsers import SpecInputParsers
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SageMakerComponentInputValidator:
|
|
"""Defines the structure of a component input to be used for validation."""
|
|
|
|
input_type: Callable
|
|
description: str
|
|
required: bool = False
|
|
choices: Optional[List[object]] = None
|
|
default: Optional[object] = None
|
|
|
|
def to_argparse_mapping(self):
|
|
"""Maps each property to an argparse argument.
|
|
|
|
See: https://docs.python.org/3/library/argparse.html#adding-arguments
|
|
"""
|
|
return {
|
|
"type": self.input_type,
|
|
"help": self.description,
|
|
"required": self.required,
|
|
"choices": self.choices,
|
|
"default": self.default,
|
|
}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SageMakerComponentOutputValidator:
|
|
"""Defines the structure of a component output."""
|
|
|
|
description: str
|
|
|
|
|
|
# The value of an input or output can be arbitrary.
|
|
# This could be replaced with a generic, as well.
|
|
SageMakerIOValue = NewType("SageMakerIOValue", object)
|
|
|
|
# Allow the component input to represent either the validator or the value itself
|
|
# This saves on having to rewrite the struct for both types.
|
|
# Could replace `object` with T if TypedDict supported generics (see above).
|
|
SageMakerComponentInput = NewType(
|
|
"SageMakerComponentInput", Union[SageMakerComponentInputValidator, SageMakerIOValue]
|
|
)
|
|
SageMakerComponentOutput = NewType(
|
|
"SageMakerComponentOutput",
|
|
Union[SageMakerComponentOutputValidator, SageMakerIOValue],
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SageMakerComponentBaseInputs:
|
|
"""The base class for all component inputs."""
|
|
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class SageMakerComponentBaseOutputs:
|
|
"""A base class for all component outputs."""
|
|
|
|
pass
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SageMakerComponentCommonInputs(SageMakerComponentBaseInputs):
|
|
"""A set of common inputs for all components."""
|
|
|
|
region: SageMakerComponentInput
|
|
endpoint_url: SageMakerComponentInput
|
|
assume_role: SageMakerComponentInput
|
|
tags: SageMakerComponentInput
|
|
|
|
|
|
COMMON_INPUTS = SageMakerComponentCommonInputs(
|
|
region=SageMakerComponentInputValidator(
|
|
input_type=str,
|
|
required=True,
|
|
description="The region for the SageMaker resource.",
|
|
),
|
|
endpoint_url=SageMakerComponentInputValidator(
|
|
input_type=SpecInputParsers.nullable_string_argument,
|
|
required=False,
|
|
description="The URL to use when communicating with the SageMaker service.",
|
|
),
|
|
assume_role=SageMakerComponentInputValidator(
|
|
input_type=str,
|
|
required=False,
|
|
description="The ARN of an IAM role to assume when connecting to SageMaker.",
|
|
),
|
|
tags=SageMakerComponentInputValidator(
|
|
input_type=SpecInputParsers.yaml_or_json_dict,
|
|
required=False,
|
|
description="An array of key-value pairs, to categorize AWS resources.",
|
|
default={},
|
|
),
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SpotInstanceInputs(SageMakerComponentBaseInputs):
|
|
"""Inputs to enable spot instance support."""
|
|
|
|
spot_instance: SageMakerComponentInput
|
|
max_wait_time: SageMakerComponentInput
|
|
max_run_time: SageMakerComponentInput
|
|
checkpoint_config: SageMakerComponentInput
|
|
|
|
|
|
SPOT_INSTANCE_INPUTS = SpotInstanceInputs(
|
|
spot_instance=SageMakerComponentInputValidator(
|
|
input_type=SpecInputParsers.str_to_bool,
|
|
description="Use managed spot training.",
|
|
default=False,
|
|
),
|
|
max_wait_time=SageMakerComponentInputValidator(
|
|
input_type=int,
|
|
description="The maximum time in seconds you are willing to wait for a managed spot training job to complete.",
|
|
default=86400,
|
|
),
|
|
max_run_time=SageMakerComponentInputValidator(
|
|
input_type=int,
|
|
description="The maximum run time in seconds for the training job.",
|
|
default=86400,
|
|
),
|
|
checkpoint_config=SageMakerComponentInputValidator(
|
|
input_type=SpecInputParsers.yaml_or_json_dict,
|
|
description="Dictionary of information about the output location for managed spot training checkpoint data.",
|
|
default={},
|
|
),
|
|
)
|