pipelines/components/aws/sagemaker/commonv2/sagemaker_component.py

861 lines
30 KiB
Python

"""Base class for all SageMaker components."""
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import signal
import string
import logging
import json
from types import FunctionType
import yaml
import random
from pathlib import Path
from time import sleep, strftime, gmtime
from abc import abstractmethod
from typing import Any, Dict, List, NamedTuple, Optional
from kubernetes import client, config
from kubernetes.client.api_client import ApiClient
from kubernetes.client.rest import ApiException
from commonv2.sagemaker_component_spec import SageMakerComponentSpec
from commonv2.common_inputs import (
SageMakerComponentBaseOutputs,
SageMakerComponentCommonInputs,
)
from commonv2 import snake_to_camel, is_ack_requeue_error
# This handler is called whenever the @ComponentMetadata is applied.
# It allows the command line compiler to detect every component spec class.
_component_decorator_handler: Optional[FunctionType] = None
def ComponentMetadata(name: str, description: str, spec: object):
"""Decorator for SageMaker components.
Used to define necessary metadata attributes about the component which will
be used for logging output and for the component specification file.
Usage:
```python
@ComponentMetadata(
name="SageMaker - Component Name",
description="A cool new component we made!",
spec=MyComponentSpec
)
"""
def _component_metadata(cls):
cls.COMPONENT_NAME = name
cls.COMPONENT_DESCRIPTION = description
cls.COMPONENT_SPEC = spec
# Add handler for compiler
if _component_decorator_handler:
return _component_decorator_handler(cls) or cls
return cls
return _component_metadata
class SageMakerJobStatus(NamedTuple):
"""Generic representation of a job status."""
is_completed: bool
raw_status: str
has_error: bool = False
error_message: Optional[str] = None
class SageMakerComponent:
"""Base class for a KFP SageMaker component.
An instance of a subclass of this component represents an instantiation of the
component within a pipeline run. Use the `@ComponentMetadata` decorator to
modify the component attributes listed below.
Attributes:
COMPONENT_NAME: The name of the component as displayed to the user.
COMPONENT_DESCRIPTION: The description of the component as displayed to
the user.
COMPONENT_SPEC: The correspending spec associated with the component.
STATUS_POLL_INTERVAL: Number of seconds between polling for the job
status.
"""
COMPONENT_NAME = ""
COMPONENT_DESCRIPTION = ""
COMPONENT_SPEC = SageMakerComponentSpec
STATUS_POLL_INTERVAL = 30
UPDATE_PROCESS_INTERVAL = 10
# parameters that will be filled by Do().
# assignment statements in Do() will be genereated
job_name: str
group: str
version: str
plural: str
spaced_out_resource_name: str # Used for Logs
namespace: Optional[str] = None
resource_upgrade: bool = False
initial_status: dict
update_supported: bool
job_request_outline_location: str
job_request_location: str
def __init__(self):
"""Initialize a new component."""
self._initialize_logging()
def _initialize_logging(self):
"""Initializes the global logging structure."""
logging.getLogger().setLevel(logging.INFO)
def Do(
self,
inputs: SageMakerComponentCommonInputs,
outputs: SageMakerComponentBaseOutputs,
output_paths: SageMakerComponentBaseOutputs,
):
"""The main execution entrypoint for a component at runtime.
Args:
inputs: A populated list of user inputs.
outputs: An unpopulated list of component output variables.
output_paths: Paths to the respective output locations.
"""
# Verify that the kubernetes cluster is available
try:
self._init_configure_k8s()
except Exception as e:
logging.exception("Failed to initialize k8s client: %s", e)
sys.exit(1)
# Global try-catch in order to allow for safe abort
try:
# Successful execution
if not self._do(inputs, outputs, output_paths):
sys.exit(1)
except Exception as e:
logging.exception("An error occurred while running the component")
raise e
def _init_configure_k8s(self):
"""Initializes the kubernetes client and configures the namespace."""
_test_client = self._get_k8s_api_client()
_test_api = client.CoreV1Api(_test_client)
_test_api.list_namespaced_pod(namespace=self.namespace)
def _get_k8s_api_client(self) -> ApiClient:
"""Create new client everytime to avoid token refresh issues."""
# when run in k8s cluster
config.load_incluster_config()
return ApiClient()
def _get_current_namespace(self):
"""
Get the current namespace.
"""
ns_path = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"
if os.path.exists(ns_path):
with open(ns_path) as f:
return f.read().strip()
try:
_, active_context = config.list_kube_config_contexts()
return active_context["context"]["namespace"]
except KeyError:
return "default"
def _do(
self,
inputs: SageMakerComponentCommonInputs,
outputs: SageMakerComponentBaseOutputs,
output_paths: SageMakerComponentBaseOutputs,
) -> bool:
# Set up SIGTERM handling
def signal_term_handler(signalNumber, frame):
self._on_job_terminated()
signal.signal(signal.SIGTERM, signal_term_handler)
self.resource_upgrade = self._is_upgrade()
if self.resource_upgrade and not self.update_supported:
logging.error(
f"Resource update is not supported for {self.spaced_out_resource_name}"
)
return False
request = self._create_job_request(inputs, outputs)
try:
job = self._submit_job_request(request)
except Exception as e:
logging.exception(
"An error occurred while attempting to submit the request"
)
return False
created = self._verify_resource_consumption()
if not created:
return False
self._after_submit_job_request(job, request, inputs, outputs)
status: SageMakerJobStatus = SageMakerJobStatus(
is_completed=False, raw_status="No Status"
)
try:
while True:
cr_condition = self._check_resource_conditions()
if cr_condition:
sleep(self.STATUS_POLL_INTERVAL)
continue
elif (
cr_condition == False
): # ACK.Terminal or special errors (Validation Exception/Invalid Input)
return False
status = (
self._get_job_status()
if not self.resource_upgrade
else self._get_upgrade_status()
)
# Continue until complete
if status and status.is_completed:
if self.resource_upgrade:
logging.info(
f"{self.spaced_out_resource_name} Update complete, final status: {status.raw_status}"
)
else:
logging.info(
f"{self.spaced_out_resource_name} Creation complete, final status: {status.raw_status}"
)
break
sleep(self.STATUS_POLL_INTERVAL)
logging.info(
f"{self.spaced_out_resource_name} is in status: {status.raw_status}"
)
except Exception as e:
logging.exception(
f"An error occurred while polling for {self.spaced_out_resource_name} status"
)
return False
if status.has_error:
logging.error(status.error_message)
return False
self._after_job_complete(job, request, inputs, outputs)
self._write_all_outputs(output_paths, outputs)
return True
def _get_conditions_of_type(self, condition_type):
resource_conditions = self._get_resource()["status"]["conditions"]
filtered_conditions = filter(
lambda condition: (condition["type"] == condition_type), resource_conditions
)
return list(filtered_conditions)
def _verify_resource_consumption(self) -> bool:
"""Verify that the resource has been successfully consumed by the controller.
In the case of an update verify that the job arn exists.
Returns:
bool: Whether the resource consumed by the controller.
"""
submission_ack_printed = False
ERROR_NOT_CREATED_MESSAGE = "An error occurred while getting resource arn, ACK CR created but Sagemaker resource not created."
ERROR_UPDATE_MESSAGE = "An error occured when getting the resource arn. Check the ACK Sagemaker Controller logs."
try:
while True:
cr_condition = self._check_resource_conditions()
if cr_condition: # ACK.Recoverable
sleep(self.STATUS_POLL_INTERVAL)
continue
elif cr_condition == False:
if (
self.resource_upgrade
and not self.is_update_consumed_by_controller()
):
sleep(self.UPDATE_PROCESS_INTERVAL)
continue
return False
# Retrieve Sagemaker ARN
arn = self.check_resource_initiation(submission_ack_printed)
# Continue until complete
if arn:
submission_ack_printed = True
if (
self.resource_upgrade
and not self.is_update_consumed_by_controller()
):
sleep(self.UPDATE_PROCESS_INTERVAL)
continue
break
sleep(self.STATUS_POLL_INTERVAL)
logging.info(f"Getting arn for {self.job_name}")
except Exception as e:
err_msg = (
ERROR_UPDATE_MESSAGE
if self.resource_upgrade
else ERROR_NOT_CREATED_MESSAGE
)
logging.exception(err_msg)
return False
return True
def check_resource_initiation(self, submission_ack_printed: bool):
""" Check if resource has been initiated in Sagemaker.
A resource is considered to be initiated if the resource ARN is present in the ack resource metadata.
If the resource ARN is present in the ack resource metadata, the resource has been successfully
created in Sagemaker.
Args:
submission_ack_printed (bool): Parameter to avoid printing the resource consumed message
multiple times.
Returns:
str: The ARN of the resource. If the resource ARN is not present in the ack resource metadata,
the resource has not been created in Sagemaker.
"""
ack_status = self._get_resource()["status"]
ack_resource_meta = ack_status.get("ackResourceMetadata", None)
if ack_resource_meta:
arn = ack_resource_meta.get("arn", None)
if arn is not None:
if submission_ack_printed:
resource_consumed_message = (
f"Created Sagemaker {self.spaced_out_resource_name} with ARN: {arn}"
if not self.resource_upgrade
else f"Submitting update for Sagemaker {self.spaced_out_resource_name} with ARN: {arn}"
)
logging.info(resource_consumed_message)
return arn
return None
@abstractmethod
def _get_job_status(self) -> SageMakerJobStatus:
"""Waits for the current job to complete.
Returns:
SageMakerJobStatus: A status object.
"""
pass
@abstractmethod
def _get_upgrade_status(self) -> SageMakerJobStatus:
"""Waits for the resource upgrade to complete
Returns:
SageMakerJobStatus: A status object.
"""
pass
def is_update_consumed_by_controller(self):
"""Check if update has been consumed by the controller, in this case it is done by
checking whether
"""
current_resource = self._get_resource()
current_status = current_resource.get("status", None)
## Python == is deep equal between dicts.
if current_status == self.initial_status:
return False
return True
def _get_resource(self):
"""Get the custom resource detail similar to: kubectl describe
trainingjob JOB_NAME -n NAMESPACE.
Returns:
None or object: None if the resource doesnt exist in server, otherwise the
custom object.
"""
_api_client = self._get_k8s_api_client()
_api = client.CustomObjectsApi(_api_client)
if self.namespace is None:
job_description = _api.get_cluster_custom_object(
self.group.lower(),
self.version.lower(),
self.plural.lower(),
self.job_name.lower(),
)
else:
job_description = _api.get_namespaced_custom_object(
self.group.lower(),
self.version.lower(),
self.namespace.lower(), # "default",
self.plural.lower(),
self.job_name.lower(),
)
return job_description
@abstractmethod
def _create_job_request(
self,
inputs: SageMakerComponentCommonInputs,
outputs: SageMakerComponentBaseOutputs,
) -> Dict:
"""Creates the ACK custom object.
Args:
inputs: A populated list of user inputs.
outputs: An unpopulated list of component output variables.
Returns:
dict: A dictionary object representing the custom object.
"""
pass
def _create_job_yaml(
self,
inputs: SageMakerComponentCommonInputs,
outputs: SageMakerComponentBaseOutputs,
) -> Dict:
"""Creates the ACK request object to execute the component.
Args:
inputs: A populated list of user inputs.
outputs: An unpopulated list of component output variables.
Returns:
dict: A dictionary object representing the request.
"""
with open(self.job_request_outline_location) as job_request_outline:
job_request_dict = yaml.load(job_request_outline, Loader=yaml.FullLoader)
job_request_spec = job_request_dict["spec"]
# populate meta data
job_request_dict["metadata"]["name"] = self.job_name
job_request_dict["metadata"]["annotations"][
"services.k8s.aws/region"
] = getattr(inputs, "region")
# populate spec from inputs
for para in vars(inputs):
camel_para = snake_to_camel(para)
if camel_para in job_request_spec:
value = getattr(inputs, para)
if value not in [{}, []]:
job_request_spec[camel_para] = value
# clean up empty fields in job_request_spec
filtered = {k: v for k, v in job_request_spec.items() if v is not None}
job_request_spec.clear()
job_request_spec.update(filtered)
job_request_dict["spec"] = job_request_spec
logging.info(f"Custom resource: {json.dumps(job_request_dict, indent=2)}")
return job_request_dict
@abstractmethod
def _submit_job_request(self, request: Dict) -> Dict:
"""Submits a pre-defined request object to SageMaker.
The `request` argument should be provided as the result of the
`_create_job_request` method.
Args:
request: A request object to execute the component.
Returns:
dict: The job object that was created.
Raises:
Exception: If SageMaker responded with an error during the request.
"""
pass
def _patch_custom_resource(self, custom_resource: dict):
"""Patch a custom resource in ACK
Args:
custom_resource: A dictionary object representing the custom object.
Returns:
dict: The job object that was patched
"""
_api_client = self._get_k8s_api_client()
_api = client.CustomObjectsApi(_api_client)
if self.namespace is None:
return _api.patch_cluster_custom_object(
self.group.lower(),
self.version.lower(),
self.plural.lower(),
self.job_name.lower(),
custom_resource,
)
return _api.patch_namespaced_custom_object(
self.group.lower(),
self.version.lower(),
self.namespace.lower(),
self.plural.lower(),
self.job_name.lower(),
custom_resource,
)
def _create_custom_resource(self, custom_resource: dict):
"""Submit a custom_resource to the ACK cluster.
Args:
custom_resource: A dictionary object representing the custom object.
"""
_api_client = self._get_k8s_api_client()
_api = client.CustomObjectsApi(_api_client)
if self.namespace is None:
return _api.create_cluster_custom_object(
self.group.lower(),
self.version.lower(),
self.plural.lower(),
custom_resource,
)
return _api.create_namespaced_custom_object(
self.group.lower(),
self.version.lower(),
self.namespace.lower(),
self.plural.lower(),
custom_resource,
)
def _wait_resource_consumed_by_controller(
self,
wait_periods,
period_length,
):
"""Wait for the custom resource to be consumed by the controller.
Args:
wait_periods: The number of times to wait for the resource to be consumed.
period_length: The length of time to wait between polling.
"""
if not self._get_resource_exists():
logging.error(
f"Resource %s does not exist",
(self.job_name),
)
return None
for _ in range(wait_periods):
resource = self._get_resource()
if "status" in resource:
return resource
sleep(period_length)
logging.error(
f"Wait for resource %s to be consumed by controller timed out",
(self.job_name),
)
return None
def _get_resource_exists(self) -> bool:
"""Check if the custom resource exists.
Returns:
bool: True if the resource exists, False otherwise.
"""
try:
return self._get_resource() is not None
except ApiException:
return False
def _create_resource(
self,
cr_spec: object,
wait_periods=6,
period_length=10,
):
"""Create a resource from the spec and wait to be consumed by
controller.
Args:
cr_spec: A dictionary object representing the custom object.
wait_periods: The number of times to wait for the resource to be created.
period_length: The length of time to wait between polling.
"""
resource = self._create_custom_resource(cr_spec)
resource = self._wait_resource_consumed_by_controller(
wait_periods, period_length
)
if resource is None:
logging.error(f"Resource {self.job_name} is not created.")
logging.error(
f"Possible reason: ACK controller may not have been configured properly."
)
raise Exception(f"Resource {self.job_name} is not created.")
else:
logging.info(f"Created custom resource with name: {self.job_name}")
return resource
def _check_resource_conditions(self):
"""Check the status of the custom resource.
* loop through all conditions
* if recoverable and condition set to true, print out message and return true
(let outside polling loop goes on forever and let user decide if should stop)
* if terminal and condition set up true, print out message and return false
* Returns None if there are no error conditions.
"""
status_conditions = self._get_resource()["status"]["conditions"]
for condition in status_conditions:
condition_type = condition["type"]
condition_status = condition["status"]
condition_message = condition.get("message", "No error message found.")
# If the controller has not consumed the update, any existing error will not representative of the new state.
if self.resource_upgrade and not self.is_update_consumed_by_controller():
continue
if condition_type == "ACK.Terminal" and condition_status == "True":
logging.error(json.dumps(condition, indent=2))
logging.error(
"Terminating the run because resource encountered a Terminal condition. Please describe the resource for further debugging and retry with correct parameters."
)
return False
if condition_type == "ACK.Recoverable" and condition_status == "True":
# ACK requeue errors are not real errors.
if is_ack_requeue_error(condition_message):
continue
logging.error(json.dumps(condition, indent=2))
if "ValidationException" in condition_message:
logging.error(
"Terminating the run because resource encountered a Validation Exception. Please describe the resource for further debugging and retry with correct parameters."
)
return False
elif "InvalidParameter" in condition_message:
logging.error(
"Terminating the run because resource encountered InvalidParameters. Please describe the resource for further debugging and retry with correct parameters."
)
return False
else:
logging.error(
"Waiting for error to be resolved . . . Please fix the error or terminate the job and retry with correct parameters if this is not a transient error"
)
return True
return None
def _get_resource_synced_status(self, ack_statuses: Dict):
""" Retrieve the resource sync status
"""
conditions = ack_statuses.get("conditions", None) # Conditions has to be there
if conditions == None:
return None
for condition in conditions:
if condition["type"] == "ACK.ResourceSynced":
if condition["status"] == "True":
return True
else:
return False
return False
@abstractmethod
def _after_submit_job_request(
self,
job: object,
request: Dict,
inputs: SageMakerComponentCommonInputs,
outputs: SageMakerComponentBaseOutputs,
):
"""Handles any events required after submitting a job to SageMaker.
Args:
job: The job returned after creation.
request: The request submitted prior.
inputs: A populated list of user inputs.
outputs: An unpopulated list of component output variables.
"""
pass
@abstractmethod
def _after_job_complete(
self,
job: object,
request: Dict,
inputs: SageMakerComponentCommonInputs,
outputs: SageMakerComponentBaseOutputs,
):
"""Handles any events after the job has been completed.
Args:
job: The job object that was created.
request: The request object used to execute the component.
inputs: A populated list of user inputs.
outputs: An unpopulated list of component output variables.
"""
pass
@abstractmethod
def _on_job_terminated(self):
"""Handles any SIGTERM events."""
pass
def _delete_custom_resource(self):
"""Delete custom resource from cluster and wait for it to be removed by
the server.
for wait_periods * period_length seconds.
Returns:
response, bool:
response is APIserver response for the operation.
bool is true if resource was removed from the server and false otherwise
"""
_api_client = self._get_k8s_api_client()
_api = client.CustomObjectsApi(_api_client)
if self.resource_upgrade:
logging.info("Recieved termination signal, stopping component but resource update will still proceed if started. Please rerun the component with the desired configuration to revert the update.")
return _response, True
logging.info("Recieved termination signal, deleting custom resource %s", (self.job_name))
_response = None
if self.namespace is None:
_response = _api.delete_cluster_custom_object(
self.group.lower(),
self.version.lower(),
self.plural.lower(),
self.job_name.lower(),
)
else:
_response = _api.delete_namespaced_custom_object(
self.group.lower(),
self.version.lower(),
self.namespace.lower(),
self.plural.lower(),
self.job_name.lower(),
)
return _response, True
@staticmethod
def _generate_unique_timestamped_id(
prefix: str = "",
size: int = 4,
chars: str = string.ascii_uppercase + string.digits,
max_length: int = 32,
) -> str:
"""Generate a pseudo-random string of characters appended to a
timestamp.
Format of the ID is as follows: `prefix-YYYYMMDDHHMMSS-unique`. If the
length of the total ID exceeds `max_length`, it will be truncated from
the beginning (prefix will be trimmed).
Args:
prefix: A prefix to append to the random suffix.
size: The number of unique characters to append to the ID.
chars: A list of characters to use in the random suffix.
max_length: The maximum length of the generated ID.
Returns:
string: A pseudo-random string with included timestamp and prefix.
"""
unique = "".join(random.choice(chars) for _ in range(size)).lower()
return f'{prefix}{"-" if prefix else ""}{strftime("%Y%m%d%H%M%S", gmtime())}-{unique}'[
-max_length:
]
def _write_all_outputs(
self,
output_paths: SageMakerComponentBaseOutputs,
outputs: SageMakerComponentBaseOutputs,
):
"""Writes all of the outputs specified by the component to their
respective file paths.
Args:
output_paths: A populated list of output paths.
outputs: A populated list of output values.
"""
for output_key, output_value in outputs.__dict__.items():
if output_value is None:
output_value = "N/A"
output_path = output_paths.__dict__.get(output_key)
if not output_path:
logging.error(f"Could not find output path for {output_key}")
continue
# Encode it if it's a List or Dict (not primitive)
encoded_types = (List, Dict)
self._write_output(
output_path,
output_value,
json_encode=isinstance(output_value, encoded_types),
)
logging.info(f"Wrote output '{output_key}' to '{output_path}'")
def _write_output(
self, output_path: str, output_value: Any, json_encode: bool = False
):
"""Write an output value to the associated path, dumping as a JSON
object if specified.
Args:
output_path: The file path of the output.
output_value: The output value to write to the file.
json_encode: True if the value should be encoded as a JSON object.
"""
write_value = json.dumps(output_value) if json_encode else output_value
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
Path(output_path).write_text(write_value)
def _is_upgrade(self):
"""If the resource already exists the component assumes that the user wants to upgrade
Returns:
Bool: If the resource is being upgraded or not.
Raises:
Exception
"""
try:
resource = self._get_resource()
if resource is None:
return False
logging.info("Existing resource detected. Starting Update.")
except client.exceptions.ApiException as error:
if error.status == 404:
logging.info("Resource does not exist. Creating a new resource.")
return False
else:
raise error
return True