"""Base class for all SageMaker components.""" # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import signal import string import logging import json from types import FunctionType import yaml import random from pathlib import Path from time import sleep, strftime, gmtime from abc import abstractmethod from typing import Any, Dict, List, NamedTuple, Optional from kubernetes import client, config from kubernetes.client.api_client import ApiClient from kubernetes.client.rest import ApiException from commonv2.sagemaker_component_spec import SageMakerComponentSpec from commonv2.common_inputs import ( SageMakerComponentBaseOutputs, SageMakerComponentCommonInputs, ) from commonv2 import snake_to_camel, is_ack_requeue_error # This handler is called whenever the @ComponentMetadata is applied. # It allows the command line compiler to detect every component spec class. _component_decorator_handler: Optional[FunctionType] = None def ComponentMetadata(name: str, description: str, spec: object): """Decorator for SageMaker components. Used to define necessary metadata attributes about the component which will be used for logging output and for the component specification file. Usage: ```python @ComponentMetadata( name="SageMaker - Component Name", description="A cool new component we made!", spec=MyComponentSpec ) """ def _component_metadata(cls): cls.COMPONENT_NAME = name cls.COMPONENT_DESCRIPTION = description cls.COMPONENT_SPEC = spec # Add handler for compiler if _component_decorator_handler: return _component_decorator_handler(cls) or cls return cls return _component_metadata class SageMakerJobStatus(NamedTuple): """Generic representation of a job status.""" is_completed: bool raw_status: str has_error: bool = False error_message: Optional[str] = None class SageMakerComponent: """Base class for a KFP SageMaker component. An instance of a subclass of this component represents an instantiation of the component within a pipeline run. Use the `@ComponentMetadata` decorator to modify the component attributes listed below. Attributes: COMPONENT_NAME: The name of the component as displayed to the user. COMPONENT_DESCRIPTION: The description of the component as displayed to the user. COMPONENT_SPEC: The correspending spec associated with the component. STATUS_POLL_INTERVAL: Number of seconds between polling for the job status. """ COMPONENT_NAME = "" COMPONENT_DESCRIPTION = "" COMPONENT_SPEC = SageMakerComponentSpec STATUS_POLL_INTERVAL = 30 UPDATE_PROCESS_INTERVAL = 10 # parameters that will be filled by Do(). # assignment statements in Do() will be genereated job_name: str group: str version: str plural: str spaced_out_resource_name: str # Used for Logs namespace: Optional[str] = None resource_upgrade: bool = False initial_status: dict update_supported: bool job_request_outline_location: str job_request_location: str def __init__(self): """Initialize a new component.""" self._initialize_logging() def _initialize_logging(self): """Initializes the global logging structure.""" logging.getLogger().setLevel(logging.INFO) def Do( self, inputs: SageMakerComponentCommonInputs, outputs: SageMakerComponentBaseOutputs, output_paths: SageMakerComponentBaseOutputs, ): """The main execution entrypoint for a component at runtime. Args: inputs: A populated list of user inputs. outputs: An unpopulated list of component output variables. output_paths: Paths to the respective output locations. """ # Verify that the kubernetes cluster is available try: self._init_configure_k8s() except Exception as e: logging.exception("Failed to initialize k8s client: %s", e) sys.exit(1) # Global try-catch in order to allow for safe abort try: # Successful execution if not self._do(inputs, outputs, output_paths): sys.exit(1) except Exception as e: logging.exception("An error occurred while running the component") raise e def _init_configure_k8s(self): """Initializes the kubernetes client and configures the namespace.""" _test_client = self._get_k8s_api_client() _test_api = client.CoreV1Api(_test_client) _test_api.list_namespaced_pod(namespace=self.namespace) def _get_k8s_api_client(self) -> ApiClient: """Create new client everytime to avoid token refresh issues.""" # when run in k8s cluster config.load_incluster_config() return ApiClient() def _get_current_namespace(self): """ Get the current namespace. """ ns_path = "/var/run/secrets/kubernetes.io/serviceaccount/namespace" if os.path.exists(ns_path): with open(ns_path) as f: return f.read().strip() try: _, active_context = config.list_kube_config_contexts() return active_context["context"]["namespace"] except KeyError: return "default" def _do( self, inputs: SageMakerComponentCommonInputs, outputs: SageMakerComponentBaseOutputs, output_paths: SageMakerComponentBaseOutputs, ) -> bool: # Set up SIGTERM handling def signal_term_handler(signalNumber, frame): self._on_job_terminated() signal.signal(signal.SIGTERM, signal_term_handler) self.resource_upgrade = self._is_upgrade() if self.resource_upgrade and not self.update_supported: logging.error( f"Resource update is not supported for {self.spaced_out_resource_name}" ) return False request = self._create_job_request(inputs, outputs) try: job = self._submit_job_request(request) except Exception as e: logging.exception( "An error occurred while attempting to submit the request" ) return False created = self._verify_resource_consumption() if not created: return False self._after_submit_job_request(job, request, inputs, outputs) status: SageMakerJobStatus = SageMakerJobStatus( is_completed=False, raw_status="No Status" ) try: while True: cr_condition = self._check_resource_conditions() if cr_condition: sleep(self.STATUS_POLL_INTERVAL) continue elif ( cr_condition == False ): # ACK.Terminal or special errors (Validation Exception/Invalid Input) return False status = ( self._get_job_status() if not self.resource_upgrade else self._get_upgrade_status() ) # Continue until complete if status and status.is_completed: if self.resource_upgrade: logging.info( f"{self.spaced_out_resource_name} Update complete, final status: {status.raw_status}" ) else: logging.info( f"{self.spaced_out_resource_name} Creation complete, final status: {status.raw_status}" ) break sleep(self.STATUS_POLL_INTERVAL) logging.info( f"{self.spaced_out_resource_name} is in status: {status.raw_status}" ) except Exception as e: logging.exception( f"An error occurred while polling for {self.spaced_out_resource_name} status" ) return False if status.has_error: logging.error(status.error_message) return False self._after_job_complete(job, request, inputs, outputs) self._write_all_outputs(output_paths, outputs) return True def _get_conditions_of_type(self, condition_type): resource_conditions = self._get_resource()["status"]["conditions"] filtered_conditions = filter( lambda condition: (condition["type"] == condition_type), resource_conditions ) return list(filtered_conditions) def _verify_resource_consumption(self) -> bool: """Verify that the resource has been successfully consumed by the controller. In the case of an update verify that the job arn exists. Returns: bool: Whether the resource consumed by the controller. """ submission_ack_printed = False ERROR_NOT_CREATED_MESSAGE = "An error occurred while getting resource arn, ACK CR created but Sagemaker resource not created." ERROR_UPDATE_MESSAGE = "An error occured when getting the resource arn. Check the ACK Sagemaker Controller logs." try: while True: cr_condition = self._check_resource_conditions() if cr_condition: # ACK.Recoverable sleep(self.STATUS_POLL_INTERVAL) continue elif cr_condition == False: if ( self.resource_upgrade and not self.is_update_consumed_by_controller() ): sleep(self.UPDATE_PROCESS_INTERVAL) continue return False # Retrieve Sagemaker ARN arn = self.check_resource_initiation(submission_ack_printed) # Continue until complete if arn: submission_ack_printed = True if ( self.resource_upgrade and not self.is_update_consumed_by_controller() ): sleep(self.UPDATE_PROCESS_INTERVAL) continue break sleep(self.STATUS_POLL_INTERVAL) logging.info(f"Getting arn for {self.job_name}") except Exception as e: err_msg = ( ERROR_UPDATE_MESSAGE if self.resource_upgrade else ERROR_NOT_CREATED_MESSAGE ) logging.exception(err_msg) return False return True def check_resource_initiation(self, submission_ack_printed: bool): """ Check if resource has been initiated in Sagemaker. A resource is considered to be initiated if the resource ARN is present in the ack resource metadata. If the resource ARN is present in the ack resource metadata, the resource has been successfully created in Sagemaker. Args: submission_ack_printed (bool): Parameter to avoid printing the resource consumed message multiple times. Returns: str: The ARN of the resource. If the resource ARN is not present in the ack resource metadata, the resource has not been created in Sagemaker. """ ack_status = self._get_resource()["status"] ack_resource_meta = ack_status.get("ackResourceMetadata", None) if ack_resource_meta: arn = ack_resource_meta.get("arn", None) if arn is not None: if submission_ack_printed: resource_consumed_message = ( f"Created Sagemaker {self.spaced_out_resource_name} with ARN: {arn}" if not self.resource_upgrade else f"Submitting update for Sagemaker {self.spaced_out_resource_name} with ARN: {arn}" ) logging.info(resource_consumed_message) return arn return None @abstractmethod def _get_job_status(self) -> SageMakerJobStatus: """Waits for the current job to complete. Returns: SageMakerJobStatus: A status object. """ pass @abstractmethod def _get_upgrade_status(self) -> SageMakerJobStatus: """Waits for the resource upgrade to complete Returns: SageMakerJobStatus: A status object. """ pass def is_update_consumed_by_controller(self): """Check if update has been consumed by the controller, in this case it is done by checking whether """ current_resource = self._get_resource() current_status = current_resource.get("status", None) ## Python == is deep equal between dicts. if current_status == self.initial_status: return False return True def _get_resource(self): """Get the custom resource detail similar to: kubectl describe trainingjob JOB_NAME -n NAMESPACE. Returns: None or object: None if the resource doesnt exist in server, otherwise the custom object. """ _api_client = self._get_k8s_api_client() _api = client.CustomObjectsApi(_api_client) if self.namespace is None: job_description = _api.get_cluster_custom_object( self.group.lower(), self.version.lower(), self.plural.lower(), self.job_name.lower(), ) else: job_description = _api.get_namespaced_custom_object( self.group.lower(), self.version.lower(), self.namespace.lower(), # "default", self.plural.lower(), self.job_name.lower(), ) return job_description @abstractmethod def _create_job_request( self, inputs: SageMakerComponentCommonInputs, outputs: SageMakerComponentBaseOutputs, ) -> Dict: """Creates the ACK custom object. Args: inputs: A populated list of user inputs. outputs: An unpopulated list of component output variables. Returns: dict: A dictionary object representing the custom object. """ pass def _create_job_yaml( self, inputs: SageMakerComponentCommonInputs, outputs: SageMakerComponentBaseOutputs, ) -> Dict: """Creates the ACK request object to execute the component. Args: inputs: A populated list of user inputs. outputs: An unpopulated list of component output variables. Returns: dict: A dictionary object representing the request. """ with open(self.job_request_outline_location) as job_request_outline: job_request_dict = yaml.load(job_request_outline, Loader=yaml.FullLoader) job_request_spec = job_request_dict["spec"] # populate meta data job_request_dict["metadata"]["name"] = self.job_name job_request_dict["metadata"]["annotations"][ "services.k8s.aws/region" ] = getattr(inputs, "region") # populate spec from inputs for para in vars(inputs): camel_para = snake_to_camel(para) if camel_para in job_request_spec: value = getattr(inputs, para) if value not in [{}, []]: job_request_spec[camel_para] = value # clean up empty fields in job_request_spec filtered = {k: v for k, v in job_request_spec.items() if v is not None} job_request_spec.clear() job_request_spec.update(filtered) job_request_dict["spec"] = job_request_spec logging.info(f"Custom resource: {json.dumps(job_request_dict, indent=2)}") return job_request_dict @abstractmethod def _submit_job_request(self, request: Dict) -> Dict: """Submits a pre-defined request object to SageMaker. The `request` argument should be provided as the result of the `_create_job_request` method. Args: request: A request object to execute the component. Returns: dict: The job object that was created. Raises: Exception: If SageMaker responded with an error during the request. """ pass def _patch_custom_resource(self, custom_resource: dict): """Patch a custom resource in ACK Args: custom_resource: A dictionary object representing the custom object. Returns: dict: The job object that was patched """ _api_client = self._get_k8s_api_client() _api = client.CustomObjectsApi(_api_client) if self.namespace is None: return _api.patch_cluster_custom_object( self.group.lower(), self.version.lower(), self.plural.lower(), self.job_name.lower(), custom_resource, ) return _api.patch_namespaced_custom_object( self.group.lower(), self.version.lower(), self.namespace.lower(), self.plural.lower(), self.job_name.lower(), custom_resource, ) def _create_custom_resource(self, custom_resource: dict): """Submit a custom_resource to the ACK cluster. Args: custom_resource: A dictionary object representing the custom object. """ _api_client = self._get_k8s_api_client() _api = client.CustomObjectsApi(_api_client) if self.namespace is None: return _api.create_cluster_custom_object( self.group.lower(), self.version.lower(), self.plural.lower(), custom_resource, ) return _api.create_namespaced_custom_object( self.group.lower(), self.version.lower(), self.namespace.lower(), self.plural.lower(), custom_resource, ) def _wait_resource_consumed_by_controller( self, wait_periods, period_length, ): """Wait for the custom resource to be consumed by the controller. Args: wait_periods: The number of times to wait for the resource to be consumed. period_length: The length of time to wait between polling. """ if not self._get_resource_exists(): logging.error( f"Resource %s does not exist", (self.job_name), ) return None for _ in range(wait_periods): resource = self._get_resource() if "status" in resource: return resource sleep(period_length) logging.error( f"Wait for resource %s to be consumed by controller timed out", (self.job_name), ) return None def _get_resource_exists(self) -> bool: """Check if the custom resource exists. Returns: bool: True if the resource exists, False otherwise. """ try: return self._get_resource() is not None except ApiException: return False def _create_resource( self, cr_spec: object, wait_periods=6, period_length=10, ): """Create a resource from the spec and wait to be consumed by controller. Args: cr_spec: A dictionary object representing the custom object. wait_periods: The number of times to wait for the resource to be created. period_length: The length of time to wait between polling. """ resource = self._create_custom_resource(cr_spec) resource = self._wait_resource_consumed_by_controller( wait_periods, period_length ) if resource is None: logging.error(f"Resource {self.job_name} is not created.") logging.error( f"Possible reason: ACK controller may not have been configured properly." ) raise Exception(f"Resource {self.job_name} is not created.") else: logging.info(f"Created custom resource with name: {self.job_name}") return resource def _check_resource_conditions(self): """Check the status of the custom resource. * loop through all conditions * if recoverable and condition set to true, print out message and return true (let outside polling loop goes on forever and let user decide if should stop) * if terminal and condition set up true, print out message and return false * Returns None if there are no error conditions. """ status_conditions = self._get_resource()["status"]["conditions"] for condition in status_conditions: condition_type = condition["type"] condition_status = condition["status"] condition_message = condition.get("message", "No error message found.") # If the controller has not consumed the update, any existing error will not representative of the new state. if self.resource_upgrade and not self.is_update_consumed_by_controller(): continue if condition_type == "ACK.Terminal" and condition_status == "True": logging.error(json.dumps(condition, indent=2)) logging.error( "Terminating the run because resource encountered a Terminal condition. Please describe the resource for further debugging and retry with correct parameters." ) return False if condition_type == "ACK.Recoverable" and condition_status == "True": # ACK requeue errors are not real errors. if is_ack_requeue_error(condition_message): continue logging.error(json.dumps(condition, indent=2)) if "ValidationException" in condition_message: logging.error( "Terminating the run because resource encountered a Validation Exception. Please describe the resource for further debugging and retry with correct parameters." ) return False elif "InvalidParameter" in condition_message: logging.error( "Terminating the run because resource encountered InvalidParameters. Please describe the resource for further debugging and retry with correct parameters." ) return False else: logging.error( "Waiting for error to be resolved . . . Please fix the error or terminate the job and retry with correct parameters if this is not a transient error" ) return True return None def _get_resource_synced_status(self, ack_statuses: Dict): """ Retrieve the resource sync status """ conditions = ack_statuses.get("conditions", None) # Conditions has to be there if conditions == None: return None for condition in conditions: if condition["type"] == "ACK.ResourceSynced": if condition["status"] == "True": return True else: return False return False @abstractmethod def _after_submit_job_request( self, job: object, request: Dict, inputs: SageMakerComponentCommonInputs, outputs: SageMakerComponentBaseOutputs, ): """Handles any events required after submitting a job to SageMaker. Args: job: The job returned after creation. request: The request submitted prior. inputs: A populated list of user inputs. outputs: An unpopulated list of component output variables. """ pass @abstractmethod def _after_job_complete( self, job: object, request: Dict, inputs: SageMakerComponentCommonInputs, outputs: SageMakerComponentBaseOutputs, ): """Handles any events after the job has been completed. Args: job: The job object that was created. request: The request object used to execute the component. inputs: A populated list of user inputs. outputs: An unpopulated list of component output variables. """ pass @abstractmethod def _on_job_terminated(self): """Handles any SIGTERM events.""" pass def _delete_custom_resource(self): """Delete custom resource from cluster and wait for it to be removed by the server. for wait_periods * period_length seconds. Returns: response, bool: response is APIserver response for the operation. bool is true if resource was removed from the server and false otherwise """ _api_client = self._get_k8s_api_client() _api = client.CustomObjectsApi(_api_client) if self.resource_upgrade: logging.info("Recieved termination signal, stopping component but resource update will still proceed if started. Please rerun the component with the desired configuration to revert the update.") return _response, True logging.info("Recieved termination signal, deleting custom resource %s", (self.job_name)) _response = None if self.namespace is None: _response = _api.delete_cluster_custom_object( self.group.lower(), self.version.lower(), self.plural.lower(), self.job_name.lower(), ) else: _response = _api.delete_namespaced_custom_object( self.group.lower(), self.version.lower(), self.namespace.lower(), self.plural.lower(), self.job_name.lower(), ) return _response, True @staticmethod def _generate_unique_timestamped_id( prefix: str = "", size: int = 4, chars: str = string.ascii_uppercase + string.digits, max_length: int = 32, ) -> str: """Generate a pseudo-random string of characters appended to a timestamp. Format of the ID is as follows: `prefix-YYYYMMDDHHMMSS-unique`. If the length of the total ID exceeds `max_length`, it will be truncated from the beginning (prefix will be trimmed). Args: prefix: A prefix to append to the random suffix. size: The number of unique characters to append to the ID. chars: A list of characters to use in the random suffix. max_length: The maximum length of the generated ID. Returns: string: A pseudo-random string with included timestamp and prefix. """ unique = "".join(random.choice(chars) for _ in range(size)).lower() return f'{prefix}{"-" if prefix else ""}{strftime("%Y%m%d%H%M%S", gmtime())}-{unique}'[ -max_length: ] def _write_all_outputs( self, output_paths: SageMakerComponentBaseOutputs, outputs: SageMakerComponentBaseOutputs, ): """Writes all of the outputs specified by the component to their respective file paths. Args: output_paths: A populated list of output paths. outputs: A populated list of output values. """ for output_key, output_value in outputs.__dict__.items(): if output_value is None: output_value = "N/A" output_path = output_paths.__dict__.get(output_key) if not output_path: logging.error(f"Could not find output path for {output_key}") continue # Encode it if it's a List or Dict (not primitive) encoded_types = (List, Dict) self._write_output( output_path, output_value, json_encode=isinstance(output_value, encoded_types), ) logging.info(f"Wrote output '{output_key}' to '{output_path}'") def _write_output( self, output_path: str, output_value: Any, json_encode: bool = False ): """Write an output value to the associated path, dumping as a JSON object if specified. Args: output_path: The file path of the output. output_value: The output value to write to the file. json_encode: True if the value should be encoded as a JSON object. """ write_value = json.dumps(output_value) if json_encode else output_value Path(output_path).parent.mkdir(parents=True, exist_ok=True) Path(output_path).write_text(write_value) def _is_upgrade(self): """If the resource already exists the component assumes that the user wants to upgrade Returns: Bool: If the resource is being upgraded or not. Raises: Exception """ try: resource = self._get_resource() if resource is None: return False logging.info("Existing resource detected. Starting Update.") except client.exceptions.ApiException as error: if error.status == 404: logging.info("Resource does not exist. Creating a new resource.") return False else: raise error return True