feat(components): Sagemaker V2 Hosting components and tests (#9243)
* Hosting Components and test * update dependency * Regenerating with spec trimming * handle None case * adress pr comments * another way of handling update not supported * test changes * removing unused logic * Staging pr * Added READMEs * Main doc changes * minor edit
This commit is contained in:
parent
c22b40452e
commit
4818e849f8
|
@ -4,6 +4,11 @@ The version of the AWS SageMaker Components is determined by the docker image ta
|
|||
Repository: [Public ECR](https://gallery.ecr.aws/kubeflow-on-aws/aws-sagemaker-kfp-components) or [Dockerhub](https://hub.docker.com/repository/docker/amazon/aws-sagemaker-kfp-components). New releases after v1.1.1 will be using the public ECR repository
|
||||
|
||||
---------------------------------------------
|
||||
** Change log for version 2.2.0 **
|
||||
|
||||
- Introducing SageMaker Hosting components v2. This release includes [Model](./Model/), [EndpointConfig](./EndpointConfig/), [Endpoint](./Endpoint/).
|
||||
> Pull request : [#9243](https://github.com/kubeflow/pipelines/pull/9243)
|
||||
|
||||
**Change log for version 2.1.0**
|
||||
- Adds support for Managed Warm Pool clusters, Instance Groups, Retry Strategy in the Training Job component.
|
||||
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# SageMaker Endpoint Kubeflow Pipelines component v2
|
||||
|
||||
## Overview
|
||||
|
||||
Endpoint is one of the three components(along with EndpointConfig and Model) you would use to create a Hosting deployment on Sagemaker.
|
||||
|
||||
Component to create [SageMaker Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/deploy-model.html) in a Kubeflow Pipelines workflow.
|
||||
|
||||
See the SageMaker Components for Kubeflow Pipelines versions section in [SageMaker Components for Kubeflow Pipelines](https://docs.aws.amazon.com/sagemaker/latest/dg/kubernetes-sagemaker-components-for-kubeflow-pipelines.html#kubeflow-pipeline-components) to learn about the differences between the version 1 and version 2 components.
|
||||
|
||||
|
||||
### Kubeflow Pipelines backend compatibility
|
||||
SageMaker components are currently supported with Kubeflow pipelines backend v1. This means, you will have to use KFP sdk 1.8.x to create your pipelines.
|
||||
|
||||
## Getting Started
|
||||
|
||||
Follow [this guide](https://github.com/kubeflow/pipelines/tree/master/samples/contrib/aws-samples#prerequisites) to setup the prerequisites for Endpoint depending on your deployment.
|
||||
|
||||
## Inputs Parameters
|
||||
Find the high level component input parameters and their description in the [component's input specification](./component.yaml). The parameters with `JsonObject` or `JsonArray` type inputs have nested fields, you will have to refer to the [Endpoint CRD specification](https://aws-controllers-k8s.github.io/community/reference/sagemaker/v1alpha1/endpoint/) for the respective structure and pass the input in JSON format.
|
||||
|
||||
A quick way to see the converted JSON style input is to copy the [sample Endpoint spec](https://aws-controllers-k8s.github.io/community/reference/sagemaker/v1alpha1/endpoint/#spec) and convert it to JSON using a YAML to JSON converter like [this website](https://jsonformatter.org/yaml-to-json).
|
||||
|
||||
For e.g. the `endpointConfigName` in the `Endpoint` CRD looks like:
|
||||
|
||||
```
|
||||
endpointConfigName: string
|
||||
```
|
||||
|
||||
The `endpointConfigName` input for the component would be:
|
||||
|
||||
```
|
||||
endpointConfigName = "my-config"
|
||||
```
|
||||
|
||||
You might also want to look at the [Endpoint API reference](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpoint.html) for a detailed explaination of parameters.
|
||||
|
||||
## References
|
||||
- [Inference on SageMaker](https://docs.aws.amazon.com/sagemaker/latest/dg/deploy-model.html)
|
||||
- [Endpoint CRD specification](https://aws-controllers-k8s.github.io/community/reference/sagemaker/v1alpha1/endpoint/)
|
||||
- [Endpoint API reference](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpoint.html)
|
|
@ -0,0 +1,106 @@
|
|||
name: "Sagemaker - Endpoint"
|
||||
description: Create Endpoint
|
||||
inputs:
|
||||
- {
|
||||
name: region,
|
||||
type: String,
|
||||
description: "The region to use for the training job",
|
||||
}
|
||||
###########################GENERATED SECTION BELOW############################
|
||||
|
||||
- {
|
||||
name: deployment_config,
|
||||
type: JsonObject,
|
||||
default: '{}',
|
||||
description: "The deployment configuration for an endpoint, which contains the desired deployment strategy and rollback configurations.",
|
||||
}
|
||||
- {
|
||||
name: endpoint_config_name,
|
||||
type: String,
|
||||
default: '',
|
||||
description: "The name of an endpoint configuration.",
|
||||
}
|
||||
- {
|
||||
name: endpoint_name,
|
||||
type: String,
|
||||
default: '',
|
||||
description: "The name of the endpoint.",
|
||||
}
|
||||
- {
|
||||
name: tags,
|
||||
type: JsonArray,
|
||||
default: '[]',
|
||||
description: "An array of key-value pairs.",
|
||||
}
|
||||
###########################GENERATED SECTION ABOVE############################
|
||||
|
||||
outputs:
|
||||
###########################GENERATED SECTION BELOW############################
|
||||
|
||||
- {
|
||||
name: ack_resource_metadata,
|
||||
type: JsonObject,
|
||||
description: "All CRs managed by ACK have a common `Status.",
|
||||
}
|
||||
- {
|
||||
name: conditions,
|
||||
type: JsonArray,
|
||||
description: "All CRS managed by ACK have a common `Status.",
|
||||
}
|
||||
- {
|
||||
name: creation_time,
|
||||
type: String,
|
||||
description: "A timestamp that shows when the endpoint was created.",
|
||||
}
|
||||
- {
|
||||
name: endpoint_status,
|
||||
type: String,
|
||||
description: "The status of the endpoint.",
|
||||
}
|
||||
- {
|
||||
name: failure_reason,
|
||||
type: String,
|
||||
description: "If the status of the endpoint is Failed, the reason why it failed.",
|
||||
}
|
||||
- {
|
||||
name: last_modified_time,
|
||||
type: String,
|
||||
description: "A timestamp that shows when the endpoint was last modified.",
|
||||
}
|
||||
- {
|
||||
name: pending_deployment_summary,
|
||||
type: JsonObject,
|
||||
description: "Returns the summary of an in-progress deployment.",
|
||||
}
|
||||
- {
|
||||
name: production_variants,
|
||||
type: JsonArray,
|
||||
description: "An array of ProductionVariantSummary objects, one for each model hosted behind this endpoint.",
|
||||
}
|
||||
- {
|
||||
name: sagemaker_resource_name,
|
||||
type: String,
|
||||
description: "Resource name on Sagemaker",
|
||||
}
|
||||
###########################GENERATED SECTION ABOVE############################
|
||||
|
||||
implementation:
|
||||
container:
|
||||
image: public.ecr.aws/kubeflow-on-aws/aws-sagemaker-kfp-components:2.2.0
|
||||
command: [python3]
|
||||
args:
|
||||
- Endpoint/src/Endpoint_component.py
|
||||
- --region
|
||||
- { inputValue: region }
|
||||
###########################GENERATED SECTION BELOW############################
|
||||
- --deployment_config
|
||||
- { inputValue: deployment_config }
|
||||
- --endpoint_config_name
|
||||
- { inputValue: endpoint_config_name }
|
||||
- --endpoint_name
|
||||
- { inputValue: endpoint_name }
|
||||
- --tags
|
||||
- { inputValue: tags }
|
||||
|
||||
###########################GENERATED SECTION ABOVE############################
|
||||
|
|
@ -0,0 +1,197 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
import json
|
||||
|
||||
from Endpoint.src.Endpoint_spec import (
|
||||
SageMakerEndpointInputs,
|
||||
SageMakerEndpointOutputs,
|
||||
SageMakerEndpointSpec,
|
||||
)
|
||||
from commonv2.sagemaker_component import (
|
||||
SageMakerComponent,
|
||||
ComponentMetadata,
|
||||
SageMakerJobStatus,
|
||||
)
|
||||
from commonv2 import snake_to_camel
|
||||
|
||||
|
||||
@ComponentMetadata(
|
||||
name="SageMaker - Endpoint",
|
||||
description="",
|
||||
spec=SageMakerEndpointSpec,
|
||||
)
|
||||
class SageMakerEndpointComponent(SageMakerComponent):
|
||||
|
||||
"""SageMaker component for Endpoint."""
|
||||
|
||||
def Do(self, spec: SageMakerEndpointSpec):
|
||||
|
||||
self.namespace = self._get_current_namespace()
|
||||
logging.info("Current namespace: " + self.namespace)
|
||||
|
||||
############GENERATED SECTION BELOW############
|
||||
|
||||
self.job_name = spec.inputs.endpoint_name = (
|
||||
spec.inputs.endpoint_name # todo: need customize
|
||||
if spec.inputs.endpoint_name # todo: need customize
|
||||
else SageMakerComponent._generate_unique_timestamped_id(prefix="endpoint")
|
||||
)
|
||||
|
||||
self.group = "sagemaker.services.k8s.aws"
|
||||
self.version = "v1alpha1"
|
||||
self.plural = "endpoints"
|
||||
self.spaced_out_resource_name = "Endpoint"
|
||||
|
||||
self.job_request_outline_location = "Endpoint/src/Endpoint_request.yaml.tpl"
|
||||
self.job_request_location = "Endpoint/src/Endpoint_request.yaml"
|
||||
self.update_supported = True
|
||||
############GENERATED SECTION ABOVE############
|
||||
|
||||
super().Do(spec.inputs, spec.outputs, spec.output_paths)
|
||||
|
||||
def _create_job_request(
|
||||
self,
|
||||
inputs: SageMakerEndpointInputs,
|
||||
outputs: SageMakerEndpointOutputs,
|
||||
) -> Dict:
|
||||
|
||||
return super()._create_job_yaml(inputs, outputs)
|
||||
|
||||
def _submit_job_request(self, request: Dict) -> object:
|
||||
|
||||
if self.resource_upgrade:
|
||||
ack_resource = self._get_resource()
|
||||
self.initial_status = ack_resource.get("status", None)
|
||||
return super()._patch_custom_resource(request)
|
||||
else:
|
||||
return super()._create_resource(request, 12, 15)
|
||||
|
||||
def _on_job_terminated(self):
|
||||
super()._delete_custom_resource()
|
||||
|
||||
def _after_submit_job_request(
|
||||
self,
|
||||
job: object,
|
||||
request: Dict,
|
||||
inputs: SageMakerEndpointInputs,
|
||||
outputs: SageMakerEndpointOutputs,
|
||||
):
|
||||
logging.info(
|
||||
"Endpoint in Sagemaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/endpoints/{}".format(
|
||||
inputs.region, inputs.region, self.job_name
|
||||
)
|
||||
)
|
||||
logging.info(
|
||||
"CloudWatch logs: https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/Endpoints/{}".format(
|
||||
inputs.region, inputs.region, self.job_name
|
||||
)
|
||||
)
|
||||
|
||||
def _get_job_status(self):
|
||||
|
||||
ack_resource = super()._get_resource()
|
||||
resourceSynced = self._get_resource_synced_status(ack_resource["status"])
|
||||
sm_job_status = ack_resource["status"]["endpointStatus"]
|
||||
if not resourceSynced:
|
||||
return SageMakerJobStatus(
|
||||
is_completed=False,
|
||||
raw_status=sm_job_status,
|
||||
)
|
||||
|
||||
if sm_job_status == "InService":
|
||||
return SageMakerJobStatus(
|
||||
is_completed=True, has_error=False, raw_status="InService"
|
||||
)
|
||||
|
||||
if sm_job_status == "Failed":
|
||||
message = ack_resource["status"]["failureReason"]
|
||||
return SageMakerJobStatus(
|
||||
is_completed=True,
|
||||
has_error=True,
|
||||
error_message=message,
|
||||
raw_status=sm_job_status,
|
||||
)
|
||||
|
||||
if sm_job_status == "OutOfService":
|
||||
message = "Sagemaker endpoint is Out of Service"
|
||||
return SageMakerJobStatus(
|
||||
is_completed=True,
|
||||
has_error=True,
|
||||
error_message=message,
|
||||
raw_status=sm_job_status,
|
||||
)
|
||||
return SageMakerJobStatus(is_completed=False, raw_status=sm_job_status)
|
||||
|
||||
def _get_upgrade_status(self):
|
||||
|
||||
return self._get_job_status()
|
||||
|
||||
def _after_job_complete(
|
||||
self,
|
||||
job: object,
|
||||
request: Dict,
|
||||
inputs: SageMakerEndpointInputs,
|
||||
outputs: SageMakerEndpointOutputs,
|
||||
):
|
||||
# prepare component outputs (defined in the spec)
|
||||
|
||||
ack_statuses = super()._get_resource()["status"]
|
||||
|
||||
############GENERATED SECTION BELOW############
|
||||
|
||||
outputs.ack_resource_metadata = str(
|
||||
ack_statuses["ackResourceMetadata"]
|
||||
if "ackResourceMetadata" in ack_statuses
|
||||
else None
|
||||
)
|
||||
outputs.conditions = str(
|
||||
ack_statuses["conditions"] if "conditions" in ack_statuses else None
|
||||
)
|
||||
outputs.creation_time = str(
|
||||
ack_statuses["creationTime"] if "creationTime" in ack_statuses else None
|
||||
)
|
||||
outputs.endpoint_status = str(
|
||||
ack_statuses["endpointStatus"] if "endpointStatus" in ack_statuses else None
|
||||
)
|
||||
outputs.failure_reason = str(
|
||||
ack_statuses["failureReason"] if "failureReason" in ack_statuses else None
|
||||
)
|
||||
outputs.last_modified_time = str(
|
||||
ack_statuses["lastModifiedTime"]
|
||||
if "lastModifiedTime" in ack_statuses
|
||||
else None
|
||||
)
|
||||
outputs.pending_deployment_summary = str(
|
||||
ack_statuses["pendingDeploymentSummary"]
|
||||
if "pendingDeploymentSummary" in ack_statuses
|
||||
else None
|
||||
)
|
||||
outputs.production_variants = str(
|
||||
ack_statuses["productionVariants"]
|
||||
if "productionVariants" in ack_statuses
|
||||
else None
|
||||
)
|
||||
outputs.sagemaker_resource_name = self.job_name
|
||||
|
||||
############GENERATED SECTION ABOVE############
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
spec = SageMakerEndpointSpec(sys.argv[1:])
|
||||
|
||||
component = SageMakerEndpointComponent()
|
||||
component.Do(spec)
|
|
@ -0,0 +1,11 @@
|
|||
apiVersion: sagemaker.services.k8s.aws/v1alpha1
|
||||
kind: Endpoint
|
||||
metadata:
|
||||
name:
|
||||
annotations:
|
||||
services.k8s.aws/region:
|
||||
spec:
|
||||
deploymentConfig:
|
||||
endpointConfigName:
|
||||
endpointName:
|
||||
tags:
|
|
@ -0,0 +1,126 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Specification for the SageMaker - Endpoint"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import List
|
||||
from commonv2.sagemaker_component_spec import (
|
||||
SageMakerComponentSpec,
|
||||
SageMakerComponentBaseOutputs,
|
||||
)
|
||||
from commonv2.spec_input_parsers import SpecInputParsers
|
||||
from commonv2.common_inputs import (
|
||||
COMMON_INPUTS,
|
||||
SageMakerComponentCommonInputs,
|
||||
SageMakerComponentInput as Input,
|
||||
SageMakerComponentOutput as Output,
|
||||
SageMakerComponentInputValidator as InputValidator,
|
||||
SageMakerComponentOutputValidator as OutputValidator,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=False)
|
||||
class SageMakerEndpointInputs(SageMakerComponentCommonInputs):
|
||||
"""Defines the set of inputs for the Endpoint component."""
|
||||
|
||||
deployment_config: Input
|
||||
endpoint_config_name: Input
|
||||
endpoint_name: Input
|
||||
tags: Input
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageMakerEndpointOutputs(SageMakerComponentBaseOutputs):
|
||||
"""Defines the set of outputs for the Endpoint component."""
|
||||
|
||||
ack_resource_metadata: Output
|
||||
conditions: Output
|
||||
creation_time: Output
|
||||
endpoint_status: Output
|
||||
failure_reason: Output
|
||||
last_modified_time: Output
|
||||
pending_deployment_summary: Output
|
||||
production_variants: Output
|
||||
sagemaker_resource_name: Output
|
||||
|
||||
|
||||
class SageMakerEndpointSpec(
|
||||
SageMakerComponentSpec[SageMakerEndpointInputs, SageMakerEndpointOutputs]
|
||||
):
|
||||
INPUTS: SageMakerEndpointInputs = SageMakerEndpointInputs(
|
||||
deployment_config=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="The deployment configuration for an endpoint, which contains the desired deployment strategy and rollback configurations.",
|
||||
required=False,
|
||||
),
|
||||
endpoint_config_name=InputValidator(
|
||||
input_type=str,
|
||||
description="The name of an endpoint configuration.",
|
||||
required=True,
|
||||
),
|
||||
endpoint_name=InputValidator(
|
||||
input_type=str, description="The name of the endpoint.", required=True
|
||||
),
|
||||
tags=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_list,
|
||||
description="An array of key-value pairs.",
|
||||
required=False,
|
||||
),
|
||||
**vars(COMMON_INPUTS),
|
||||
)
|
||||
|
||||
OUTPUTS = SageMakerEndpointOutputs(
|
||||
ack_resource_metadata=OutputValidator(
|
||||
description="All CRs managed by ACK have a common `Status.",
|
||||
),
|
||||
conditions=OutputValidator(
|
||||
description="All CRS managed by ACK have a common `Status.",
|
||||
),
|
||||
creation_time=OutputValidator(
|
||||
description="A timestamp that shows when the endpoint was created.",
|
||||
),
|
||||
endpoint_status=OutputValidator(
|
||||
description="The status of the endpoint.",
|
||||
),
|
||||
failure_reason=OutputValidator(
|
||||
description="If the status of the endpoint is Failed, the reason why it failed.",
|
||||
),
|
||||
last_modified_time=OutputValidator(
|
||||
description="A timestamp that shows when the endpoint was last modified.",
|
||||
),
|
||||
pending_deployment_summary=OutputValidator(
|
||||
description="Returns the summary of an in-progress deployment.",
|
||||
),
|
||||
production_variants=OutputValidator(
|
||||
description="An array of ProductionVariantSummary objects, one for each model hosted behind this endpoint.",
|
||||
),
|
||||
sagemaker_resource_name=OutputValidator(
|
||||
description="Resource name on Sagemaker",
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, arguments: List[str]):
|
||||
super().__init__(arguments, SageMakerEndpointInputs, SageMakerEndpointOutputs)
|
||||
|
||||
@property
|
||||
def inputs(self) -> SageMakerEndpointInputs:
|
||||
return self._inputs
|
||||
|
||||
@property
|
||||
def outputs(self) -> SageMakerEndpointOutputs:
|
||||
return self._outputs
|
||||
|
||||
@property
|
||||
def output_paths(self) -> SageMakerEndpointOutputs:
|
||||
return self._output_paths
|
|
@ -0,0 +1,64 @@
|
|||
# SageMaker Endpoint Config Kubeflow Pipelines component v2
|
||||
|
||||
## Overview
|
||||
|
||||
EndpointConfig is one of the three components(along with Endpoint and Model) you would use to create a Hosting deployment on Sagemaker.
|
||||
|
||||
Component to create [SageMaker Endpoint Configurations](https://docs.aws.amazon.com/sagemaker/latest/dg/deploy-model.html) in a Kubeflow Pipelines workflow.
|
||||
|
||||
See the SageMaker Components for Kubeflow Pipelines versions section in [SageMaker Components for Kubeflow Pipelines](https://docs.aws.amazon.com/sagemaker/latest/dg/kubernetes-sagemaker-components-for-kubeflow-pipelines.html#kubeflow-pipeline-components) to learn about the differences between the version 1 and version 2 components.
|
||||
|
||||
### Kubeflow Pipelines backend compatibility
|
||||
SageMaker components are currently supported with Kubeflow pipelines backend v1. This means, you will have to use KFP sdk 1.8.x to create your pipelines.
|
||||
|
||||
## Getting Started
|
||||
|
||||
Follow [this guide](https://github.com/kubeflow/pipelines/tree/master/samples/contrib/aws-samples#prerequisites) to setup the prerequisites for Endpoint Config depending on your deployment.
|
||||
|
||||
## Inputs Parameters
|
||||
Find the high level component input parameters and their description in the [component's input specification](./component.yaml). The parameters with `JsonObject` or `JsonArray` type inputs have nested fields, you will have to refer to the [EndpointConfig CRD specification](https://aws-controllers-k8s.github.io/community/reference/sagemaker/v1alpha1/endpointconfig/) for the respective structure and pass the input in JSON format.
|
||||
|
||||
A quick way to see the converted JSON style input is to copy the [sample EndpointConfig spec](https://aws-controllers-k8s.github.io/community/reference/sagemaker/v1alpha1/endpointconfig/#spec) and convert it to JSON using a YAML to JSON converter like [this website](https://jsonformatter.org/yaml-to-json).
|
||||
|
||||
For e.g. the `productionVariants` in the `EndpointConfig` CRD looks like:
|
||||
|
||||
```
|
||||
productionVariants:
|
||||
- acceleratorType: string
|
||||
containerStartupHealthCheckTimeoutInSeconds: integer
|
||||
coreDumpConfig:
|
||||
destinationS3URI: string
|
||||
kmsKeyID: string
|
||||
enableSSMAccess: boolean
|
||||
initialInstanceCount: integer
|
||||
initialVariantWeight: number
|
||||
instanceType: string
|
||||
modelDataDownloadTimeoutInSeconds: integer
|
||||
modelName: string
|
||||
serverlessConfig:
|
||||
maxConcurrency: integer
|
||||
memorySizeInMB: integer
|
||||
variantName: string
|
||||
volumeSizeInGB: integer
|
||||
```
|
||||
|
||||
The `productionVariants` input for the component would be (not all parameters are included):
|
||||
|
||||
```
|
||||
productionVariants = [
|
||||
{
|
||||
"initialInstanceCount": 1,
|
||||
"instanceType": "ml.m5.large",
|
||||
"modelName": "<my model>",
|
||||
"variantName": "<my variant>",
|
||||
"volumeSizeInGB": 10
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
You might also want to look at the [EndpointConfig API reference](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html) for a detailed explaination of parameters.
|
||||
|
||||
## References
|
||||
- [Inference on SageMaker](https://docs.aws.amazon.com/sagemaker/latest/dg/deploy-model.html)
|
||||
- [EndpointConfig CRD specification](https://aws-controllers-k8s.github.io/community/reference/sagemaker/v1alpha1/endpointconfig/)
|
||||
- [EndpointConfig API reference](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html)
|
|
@ -0,0 +1,92 @@
|
|||
name: "Sagemaker - EndpointConfig"
|
||||
description: Create EndpointConfig
|
||||
inputs:
|
||||
- {
|
||||
name: region,
|
||||
type: String,
|
||||
description: "The region to use for the training job",
|
||||
}
|
||||
###########################GENERATED SECTION BELOW############################
|
||||
|
||||
- {
|
||||
name: async_inference_config,
|
||||
type: JsonObject,
|
||||
default: '{}',
|
||||
description: "Specifies configuration for how an endpoint performs asynchronous inference.",
|
||||
}
|
||||
- {
|
||||
name: data_capture_config,
|
||||
type: JsonObject,
|
||||
default: '{}',
|
||||
description: "Configuration to control how SageMaker captures inference data.",
|
||||
}
|
||||
- {
|
||||
name: endpoint_config_name,
|
||||
type: String,
|
||||
default: '',
|
||||
description: "The name of the endpoint configuration.",
|
||||
}
|
||||
- {
|
||||
name: kms_key_id,
|
||||
type: String,
|
||||
default: '',
|
||||
description: "The Amazon Resource Name (ARN) of a Amazon Web Services Key Management Service key that SageMaker uses to encrypt data on the storage volume attached to the ML compute instance that hosts the endpoint.",
|
||||
}
|
||||
- {
|
||||
name: production_variants,
|
||||
type: JsonArray,
|
||||
default: '[]',
|
||||
description: "An array of ProductionVariant objects, one for each model that you want to host at this endpoint.",
|
||||
}
|
||||
- {
|
||||
name: tags,
|
||||
type: JsonArray,
|
||||
default: '[]',
|
||||
description: "An array of key-value pairs.",
|
||||
}
|
||||
###########################GENERATED SECTION ABOVE############################
|
||||
|
||||
outputs:
|
||||
###########################GENERATED SECTION BELOW############################
|
||||
|
||||
- {
|
||||
name: ack_resource_metadata,
|
||||
type: JsonObject,
|
||||
description: "All CRs managed by ACK have a common `Status.",
|
||||
}
|
||||
- {
|
||||
name: conditions,
|
||||
type: JsonArray,
|
||||
description: "All CRS managed by ACK have a common `Status.",
|
||||
}
|
||||
- {
|
||||
name: sagemaker_resource_name,
|
||||
type: String,
|
||||
description: "Resource name on Sagemaker",
|
||||
}
|
||||
###########################GENERATED SECTION ABOVE############################
|
||||
|
||||
implementation:
|
||||
container:
|
||||
image: public.ecr.aws/kubeflow-on-aws/aws-sagemaker-kfp-components:2.2.0
|
||||
command: [python3]
|
||||
args:
|
||||
- EndpointConfig/src/EndpointConfig_component.py
|
||||
- --region
|
||||
- { inputValue: region }
|
||||
###########################GENERATED SECTION BELOW############################
|
||||
- --async_inference_config
|
||||
- { inputValue: async_inference_config }
|
||||
- --data_capture_config
|
||||
- { inputValue: data_capture_config }
|
||||
- --endpoint_config_name
|
||||
- { inputValue: endpoint_config_name }
|
||||
- --kms_key_id
|
||||
- { inputValue: kms_key_id }
|
||||
- --production_variants
|
||||
- { inputValue: production_variants }
|
||||
- --tags
|
||||
- { inputValue: tags }
|
||||
|
||||
###########################GENERATED SECTION ABOVE############################
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
import json
|
||||
|
||||
from EndpointConfig.src.EndpointConfig_spec import (
|
||||
SageMakerEndpointConfigInputs,
|
||||
SageMakerEndpointConfigOutputs,
|
||||
SageMakerEndpointConfigSpec,
|
||||
)
|
||||
from commonv2.sagemaker_component import (
|
||||
SageMakerComponent,
|
||||
ComponentMetadata,
|
||||
SageMakerJobStatus,
|
||||
)
|
||||
from commonv2 import snake_to_camel
|
||||
|
||||
|
||||
@ComponentMetadata(
|
||||
name="SageMaker - EndpointConfig",
|
||||
description="",
|
||||
spec=SageMakerEndpointConfigSpec,
|
||||
)
|
||||
class SageMakerEndpointConfigComponent(SageMakerComponent):
|
||||
|
||||
"""SageMaker component for EndpointConfig."""
|
||||
|
||||
def Do(self, spec: SageMakerEndpointConfigSpec):
|
||||
|
||||
self.namespace = self._get_current_namespace()
|
||||
logging.info("Current namespace: " + self.namespace)
|
||||
|
||||
############GENERATED SECTION BELOW############
|
||||
|
||||
self.job_name = spec.inputs.endpoint_config_name = (
|
||||
spec.inputs.endpoint_config_name # todo: need customize
|
||||
if spec.inputs.endpoint_config_name # todo: need customize
|
||||
else SageMakerComponent._generate_unique_timestamped_id(
|
||||
prefix="endpoint-config"
|
||||
)
|
||||
)
|
||||
|
||||
self.group = "sagemaker.services.k8s.aws"
|
||||
self.version = "v1alpha1"
|
||||
self.plural = "endpointconfigs"
|
||||
self.spaced_out_resource_name = "Endpoint Config"
|
||||
|
||||
self.job_request_outline_location = (
|
||||
"EndpointConfig/src/EndpointConfig_request.yaml.tpl"
|
||||
)
|
||||
self.job_request_location = "EndpointConfig/src/EndpointConfig_request.yaml"
|
||||
self.update_supported = False
|
||||
############GENERATED SECTION ABOVE############
|
||||
|
||||
super().Do(spec.inputs, spec.outputs, spec.output_paths)
|
||||
|
||||
def _create_job_request(
|
||||
self,
|
||||
inputs: SageMakerEndpointConfigInputs,
|
||||
outputs: SageMakerEndpointConfigOutputs,
|
||||
) -> Dict:
|
||||
|
||||
return super()._create_job_yaml(inputs, outputs)
|
||||
|
||||
def _submit_job_request(self, request: Dict) -> object:
|
||||
|
||||
return super()._create_resource(request, 12, 15)
|
||||
|
||||
def _on_job_terminated(self):
|
||||
super()._delete_custom_resource()
|
||||
|
||||
def _after_submit_job_request(
|
||||
self,
|
||||
job: object,
|
||||
request: Dict,
|
||||
inputs: SageMakerEndpointConfigInputs,
|
||||
outputs: SageMakerEndpointConfigOutputs,
|
||||
):
|
||||
logging.info(
|
||||
"Endpoint Config in Sagemaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/endpointConfig/{}".format(
|
||||
inputs.region, inputs.region, self.job_name
|
||||
)
|
||||
)
|
||||
|
||||
def _get_job_status(self):
|
||||
return SageMakerJobStatus(is_completed=True, raw_status="Completed")
|
||||
|
||||
def _get_upgrade_status(self):
|
||||
|
||||
return self._get_job_status()
|
||||
|
||||
def _after_job_complete(
|
||||
self,
|
||||
job: object,
|
||||
request: Dict,
|
||||
inputs: SageMakerEndpointConfigInputs,
|
||||
outputs: SageMakerEndpointConfigOutputs,
|
||||
):
|
||||
# prepare component outputs (defined in the spec)
|
||||
|
||||
ack_statuses = super()._get_resource()["status"]
|
||||
|
||||
############GENERATED SECTION BELOW############
|
||||
|
||||
outputs.ack_resource_metadata = str(
|
||||
ack_statuses["ackResourceMetadata"]
|
||||
if "ackResourceMetadata" in ack_statuses
|
||||
else None
|
||||
)
|
||||
outputs.conditions = str(
|
||||
ack_statuses["conditions"] if "conditions" in ack_statuses else None
|
||||
)
|
||||
outputs.sagemaker_resource_name = self.job_name
|
||||
|
||||
############GENERATED SECTION ABOVE############
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
spec = SageMakerEndpointConfigSpec(sys.argv[1:])
|
||||
|
||||
component = SageMakerEndpointConfigComponent()
|
||||
component.Do(spec)
|
|
@ -0,0 +1,13 @@
|
|||
apiVersion: sagemaker.services.k8s.aws/v1alpha1
|
||||
kind: EndpointConfig
|
||||
metadata:
|
||||
name:
|
||||
annotations:
|
||||
services.k8s.aws/region:
|
||||
spec:
|
||||
asyncInferenceConfig:
|
||||
dataCaptureConfig:
|
||||
endpointConfigName:
|
||||
kmsKeyID:
|
||||
productionVariants:
|
||||
tags:
|
|
@ -0,0 +1,120 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Specification for the SageMaker - EndpointConfig"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import List
|
||||
from commonv2.sagemaker_component_spec import (
|
||||
SageMakerComponentSpec,
|
||||
SageMakerComponentBaseOutputs,
|
||||
)
|
||||
from commonv2.spec_input_parsers import SpecInputParsers
|
||||
from commonv2.common_inputs import (
|
||||
COMMON_INPUTS,
|
||||
SageMakerComponentCommonInputs,
|
||||
SageMakerComponentInput as Input,
|
||||
SageMakerComponentOutput as Output,
|
||||
SageMakerComponentInputValidator as InputValidator,
|
||||
SageMakerComponentOutputValidator as OutputValidator,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=False)
|
||||
class SageMakerEndpointConfigInputs(SageMakerComponentCommonInputs):
|
||||
"""Defines the set of inputs for the EndpointConfig component."""
|
||||
|
||||
async_inference_config: Input
|
||||
data_capture_config: Input
|
||||
endpoint_config_name: Input
|
||||
kms_key_id: Input
|
||||
production_variants: Input
|
||||
tags: Input
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageMakerEndpointConfigOutputs(SageMakerComponentBaseOutputs):
|
||||
"""Defines the set of outputs for the EndpointConfig component."""
|
||||
|
||||
ack_resource_metadata: Output
|
||||
conditions: Output
|
||||
sagemaker_resource_name: Output
|
||||
|
||||
|
||||
class SageMakerEndpointConfigSpec(
|
||||
SageMakerComponentSpec[
|
||||
SageMakerEndpointConfigInputs, SageMakerEndpointConfigOutputs
|
||||
]
|
||||
):
|
||||
INPUTS: SageMakerEndpointConfigInputs = SageMakerEndpointConfigInputs(
|
||||
async_inference_config=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="Specifies configuration for how an endpoint performs asynchronous inference.",
|
||||
required=False,
|
||||
),
|
||||
data_capture_config=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="Configuration to control how SageMaker captures inference data.",
|
||||
required=False,
|
||||
),
|
||||
endpoint_config_name=InputValidator(
|
||||
input_type=str,
|
||||
description="The name of the endpoint configuration.",
|
||||
required=True,
|
||||
),
|
||||
kms_key_id=InputValidator(
|
||||
input_type=str,
|
||||
description="The Amazon Resource Name (ARN) of a Amazon Web Services Key Management Service key that SageMaker uses to encrypt data on the storage volume attached to the ML compute instance that hosts the endpoint",
|
||||
required=False,
|
||||
),
|
||||
production_variants=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_list,
|
||||
description="An array of ProductionVariant objects, one for each model that you want to host at this endpoint.",
|
||||
required=True,
|
||||
),
|
||||
tags=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_list,
|
||||
description="An array of key-value pairs.",
|
||||
required=False,
|
||||
),
|
||||
**vars(COMMON_INPUTS),
|
||||
)
|
||||
|
||||
OUTPUTS = SageMakerEndpointConfigOutputs(
|
||||
ack_resource_metadata=OutputValidator(
|
||||
description="All CRs managed by ACK have a common `Status.",
|
||||
),
|
||||
conditions=OutputValidator(
|
||||
description="All CRS managed by ACK have a common `Status.",
|
||||
),
|
||||
sagemaker_resource_name=OutputValidator(
|
||||
description="Resource name on Sagemaker",
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, arguments: List[str]):
|
||||
super().__init__(
|
||||
arguments, SageMakerEndpointConfigInputs, SageMakerEndpointConfigOutputs
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> SageMakerEndpointConfigInputs:
|
||||
return self._inputs
|
||||
|
||||
@property
|
||||
def outputs(self) -> SageMakerEndpointConfigOutputs:
|
||||
return self._outputs
|
||||
|
||||
@property
|
||||
def output_paths(self) -> SageMakerEndpointConfigOutputs:
|
||||
return self._output_paths
|
|
@ -0,0 +1,59 @@
|
|||
# SageMaker Model Kubeflow Pipelines component v2
|
||||
|
||||
## Overview
|
||||
|
||||
Model is one of the three components(along with Endpoint and EndpointConfig) you would use to create a Hosting deployment on Sagemaker.
|
||||
|
||||
Component to create [SageMaker Models](https://docs.aws.amazon.com/sagemaker/latest/dg/deploy-model.html) in a Kubeflow Pipelines workflow.
|
||||
|
||||
See the SageMaker Components for Kubeflow Pipelines versions section in [SageMaker Components for Kubeflow Pipelines](https://docs.aws.amazon.com/sagemaker/latest/dg/kubernetes-sagemaker-components-for-kubeflow-pipelines.html#kubeflow-pipeline-components) to learn about the differences between the version 1 and version 2 components.
|
||||
|
||||
### Kubeflow Pipelines backend compatibility
|
||||
SageMaker components are currently supported with Kubeflow pipelines backend v1. This means, you will have to use KFP sdk 1.8.x to create your pipelines.
|
||||
|
||||
## Getting Started
|
||||
|
||||
Follow [this guide](https://github.com/kubeflow/pipelines/tree/master/samples/contrib/aws-samples#prerequisites) to setup the prerequisites for Model depending on your deployment.
|
||||
|
||||
## Inputs Parameters
|
||||
Find the high level component input parameters and their description in the [component's input specification](./component.yaml). The parameters with `JsonObject` or `JsonArray` type inputs have nested fields, you will have to refer to the [TrainingJob CRD specification](https://aws-controllers-k8s.github.io/community/reference/sagemaker/v1alpha1/model/) for the respective structure and pass the input in JSON format.
|
||||
|
||||
A quick way to see the converted JSON style input is to copy the [sample Model spec](https://aws-controllers-k8s.github.io/community/reference/sagemaker/v1alpha1/model/#spec) and convert it to JSON using a YAML to JSON converter like [this website](https://jsonformatter.org/yaml-to-json).
|
||||
|
||||
For e.g. the `primaryContainer` in the `Model` CRD looks like:
|
||||
|
||||
```
|
||||
primaryContainer:
|
||||
containerHostname: string
|
||||
environment: {}
|
||||
image: string
|
||||
imageConfig:
|
||||
repositoryAccessMode: string
|
||||
repositoryAuthConfig:
|
||||
repositoryCredentialsProviderARN: string
|
||||
inferenceSpecificationName: string
|
||||
mode: string
|
||||
modelDataURL: string
|
||||
modelPackageName: string
|
||||
multiModelConfig:
|
||||
modelCacheSetting: string
|
||||
```
|
||||
|
||||
The `primaryContainer` input for the component would be (not all parameters are included):
|
||||
|
||||
```
|
||||
primaryContainer = {
|
||||
"containerHostname": "xgboost",
|
||||
"environment": {"my_env_key": "my_env_value"},
|
||||
"image": "257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost:0.90-1-cpu-py3",
|
||||
"modelDataURL": "s3://<path to model>",
|
||||
"modelPackageName": "SingleModel",
|
||||
}
|
||||
```
|
||||
|
||||
You might also want to look at the [Model API reference](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html) for a detailed explaination of parameters.
|
||||
|
||||
## References
|
||||
- [Inference on SageMaker](https://docs.aws.amazon.com/sagemaker/latest/dg/deploy-model.html)
|
||||
- [Model CRD specification](https://aws-controllers-k8s.github.io/community/reference/sagemaker/v1alpha1/model/)
|
||||
- [Model API reference](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html)
|
|
@ -0,0 +1,108 @@
|
|||
name: "Sagemaker - Model"
|
||||
description: Create Model
|
||||
inputs:
|
||||
- {
|
||||
name: region,
|
||||
type: String,
|
||||
description: "The region to use for the training job",
|
||||
}
|
||||
###########################GENERATED SECTION BELOW############################
|
||||
|
||||
- {
|
||||
name: containers,
|
||||
type: JsonArray,
|
||||
default: '[]',
|
||||
description: "Specifies the containers in the inference pipeline.",
|
||||
}
|
||||
- {
|
||||
name: enable_network_isolation,
|
||||
type: Bool,
|
||||
default: False,
|
||||
description: "Isolates the model container.",
|
||||
}
|
||||
- {
|
||||
name: execution_role_arn,
|
||||
type: String,
|
||||
default: '',
|
||||
description: "The Amazon Resource Name (ARN) of the IAM role that SageMaker can assume to access model artifacts and docker image for deployment on ML compute instances or for batch transform jobs.",
|
||||
}
|
||||
- {
|
||||
name: inference_execution_config,
|
||||
type: JsonObject,
|
||||
default: '{}',
|
||||
description: "Specifies details of how containers in a multi-container endpoint are called.",
|
||||
}
|
||||
- {
|
||||
name: model_name,
|
||||
type: String,
|
||||
default: '',
|
||||
description: "The name of the new model.",
|
||||
}
|
||||
- {
|
||||
name: primary_container,
|
||||
type: JsonObject,
|
||||
default: '{}',
|
||||
description: "The location of the primary docker image containing inference code, associated artifacts, and custom environment map that the inference code uses when the model is deployed for predictions.",
|
||||
}
|
||||
- {
|
||||
name: tags,
|
||||
type: JsonArray,
|
||||
default: '[]',
|
||||
description: "An array of key-value pairs.",
|
||||
}
|
||||
- {
|
||||
name: vpc_config,
|
||||
type: JsonObject,
|
||||
default: '{}',
|
||||
description: "A VpcConfig object that specifies the VPC that you want your model to connect to.",
|
||||
}
|
||||
###########################GENERATED SECTION ABOVE############################
|
||||
|
||||
outputs:
|
||||
###########################GENERATED SECTION BELOW############################
|
||||
|
||||
- {
|
||||
name: ack_resource_metadata,
|
||||
type: JsonObject,
|
||||
description: "All CRs managed by ACK have a common `Status.",
|
||||
}
|
||||
- {
|
||||
name: conditions,
|
||||
type: JsonArray,
|
||||
description: "All CRS managed by ACK have a common `Status.",
|
||||
}
|
||||
- {
|
||||
name: sagemaker_resource_name,
|
||||
type: String,
|
||||
description: "Resource name on Sagemaker",
|
||||
}
|
||||
###########################GENERATED SECTION ABOVE############################
|
||||
|
||||
implementation:
|
||||
container:
|
||||
image: public.ecr.aws/kubeflow-on-aws/aws-sagemaker-kfp-components:2.2.0
|
||||
command: [python3]
|
||||
args:
|
||||
- Model/src/Model_component.py
|
||||
- --region
|
||||
- { inputValue: region }
|
||||
###########################GENERATED SECTION BELOW############################
|
||||
- --containers
|
||||
- { inputValue: containers }
|
||||
- --enable_network_isolation
|
||||
- { inputValue: enable_network_isolation }
|
||||
- --execution_role_arn
|
||||
- { inputValue: execution_role_arn }
|
||||
- --inference_execution_config
|
||||
- { inputValue: inference_execution_config }
|
||||
- --model_name
|
||||
- { inputValue: model_name }
|
||||
- --primary_container
|
||||
- { inputValue: primary_container }
|
||||
- --tags
|
||||
- { inputValue: tags }
|
||||
- --vpc_config
|
||||
- { inputValue: vpc_config }
|
||||
|
||||
###########################GENERATED SECTION ABOVE############################
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
import json
|
||||
|
||||
from Model.src.Model_spec import (
|
||||
SageMakerModelInputs,
|
||||
SageMakerModelOutputs,
|
||||
SageMakerModelSpec,
|
||||
)
|
||||
from commonv2.sagemaker_component import (
|
||||
SageMakerComponent,
|
||||
ComponentMetadata,
|
||||
SageMakerJobStatus,
|
||||
)
|
||||
from commonv2 import snake_to_camel
|
||||
|
||||
|
||||
@ComponentMetadata(
|
||||
name="SageMaker - Model",
|
||||
description="",
|
||||
spec=SageMakerModelSpec,
|
||||
)
|
||||
class SageMakerModelComponent(SageMakerComponent):
|
||||
|
||||
"""SageMaker component for Model."""
|
||||
|
||||
def Do(self, spec: SageMakerModelSpec):
|
||||
|
||||
self.namespace = self._get_current_namespace()
|
||||
logging.info("Current namespace: " + self.namespace)
|
||||
|
||||
############GENERATED SECTION BELOW############
|
||||
|
||||
self.job_name = spec.inputs.model_name = (
|
||||
spec.inputs.model_name # todo: need customize
|
||||
if spec.inputs.model_name # todo: need customize
|
||||
else SageMakerComponent._generate_unique_timestamped_id(prefix="model")
|
||||
)
|
||||
|
||||
self.group = "sagemaker.services.k8s.aws"
|
||||
self.version = "v1alpha1"
|
||||
self.plural = "models"
|
||||
self.spaced_out_resource_name = "Model"
|
||||
|
||||
self.job_request_outline_location = "Model/src/Model_request.yaml.tpl"
|
||||
self.job_request_location = "Model/src/Model_request.yaml"
|
||||
self.update_supported = False
|
||||
############GENERATED SECTION ABOVE############
|
||||
|
||||
super().Do(spec.inputs, spec.outputs, spec.output_paths)
|
||||
|
||||
def _create_job_request(
|
||||
self,
|
||||
inputs: SageMakerModelInputs,
|
||||
outputs: SageMakerModelOutputs,
|
||||
) -> Dict:
|
||||
|
||||
return super()._create_job_yaml(inputs, outputs)
|
||||
|
||||
def _submit_job_request(self, request: Dict) -> object:
|
||||
|
||||
return super()._create_resource(request, 12, 15)
|
||||
|
||||
def _on_job_terminated(self):
|
||||
super()._delete_custom_resource()
|
||||
|
||||
def _after_submit_job_request(
|
||||
self,
|
||||
job: object,
|
||||
request: Dict,
|
||||
inputs: SageMakerModelInputs,
|
||||
outputs: SageMakerModelOutputs,
|
||||
):
|
||||
logging.info(
|
||||
"Model in Sagemaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/models/{}".format(
|
||||
inputs.region, inputs.region, self.job_name
|
||||
)
|
||||
)
|
||||
|
||||
def _get_job_status(self):
|
||||
return SageMakerJobStatus(is_completed=True, raw_status="Completed")
|
||||
|
||||
def _get_upgrade_status(self):
|
||||
|
||||
return self._get_job_status()
|
||||
|
||||
def _after_job_complete(
|
||||
self,
|
||||
job: object,
|
||||
request: Dict,
|
||||
inputs: SageMakerModelInputs,
|
||||
outputs: SageMakerModelOutputs,
|
||||
):
|
||||
# prepare component outputs (defined in the spec)
|
||||
|
||||
ack_statuses = super()._get_resource()["status"]
|
||||
|
||||
############GENERATED SECTION BELOW############
|
||||
|
||||
outputs.ack_resource_metadata = str(
|
||||
ack_statuses["ackResourceMetadata"]
|
||||
if "ackResourceMetadata" in ack_statuses
|
||||
else None
|
||||
)
|
||||
outputs.conditions = str(
|
||||
ack_statuses["conditions"] if "conditions" in ack_statuses else None
|
||||
)
|
||||
outputs.sagemaker_resource_name = self.job_name
|
||||
|
||||
############GENERATED SECTION ABOVE############
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
spec = SageMakerModelSpec(sys.argv[1:])
|
||||
|
||||
component = SageMakerModelComponent()
|
||||
component.Do(spec)
|
|
@ -0,0 +1,15 @@
|
|||
apiVersion: sagemaker.services.k8s.aws/v1alpha1
|
||||
kind: Model
|
||||
metadata:
|
||||
name:
|
||||
annotations:
|
||||
services.k8s.aws/region:
|
||||
spec:
|
||||
containers:
|
||||
enableNetworkIsolation:
|
||||
executionRoleARN:
|
||||
inferenceExecutionConfig:
|
||||
modelName:
|
||||
primaryContainer:
|
||||
tags:
|
||||
vpcConfig:
|
|
@ -0,0 +1,126 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Specification for the SageMaker - Model"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import List
|
||||
from commonv2.sagemaker_component_spec import (
|
||||
SageMakerComponentSpec,
|
||||
SageMakerComponentBaseOutputs,
|
||||
)
|
||||
from commonv2.spec_input_parsers import SpecInputParsers
|
||||
from commonv2.common_inputs import (
|
||||
COMMON_INPUTS,
|
||||
SageMakerComponentCommonInputs,
|
||||
SageMakerComponentInput as Input,
|
||||
SageMakerComponentOutput as Output,
|
||||
SageMakerComponentInputValidator as InputValidator,
|
||||
SageMakerComponentOutputValidator as OutputValidator,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=False)
|
||||
class SageMakerModelInputs(SageMakerComponentCommonInputs):
|
||||
"""Defines the set of inputs for the Model component."""
|
||||
|
||||
containers: Input
|
||||
enable_network_isolation: Input
|
||||
execution_role_arn: Input
|
||||
inference_execution_config: Input
|
||||
model_name: Input
|
||||
primary_container: Input
|
||||
tags: Input
|
||||
vpc_config: Input
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageMakerModelOutputs(SageMakerComponentBaseOutputs):
|
||||
"""Defines the set of outputs for the Model component."""
|
||||
|
||||
ack_resource_metadata: Output
|
||||
conditions: Output
|
||||
sagemaker_resource_name: Output
|
||||
|
||||
|
||||
class SageMakerModelSpec(
|
||||
SageMakerComponentSpec[SageMakerModelInputs, SageMakerModelOutputs]
|
||||
):
|
||||
INPUTS: SageMakerModelInputs = SageMakerModelInputs(
|
||||
containers=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_list,
|
||||
description="Specifies the containers in the inference pipeline.",
|
||||
required=False,
|
||||
),
|
||||
enable_network_isolation=InputValidator(
|
||||
input_type=SpecInputParsers.str_to_bool,
|
||||
description="Isolates the model container.",
|
||||
required=False,
|
||||
),
|
||||
execution_role_arn=InputValidator(
|
||||
input_type=str,
|
||||
description="The Amazon Resource Name (ARN) of the IAM role that SageMaker can assume to access model artifacts and docker image for deployment on ML compute instances or for batch transform jobs.",
|
||||
required=True,
|
||||
),
|
||||
inference_execution_config=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="Specifies details of how containers in a multi-container endpoint are called.",
|
||||
required=False,
|
||||
),
|
||||
model_name=InputValidator(
|
||||
input_type=str, description="The name of the new model.", required=True
|
||||
),
|
||||
primary_container=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="The location of the primary docker image containing inference code, associated artifacts, and custom environment map that the inference code uses when the model is deployed for predictions.",
|
||||
required=False,
|
||||
),
|
||||
tags=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_list,
|
||||
description="An array of key-value pairs.",
|
||||
required=False,
|
||||
),
|
||||
vpc_config=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="A VpcConfig object that specifies the VPC that you want your model to connect to.",
|
||||
required=False,
|
||||
),
|
||||
**vars(COMMON_INPUTS),
|
||||
)
|
||||
|
||||
OUTPUTS = SageMakerModelOutputs(
|
||||
ack_resource_metadata=OutputValidator(
|
||||
description="All CRs managed by ACK have a common `Status.",
|
||||
),
|
||||
conditions=OutputValidator(
|
||||
description="All CRS managed by ACK have a common `Status.",
|
||||
),
|
||||
sagemaker_resource_name=OutputValidator(
|
||||
description="Resource name on Sagemaker",
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, arguments: List[str]):
|
||||
super().__init__(arguments, SageMakerModelInputs, SageMakerModelOutputs)
|
||||
|
||||
@property
|
||||
def inputs(self) -> SageMakerModelInputs:
|
||||
return self._inputs
|
||||
|
||||
@property
|
||||
def outputs(self) -> SageMakerModelOutputs:
|
||||
return self._outputs
|
||||
|
||||
@property
|
||||
def output_paths(self) -> SageMakerModelOutputs:
|
||||
return self._output_paths
|
|
@ -40,7 +40,9 @@ The Processing component enables you to submit processing jobs to Amazon SageMak
|
|||
|
||||
#### Hosting Deploy
|
||||
|
||||
The Deploy component enables you to deploy a model in Amazon SageMaker Hosting from a Kubeflow Pipelines workflow. For more information, see [SageMaker Hosting Services - Create Endpoint Kubeflow Pipeline component version 1](https://github.com/kubeflow/pipelines/tree/master/components/aws/sagemaker/deploy).
|
||||
The Hosting component allows you to submit Amazon SageMaker Hosting deployments directly from a Kubeflow Pipelines workflow. For more information, see [SageMaker Endpoint Kubeflow Pipelines component version 2](./Endpoint), [SageMaker Endpoint Config Kubeflow Pipelines component version 2](./EndpointConfig), [SageMaker Model Kubeflow Pipelines component version 2](./Model)
|
||||
|
||||
For more information about Version 1 of Hosting components see [SageMaker Hosting Services - Create Endpoint Kubeflow Pipeline component version 1](https://github.com/kubeflow/pipelines/tree/master/components/aws/sagemaker/deploy).
|
||||
|
||||
#### Batch Transform component
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
** Amazon SageMaker Components for Kubeflow Pipelines; version 2.1.0 --
|
||||
** Amazon SageMaker Components for Kubeflow Pipelines; version 2.2.0 --
|
||||
https://github.com/kubeflow/pipelines/tree/master/components/aws/sagemaker
|
||||
Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
** pathlib2; version 2.3.5 --
|
||||
|
@ -14,7 +14,9 @@ Copyright (c) 2016-2017 Jukka Lehtosalo and contributors
|
|||
** kubernetes; version 12.0.1 --
|
||||
https://github.com/kubernetes-client/python
|
||||
Copyright 2021 The Kubernetes Authors.
|
||||
|
||||
https://github.com/urllib3/urllib3
|
||||
** urllib3; 1.26.15 --
|
||||
Copyright (c) 2008-2020 Andrey Petrov and contributors
|
||||
|
||||
Apache License
|
||||
|
||||
|
|
|
@ -24,13 +24,13 @@ inputs:
|
|||
name: debug_hook_config,
|
||||
type: JsonObject,
|
||||
default: '{}',
|
||||
description: "Configuration information for the Debugger hook parameters, metric and tensor collections, and storage paths.",
|
||||
description: "Configuration information for the Amazon SageMaker Debugger hook parameters, metric and tensor collections, and storage paths.",
|
||||
}
|
||||
- {
|
||||
name: debug_rule_configurations,
|
||||
type: JsonArray,
|
||||
default: '[]',
|
||||
description: "Configuration information for Debugger rules for debugging output tensors.",
|
||||
description: "Configuration information for Amazon SageMaker Debugger rules for debugging output tensors.",
|
||||
}
|
||||
- {
|
||||
name: enable_inter_container_traffic_encryption,
|
||||
|
@ -84,13 +84,13 @@ inputs:
|
|||
name: profiler_config,
|
||||
type: JsonObject,
|
||||
default: '{}',
|
||||
description: "Configuration information for Debugger system monitoring, framework profiling, and storage paths.",
|
||||
description: "Configuration information for Amazon SageMaker Debugger system monitoring, framework profiling, and storage paths.",
|
||||
}
|
||||
- {
|
||||
name: profiler_rule_configurations,
|
||||
type: JsonArray,
|
||||
default: '[]',
|
||||
description: "Configuration information for Debugger rules for profiling system and framework metrics.",
|
||||
description: "Configuration information for Amazon SageMaker Debugger rules for profiling system and framework metrics.",
|
||||
}
|
||||
- {
|
||||
name: resource_config,
|
||||
|
@ -126,7 +126,7 @@ inputs:
|
|||
name: tensor_board_output_config,
|
||||
type: JsonObject,
|
||||
default: '{}',
|
||||
description: "Configuration of storage locations for the Debugger TensorBoard output data.",
|
||||
description: "Configuration of storage locations for the Amazon SageMaker Debugger TensorBoard output data.",
|
||||
}
|
||||
- {
|
||||
name: training_job_name,
|
||||
|
@ -155,16 +155,26 @@ outputs:
|
|||
type: JsonArray,
|
||||
description: "All CRS managed by ACK have a common `Status.",
|
||||
}
|
||||
- {
|
||||
name: creation_time,
|
||||
type: String,
|
||||
description: "A timestamp that indicates when the training job was created.",
|
||||
}
|
||||
- {
|
||||
name: debug_rule_evaluation_statuses,
|
||||
type: JsonArray,
|
||||
description: "Evaluation status of Debugger rules for debugging on a training job.",
|
||||
description: "Evaluation status of Amazon SageMaker Debugger rules for debugging on a training job.",
|
||||
}
|
||||
- {
|
||||
name: failure_reason,
|
||||
type: String,
|
||||
description: "If the training job failed, the reason it failed.",
|
||||
}
|
||||
- {
|
||||
name: last_modified_time,
|
||||
type: String,
|
||||
description: "A timestamp that indicates when the status of the training job was last modified.",
|
||||
}
|
||||
- {
|
||||
name: model_artifacts,
|
||||
type: JsonObject,
|
||||
|
@ -173,7 +183,7 @@ outputs:
|
|||
- {
|
||||
name: profiler_rule_evaluation_statuses,
|
||||
type: JsonArray,
|
||||
description: "Evaluation status of Debugger rules for profiling on a training job.",
|
||||
description: "Evaluation status of Amazon SageMaker Debugger rules for profiling on a training job.",
|
||||
}
|
||||
- {
|
||||
name: profiling_status,
|
||||
|
@ -199,7 +209,7 @@ outputs:
|
|||
|
||||
implementation:
|
||||
container:
|
||||
image: public.ecr.aws/kubeflow-on-aws/aws-sagemaker-kfp-components:2.1.0
|
||||
image: public.ecr.aws/kubeflow-on-aws/aws-sagemaker-kfp-components:2.2.0
|
||||
command: [python3]
|
||||
args:
|
||||
- TrainingJob/src/TrainingJob_component.py
|
||||
|
|
|
@ -54,11 +54,13 @@ class SageMakerTrainingJobComponent(SageMakerComponent):
|
|||
self.group = "sagemaker.services.k8s.aws"
|
||||
self.version = "v1alpha1"
|
||||
self.plural = "trainingjobs"
|
||||
self.spaced_out_resource_name = "Training Job"
|
||||
|
||||
self.job_request_outline_location = (
|
||||
"TrainingJob/src/TrainingJob_request.yaml.tpl"
|
||||
)
|
||||
self.job_request_location = "TrainingJob/src/TrainingJob_request.yaml"
|
||||
self.update_supported = False
|
||||
############GENERATED SECTION ABOVE############
|
||||
|
||||
super().Do(spec.inputs, spec.outputs, spec.output_paths)
|
||||
|
@ -84,7 +86,7 @@ class SageMakerTrainingJobComponent(SageMakerComponent):
|
|||
|
||||
def _submit_job_request(self, request: Dict) -> object:
|
||||
|
||||
return super()._create_resource(request, 6, 10)
|
||||
return super()._create_resource(request, 12, 15)
|
||||
|
||||
def _on_job_terminated(self):
|
||||
super()._delete_custom_resource()
|
||||
|
@ -179,6 +181,10 @@ class SageMakerTrainingJobComponent(SageMakerComponent):
|
|||
)
|
||||
return SageMakerJobStatus(is_completed=False, raw_status=sm_job_status)
|
||||
|
||||
def _get_upgrade_status(self):
|
||||
|
||||
return self._get_job_status()
|
||||
|
||||
def _after_job_complete(
|
||||
self,
|
||||
job: object,
|
||||
|
@ -200,6 +206,9 @@ class SageMakerTrainingJobComponent(SageMakerComponent):
|
|||
outputs.conditions = str(
|
||||
ack_statuses["conditions"] if "conditions" in ack_statuses else None
|
||||
)
|
||||
outputs.creation_time = str(
|
||||
ack_statuses["creationTime"] if "creationTime" in ack_statuses else None
|
||||
)
|
||||
outputs.debug_rule_evaluation_statuses = str(
|
||||
ack_statuses["debugRuleEvaluationStatuses"]
|
||||
if "debugRuleEvaluationStatuses" in ack_statuses
|
||||
|
@ -208,6 +217,11 @@ class SageMakerTrainingJobComponent(SageMakerComponent):
|
|||
outputs.failure_reason = str(
|
||||
ack_statuses["failureReason"] if "failureReason" in ack_statuses else None
|
||||
)
|
||||
outputs.last_modified_time = str(
|
||||
ack_statuses["lastModifiedTime"]
|
||||
if "lastModifiedTime" in ack_statuses
|
||||
else None
|
||||
)
|
||||
outputs.model_artifacts = str(
|
||||
ack_statuses["modelArtifacts"] if "modelArtifacts" in ack_statuses else None
|
||||
)
|
||||
|
|
|
@ -64,8 +64,10 @@ class SageMakerTrainingJobOutputs(SageMakerComponentBaseOutputs):
|
|||
|
||||
ack_resource_metadata: Output
|
||||
conditions: Output
|
||||
creation_time: Output
|
||||
debug_rule_evaluation_statuses: Output
|
||||
failure_reason: Output
|
||||
last_modified_time: Output
|
||||
model_artifacts: Output
|
||||
profiler_rule_evaluation_statuses: Output
|
||||
profiling_status: Output
|
||||
|
@ -80,7 +82,7 @@ class SageMakerTrainingJobSpec(
|
|||
INPUTS: SageMakerTrainingJobInputs = SageMakerTrainingJobInputs(
|
||||
algorithm_specification=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="The registry path of the Docker image that contains the training algorithm and algorithm-specific me",
|
||||
description="The registry path of the Docker image that contains the training algorithm and algorithm-specific metadata, including the input mode.",
|
||||
required=True,
|
||||
),
|
||||
checkpoint_config=InputValidator(
|
||||
|
@ -90,27 +92,27 @@ class SageMakerTrainingJobSpec(
|
|||
),
|
||||
debug_hook_config=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="Configuration information for the Debugger hook parameters, metric and tensor collections, and stora",
|
||||
description="Configuration information for the Amazon SageMaker Debugger hook parameters, metric and tensor collections, and storage paths.",
|
||||
required=False,
|
||||
),
|
||||
debug_rule_configurations=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_list,
|
||||
description="Configuration information for Debugger rules for debugging output tensors.",
|
||||
description="Configuration information for Amazon SageMaker Debugger rules for debugging output tensors.",
|
||||
required=False,
|
||||
),
|
||||
enable_inter_container_traffic_encryption=InputValidator(
|
||||
input_type=SpecInputParsers.str_to_bool,
|
||||
description="To encrypt all communications between ML compute instances in distributed training, choose True. Enc",
|
||||
description="To encrypt all communications between ML compute instances in distributed training, choose True.",
|
||||
required=False,
|
||||
),
|
||||
enable_managed_spot_training=InputValidator(
|
||||
input_type=SpecInputParsers.str_to_bool,
|
||||
description="To train models using managed spot training, choose True. Managed spot training provides a fully man",
|
||||
description="To train models using managed spot training, choose True.",
|
||||
required=False,
|
||||
),
|
||||
enable_network_isolation=InputValidator(
|
||||
input_type=SpecInputParsers.str_to_bool,
|
||||
description="Isolates the training container. No inbound or outbound network calls can be made, except for calls",
|
||||
description="Isolates the training container.",
|
||||
required=False,
|
||||
),
|
||||
environment=InputValidator(
|
||||
|
@ -120,32 +122,32 @@ class SageMakerTrainingJobSpec(
|
|||
),
|
||||
experiment_config=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="Associates a SageMaker job as a trial component with an experiment and trial. Specified when you cal",
|
||||
description="Associates a SageMaker job as a trial component with an experiment and trial.",
|
||||
required=False,
|
||||
),
|
||||
hyper_parameters=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="Algorithm-specific parameters that influence the quality of the model. You set hyperparameters befor",
|
||||
description="Algorithm-specific parameters that influence the quality of the model.",
|
||||
required=False,
|
||||
),
|
||||
input_data_config=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_list,
|
||||
description="An array of Channel objects. Each channel is a named input source. InputDataConfig describes the inp",
|
||||
description="An array of Channel objects.",
|
||||
required=False,
|
||||
),
|
||||
output_data_config=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="Specifies the path to the S3 location where you want to store model artifacts. SageMaker creates sub",
|
||||
description="Specifies the path to the S3 location where you want to store model artifacts.",
|
||||
required=True,
|
||||
),
|
||||
profiler_config=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="Configuration information for Debugger system monitoring, framework profiling, and storage paths.",
|
||||
description="Configuration information for Amazon SageMaker Debugger system monitoring, framework profiling, and storage paths.",
|
||||
required=False,
|
||||
),
|
||||
profiler_rule_configurations=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_list,
|
||||
description="Configuration information for Debugger rules for profiling system and framework metrics.",
|
||||
description="Configuration information for Amazon SageMaker Debugger rules for profiling system and framework metrics.",
|
||||
required=False,
|
||||
),
|
||||
resource_config=InputValidator(
|
||||
|
@ -160,32 +162,30 @@ class SageMakerTrainingJobSpec(
|
|||
),
|
||||
role_arn=InputValidator(
|
||||
input_type=str,
|
||||
description="The Amazon Resource Name (ARN) of an IAM role that SageMaker can assume to perform tasks on your beh",
|
||||
description="The Amazon Resource Name (ARN) of an IAM role that SageMaker can assume to perform tasks on your behalf.",
|
||||
required=True,
|
||||
),
|
||||
stopping_condition=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="Specifies a limit to how long a model training job can run. It also specifies how long a managed Spo",
|
||||
description="Specifies a limit to how long a model training job can run.",
|
||||
required=True,
|
||||
),
|
||||
tags=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_list,
|
||||
description="An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in di",
|
||||
description="An array of key-value pairs.",
|
||||
required=False,
|
||||
),
|
||||
tensor_board_output_config=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="Configuration of storage locations for the Debugger TensorBoard output data.",
|
||||
description="Configuration of storage locations for the Amazon SageMaker Debugger TensorBoard output data.",
|
||||
required=False,
|
||||
),
|
||||
training_job_name=InputValidator(
|
||||
input_type=str,
|
||||
description="The name of the training job. The name must be unique within an Amazon Web Services Region in an Ama",
|
||||
required=True,
|
||||
input_type=str, description="The name of the training job.", required=True
|
||||
),
|
||||
vpc_config=InputValidator(
|
||||
input_type=SpecInputParsers.yaml_or_json_dict,
|
||||
description="A VpcConfig object that specifies the VPC that you want your training job to connect to. Control acc",
|
||||
description="A VpcConfig object that specifies the VPC that you want your training job to connect to.",
|
||||
required=False,
|
||||
),
|
||||
**vars(COMMON_INPUTS),
|
||||
|
@ -193,31 +193,37 @@ class SageMakerTrainingJobSpec(
|
|||
|
||||
OUTPUTS = SageMakerTrainingJobOutputs(
|
||||
ack_resource_metadata=OutputValidator(
|
||||
description="All CRs managed by ACK have a common `Status.ACKResourceMetadata` member that is used to contain res",
|
||||
description="All CRs managed by ACK have a common `Status.",
|
||||
),
|
||||
conditions=OutputValidator(
|
||||
description="All CRS managed by ACK have a common `Status.Conditions` member that contains a collection of `ackv1",
|
||||
description="All CRS managed by ACK have a common `Status.",
|
||||
),
|
||||
creation_time=OutputValidator(
|
||||
description="A timestamp that indicates when the training job was created.",
|
||||
),
|
||||
debug_rule_evaluation_statuses=OutputValidator(
|
||||
description="Evaluation status of Debugger rules for debugging on a training job.",
|
||||
description="Evaluation status of Amazon SageMaker Debugger rules for debugging on a training job.",
|
||||
),
|
||||
failure_reason=OutputValidator(
|
||||
description="If the training job failed, the reason it failed.",
|
||||
),
|
||||
last_modified_time=OutputValidator(
|
||||
description="A timestamp that indicates when the status of the training job was last modified.",
|
||||
),
|
||||
model_artifacts=OutputValidator(
|
||||
description="Information about the Amazon S3 location that is configured for storing model artifacts.",
|
||||
),
|
||||
profiler_rule_evaluation_statuses=OutputValidator(
|
||||
description="Evaluation status of Debugger rules for profiling on a training job.",
|
||||
description="Evaluation status of Amazon SageMaker Debugger rules for profiling on a training job.",
|
||||
),
|
||||
profiling_status=OutputValidator(
|
||||
description="Profiling status of a training job.",
|
||||
),
|
||||
secondary_status=OutputValidator(
|
||||
description="Provides detailed information about the state of the training job. For detailed information on the s",
|
||||
description="Provides detailed information about the state of the training job.",
|
||||
),
|
||||
training_job_status=OutputValidator(
|
||||
description="The status of the training job. SageMaker provides the following training job statuses: * InProg",
|
||||
description="The status of the training job.",
|
||||
),
|
||||
warm_pool_status=OutputValidator(
|
||||
description="The status of the warm pool associated with the training job.",
|
||||
|
|
|
@ -10,10 +10,24 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
def snake_to_camel(name):
|
||||
"""Convert snake case to camel case."""
|
||||
if name == "role_arn":
|
||||
return "roleARN"
|
||||
overrides = {
|
||||
"role_arn": "roleARN",
|
||||
"execution_role_arn" : "executionRoleARN"
|
||||
}
|
||||
if name in overrides:
|
||||
return overrides[name]
|
||||
temp = name.split("_")
|
||||
return temp[0] + "".join(ele.title() for ele in temp[1:])
|
||||
|
||||
def is_ack_requeue_error(error_msg):
|
||||
# ^(start) [one or more alphanumeric characters] in [one or more alphanumeric characters] state cannot be modified or deleted. $(end)
|
||||
# Alternative was using substring "state cannot be modified or deleted.", but there is a risk of validation errors including it too.
|
||||
requeue_regex = re.compile(r"^(\w+) in (\w+) state cannot be modified or deleted.$")
|
||||
matches = requeue_regex.fullmatch(error_msg)
|
||||
if matches is not None:
|
||||
return True
|
||||
return False
|
|
@ -35,7 +35,7 @@ from commonv2.common_inputs import (
|
|||
SageMakerComponentCommonInputs,
|
||||
)
|
||||
|
||||
from commonv2 import snake_to_camel
|
||||
from commonv2 import snake_to_camel, is_ack_requeue_error
|
||||
|
||||
# This handler is called whenever the @ComponentMetadata is applied.
|
||||
# It allows the command line compiler to detect every component spec class.
|
||||
|
@ -101,6 +101,7 @@ class SageMakerComponent:
|
|||
COMPONENT_SPEC = SageMakerComponentSpec
|
||||
|
||||
STATUS_POLL_INTERVAL = 30
|
||||
UPDATE_PROCESS_INTERVAL = 10
|
||||
|
||||
# parameters that will be filled by Do().
|
||||
# assignment statements in Do() will be genereated
|
||||
|
@ -108,7 +109,11 @@ class SageMakerComponent:
|
|||
group: str
|
||||
version: str
|
||||
plural: str
|
||||
spaced_out_resource_name: str # Used for Logs
|
||||
namespace: Optional[str] = None
|
||||
resource_upgrade: bool = False
|
||||
initial_status: dict
|
||||
update_supported: bool
|
||||
|
||||
job_request_outline_location: str
|
||||
job_request_location: str
|
||||
|
@ -135,7 +140,7 @@ class SageMakerComponent:
|
|||
output_paths: Paths to the respective output locations.
|
||||
"""
|
||||
|
||||
# test if k8s is available
|
||||
# Verify that the kubernetes cluster is available
|
||||
try:
|
||||
self._init_configure_k8s()
|
||||
except Exception as e:
|
||||
|
@ -191,6 +196,12 @@ class SageMakerComponent:
|
|||
|
||||
signal.signal(signal.SIGTERM, signal_term_handler)
|
||||
|
||||
self.resource_upgrade = self._is_upgrade()
|
||||
if self.resource_upgrade and not self.update_supported:
|
||||
logging.error(
|
||||
f"Resource update is not supported for {self.spaced_out_resource_name}"
|
||||
)
|
||||
return False
|
||||
request = self._create_job_request(inputs, outputs)
|
||||
|
||||
try:
|
||||
|
@ -201,34 +212,8 @@ class SageMakerComponent:
|
|||
)
|
||||
return False
|
||||
|
||||
# check if the SM job is created by finding its arn
|
||||
try:
|
||||
while True:
|
||||
cr_condition = self._check_resource_conditions()
|
||||
if cr_condition: # ACK.Recoverable
|
||||
sleep(self.STATUS_POLL_INTERVAL)
|
||||
continue
|
||||
elif cr_condition == False:
|
||||
return False
|
||||
|
||||
arn = None
|
||||
ack_status = self._get_resource()["status"]
|
||||
ack_resource_meta = ack_status.get("ackResourceMetadata", None)
|
||||
if ack_resource_meta:
|
||||
arn = ack_resource_meta.get("arn", None)
|
||||
if arn is not None:
|
||||
logging.info(f"Created Sagemaker job with ARN: {arn}")
|
||||
|
||||
# Continue until complete
|
||||
if arn:
|
||||
break
|
||||
|
||||
sleep(self.STATUS_POLL_INTERVAL)
|
||||
logging.info(f"Getting arn for {self.job_name}")
|
||||
except Exception as e:
|
||||
logging.exception(
|
||||
"An error occurred while getting job arn, ACK CR created but Sagemaker job not created."
|
||||
)
|
||||
created = self._verify_resource_consumption()
|
||||
if not created:
|
||||
return False
|
||||
|
||||
self._after_submit_job_request(job, request, inputs, outputs)
|
||||
|
@ -242,19 +227,36 @@ class SageMakerComponent:
|
|||
if cr_condition:
|
||||
sleep(self.STATUS_POLL_INTERVAL)
|
||||
continue
|
||||
elif cr_condition == False: # ACK.Terminal
|
||||
elif (
|
||||
cr_condition == False
|
||||
): # ACK.Terminal or special errors (Validation Exception/Invalid Input)
|
||||
return False
|
||||
|
||||
status = self._get_job_status()
|
||||
status = (
|
||||
self._get_job_status()
|
||||
if not self.resource_upgrade
|
||||
else self._get_upgrade_status()
|
||||
)
|
||||
# Continue until complete
|
||||
if status and status.is_completed:
|
||||
logging.info(f"Job ended, final status: {status.raw_status}")
|
||||
if self.resource_upgrade:
|
||||
logging.info(
|
||||
f"{self.spaced_out_resource_name} Update complete, final status: {status.raw_status}"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"{self.spaced_out_resource_name} Creation complete, final status: {status.raw_status}"
|
||||
)
|
||||
break
|
||||
|
||||
sleep(self.STATUS_POLL_INTERVAL)
|
||||
logging.info(f"Job is in status: {status.raw_status}")
|
||||
logging.info(
|
||||
f"{self.spaced_out_resource_name} is in status: {status.raw_status}"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("An error occurred while polling for job status")
|
||||
logging.exception(
|
||||
f"An error occurred while polling for {self.spaced_out_resource_name} status"
|
||||
)
|
||||
return False
|
||||
|
||||
if status.has_error:
|
||||
|
@ -266,6 +268,94 @@ class SageMakerComponent:
|
|||
|
||||
return True
|
||||
|
||||
def _get_conditions_of_type(self, condition_type):
|
||||
resource_conditions = self._get_resource()["status"]["conditions"]
|
||||
filtered_conditions = filter(
|
||||
lambda condition: (condition["type"] == condition_type), resource_conditions
|
||||
)
|
||||
return list(filtered_conditions)
|
||||
|
||||
def _verify_resource_consumption(self) -> bool:
|
||||
"""Verify that the resource has been successfully consumed by the controller.
|
||||
In the case of an update verify that the job arn exists.
|
||||
|
||||
Returns:
|
||||
bool: Whether the resource consumed by the controller.
|
||||
"""
|
||||
submission_ack_printed = False
|
||||
ERROR_NOT_CREATED_MESSAGE = "An error occurred while getting resource arn, ACK CR created but Sagemaker resource not created."
|
||||
ERROR_UPDATE_MESSAGE = "An error occured when getting the resource arn. Check the ACK Sagemaker Controller logs."
|
||||
|
||||
try:
|
||||
while True:
|
||||
cr_condition = self._check_resource_conditions()
|
||||
if cr_condition: # ACK.Recoverable
|
||||
sleep(self.STATUS_POLL_INTERVAL)
|
||||
continue
|
||||
elif cr_condition == False:
|
||||
if (
|
||||
self.resource_upgrade
|
||||
and not self.is_update_consumed_by_controller()
|
||||
):
|
||||
sleep(self.UPDATE_PROCESS_INTERVAL)
|
||||
continue
|
||||
return False
|
||||
|
||||
# Retrieve Sagemaker ARN
|
||||
arn = self.check_resource_initiation(submission_ack_printed)
|
||||
|
||||
# Continue until complete
|
||||
if arn:
|
||||
submission_ack_printed = True
|
||||
if (
|
||||
self.resource_upgrade
|
||||
and not self.is_update_consumed_by_controller()
|
||||
):
|
||||
sleep(self.UPDATE_PROCESS_INTERVAL)
|
||||
continue
|
||||
break
|
||||
|
||||
sleep(self.STATUS_POLL_INTERVAL)
|
||||
logging.info(f"Getting arn for {self.job_name}")
|
||||
except Exception as e:
|
||||
err_msg = (
|
||||
ERROR_UPDATE_MESSAGE
|
||||
if self.resource_upgrade
|
||||
else ERROR_NOT_CREATED_MESSAGE
|
||||
)
|
||||
logging.exception(err_msg)
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_resource_initiation(self, submission_ack_printed: bool):
|
||||
""" Check if resource has been initiated in Sagemaker.
|
||||
A resource is considered to be initiated if the resource ARN is present in the ack resource metadata.
|
||||
If the resource ARN is present in the ack resource metadata, the resource has been successfully
|
||||
created in Sagemaker.
|
||||
|
||||
Args:
|
||||
submission_ack_printed (bool): Parameter to avoid printing the resource consumed message
|
||||
multiple times.
|
||||
|
||||
Returns:
|
||||
str: The ARN of the resource. If the resource ARN is not present in the ack resource metadata,
|
||||
the resource has not been created in Sagemaker.
|
||||
"""
|
||||
ack_status = self._get_resource()["status"]
|
||||
ack_resource_meta = ack_status.get("ackResourceMetadata", None)
|
||||
if ack_resource_meta:
|
||||
arn = ack_resource_meta.get("arn", None)
|
||||
if arn is not None:
|
||||
if submission_ack_printed:
|
||||
resource_consumed_message = (
|
||||
f"Created Sagemaker {self.spaced_out_resource_name} with ARN: {arn}"
|
||||
if not self.resource_upgrade
|
||||
else f"Submitting update for Sagemaker {self.spaced_out_resource_name} with ARN: {arn}"
|
||||
)
|
||||
logging.info(resource_consumed_message)
|
||||
return arn
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def _get_job_status(self) -> SageMakerJobStatus:
|
||||
"""Waits for the current job to complete.
|
||||
|
@ -275,6 +365,26 @@ class SageMakerComponent:
|
|||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_upgrade_status(self) -> SageMakerJobStatus:
|
||||
"""Waits for the resource upgrade to complete
|
||||
|
||||
Returns:
|
||||
SageMakerJobStatus: A status object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_update_consumed_by_controller(self):
|
||||
"""Check if update has been consumed by the controller, in this case it is done by
|
||||
checking whether
|
||||
"""
|
||||
current_resource = self._get_resource()
|
||||
current_status = current_resource.get("status", None)
|
||||
## Python == is deep equal between dicts.
|
||||
if current_status == self.initial_status:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_resource(self):
|
||||
"""Get the custom resource detail similar to: kubectl describe
|
||||
trainingjob JOB_NAME -n NAMESPACE.
|
||||
|
@ -362,12 +472,6 @@ class SageMakerComponent:
|
|||
|
||||
logging.info(f"Custom resource: {json.dumps(job_request_dict, indent=2)}")
|
||||
|
||||
# write ACK custom object YAML to file
|
||||
# out_loc = self.job_request_location
|
||||
# with open(out_loc, "w+") as f:
|
||||
# yaml.dump(job_request_dict, f, default_flow_style=False)
|
||||
# print("FILE CREATED: " + out_loc)
|
||||
|
||||
return job_request_dict
|
||||
|
||||
@abstractmethod
|
||||
|
@ -388,6 +492,36 @@ class SageMakerComponent:
|
|||
"""
|
||||
pass
|
||||
|
||||
def _patch_custom_resource(self, custom_resource: dict):
|
||||
"""Patch a custom resource in ACK
|
||||
|
||||
Args:
|
||||
custom_resource: A dictionary object representing the custom object.
|
||||
Returns:
|
||||
dict: The job object that was patched
|
||||
|
||||
"""
|
||||
|
||||
_api_client = self._get_k8s_api_client()
|
||||
_api = client.CustomObjectsApi(_api_client)
|
||||
|
||||
if self.namespace is None:
|
||||
return _api.patch_cluster_custom_object(
|
||||
self.group.lower(),
|
||||
self.version.lower(),
|
||||
self.plural.lower(),
|
||||
self.job_name.lower(),
|
||||
custom_resource,
|
||||
)
|
||||
return _api.patch_namespaced_custom_object(
|
||||
self.group.lower(),
|
||||
self.version.lower(),
|
||||
self.namespace.lower(),
|
||||
self.plural.lower(),
|
||||
self.job_name.lower(),
|
||||
custom_resource,
|
||||
)
|
||||
|
||||
def _create_custom_resource(self, custom_resource: dict):
|
||||
"""Submit a custom_resource to the ACK cluster.
|
||||
|
||||
|
@ -493,6 +627,7 @@ class SageMakerComponent:
|
|||
* if recoverable and condition set to true, print out message and return true
|
||||
(let outside polling loop goes on forever and let user decide if should stop)
|
||||
* if terminal and condition set up true, print out message and return false
|
||||
* Returns None if there are no error conditions.
|
||||
"""
|
||||
status_conditions = self._get_resource()["status"]["conditions"]
|
||||
|
||||
|
@ -501,6 +636,9 @@ class SageMakerComponent:
|
|||
condition_status = condition["status"]
|
||||
condition_message = condition.get("message", "No error message found.")
|
||||
|
||||
# If the controller has not consumed the update, any existing error will not representative of the new state.
|
||||
if self.resource_upgrade and not self.is_update_consumed_by_controller():
|
||||
continue
|
||||
if condition_type == "ACK.Terminal" and condition_status == "True":
|
||||
logging.error(json.dumps(condition, indent=2))
|
||||
logging.error(
|
||||
|
@ -508,6 +646,9 @@ class SageMakerComponent:
|
|||
)
|
||||
return False
|
||||
if condition_type == "ACK.Recoverable" and condition_status == "True":
|
||||
# ACK requeue errors are not real errors.
|
||||
if is_ack_requeue_error(condition_message):
|
||||
continue
|
||||
logging.error(json.dumps(condition, indent=2))
|
||||
if "ValidationException" in condition_message:
|
||||
logging.error(
|
||||
|
@ -528,6 +669,8 @@ class SageMakerComponent:
|
|||
return None
|
||||
|
||||
def _get_resource_synced_status(self, ack_statuses: Dict):
|
||||
""" Retrieve the resource sync status
|
||||
"""
|
||||
conditions = ack_statuses.get("conditions", None) # Conditions has to be there
|
||||
if conditions == None:
|
||||
return None
|
||||
|
@ -590,10 +733,15 @@ class SageMakerComponent:
|
|||
response is APIserver response for the operation.
|
||||
bool is true if resource was removed from the server and false otherwise
|
||||
"""
|
||||
|
||||
_api_client = self._get_k8s_api_client()
|
||||
_api = client.CustomObjectsApi(_api_client)
|
||||
|
||||
logging.info("Deleting resource %s", (self.job_name))
|
||||
if self.resource_upgrade:
|
||||
logging.info("Recieved termination signal, stopping component but resource update will still proceed if started. Please rerun the component with the desired configuration to revert the update.")
|
||||
return _response, True
|
||||
|
||||
logging.info("Recieved termination signal, deleting custom resource %s", (self.job_name))
|
||||
_response = None
|
||||
if self.namespace is None:
|
||||
_response = _api.delete_cluster_custom_object(
|
||||
|
@ -689,3 +837,24 @@ class SageMakerComponent:
|
|||
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(output_path).write_text(write_value)
|
||||
|
||||
def _is_upgrade(self):
|
||||
"""If the resource already exists the component assumes that the user wants to upgrade
|
||||
Returns:
|
||||
Bool: If the resource is being upgraded or not.
|
||||
Raises:
|
||||
Exception
|
||||
"""
|
||||
try:
|
||||
resource = self._get_resource()
|
||||
if resource is None:
|
||||
return False
|
||||
logging.info("Existing resource detected. Starting Update.")
|
||||
except client.exceptions.ApiException as error:
|
||||
if error.status == 404:
|
||||
logging.info("Resource does not exist. Creating a new resource.")
|
||||
return False
|
||||
else:
|
||||
raise error
|
||||
return True
|
||||
|
||||
|
|
|
@ -2,4 +2,5 @@
|
|||
pathlib2==2.3.5
|
||||
pyyaml==5.4.1
|
||||
mypy-extensions==0.4.3
|
||||
kubernetes==12.0.1
|
||||
kubernetes==12.0.1
|
||||
urllib3==1.26.15
|
|
@ -1,49 +1,13 @@
|
|||
import pytest
|
||||
import os
|
||||
import utils
|
||||
import io
|
||||
import numpy
|
||||
import json
|
||||
import pickle
|
||||
import gzip
|
||||
|
||||
from utils import kfp_client_utils
|
||||
from utils import minio_utils
|
||||
from utils import sagemaker_utils
|
||||
|
||||
|
||||
def run_predict_mnist(boto3_session, endpoint_name, download_dir):
|
||||
"""https://github.com/awslabs/amazon-sagemaker-
|
||||
examples/blob/a8c20eeb72dc7d3e94aaaf28be5bf7d7cd5695cb.
|
||||
|
||||
/sagemaker-python-sdk/1P_kmeans_lowlevel/kmeans_mnist_lowlevel.ipynb
|
||||
"""
|
||||
# Download and load dataset
|
||||
region = boto3_session.region_name
|
||||
download_path = os.path.join(download_dir, "mnist.pkl.gz")
|
||||
boto3_session.resource("s3", region_name=region).Bucket(
|
||||
utils.get_s3_data_bucket()
|
||||
).download_file("algorithms/mnist.pkl.gz", download_path)
|
||||
with gzip.open(download_path, "rb") as f:
|
||||
train_set, valid_set, test_set = pickle.load(f, encoding="latin1")
|
||||
|
||||
# Function to create a csv from numpy array
|
||||
def np2csv(arr):
|
||||
csv = io.BytesIO()
|
||||
numpy.savetxt(csv, arr, delimiter=",", fmt="%g")
|
||||
return csv.getvalue().decode().rstrip()
|
||||
|
||||
# Run prediction on an image
|
||||
runtime = boto3_session.client("sagemaker-runtime")
|
||||
payload = np2csv(train_set[0][30:31])
|
||||
|
||||
response = runtime.invoke_endpoint(
|
||||
EndpointName=endpoint_name,
|
||||
ContentType="text/csv",
|
||||
Body=payload,
|
||||
)
|
||||
return json.loads(response["Body"].read().decode())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_file_dir",
|
||||
|
@ -125,7 +89,7 @@ def test_create_endpoint(
|
|||
assert instance_type == test_params["ExpectedInstanceType"]
|
||||
|
||||
# Validate the model for use by running a prediction
|
||||
result = run_predict_mnist(boto3_session, input_endpoint_name, download_dir)
|
||||
result = sagemaker_utils.run_predict_mnist(boto3_session, input_endpoint_name, download_dir)
|
||||
print(f"prediction result: {result}")
|
||||
assert json.dumps(result, sort_keys=True) == json.dumps(
|
||||
test_params["ExpectedPrediction"], sort_keys=True
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
import pytest
|
||||
import os
|
||||
import utils
|
||||
from utils import kfp_client_utils
|
||||
from utils import ack_utils
|
||||
from utils import sagemaker_utils
|
||||
import json
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_file_dir",
|
||||
[
|
||||
pytest.param(
|
||||
"resources/config/ack-hosting",
|
||||
marks=[pytest.mark.canary_test, pytest.mark.v2, pytest.mark.shallow_canary],
|
||||
),
|
||||
pytest.param("resources/config/ack-hosting-update", marks=pytest.mark.v2),
|
||||
],
|
||||
)
|
||||
def test_create_v2_endpoint(kfp_client, experiment_id, boto3_session, test_file_dir):
|
||||
download_dir = utils.mkdir(os.path.join(test_file_dir + "/generated"))
|
||||
test_params = utils.load_params(
|
||||
utils.replace_placeholders(
|
||||
os.path.join(test_file_dir, "config.yaml"),
|
||||
os.path.join(download_dir, "config.yaml"),
|
||||
)
|
||||
)
|
||||
k8s_client = ack_utils.k8s_client()
|
||||
input_model_name = utils.generate_random_string(10) + "-v2-model"
|
||||
input_endpoint_config_name = (
|
||||
utils.generate_random_string(10) + "-v2-endpoint-config"
|
||||
)
|
||||
input_endpoint_name = utils.generate_random_string(10) + "-v2-endpoint"
|
||||
|
||||
test_params["Arguments"]["model_name"] = input_model_name
|
||||
test_params["Arguments"]["endpoint_config_name"] = input_endpoint_config_name
|
||||
test_params["Arguments"]["endpoint_name"] = input_endpoint_name
|
||||
test_params["Arguments"]["production_variants"][0]["modelName"] = input_model_name
|
||||
|
||||
if "ExpectedEndpointConfig" in test_params.keys():
|
||||
input_second_endpoint_config_name = (
|
||||
utils.generate_random_string(10) + "-v2-sec-endpoint-config"
|
||||
)
|
||||
test_params["Arguments"][
|
||||
"second_endpoint_config_name"
|
||||
] = input_second_endpoint_config_name
|
||||
test_params["Arguments"]["second_production_variants"][0][
|
||||
"modelName"
|
||||
] = input_model_name
|
||||
|
||||
try:
|
||||
_, _, _ = kfp_client_utils.compile_run_monitor_pipeline(
|
||||
kfp_client,
|
||||
experiment_id,
|
||||
test_params["PipelineDefinition"],
|
||||
test_params["Arguments"],
|
||||
download_dir,
|
||||
test_params["TestName"],
|
||||
test_params["Timeout"],
|
||||
)
|
||||
|
||||
endpoint_describe = ack_utils._get_resource(
|
||||
k8s_client, input_endpoint_name, "endpoints"
|
||||
)
|
||||
|
||||
endpoint_describe["status"]["endpointStatus"] == "InService"
|
||||
|
||||
# Verify that the update was successful by checking that the endpoint config name is the same as the second one.
|
||||
if "ExpectedEndpointConfig" in test_params.keys():
|
||||
endpoint_describe["spec"][
|
||||
"endpointConfigName"
|
||||
] == input_second_endpoint_config_name
|
||||
|
||||
# Validate the model for use by running a prediction
|
||||
result = sagemaker_utils.run_predict_mnist(
|
||||
boto3_session, input_endpoint_name, download_dir
|
||||
)
|
||||
print(f"prediction result: {result}")
|
||||
assert json.dumps(result, sort_keys=True) == json.dumps(
|
||||
test_params["ExpectedPrediction"], sort_keys=True
|
||||
)
|
||||
utils.remove_dir(download_dir)
|
||||
finally:
|
||||
ack_utils._delete_resource(k8s_client, input_endpoint_name, "endpoints")
|
||||
ack_utils._delete_resource(
|
||||
k8s_client, input_endpoint_config_name, "endpointconfigs"
|
||||
)
|
||||
ack_utils._delete_resource(k8s_client, input_model_name, "models")
|
||||
|
||||
|
||||
@pytest.mark.v2
|
||||
def test_terminate_v2_endpoint(kfp_client, experiment_id):
|
||||
test_file_dir = "resources/config/ack-hosting"
|
||||
download_dir = utils.mkdir(os.path.join(test_file_dir + "/generated"))
|
||||
test_params = utils.load_params(
|
||||
utils.replace_placeholders(
|
||||
os.path.join(test_file_dir, "config.yaml"),
|
||||
os.path.join(download_dir, "config.yaml"),
|
||||
)
|
||||
)
|
||||
k8s_client = ack_utils.k8s_client()
|
||||
input_model_name = utils.generate_random_string(10) + "-v2-model"
|
||||
input_endpoint_config_name = (
|
||||
utils.generate_random_string(10) + "-v2-endpoint-config"
|
||||
)
|
||||
input_endpoint_name = utils.generate_random_string(10) + "-v2-endpoint"
|
||||
test_params["Arguments"]["model_name"] = input_model_name
|
||||
test_params["Arguments"]["endpoint_config_name"] = input_endpoint_config_name
|
||||
test_params["Arguments"]["endpoint_name"] = input_endpoint_name
|
||||
test_params["Arguments"]["production_variants"][0]["modelName"] = input_model_name
|
||||
try:
|
||||
run_id, _, _ = kfp_client_utils.compile_run_monitor_pipeline(
|
||||
kfp_client,
|
||||
experiment_id,
|
||||
test_params["PipelineDefinition"],
|
||||
test_params["Arguments"],
|
||||
download_dir,
|
||||
test_params["TestName"],
|
||||
60,
|
||||
"running",
|
||||
)
|
||||
assert ack_utils.wait_for_condition(
|
||||
k8s_client,
|
||||
input_endpoint_name,
|
||||
ack_utils.does_endpoint_exist,
|
||||
wait_periods=12,
|
||||
period_length=12,
|
||||
)
|
||||
kfp_client_utils.terminate_run(kfp_client, run_id)
|
||||
assert ack_utils.wait_for_condition(
|
||||
k8s_client,
|
||||
input_endpoint_name,
|
||||
ack_utils.is_endpoint_deleted,
|
||||
wait_periods=20,
|
||||
period_length=20,
|
||||
)
|
||||
finally:
|
||||
ack_utils._delete_resource(
|
||||
k8s_client, input_endpoint_config_name, "endpointconfigs"
|
||||
)
|
||||
ack_utils._delete_resource(k8s_client, input_model_name, "models")
|
|
@ -12,7 +12,7 @@ import ast
|
|||
[
|
||||
pytest.param(
|
||||
"resources/config/ack-training-job",
|
||||
marks=[pytest.mark.canary_test, pytest.mark.shallow_canary],
|
||||
marks=[pytest.mark.canary_test, pytest.mark.shallow_canary,pytest.mark.v2],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -55,7 +55,7 @@ def test_trainingjobV2(kfp_client, experiment_id, test_file_dir):
|
|||
|
||||
# Verify Training job was successful on SageMaker
|
||||
print(f"training job name: {input_job_name}")
|
||||
train_response = ack_utils.describe_training_job(k8s_client, input_job_name)
|
||||
train_response = ack_utils._get_resource(k8s_client, input_job_name, "trainingjobs")
|
||||
assert train_response["status"]["trainingJobStatus"] == "Completed"
|
||||
|
||||
# Verify model artifacts output was generated from this run
|
||||
|
@ -66,7 +66,7 @@ def test_trainingjobV2(kfp_client, experiment_id, test_file_dir):
|
|||
|
||||
utils.remove_dir(download_dir)
|
||||
|
||||
|
||||
@pytest.mark.v2
|
||||
def test_terminate_trainingjob(kfp_client, experiment_id):
|
||||
k8s_client = ack_utils.k8s_client()
|
||||
test_file_dir = "resources/config/ack-training-job"
|
||||
|
|
|
@ -4,4 +4,5 @@ addopts = -rA
|
|||
markers =
|
||||
canary_test: test to be run as part of canaries.
|
||||
fsx_test: tests for FSx features
|
||||
shallow_canary: a subset of canary_test
|
||||
shallow_canary: a subset of canary_test
|
||||
v2: v2 component test
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
PipelineDefinition: resources/definition/update_hosting_v2_pipeline.py
|
||||
TestName: ack-update-endpoint-test
|
||||
Timeout: 3600
|
||||
ExpectedPrediction:
|
||||
predictions:
|
||||
- distance_to_cluster: 7.448746204376221
|
||||
closest_cluster: 2.0
|
||||
ExpectedEndpointConfig: ept2
|
||||
Arguments:
|
||||
region: ((REGION))
|
||||
execution_role_arn: ((SAGEMAKER_ROLE_ARN))
|
||||
primary_container:
|
||||
image: ((KMEANS_REGISTRY)).dkr.ecr.((REGION)).amazonaws.com/kmeans:1
|
||||
modelDataURL: s3://((DATA_BUCKET))/mnist_kmeans_example/model/kmeans-mnist-model/model.tar.gz
|
||||
containerHostname: xgboost
|
||||
mode: SingleModel
|
||||
production_variants:
|
||||
- variantName: variant-1
|
||||
initialVariantWeight: 1.0
|
||||
instanceType: ml.m5.xlarge
|
||||
initialInstanceCount: 1
|
||||
second_production_variants:
|
||||
- variantName: variant-1
|
||||
initialVariantWeight: 1.0
|
||||
instanceType: ml.m5.large
|
||||
initialInstanceCount: 1
|
|
@ -0,0 +1,20 @@
|
|||
PipelineDefinition: resources/definition/create_hosting_v2_pipeline.py
|
||||
TestName: ack-endpoint-test
|
||||
Timeout: 3600
|
||||
ExpectedPrediction:
|
||||
predictions:
|
||||
- distance_to_cluster: 7.448746204376221
|
||||
closest_cluster: 2.0
|
||||
Arguments:
|
||||
region: ((REGION))
|
||||
execution_role_arn: ((SAGEMAKER_ROLE_ARN))
|
||||
primary_container:
|
||||
image: ((KMEANS_REGISTRY)).dkr.ecr.((REGION)).amazonaws.com/kmeans:1
|
||||
modelDataURL: s3://((DATA_BUCKET))/mnist_kmeans_example/model/kmeans-mnist-model/model.tar.gz
|
||||
containerHostname: xgboost
|
||||
mode: SingleModel
|
||||
production_variants:
|
||||
- variantName: variant-1
|
||||
initialVariantWeight: 1.0
|
||||
instanceType: ml.m5.xlarge
|
||||
initialInstanceCount: 1
|
|
@ -0,0 +1,45 @@
|
|||
import kfp
|
||||
from kfp import components
|
||||
from kfp import dsl
|
||||
|
||||
sagemaker_Model_op = components.load_component_from_file("../../Model/component.yaml")
|
||||
|
||||
sagemaker_EndpointConfig_op = components.load_component_from_file(
|
||||
"../../EndpointConfig/component.yaml"
|
||||
)
|
||||
|
||||
sagemaker_Endpoint_op = components.load_component_from_file(
|
||||
"../../Endpoint/component.yaml"
|
||||
)
|
||||
|
||||
|
||||
@dsl.pipeline(name="CreateHosting", description="SageMaker Hosting")
|
||||
def Hosting(
|
||||
region="",
|
||||
execution_role_arn="",
|
||||
model_name="",
|
||||
primary_container="",
|
||||
endpoint_config_name="",
|
||||
production_variants="",
|
||||
endpoint_name="",
|
||||
):
|
||||
Model = sagemaker_Model_op(
|
||||
region=region,
|
||||
execution_role_arn=execution_role_arn,
|
||||
model_name=model_name,
|
||||
primary_container=primary_container,
|
||||
)
|
||||
EndpointConfig = sagemaker_EndpointConfig_op(
|
||||
region=region,
|
||||
endpoint_config_name=endpoint_config_name,
|
||||
production_variants=production_variants,
|
||||
).after(Model)
|
||||
|
||||
Endpoint = sagemaker_Endpoint_op(
|
||||
region=region,
|
||||
endpoint_config_name=endpoint_config_name,
|
||||
endpoint_name=endpoint_name,
|
||||
).after(EndpointConfig)
|
||||
|
||||
|
||||
kfp.compiler.Compiler().compile(Hosting, __file__ + ".tar.gz")
|
|
@ -0,0 +1,59 @@
|
|||
import kfp
|
||||
from kfp import components
|
||||
from kfp import dsl
|
||||
|
||||
sagemaker_Model_op = components.load_component_from_file("../../Model/component.yaml")
|
||||
|
||||
sagemaker_EndpointConfig_op = components.load_component_from_file(
|
||||
"../../EndpointConfig/component.yaml"
|
||||
)
|
||||
|
||||
sagemaker_Endpoint_op = components.load_component_from_file(
|
||||
"../../Endpoint/component.yaml"
|
||||
)
|
||||
|
||||
|
||||
@dsl.pipeline(name="Update Hosting", description="SageMaker Hosting")
|
||||
def UpdateHosting(
|
||||
region="",
|
||||
execution_role_arn="",
|
||||
model_name="",
|
||||
primary_container="",
|
||||
endpoint_config_name="",
|
||||
production_variants="",
|
||||
endpoint_name="",
|
||||
second_endpoint_config_name="",
|
||||
second_production_variants="",
|
||||
):
|
||||
Model = sagemaker_Model_op(
|
||||
region=region,
|
||||
execution_role_arn=execution_role_arn,
|
||||
model_name=model_name,
|
||||
primary_container=primary_container,
|
||||
)
|
||||
EndpointConfig = sagemaker_EndpointConfig_op(
|
||||
region=region,
|
||||
endpoint_config_name=endpoint_config_name,
|
||||
production_variants=production_variants,
|
||||
).after(Model)
|
||||
|
||||
Endpoint = sagemaker_Endpoint_op(
|
||||
region=region,
|
||||
endpoint_config_name=endpoint_config_name,
|
||||
endpoint_name=endpoint_name,
|
||||
).after(EndpointConfig)
|
||||
|
||||
SecondEndpointConfig = sagemaker_EndpointConfig_op(
|
||||
region=region,
|
||||
endpoint_config_name=second_endpoint_config_name,
|
||||
production_variants=second_production_variants,
|
||||
).after(Model)
|
||||
|
||||
EndpointUpdate = sagemaker_Endpoint_op(
|
||||
region=region,
|
||||
endpoint_config_name=second_endpoint_config_name,
|
||||
endpoint_name=endpoint_name,
|
||||
).after(Endpoint)
|
||||
|
||||
|
||||
kfp.compiler.Compiler().compile(UpdateHosting, __file__ + ".tar.gz")
|
|
@ -7,7 +7,7 @@ def k8s_client():
|
|||
return config.new_client_from_config()
|
||||
|
||||
|
||||
def _get_resource(k8s_client, job_name, kvars):
|
||||
def _get_resource(k8s_client, job_name, plural):
|
||||
"""Get the custom resource detail similar to: kubectl describe <resource> JOB_NAME -n NAMESPACE.
|
||||
Returns:
|
||||
None or object: None if the resource doesnt exist in server, otherwise the
|
||||
|
@ -16,22 +16,34 @@ def _get_resource(k8s_client, job_name, kvars):
|
|||
_api = client.CustomObjectsApi(k8s_client)
|
||||
namespace = os.environ.get("NAMESPACE")
|
||||
job_description = _api.get_namespaced_custom_object(
|
||||
kvars["group"].lower(),
|
||||
kvars["version"].lower(),
|
||||
"sagemaker.services.k8s.aws",
|
||||
"v1alpha1",
|
||||
namespace.lower(),
|
||||
kvars["plural"].lower(),
|
||||
plural,
|
||||
job_name.lower(),
|
||||
)
|
||||
return job_description
|
||||
|
||||
|
||||
def describe_training_job(k8s_client, training_job_name):
|
||||
training_vars = {
|
||||
"group": "sagemaker.services.k8s.aws",
|
||||
"version": "v1alpha1",
|
||||
"plural": "trainingjobs",
|
||||
}
|
||||
return _get_resource(k8s_client, training_job_name, training_vars)
|
||||
def _delete_resource(k8s_client, job_name, plural):
|
||||
"""Delete the custom resource
|
||||
Returns:
|
||||
None or object: None if the resource doesnt exist in server, otherwise the
|
||||
custom object.
|
||||
"""
|
||||
_api = client.CustomObjectsApi(k8s_client)
|
||||
namespace = os.environ.get("NAMESPACE")
|
||||
try:
|
||||
_api.delete_namespaced_custom_object(
|
||||
"sagemaker.services.k8s.aws",
|
||||
"v1alpha1",
|
||||
namespace.lower(),
|
||||
plural,
|
||||
job_name.lower(),
|
||||
)
|
||||
except:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# TODO: Make this a generalized function for non-job resources.
|
||||
|
@ -39,8 +51,43 @@ def wait_for_trainingjob_status(
|
|||
k8s_client, training_job_name, desiredStatuses, wait_periods, period_length
|
||||
):
|
||||
for _ in range(wait_periods):
|
||||
response = describe_training_job(k8s_client, training_job_name)
|
||||
response = _get_resource(k8s_client, training_job_name, "trainingjobs")
|
||||
if response["status"]["trainingJobStatus"] in desiredStatuses:
|
||||
return True
|
||||
sleep(period_length)
|
||||
return False
|
||||
|
||||
|
||||
def wait_for_condition(
|
||||
k8s_client, resource_name, validator_function, wait_periods=10, period_length=8
|
||||
):
|
||||
for _ in range(wait_periods):
|
||||
if not validator_function(k8s_client, resource_name):
|
||||
sleep(period_length)
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def does_endpoint_exist(k8s_client, endpoint_name):
|
||||
try:
|
||||
response = _get_resource(k8s_client, endpoint_name, "endpoints")
|
||||
if response:
|
||||
return True
|
||||
if response is None: # kubernetes module error
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def is_endpoint_deleted(k8s_client, endpoint_name):
|
||||
try:
|
||||
response = _get_resource(k8s_client, endpoint_name, "endpoints")
|
||||
if response:
|
||||
return False
|
||||
if (
|
||||
response is None
|
||||
): # kubernetes module error, 404 would mean the resource doesnt exist
|
||||
return False
|
||||
except:
|
||||
return True
|
||||
|
|
|
@ -2,7 +2,12 @@ import logging
|
|||
import re
|
||||
from datetime import datetime
|
||||
from time import sleep
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import gzip
|
||||
import io
|
||||
import numpy
|
||||
import json
|
||||
|
||||
def describe_training_job(client, training_job_name):
|
||||
return client.describe_training_job(TrainingJobName=training_job_name)
|
||||
|
@ -85,3 +90,33 @@ def stop_labeling_job(client, labeling_job_name):
|
|||
|
||||
def describe_processing_job(client, processing_job_name):
|
||||
return client.describe_processing_job(ProcessingJobName=processing_job_name)
|
||||
|
||||
def run_predict_mnist(boto3_session, endpoint_name, download_dir):
|
||||
"""https://github.com/awslabs/amazon-sagemaker-
|
||||
examples/blob/a8c20eeb72dc7d3e94aaaf28be5bf7d7cd5695cb.
|
||||
|
||||
/sagemaker-python-sdk/1P_kmeans_lowlevel/kmeans_mnist_lowlevel.ipynb
|
||||
"""
|
||||
# Download and load dataset
|
||||
region = boto3_session.region_name
|
||||
download_path = os.path.join(download_dir, "mnist.pkl.gz")
|
||||
boto3_session.resource("s3", region_name=region).Bucket(
|
||||
"sagemaker-sample-data-{}".format(region)
|
||||
).download_file("algorithms/kmeans/mnist/mnist.pkl.gz", download_path)
|
||||
with gzip.open(download_path, "rb") as f:
|
||||
train_set, valid_set, test_set = pickle.load(f, encoding="latin1")
|
||||
|
||||
# Function to create a csv from numpy array
|
||||
def np2csv(arr):
|
||||
csv = io.BytesIO()
|
||||
numpy.savetxt(csv, arr, delimiter=",", fmt="%g")
|
||||
return csv.getvalue().decode().rstrip()
|
||||
|
||||
# Run prediction on an image
|
||||
runtime = boto3_session.client("sagemaker-runtime")
|
||||
payload = np2csv(train_set[0][30:31])
|
||||
|
||||
response = runtime.invoke_endpoint(
|
||||
EndpointName=endpoint_name, ContentType="text/csv", Body=payload,
|
||||
)
|
||||
return json.loads(response["Body"].read().decode())
|
Loading…
Reference in New Issue