[AWS SageMaker] Processing job component (#3944)
* Add TDD processing definition * Update README * Update temporary image * Update component entrypoint * Add WORKDIR to fix Docker 18 support * integration test for processing job * Remove job links * Add container outputs and tests * Update default properties * Remove max_run_time if none provided * Update integration readme steps * Updated README with more resources * Add CloudWatch link back to logs * Update input and output config to arrays * Update processing integration test * Update process README * Update unit tests * Updated license version * Update component image versions * Update changelog Co-authored-by: Suraj Kota <surakota@amazon.com>
This commit is contained in:
parent
53d0244538
commit
bea63652e1
|
|
@ -4,10 +4,21 @@ The version of the AWS SageMaker Components is determined by the docker image ta
|
|||
Repository: https://hub.docker.com/repository/docker/amazon/aws-sagemaker-kfp-components
|
||||
|
||||
---------------------------------------------
|
||||
**Change log for version 0.4.0**
|
||||
- Add new component for SageMaker Processing Jobs
|
||||
|
||||
> Pull requests : [#3944](https://github.com/kubeflow/pipelines/pull/3944)
|
||||
|
||||
|
||||
**Change log for version 0.3.1**
|
||||
- Explicitly specify component field types
|
||||
|
||||
> Pull requests : [#3683](https://github.com/kubeflow/pipelines/pull/3683)
|
||||
|
||||
|
||||
**Change log for version 0.3.0**
|
||||
- Remove data_location parameters from all components
|
||||
(Use "channes" parameter instead)
|
||||
(Use "channels" parameter instead)
|
||||
|
||||
> Pull requests : [#3518](https://github.com/kubeflow/pipelines/pull/3518)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,11 +23,13 @@ RUN yum update -y \
|
|||
unzip
|
||||
|
||||
RUN pip3 install \
|
||||
boto3==1.12.33 \
|
||||
boto3==1.13.19 \
|
||||
sagemaker==1.54.0 \
|
||||
pathlib2==2.3.5 \
|
||||
pyyaml==3.12
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY LICENSE.txt .
|
||||
COPY NOTICE.txt .
|
||||
COPY THIRD-PARTY-LICENSES.txt .
|
||||
|
|
@ -35,6 +37,7 @@ COPY hyperparameter_tuning/src/hyperparameter_tuning.py .
|
|||
COPY train/src/train.py .
|
||||
COPY deploy/src/deploy.py .
|
||||
COPY model/src/create_model.py .
|
||||
COPY process/src/process.py .
|
||||
COPY batch_transform/src/batch_transform.py .
|
||||
COPY workteam/src/workteam.py .
|
||||
COPY ground_truth/src/ground_truth.py .
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
** Amazon SageMaker Components for Kubeflow Pipelines; version 0.3.1 --
|
||||
** Amazon SageMaker Components for Kubeflow Pipelines; version 0.4.0 --
|
||||
https://github.com/kubeflow/pipelines/tree/master/components/aws/sagemaker
|
||||
Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
** boto3; version 1.12.33 -- https://github.com/boto/boto3/
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ outputs:
|
|||
- {name: output_location, description: 'S3 URI of the transform job results.'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.3.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.4.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
batch_transform.py,
|
||||
|
|
|
|||
|
|
@ -861,6 +861,116 @@ def enable_spot_instance_support(training_job_config, args):
|
|||
del training_job_config['StoppingCondition']['MaxWaitTimeInSeconds']
|
||||
del training_job_config['CheckpointConfig']
|
||||
|
||||
def create_processing_job_request(args):
|
||||
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_processing_job
|
||||
with open(os.path.join(__cwd__, 'process.template.yaml'), 'r') as f:
|
||||
request = yaml.safe_load(f)
|
||||
|
||||
job_name = args['job_name'] if args['job_name'] else 'ProcessingJob-' + strftime("%Y%m%d%H%M%S", gmtime()) + '-' + id_generator()
|
||||
|
||||
request['ProcessingJobName'] = job_name
|
||||
request['RoleArn'] = args['role']
|
||||
|
||||
### Update processing container settings
|
||||
request['AppSpecification']['ImageUri'] = args['image']
|
||||
|
||||
if args['container_entrypoint']:
|
||||
request['AppSpecification']['ContainerEntrypoint'] = args['container_entrypoint']
|
||||
else:
|
||||
request['AppSpecification'].pop('ContainerEntrypoint')
|
||||
if args['container_arguments']:
|
||||
request['AppSpecification']['ContainerArguments'] = args['container_arguments']
|
||||
else:
|
||||
request['AppSpecification'].pop('ContainerArguments')
|
||||
|
||||
### Update or pop VPC configs
|
||||
if args['vpc_security_group_ids'] and args['vpc_subnets']:
|
||||
request['NetworkConfig']['VpcConfig']['SecurityGroupIds'] = args['vpc_security_group_ids'].split(',')
|
||||
request['NetworkConfig']['VpcConfig']['Subnets'] = args['vpc_subnets'].split(',')
|
||||
else:
|
||||
request['NetworkConfig'].pop('VpcConfig')
|
||||
request['NetworkConfig']['EnableNetworkIsolation'] = args['network_isolation']
|
||||
request['NetworkConfig']['EnableInterContainerTrafficEncryption'] = args['traffic_encryption']
|
||||
|
||||
### Update input channels, not a required field
|
||||
if args['input_config']:
|
||||
request['ProcessingInputs'] = args['input_config']
|
||||
else:
|
||||
request.pop('ProcessingInputs')
|
||||
|
||||
### Update output channels, must have at least one specified
|
||||
if len(args['output_config']) > 0:
|
||||
request['ProcessingOutputConfig']['Outputs'] = args['output_config']
|
||||
else:
|
||||
logging.error("Must specify at least one output channel.")
|
||||
raise Exception('Could not create job request')
|
||||
|
||||
if args['output_encryption_key']:
|
||||
request['ProcessingOutputConfig']['KmsKeyId'] = args['output_encryption_key']
|
||||
else:
|
||||
request['ProcessingOutputConfig'].pop('KmsKeyId')
|
||||
|
||||
### Update cluster config resources
|
||||
request['ProcessingResources']['ClusterConfig']['InstanceType'] = args['instance_type']
|
||||
request['ProcessingResources']['ClusterConfig']['InstanceCount'] = args['instance_count']
|
||||
request['ProcessingResources']['ClusterConfig']['VolumeSizeInGB'] = args['volume_size']
|
||||
|
||||
if args['resource_encryption_key']:
|
||||
request['ProcessingResources']['ClusterConfig']['VolumeKmsKeyId'] = args['resource_encryption_key']
|
||||
else:
|
||||
request['ProcessingResources']['ClusterConfig'].pop('VolumeKmsKeyId')
|
||||
|
||||
if args['max_run_time']:
|
||||
request['StoppingCondition']['MaxRuntimeInSeconds'] = args['max_run_time']
|
||||
else:
|
||||
request['StoppingCondition']['MaxRuntimeInSeconds'].pop('max_run_time')
|
||||
|
||||
request['Environment'] = args['environment']
|
||||
|
||||
### Update tags
|
||||
for key, val in args['tags'].items():
|
||||
request['Tags'].append({'Key': key, 'Value': val})
|
||||
|
||||
return request
|
||||
|
||||
|
||||
def create_processing_job(client, args):
|
||||
"""Create a SageMaker processing job."""
|
||||
request = create_processing_job_request(args)
|
||||
try:
|
||||
client.create_processing_job(**request)
|
||||
processing_job_name = request['ProcessingJobName']
|
||||
logging.info("Created Processing Job with name: " + processing_job_name)
|
||||
logging.info("CloudWatch logs: https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix"
|
||||
.format(args['region'], args['region'], processing_job_name))
|
||||
return processing_job_name
|
||||
except ClientError as e:
|
||||
raise Exception(e.response['Error']['Message'])
|
||||
|
||||
|
||||
def wait_for_processing_job(client, processing_job_name, poll_interval=30):
|
||||
while(True):
|
||||
response = client.describe_processing_job(ProcessingJobName=processing_job_name)
|
||||
status = response['ProcessingJobStatus']
|
||||
if status == 'Completed':
|
||||
logging.info("Processing job ended with status: " + status)
|
||||
break
|
||||
if status == 'Failed':
|
||||
message = response['FailureReason']
|
||||
logging.info('Processing failed with the following error: {}'.format(message))
|
||||
raise Exception('Processing job failed')
|
||||
logging.info("Processing job is still in status: " + status)
|
||||
time.sleep(poll_interval)
|
||||
|
||||
def get_processing_job_outputs(client, processing_job_name):
|
||||
"""Map the S3 outputs of a processing job to a dictionary object."""
|
||||
response = client.describe_processing_job(ProcessingJobName=processing_job_name)
|
||||
outputs = {}
|
||||
for output in response['ProcessingOutputConfig']['Outputs']:
|
||||
outputs[output['OutputName']] = output['S3Output']['S3Uri']
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def id_generator(size=4, chars=string.ascii_uppercase + string.digits):
|
||||
return ''.join(random.choice(chars) for _ in range(size))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
ProcessingJobName: ''
|
||||
ProcessingInputs: []
|
||||
ProcessingOutputConfig:
|
||||
Outputs: []
|
||||
KmsKeyId: ''
|
||||
RoleArn: ''
|
||||
ProcessingResources:
|
||||
ClusterConfig:
|
||||
InstanceType: ''
|
||||
InstanceCount: 1
|
||||
VolumeSizeInGB: 1
|
||||
VolumeKmsKeyId: ''
|
||||
NetworkConfig:
|
||||
EnableInterContainerTrafficEncryption: False
|
||||
EnableNetworkIsolation: False
|
||||
VpcConfig:
|
||||
SecurityGroupIds: []
|
||||
Subnets: []
|
||||
StoppingCondition:
|
||||
MaxRuntimeInSeconds: 86400
|
||||
AppSpecification:
|
||||
ImageUri: ''
|
||||
ContainerEntrypoint: []
|
||||
ContainerArguments: []
|
||||
Environment: {}
|
||||
Tags: []
|
||||
|
|
@ -104,7 +104,7 @@ outputs:
|
|||
- {name: endpoint_name, description: 'Endpoint name'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.3.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.4.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
deploy.py,
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ outputs:
|
|||
- {name: active_learning_model_arn, description: 'The ARN for the most recent Amazon SageMaker model trained as part of automated data labeling.'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.3.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.4.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
ground_truth.py,
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ outputs:
|
|||
description: 'The registry path of the Docker image that contains the training algorithm'
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.3.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.4.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
hyperparameter_tuning.py,
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ outputs:
|
|||
- {name: model_name, description: 'The model name Sagemaker created'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.3.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.4.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
create_model.py,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
# SageMaker Processing Kubeflow Pipelines component
|
||||
|
||||
## Summary
|
||||
Component to submit SageMaker Processing jobs directly from a Kubeflow Pipelines workflow.
|
||||
https://docs.aws.amazon.com/sagemaker/latest/dg/processing-job.html
|
||||
|
||||
## Intended Use
|
||||
For running your data processing workloads, such as feature engineering, data validation, model evaluation, and model interpretation using AWS SageMaker.
|
||||
|
||||
## Runtime Arguments
|
||||
Argument | Description | Optional | Data type | Accepted values | Default |
|
||||
:--- | :---------- | :----------| :----------| :---------- | :----------|
|
||||
region | The region where the cluster launches | No | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint. | Yes | String | | |
|
||||
job_name | The name of the Processing job. Must be unique within the same AWS account and AWS region | Yes | String | | ProcessingJob-[datetime]-[random id]|
|
||||
role | The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf | No | String | | |
|
||||
image | The registry path of the Docker image that contains the processing script | Yes | String | | |
|
||||
instance_type | The ML compute instance type | Yes | String | ml.m4.xlarge, ml.m4.2xlarge, ml.m4.4xlarge, ml.m4.10xlarge, ml.m4.16xlarge, ml.m5.large, ml.m5.xlarge, ml.m5.2xlarge, ml.m5.4xlarge, ml.m5.12xlarge, ml.m5.24xlarge, ml.c4.xlarge, ml.c4.2xlarge, ml.c4.4xlarge, ml.c4.8xlarge, ml.p2.xlarge, ml.p2.8xlarge, ml.p2.16xlarge, ml.p3.2xlarge, ml.p3.8xlarge, ml.p3.16xlarge, ml.c5.xlarge, ml.c5.2xlarge, ml.c5.4xlarge, ml.c5.9xlarge, ml.c5.18xlarge [and many more](https://aws.amazon.com/sagemaker/pricing/instance-types/) | ml.m4.xlarge |
|
||||
instance_count | The number of ML compute instances to use in each processing job | Yes | Int | ≥ 1 | 1 |
|
||||
volume_size | The size of the ML storage volume that you want to provision in GB | Yes | Int | ≥ 1 | 30 |
|
||||
resource_encryption_key | The AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) | Yes | String | | |
|
||||
output_encryption_key | The AWS KMS key that Amazon SageMaker uses to encrypt the model artifacts | Yes | String | | |
|
||||
max_run_time | The maximum run time in seconds per processing job | Yes | Int | ≤ 432000 (5 days) | 86400 (1 day) |
|
||||
environment | The environment variables to set in the Docker container | Yes | Yes | Dict | Maximum length of 1024. Key Pattern: `[a-zA-Z_][a-zA-Z0-9_]*`. Value Pattern: `[\S\s]*`. Upto 16 key and values entries in the map | |
|
||||
container_entrypoint | The entrypoint for the processing job. This is in the form of a list of strings that make a command | Yes | Yes | List of Strings | | [] |
|
||||
container_arguments | A list of string arguments to be passed to a processing job | Yes | Yes | List of Strings | | [] |
|
||||
input_config | Parameters that specify Amazon S3 inputs for a processing job | No | List of Dicts | | [] |
|
||||
output_config | Parameters that specify Amazon S3 outputs for a processing job | No | List of Dict | | [] |
|
||||
vpc_security_group_ids | A comma-delimited list of security group IDs, in the form sg-xxxxxxxx | Yes | String | | |
|
||||
vpc_subnets | A comma-delimited list of subnet IDs in the VPC to which you want to connect your hpo job | Yes | String | | |
|
||||
network_isolation | Isolates the processing container if true | No | Boolean | False, True | True |
|
||||
traffic_encryption | Encrypts all communications between ML compute instances in distributed processing if true | No | Boolean | False, True | False |
|
||||
tags | Key-value pairs to categorize AWS resources | Yes | Dict | | {} |
|
||||
|
||||
Notes:
|
||||
* You can find more information about how container entrypoint and arguments are used at the [Build Your Own Processing Container](https://docs.aws.amazon.com/sagemaker/latest/dg/build-your-own-processing-container.html#byoc-run-image) documentation.
|
||||
* Each key and value in the `environment` parameter string to string map can have length of up to 1024. SageMaker supports up to 16 entries in the map.
|
||||
* The format for the [`input_config`](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProcessingInput.html) field is:
|
||||
```
|
||||
[
|
||||
{
|
||||
'InputName': 'string',
|
||||
'S3Input': {
|
||||
'S3Uri': 'string',
|
||||
'LocalPath': 'string',
|
||||
'S3DataType': 'ManifestFile'|'S3Prefix',
|
||||
'S3InputMode': 'Pipe'|'File',
|
||||
'S3DataDistributionType': 'FullyReplicated'|'ShardedByS3Key',
|
||||
'S3CompressionType': 'None'|'Gzip'
|
||||
}
|
||||
},
|
||||
]
|
||||
```
|
||||
* The format for the [`output_config`](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProcessingS3Output.html) field is:
|
||||
```
|
||||
[
|
||||
{
|
||||
'OutputName': 'string',
|
||||
'S3Output': {
|
||||
'S3Uri': 'string',
|
||||
'LocalPath': 'string',
|
||||
'S3UploadMode': 'Continuous'|'EndOfJob'
|
||||
}
|
||||
},
|
||||
]
|
||||
```
|
||||
|
||||
## Outputs
|
||||
Name | Description
|
||||
:--- | :----------
|
||||
job_name | Processing job name
|
||||
output_artifacts | A dictionary mapping with `output_config` `OutputName` as the key and `S3Uri` as the value
|
||||
|
||||
## Requirements
|
||||
* [Kubeflow pipelines SDK](https://www.kubeflow.org/docs/pipelines/sdk/install-sdk/)
|
||||
* [Kubeflow set-up](https://www.kubeflow.org/docs/aws/deploy/install-kubeflow/)
|
||||
|
||||
## Resources
|
||||
* [Create Processing Job API documentation](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html)
|
||||
* [Boto3 API reference](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_processing_job)
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
name: 'SageMaker - Processing Job'
|
||||
description: |
|
||||
Perform data pre-processing, post-processing, feature engineering, data validation, and model evaluation, and interpretation on using SageMaker
|
||||
inputs:
|
||||
- name: region
|
||||
description: 'The region where the processing job launches.'
|
||||
type: String
|
||||
- name: job_name
|
||||
description: 'The name of the processing job.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: role
|
||||
description: 'The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf.'
|
||||
type: String
|
||||
- name: image
|
||||
description: 'The registry path of the Docker image that contains the processing container.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: instance_type
|
||||
description: 'The ML compute instance type.'
|
||||
default: 'ml.m4.xlarge'
|
||||
type: String
|
||||
- name: instance_count
|
||||
description: 'The number of ML compute instances to use in each processing job.'
|
||||
default: '1'
|
||||
type: Integer
|
||||
- name: volume_size
|
||||
description: 'The size of the ML storage volume that you want to provision.'
|
||||
default: '30'
|
||||
type: Integer
|
||||
- name: resource_encryption_key
|
||||
description: 'The AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s).'
|
||||
default: ''
|
||||
type: String
|
||||
- name: max_run_time
|
||||
description: 'The maximum run time in seconds for the processing job.'
|
||||
default: '86400'
|
||||
type: Integer
|
||||
- name: environment
|
||||
description: 'The environment variables to set in the Docker container. Up to 16 key-value entries in the map.'
|
||||
default: '{}'
|
||||
type: JsonObject
|
||||
- name: container_entrypoint
|
||||
description: 'The entrypoint for the processing job. This is in the form of a list of strings that make a command.'
|
||||
default: '[]'
|
||||
type: JsonArray
|
||||
- name: container_arguments
|
||||
description: 'A list of string arguments to be passed to a processing job.'
|
||||
default: '[]'
|
||||
type: JsonArray
|
||||
- name: output_config
|
||||
description: 'Parameters that specify Amazon S3 outputs for a processing job.'
|
||||
default: '[]'
|
||||
type: JsonArray
|
||||
- name: input_config
|
||||
description: 'Parameters that specify Amazon S3 inputs for a processing job.'
|
||||
default: '[]'
|
||||
type: JsonArray
|
||||
- name: output_encryption_key
|
||||
description: 'The AWS KMS key that Amazon SageMaker uses to encrypt the processing artifacts.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: vpc_security_group_ids
|
||||
description: 'The VPC security group IDs, in the form sg-xxxxxxxx.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: vpc_subnets
|
||||
description: 'The ID of the subnets in the VPC to which you want to connect your hpo job.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: network_isolation
|
||||
description: 'Isolates the processing job container.'
|
||||
default: 'True'
|
||||
type: Bool
|
||||
- name: traffic_encryption
|
||||
description: 'Encrypts all communications between ML compute instances in distributed training.'
|
||||
default: 'False'
|
||||
type: Bool
|
||||
- name: endpoint_url
|
||||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: tags
|
||||
description: 'Key-value pairs, to categorize AWS resources.'
|
||||
default: '{}'
|
||||
type: JsonObject
|
||||
outputs:
|
||||
- {name: job_name, description: 'Processing job name'}
|
||||
- {name: output_artifacts, description: 'A dictionary containing the output S3 artifacts'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.4.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
process.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--job_name, {inputValue: job_name},
|
||||
--role, {inputValue: role},
|
||||
--image, {inputValue: image},
|
||||
--instance_type, {inputValue: instance_type},
|
||||
--instance_count, {inputValue: instance_count},
|
||||
--volume_size, {inputValue: volume_size},
|
||||
--resource_encryption_key, {inputValue: resource_encryption_key},
|
||||
--output_encryption_key, {inputValue: output_encryption_key},
|
||||
--max_run_time, {inputValue: max_run_time},
|
||||
--environment, {inputValue: environment},
|
||||
--container_entrypoint, {inputValue: container_entrypoint},
|
||||
--container_arguments, {inputValue: container_arguments},
|
||||
--output_config, {inputValue: output_config},
|
||||
--input_config, {inputValue: input_config},
|
||||
--vpc_security_group_ids, {inputValue: vpc_security_group_ids},
|
||||
--vpc_subnets, {inputValue: vpc_subnets},
|
||||
--network_isolation, {inputValue: network_isolation},
|
||||
--traffic_encryption, {inputValue: traffic_encryption},
|
||||
--tags, {inputValue: tags}
|
||||
]
|
||||
fileOutputs:
|
||||
job_name: /tmp/job_name.txt
|
||||
output_artifacts: /tmp/output_artifacts.txt
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
|
||||
from common import _utils
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser(description='SageMaker Processing Job')
|
||||
_utils.add_default_client_arguments(parser)
|
||||
|
||||
parser.add_argument('--job_name', type=str, required=False, help='The name of the processing job.', default='')
|
||||
parser.add_argument('--role', type=str, required=True, help='The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf.')
|
||||
parser.add_argument('--image', type=str, required=True, help='The registry path of the Docker image that contains the processing container.', default='')
|
||||
parser.add_argument('--instance_type', required=True, type=str, help='The ML compute instance type.', default='ml.m4.xlarge')
|
||||
parser.add_argument('--instance_count', required=True, type=int, help='The number of ML compute instances to use in each processing job.', default=1)
|
||||
parser.add_argument('--volume_size', type=int, required=False, help='The size of the ML storage volume that you want to provision.', default=30)
|
||||
parser.add_argument('--resource_encryption_key', type=str, required=False, help='The AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s).', default='')
|
||||
parser.add_argument('--output_encryption_key', type=str, required=False, help='The AWS KMS key that Amazon SageMaker uses to encrypt the processing artifacts.', default='')
|
||||
parser.add_argument('--max_run_time', type=int, required=False, help='The maximum run time in seconds for the processing job.', default=86400)
|
||||
parser.add_argument('--environment', type=_utils.yaml_or_json_str, required=False, help='The dictionary of the environment variables to set in the Docker container. Up to 16 key-value entries in the map.', default={})
|
||||
parser.add_argument('--container_entrypoint', type=_utils.yaml_or_json_str, required=False, help='The entrypoint for the processing job. This is in the form of a list of strings that make a command.', default=[])
|
||||
parser.add_argument('--container_arguments', type=_utils.yaml_or_json_str, required=False, help='A list of string arguments to be passed to a processing job.', default=[])
|
||||
parser.add_argument('--input_config', type=_utils.yaml_or_json_str, required=False, help='Parameters that specify Amazon S3 inputs for a processing job.', default=[])
|
||||
parser.add_argument('--output_config', type=_utils.yaml_or_json_str, required=True, help='Parameters that specify Amazon S3 outputs for a processing job.', default=[])
|
||||
parser.add_argument('--vpc_security_group_ids', type=str, required=False, help='The VPC security group IDs, in the form sg-xxxxxxxx.')
|
||||
parser.add_argument('--vpc_subnets', type=str, required=False, help='The ID of the subnets in the VPC to which you want to connect your hpo job.')
|
||||
parser.add_argument('--network_isolation', type=_utils.str_to_bool, required=False, help='Isolates the processing container.', default=True)
|
||||
parser.add_argument('--traffic_encryption', type=_utils.str_to_bool, required=False, help='Encrypts all communications between ML compute instances in distributed training.', default=False)
|
||||
parser.add_argument('--tags', type=_utils.yaml_or_json_str, required=False, help='An array of key-value pairs, to categorize AWS resources.', default={})
|
||||
|
||||
return parser
|
||||
|
||||
def main(argv=None):
|
||||
parser = create_parser()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
|
||||
|
||||
logging.info('Submitting Processing Job to SageMaker...')
|
||||
job_name = _utils.create_processing_job(client, vars(args))
|
||||
logging.info('Job request submitted. Waiting for completion...')
|
||||
_utils.wait_for_processing_job(client, job_name)
|
||||
|
||||
outputs = _utils.get_processing_job_outputs(client, job_name)
|
||||
|
||||
with open('/tmp/job_name.txt', 'w') as f:
|
||||
f.write(job_name)
|
||||
|
||||
with open('/tmp/output_artifacts.txt', 'w') as f:
|
||||
f.write(json.dumps(outputs))
|
||||
|
||||
logging.info('Job completed.')
|
||||
|
||||
|
||||
if __name__== "__main__":
|
||||
main(sys.argv[1:])
|
||||
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
1. In the following Python script, change the bucket name and run the [`s3_sample_data_creator.py`](https://github.com/kubeflow/pipelines/tree/master/samples/contrib/aws-samples/mnist-kmeans-sagemaker#the-sample-dataset) to create an S3 bucket with the sample mnist dataset in the region where you want to run the tests.
|
||||
2. To prepare the dataset for the SageMaker GroundTruth Component test, follow the steps in the `[GroundTruth Sample README](https://github.com/kubeflow/pipelines/tree/master/samples/contrib/aws-samples/ground_truth_pipeline_demo#prep-the-dataset-label-categories-and-ui-template)`.
|
||||
3. To prepare the processing script for the SageMaker Processing Component tests, upload the `scripts/kmeans_preprocessing.py` script to your bucket. This can be done by replacing `<my-bucket> with your bucket name and running `aws s3 cp scripts/kmeans_preprocessing.py s3://<my-bucket>/mnist_kmeans_example/processing_code/kmeans_preprocessing.py`
|
||||
|
||||
|
||||
## Step to run integration tests
|
||||
|
|
|
|||
|
|
@ -0,0 +1,89 @@
|
|||
import pytest
|
||||
import os
|
||||
import json
|
||||
import utils
|
||||
from utils import kfp_client_utils
|
||||
from utils import minio_utils
|
||||
from utils import sagemaker_utils
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_file_dir",
|
||||
[
|
||||
pytest.param(
|
||||
"resources/config/kmeans-algo-mnist-processing",
|
||||
marks=pytest.mark.canary_test,
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_processingjob(
|
||||
kfp_client, experiment_id, region, sagemaker_client, test_file_dir
|
||||
):
|
||||
|
||||
download_dir = utils.mkdir(os.path.join(test_file_dir + "/generated"))
|
||||
test_params = utils.load_params(
|
||||
utils.replace_placeholders(
|
||||
os.path.join(test_file_dir, "config.yaml"),
|
||||
os.path.join(download_dir, "config.yaml"),
|
||||
)
|
||||
)
|
||||
|
||||
test_params["Arguments"]["input_config"] = json.dumps(
|
||||
test_params["Arguments"]["input_config"]
|
||||
)
|
||||
test_params["Arguments"]["output_config"] = json.dumps(
|
||||
test_params["Arguments"]["output_config"]
|
||||
)
|
||||
|
||||
# Generate random prefix for job name to avoid errors if model with same name exists
|
||||
test_params["Arguments"]["job_name"] = input_job_name = (
|
||||
utils.generate_random_string(5) + "-" + test_params["Arguments"]["job_name"]
|
||||
)
|
||||
print(f"running test with job_name: {input_job_name}")
|
||||
|
||||
for index, output in enumerate(test_params["Arguments"]["output_config"]):
|
||||
if "S3Output" in output:
|
||||
test_params["Arguments"]["output_config"][index]["S3Output"][
|
||||
"S3Uri"
|
||||
] = os.path.join(output["S3Output"]["S3Uri"], input_job_name)
|
||||
|
||||
_, _, workflow_json = kfp_client_utils.compile_run_monitor_pipeline(
|
||||
kfp_client,
|
||||
experiment_id,
|
||||
test_params["PipelineDefinition"],
|
||||
test_params["Arguments"],
|
||||
download_dir,
|
||||
test_params["TestName"],
|
||||
test_params["Timeout"],
|
||||
)
|
||||
|
||||
outputs = {"sagemaker-processing-job": ["job_name", "output_artifacts"]}
|
||||
output_files = minio_utils.artifact_download_iterator(
|
||||
workflow_json, outputs, download_dir
|
||||
)
|
||||
|
||||
# Verify Processing job was successful on SageMaker
|
||||
processing_job_name = utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-processing-job"]["job_name"], "job_name.txt"
|
||||
)
|
||||
print(f"processing job name: {processing_job_name}")
|
||||
process_response = sagemaker_utils.describe_processing_job(
|
||||
sagemaker_client, processing_job_name
|
||||
)
|
||||
assert process_response["ProcessingJobStatus"] == "Completed"
|
||||
assert process_response["ProcessingJobArn"].split("/")[1] == input_job_name
|
||||
|
||||
# Verify processing job produced the correct outputs
|
||||
processing_outputs = json.loads(
|
||||
utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-processing-job"]["output_artifacts"],
|
||||
"output_artifacts.txt",
|
||||
)
|
||||
)
|
||||
print(f"processing job outputs: {json.dumps(processing_outputs, indent = 2)}")
|
||||
assert processing_outputs is not None
|
||||
|
||||
for output in process_response["ProcessingOutputConfig"]["Outputs"]:
|
||||
assert processing_outputs[output["OutputName"]] == output["S3Output"]["S3Uri"]
|
||||
|
||||
utils.remove_dir(download_dir)
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
PipelineDefinition: resources/definition/processing_pipeline.py
|
||||
TestName: kmeans-algo-mnist-processing
|
||||
Timeout: 1800
|
||||
Arguments:
|
||||
region: ((REGION))
|
||||
job_name: process-mnist-for-kmeans
|
||||
image: 763104351884.dkr.ecr.((REGION)).amazonaws.com/pytorch-training:1.5.0-cpu-py36-ubuntu16.04
|
||||
container_entrypoint:
|
||||
- python
|
||||
- /opt/ml/processing/code/kmeans_preprocessing.py
|
||||
input_config:
|
||||
- InputName: mnist_tar
|
||||
S3Input:
|
||||
S3Uri: s3://sagemaker-sample-data-((REGION))/algorithms/kmeans/mnist/mnist.pkl.gz
|
||||
LocalPath: /opt/ml/processing/input
|
||||
S3DataType: S3Prefix
|
||||
S3InputMode: File
|
||||
S3CompressionType: None
|
||||
- InputName: source_code
|
||||
S3Input:
|
||||
S3Uri: s3://((DATA_BUCKET))/mnist_kmeans_example/processing_code/kmeans_preprocessing.py
|
||||
LocalPath: /opt/ml/processing/code
|
||||
S3DataType: S3Prefix
|
||||
S3InputMode: File
|
||||
S3CompressionType: None
|
||||
output_config:
|
||||
- OutputName: train_data
|
||||
S3Output:
|
||||
S3Uri: s3://((DATA_BUCKET))/mnist_kmeans_example/output/
|
||||
LocalPath: /opt/ml/processing/output_train/
|
||||
S3UploadMode: EndOfJob
|
||||
- OutputName: test_data
|
||||
S3Output:
|
||||
S3Uri: s3://((DATA_BUCKET))/mnist_kmeans_example/output/
|
||||
LocalPath: /opt/ml/processing/output_test/
|
||||
S3UploadMode: EndOfJob
|
||||
- OutputName: valid_data
|
||||
S3Output:
|
||||
S3Uri: s3://((DATA_BUCKET))/mnist_kmeans_example/output/
|
||||
LocalPath: /opt/ml/processing/output_valid/
|
||||
S3UploadMode: EndOfJob
|
||||
instance_type: ml.m5.xlarge
|
||||
instance_count: 1
|
||||
volume_size: 50
|
||||
max_run_time: 1800
|
||||
role: ((ROLE_ARN))
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import kfp
|
||||
from kfp import components
|
||||
from kfp import dsl
|
||||
|
||||
sagemaker_process_op = components.load_component_from_file(
|
||||
"../../process/component.yaml"
|
||||
)
|
||||
|
||||
|
||||
@dsl.pipeline(name="SageMaker Processing", description="SageMaker processing job test")
|
||||
def processing_pipeline(
|
||||
region="",
|
||||
job_name="",
|
||||
image="",
|
||||
instance_type="",
|
||||
instance_count="",
|
||||
volume_size="",
|
||||
max_run_time="",
|
||||
environment={},
|
||||
container_entrypoint=[],
|
||||
container_arguments=[],
|
||||
input_config={},
|
||||
output_config={},
|
||||
network_isolation=False,
|
||||
role="",
|
||||
):
|
||||
sagemaker_process_op(
|
||||
region=region,
|
||||
job_name=job_name,
|
||||
image=image,
|
||||
instance_type=instance_type,
|
||||
instance_count=instance_count,
|
||||
volume_size=volume_size,
|
||||
max_run_time=max_run_time,
|
||||
environment=environment,
|
||||
container_entrypoint=container_entrypoint,
|
||||
container_arguments=container_arguments,
|
||||
input_config=input_config,
|
||||
output_config=output_config,
|
||||
network_isolation=network_isolation,
|
||||
role=role,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
kfp.compiler.Compiler().compile(
|
||||
processing_pipeline, "SageMaker_processing_pipeline" + ".yaml"
|
||||
)
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
import pickle
|
||||
import gzip
|
||||
import numpy
|
||||
import io
|
||||
from sagemaker.amazon.common import write_numpy_to_dense_tensor
|
||||
|
||||
print("Extracting MNIST data set")
|
||||
# Load the dataset
|
||||
with gzip.open('/opt/ml/processing/input/mnist.pkl.gz', 'rb') as f:
|
||||
train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
|
||||
|
||||
# process the data
|
||||
# Convert the training data into the format required by the SageMaker KMeans algorithm
|
||||
print("Writing training data")
|
||||
with open('/opt/ml/processing/output_train/train_data', 'wb') as train_file:
|
||||
write_numpy_to_dense_tensor(train_file, train_set[0], train_set[1])
|
||||
|
||||
print("Writing test data")
|
||||
with open('/opt/ml/processing/output_test/test_data', 'wb') as test_file:
|
||||
write_numpy_to_dense_tensor(test_file, test_set[0], test_set[1])
|
||||
|
||||
print("Writing validation data")
|
||||
# Convert the valid data into the format required by the SageMaker KMeans algorithm
|
||||
numpy.savetxt('/opt/ml/processing/output_valid/valid-data.csv', valid_set[0], delimiter=',', fmt='%g')
|
||||
|
|
@ -67,3 +67,7 @@ def delete_workteam(client, workteam_name):
|
|||
|
||||
def stop_labeling_job(client, labeling_job_name):
|
||||
client.stop_labeling_job(LabelingJobName=labeling_job_name)
|
||||
|
||||
|
||||
def describe_processing_job(client, processing_job_name):
|
||||
return client.describe_processing_job(ProcessingJobName=processing_job_name)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,238 @@
|
|||
import json
|
||||
import unittest
|
||||
|
||||
from unittest.mock import patch, call, Mock, MagicMock, mock_open
|
||||
from botocore.exceptions import ClientError
|
||||
from datetime import datetime
|
||||
|
||||
from process.src import process
|
||||
from common import _utils
|
||||
from . import test_utils
|
||||
|
||||
required_args = [
|
||||
'--region', 'us-west-2',
|
||||
'--role', 'arn:aws:iam::123456789012:user/Development/product_1234/*',
|
||||
'--image', 'test-image',
|
||||
'--instance_type', 'ml.m4.xlarge',
|
||||
'--instance_count', '1',
|
||||
'--input_config', json.dumps([{
|
||||
'InputName': "dataset-input",
|
||||
'S3Input': {
|
||||
'S3Uri': "s3://my-bucket/dataset.csv",
|
||||
'LocalPath': "/opt/ml/processing/input",
|
||||
'S3DataType': "S3Prefix",
|
||||
'S3InputMode': "File"
|
||||
}
|
||||
}]),
|
||||
'--output_config', json.dumps([{
|
||||
'OutputName': "training-outputs",
|
||||
'S3Output': {
|
||||
'S3Uri': "s3://my-bucket/outputs/train.csv",
|
||||
'LocalPath': "/opt/ml/processing/output/train",
|
||||
'S3UploadMode': "Continuous"
|
||||
}
|
||||
}])
|
||||
]
|
||||
|
||||
class ProcessTestCase(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
parser = process.create_parser()
|
||||
cls.parser = parser
|
||||
|
||||
def test_create_parser(self):
|
||||
self.assertIsNotNone(self.parser)
|
||||
|
||||
def test_main(self):
|
||||
# Mock out all of utils except parser
|
||||
process._utils = MagicMock()
|
||||
process._utils.add_default_client_arguments = _utils.add_default_client_arguments
|
||||
|
||||
# Set some static returns
|
||||
process._utils.create_processing_job.return_value = 'job-name'
|
||||
process._utils.get_processing_job_outputs.return_value = mock_outputs = {'val1': 's3://1', 'val2': 's3://2'}
|
||||
|
||||
with patch('builtins.open', mock_open()) as file_open:
|
||||
process.main(required_args)
|
||||
|
||||
# Check if correct requests were created and triggered
|
||||
process._utils.create_processing_job.assert_called()
|
||||
process._utils.wait_for_processing_job.assert_called()
|
||||
|
||||
# Check the file outputs
|
||||
file_open.assert_has_calls([
|
||||
call('/tmp/job_name.txt', 'w'),
|
||||
call('/tmp/output_artifacts.txt', 'w')
|
||||
], any_order=True)
|
||||
|
||||
file_open().write.assert_has_calls([
|
||||
call('job-name'),
|
||||
call(json.dumps(mock_outputs))
|
||||
], any_order=False) # Must be in the same order as called
|
||||
|
||||
def test_create_processing_job(self):
|
||||
mock_client = MagicMock()
|
||||
mock_args = self.parser.parse_args(required_args + ['--job_name', 'test-job'])
|
||||
response = _utils.create_processing_job(mock_client, vars(mock_args))
|
||||
|
||||
mock_client.create_processing_job.assert_called_once_with(
|
||||
AppSpecification={"ImageUri": "test-image"},
|
||||
Environment={},
|
||||
NetworkConfig={
|
||||
"EnableInterContainerTrafficEncryption": False,
|
||||
"EnableNetworkIsolation": True,
|
||||
},
|
||||
ProcessingInputs=[
|
||||
{
|
||||
"InputName": "dataset-input",
|
||||
"S3Input": {
|
||||
"S3Uri": "s3://my-bucket/dataset.csv",
|
||||
"LocalPath": "/opt/ml/processing/input",
|
||||
"S3DataType": "S3Prefix",
|
||||
"S3InputMode": "File"
|
||||
},
|
||||
}
|
||||
],
|
||||
ProcessingJobName="test-job",
|
||||
ProcessingOutputConfig={
|
||||
"Outputs": [
|
||||
{
|
||||
"OutputName": "training-outputs",
|
||||
"S3Output": {
|
||||
"S3Uri": "s3://my-bucket/outputs/train.csv",
|
||||
"LocalPath": "/opt/ml/processing/output/train",
|
||||
"S3UploadMode": "Continuous"
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
ProcessingResources={
|
||||
"ClusterConfig": {
|
||||
"InstanceType": "ml.m4.xlarge",
|
||||
"InstanceCount": 1,
|
||||
"VolumeSizeInGB": 30,
|
||||
}
|
||||
},
|
||||
RoleArn="arn:aws:iam::123456789012:user/Development/product_1234/*",
|
||||
StoppingCondition={"MaxRuntimeInSeconds": 86400},
|
||||
Tags=[],
|
||||
)
|
||||
self.assertEqual(response, 'test-job')
|
||||
|
||||
def test_sagemaker_exception_in_create_processing_job(self):
|
||||
mock_client = MagicMock()
|
||||
mock_exception = ClientError({"Error": {"Message": "SageMaker broke"}}, "create_processing_job")
|
||||
mock_client.create_processing_job.side_effect = mock_exception
|
||||
mock_args = self.parser.parse_args(required_args)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
response = _utils.create_processing_job(mock_client, vars(mock_args))
|
||||
|
||||
def test_wait_for_processing_job(self):
|
||||
mock_client = MagicMock()
|
||||
mock_client.describe_processing_job.side_effect = [
|
||||
{"ProcessingJobStatus": "Starting"},
|
||||
{"ProcessingJobStatus": "InProgress"},
|
||||
{"ProcessingJobStatus": "Downloading"},
|
||||
{"ProcessingJobStatus": "Completed"},
|
||||
{"ProcessingJobStatus": "Should not be called"}
|
||||
]
|
||||
|
||||
_utils.wait_for_processing_job(mock_client, 'processing-job', 0)
|
||||
self.assertEqual(mock_client.describe_processing_job.call_count, 4)
|
||||
|
||||
def test_wait_for_failed_job(self):
|
||||
mock_client = MagicMock()
|
||||
mock_client.describe_processing_job.side_effect = [
|
||||
{"ProcessingJobStatus": "Starting"},
|
||||
{"ProcessingJobStatus": "InProgress"},
|
||||
{"ProcessingJobStatus": "Downloading"},
|
||||
{"ProcessingJobStatus": "Failed", "FailureReason": "Something broke lol"},
|
||||
{"ProcessingJobStatus": "Should not be called"}
|
||||
]
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
_utils.wait_for_processing_job(mock_client, 'processing-job', 0)
|
||||
|
||||
self.assertEqual(mock_client.describe_processing_job.call_count, 4)
|
||||
|
||||
def test_reasonable_required_args(self):
|
||||
response = _utils.create_processing_job_request(vars(self.parser.parse_args(required_args)))
|
||||
|
||||
# Ensure all of the optional arguments have reasonable default values
|
||||
self.assertNotIn('VpcConfig', response['NetworkConfig'])
|
||||
self.assertEqual(response['Tags'], [])
|
||||
## TODO
|
||||
|
||||
def test_no_defined_image(self):
|
||||
# Pass the image to pass the parser
|
||||
no_image_args = required_args.copy()
|
||||
image_index = no_image_args.index('--image')
|
||||
# Cut out --image and it's associated value
|
||||
no_image_args = no_image_args[:image_index] + no_image_args[image_index+2:]
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
parsed_args = self.parser.parse_args(no_image_args)
|
||||
|
||||
def test_container_entrypoint(self):
|
||||
entrypoint, arguments = ['/bin/bash'], ['arg1', 'arg2']
|
||||
|
||||
container_args = self.parser.parse_args(required_args + ['--container_entrypoint', json.dumps(entrypoint),
|
||||
'--container_arguments', json.dumps(arguments)])
|
||||
response = _utils.create_processing_job_request(vars(container_args))
|
||||
|
||||
self.assertEqual(response['AppSpecification']['ContainerEntrypoint'], entrypoint)
|
||||
self.assertEqual(response['AppSpecification']['ContainerArguments'], arguments)
|
||||
|
||||
def test_environment_variables(self):
|
||||
env_vars = {
|
||||
'key1': 'val1',
|
||||
'key2': 'val2'
|
||||
}
|
||||
|
||||
environment_args = self.parser.parse_args(required_args + ['--environment', json.dumps(env_vars)])
|
||||
response = _utils.create_processing_job_request(vars(environment_args))
|
||||
|
||||
self.assertEqual(response['Environment'], env_vars)
|
||||
|
||||
def test_vpc_configuration(self):
|
||||
required_vpc_args = self.parser.parse_args(required_args + ['--vpc_security_group_ids', 'sg1,sg2', '--vpc_subnets', 'subnet1,subnet2'])
|
||||
response = _utils.create_processing_job_request(vars(required_vpc_args))
|
||||
|
||||
self.assertIn('VpcConfig', response['NetworkConfig'])
|
||||
self.assertIn('sg1', response['NetworkConfig']['VpcConfig']['SecurityGroupIds'])
|
||||
self.assertIn('sg2', response['NetworkConfig']['VpcConfig']['SecurityGroupIds'])
|
||||
self.assertIn('subnet1', response['NetworkConfig']['VpcConfig']['Subnets'])
|
||||
self.assertIn('subnet2', response['NetworkConfig']['VpcConfig']['Subnets'])
|
||||
|
||||
def test_tags(self):
|
||||
args = self.parser.parse_args(required_args + ['--tags', '{"key1": "val1", "key2": "val2"}'])
|
||||
response = _utils.create_processing_job_request(vars(args))
|
||||
self.assertIn({'Key': 'key1', 'Value': 'val1'}, response['Tags'])
|
||||
self.assertIn({'Key': 'key2', 'Value': 'val2'}, response['Tags'])
|
||||
|
||||
def test_get_processing_job_output(self):
|
||||
mock_client = MagicMock()
|
||||
mock_client.describe_processing_job.return_value = {
|
||||
'ProcessingOutputConfig': {
|
||||
'Outputs': [{
|
||||
'OutputName': 'train',
|
||||
'S3Output': {
|
||||
'S3Uri': 's3://train'
|
||||
}
|
||||
},{
|
||||
'OutputName': 'valid',
|
||||
'S3Output': {
|
||||
'S3Uri': 's3://valid'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
response = _utils.get_processing_job_outputs(mock_client, 'processing-job')
|
||||
|
||||
self.assertIn('train', response)
|
||||
self.assertIn('valid', response)
|
||||
self.assertEqual(response['train'], 's3://train')
|
||||
self.assertEqual(response['valid'], 's3://valid')
|
||||
|
|
@ -104,7 +104,7 @@ outputs:
|
|||
- {name: training_image, description: 'The registry path of the Docker image that contains the training algorithm'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.3.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.4.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
train.py,
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ outputs:
|
|||
- {name: workteam_arn, description: 'The ARN of the workteam.'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.3.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.4.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
workteam.py,
|
||||
|
|
|
|||
Loading…
Reference in New Issue