pipelines/components/aws/sagemaker/common/sagemaker_component.py

598 lines
20 KiB
Python

"""Base class for all SageMaker components."""
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import re
import signal
import string
import logging
import json
from enum import Enum, auto
from types import FunctionType
import yaml
import random
from pathlib import Path
from time import sleep, strftime, gmtime
from abc import abstractmethod
from typing import Any, Type, Dict, List, NamedTuple, Optional
from .sagemaker_component_spec import SageMakerComponentSpec
from .boto3_manager import Boto3Manager
from .common_inputs import (
SageMakerComponentBaseOutputs,
SageMakerComponentCommonInputs,
SpotInstanceInputs,
)
# This handler is called whenever the @ComponentMetadata is applied.
# It allows the command line compiler to detect every component spec class.
_component_decorator_handler: Optional[FunctionType] = None
def ComponentMetadata(name: str, description: str, spec: object):
"""Decorator for SageMaker components.
Used to define necessary metadata attributes about the component which will
be used for logging output and for the component specification file.
Usage:
```python
@ComponentMetadata(
name="SageMaker - Component Name",
description="A cool new component we made!",
spec=MyComponentSpec
)
"""
def _component_metadata(cls):
cls.COMPONENT_NAME = name
cls.COMPONENT_DESCRIPTION = description
cls.COMPONENT_SPEC = spec
# Add handler for compiler
if _component_decorator_handler:
return _component_decorator_handler(cls) or cls
return cls
return _component_metadata
class SageMakerJobStatus(NamedTuple):
"""Generic representation of a job status."""
is_completed: bool
raw_status: str
has_error: bool = False
error_message: Optional[str] = None
class DebugRulesStatus(Enum):
COMPLETED = auto()
ERRORED = auto()
INPROGRESS = auto()
@classmethod
def from_describe(cls, response):
has_error = False
for debug_rule in response["DebugRuleEvaluationStatuses"]:
if debug_rule["RuleEvaluationStatus"] == "Error":
has_error = True
if debug_rule["RuleEvaluationStatus"] == "InProgress":
return DebugRulesStatus.INPROGRESS
if has_error:
return DebugRulesStatus.ERRORED
else:
return DebugRulesStatus.COMPLETED
class SageMakerComponent:
"""Base class for a KFP SageMaker component.
An instance of a subclass of this component represents an instantiation of the
component within a pipeline run. Use the `@ComponentMetadata` decorator to
modify the component attributes listed below.
Attributes:
COMPONENT_NAME: The name of the component as displayed to the user.
COMPONENT_DESCRIPTION: The description of the component as displayed to
the user.
COMPONENT_SPEC: The correspending spec associated with the component.
STATUS_POLL_INTERVAL: Number of seconds between polling for the job
status.
"""
COMPONENT_NAME = ""
COMPONENT_DESCRIPTION = ""
COMPONENT_SPEC = SageMakerComponentSpec
STATUS_POLL_INTERVAL = 30
def __init__(self):
"""Initialize a new component."""
self._initialize_logging()
def _initialize_logging(self):
"""Initializes the global logging structure."""
logging.getLogger().setLevel(logging.INFO)
def Do(
self,
inputs: SageMakerComponentCommonInputs,
outputs: SageMakerComponentBaseOutputs,
output_paths: SageMakerComponentBaseOutputs,
):
"""The main execution entrypoint for a component at runtime.
Args:
inputs: A populated list of user inputs.
outputs: An unpopulated list of component output variables.
output_paths: Paths to the respective output locations.
"""
# Global try-catch in order to allow for safe abort
try:
self._configure_aws_clients(inputs)
# Successful execution
if not self._do(inputs, outputs, output_paths):
sys.exit(1)
except Exception as e:
logging.exception("An error occurred while running the component")
raise e
def _configure_aws_clients(self, inputs: SageMakerComponentCommonInputs):
"""Configures the internal AWS clients for the component.
Args:
inputs: A populated list of user inputs.
"""
self._sm_client = Boto3Manager.get_sagemaker_client(
self._get_component_version(),
inputs.region,
endpoint_url=inputs.endpoint_url,
assume_role_arn=inputs.assume_role,
)
self._cw_client = Boto3Manager.get_cloudwatch_client(
inputs.region, assume_role_arn=inputs.assume_role
)
def _do(
self,
inputs: SageMakerComponentCommonInputs,
outputs: SageMakerComponentBaseOutputs,
output_paths: SageMakerComponentBaseOutputs,
) -> bool:
# Set up SIGTERM handling
def signal_term_handler(signalNumber, frame):
self._on_job_terminated()
signal.signal(signal.SIGTERM, signal_term_handler)
request = self._create_job_request(inputs, outputs)
try:
job = self._submit_job_request(request)
except Exception as e:
logging.exception(
"An error occurred while attempting to submit the request"
)
return False
self._after_submit_job_request(job, request, inputs, outputs)
status: SageMakerJobStatus = SageMakerJobStatus(
is_completed=False, raw_status="No Status"
)
try:
while True:
status = self._get_job_status()
# Continue until complete
if status and status.is_completed:
break
sleep(self.STATUS_POLL_INTERVAL)
logging.info(f"Job is in status: {status.raw_status}")
except Exception as e:
logging.exception("An error occurred while polling for job status")
return False
finally:
self._print_logs_for_job()
if status.has_error:
logging.error(status.error_message)
return False
self._after_job_complete(job, request, inputs, outputs)
self._write_all_outputs(output_paths, outputs)
return True
@abstractmethod
def _get_job_status(self) -> SageMakerJobStatus:
"""Waits for the current job to complete.
Returns:
SageMakerJobStatus: A status object.
"""
pass
@abstractmethod
def _create_job_request(
self,
inputs: SageMakerComponentCommonInputs,
outputs: SageMakerComponentBaseOutputs,
) -> Dict:
"""Creates the boto3 request object to execute the component.
Args:
inputs: A populated list of user inputs.
outputs: An unpopulated list of component output variables.
Returns:
dict: A dictionary object representing the request.
"""
pass
@abstractmethod
def _submit_job_request(self, request: Dict) -> Dict:
"""Submits a pre-defined request object to SageMaker.
The `request` argument should be provided as the result of the
`_create_job_request` method.
Args:
request: A request object to execute the component.
Returns:
dict: The job object that was created.
Raises:
Exception: If SageMaker responded with an error during the request.
"""
pass
@abstractmethod
def _after_submit_job_request(
self,
job: object,
request: Dict,
inputs: SageMakerComponentCommonInputs,
outputs: SageMakerComponentBaseOutputs,
):
"""Handles any events required after submitting a job to SageMaker.
Args:
job: The job returned after creation.
request: The request submitted prior.
inputs: A populated list of user inputs.
outputs: An unpopulated list of component output variables.
"""
pass
@abstractmethod
def _after_job_complete(
self,
job: object,
request: Dict,
inputs: SageMakerComponentCommonInputs,
outputs: SageMakerComponentBaseOutputs,
):
"""Handles any events after the job has been completed.
Args:
job: The job object that was created.
request: The request object used to execute the component.
inputs: A populated list of user inputs.
outputs: An unpopulated list of component output variables.
"""
pass
@abstractmethod
def _on_job_terminated(self):
"""Handles any SIGTERM events."""
pass
@abstractmethod
def _print_logs_for_job(self):
"""Print the associated logs for the current job."""
pass
@staticmethod
def _generate_unique_timestamped_id(
prefix: str = "",
size: int = 4,
chars: str = string.ascii_uppercase + string.digits,
max_length: int = 63,
) -> str:
"""Generate a pseudo-random string of characters appended to a
timestamp.
Format of the ID is as follows: `prefix-YYYYMMDDHHMMSS-unique`. If the
length of the total ID exceeds `max_length`, it will be truncated from
the beginning (prefix will be trimmed).
Args:
prefix: A prefix to append to the random suffix.
size: The number of unique characters to append to the ID.
chars: A list of characters to use in the random suffix.
max_length: The maximum length of the generated ID.
Returns:
string: A pseudo-random string with included timestamp and prefix.
"""
unique = "".join(random.choice(chars) for _ in range(size))
return f'{prefix}{"-" if prefix else ""}{strftime("%Y%m%d%H%M%S", gmtime())}-{unique}'[
-max_length:
]
@staticmethod
def _enable_spot_instance_support(
request: Dict, inputs: SpotInstanceInputs,
) -> Dict:
"""Modifies a request object to add support for spot instance fields.
Args:
request: A request object to modify.
inputs: A populated list of user inputs.
Returns:
dict: The modified dictionary
"""
if inputs.max_run_time:
request["StoppingCondition"]["MaxRuntimeInSeconds"] = inputs.max_run_time
if inputs.spot_instance:
request["EnableManagedSpotTraining"] = inputs.spot_instance
if (
inputs.max_wait_time
>= request["StoppingCondition"]["MaxRuntimeInSeconds"]
):
request["StoppingCondition"][
"MaxWaitTimeInSeconds"
] = inputs.max_wait_time
else:
logging.error(
"Max wait time must be greater than or equal to max run time."
)
raise Exception("Could not create job request.")
if inputs.checkpoint_config and "S3Uri" in inputs.checkpoint_config:
request["CheckpointConfig"] = inputs.checkpoint_config
else:
logging.error(
"EnableManagedSpotTraining requires checkpoint config with an S3 uri."
)
raise Exception("Could not create job request.")
else:
# Remove any artifacts that require spot instance support
del request["StoppingCondition"]["MaxWaitTimeInSeconds"]
del request["CheckpointConfig"]
return request
@staticmethod
def _validate_hyperparameters(hyperparam_args: Dict) -> Dict:
"""Validates hyperparameters and returns the dictionary used for a
request.
Args:
hyperparam_args: HyperParameters as passed in by the user.
Returns:
dict: A validated set of HyperParameters.
"""
# Validate all values are strings
for key, value in hyperparam_args.items():
if not isinstance(value, str):
raise Exception(
f"Could not parse hyperparameters. Value for {key} was not a string."
)
return hyperparam_args
@staticmethod
def _enable_tag_support(
request: Dict, inputs: SageMakerComponentCommonInputs
) -> Dict:
"""Modifies a request object to add support for tag fields.
Args:
request: A request object to modify.
inputs: A populated list of user inputs.
Returns:
dict: The modified dictionary
"""
for key, val in inputs.tags.items():
request["Tags"].append({"Key": key, "Value": val})
return request
def _write_all_outputs(
self,
output_paths: SageMakerComponentBaseOutputs,
outputs: SageMakerComponentBaseOutputs,
):
"""Writes all of the outputs specified by the component to their
respective file paths.
Args:
output_paths: A populated list of output paths.
outputs: A populated list of output values.
"""
for output_key, output_value in outputs.__dict__.items():
output_path = output_paths.__dict__.get(output_key)
if not output_path:
logging.error(f"Could not find output path for {output_key}")
continue
# Encode it if it's a List or Dict (not primitive)
encoded_types = (List, Dict)
self._write_output(
output_path,
output_value,
json_encode=isinstance(output_value, encoded_types),
)
def _write_output(
self, output_path: str, output_value: Any, json_encode: bool = False
):
"""Write an output value to the associated path, dumping as a JSON
object if specified.
Args:
output_path: The file path of the output.
output_value: The output value to write to the file.
json_encode: True if the value should be encoded as a JSON object.
"""
write_value = json.dumps(output_value) if json_encode else output_value
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
Path(output_path).write_text(write_value)
@staticmethod
def _get_component_version() -> str:
"""Get component version from the first line of License file.
Returns:
str: The string version as specified in the License file.
"""
component_version = "NULL"
# Get license file using known common directory
license_file_path = os.path.abspath(
os.path.join(
SageMakerComponent._get_common_path(), "../THIRD-PARTY-LICENSES.txt"
)
)
with open(license_file_path, "r") as license_file:
version_match = re.search(
"Amazon SageMaker Components for Kubeflow Pipelines; version (([0-9]+[.])+[0-9]+)",
license_file.readline(),
)
if version_match is not None:
component_version = version_match.group(1)
return component_version
@staticmethod
def _get_common_path() -> str:
"""Gets the path of the common directory in the project.
Returns:
str: The `realpath` representation of the common directory.
"""
return os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
@staticmethod
def _get_request_template(template_name: str) -> Dict[str, Any]:
"""Loads and returns a template file as a Python construct.
Args:
template_name: The name corresponding to the template file to load.
Returns:
dict: A Python construct created by loading the template file.
"""
with open(
os.path.join(
SageMakerComponent._get_common_path(),
"templates",
f"{template_name}.template.yaml",
),
"r",
) as f:
request = yaml.safe_load(f)
return request
def _print_log_header(self, header_len, title=""):
"""Prints a header section for logs.
Args:
header_len: The maximum length of the header line.
title: The header title.
"""
logging.info(f"{title:*^{header_len}}")
def _print_cloudwatch_logs(self, log_grp: str, job_name: str):
"""Gets the CloudWatch logs for SageMaker jobs.
Args:
log_grp: The name of a CloudWatch log group.
job_name: The name of the job as defined in CloudWatch.
"""
CW_ERROR_MESSAGE = "Error in fetching CloudWatch logs for SageMaker job"
try:
logging.info(
"\n******************** CloudWatch logs for {} {} ********************\n".format(
log_grp, job_name
)
)
log_streams = self._cw_client.describe_log_streams(
logGroupName=log_grp, logStreamNamePrefix=job_name + "/"
)["logStreams"]
for log_stream in log_streams:
logging.info("\n***** {} *****\n".format(log_stream["logStreamName"]))
response = self._cw_client.get_log_events(
logGroupName=log_grp, logStreamName=log_stream["logStreamName"]
)
for event in response["events"]:
logging.info(event["message"])
logging.info(
"\n******************** End of CloudWatch logs for {} {} ********************\n".format(
log_grp, job_name
)
)
except Exception as e:
logging.error(CW_ERROR_MESSAGE)
logging.error(e)
def _get_model_artifacts_from_job(self, job_name: str):
"""Loads training job model artifact results from a completed job.
Args:
job_name: The name of the completed training job.
Returns:
str: The S3 model artifacts of the job.
"""
info = self._sm_client.describe_training_job(TrainingJobName=job_name)
model_artifact_url = info["ModelArtifacts"]["S3ModelArtifacts"]
return model_artifact_url
def _get_image_from_job(self, job_name: str):
"""Gets the training image URL from a training job.
Args:
job_name: The name of a training job.
Returns:
str: A training image URL.
"""
info = self._sm_client.describe_training_job(TrainingJobName=job_name)
if "TrainingImage" in info["AlgorithmSpecification"]:
image = info["AlgorithmSpecification"]["TrainingImage"]
else:
algorithm_name = info["AlgorithmSpecification"]["AlgorithmName"]
image = self._sm_client.describe_algorithm(AlgorithmName=algorithm_name)[
"TrainingSpecification"
]["TrainingImage"]
return image