[AWS SageMaker] De-hardcode output paths in AWS components (#4119)
* Update input arguments * Remove fileOutputs * Update outputs to new paths * Modify integ test artifact path * Add unit test for new output format * Add unit test for write_output * Migrate tests into test_utils * Add clarifying comment * Remove output path file extension * Update license to 0.5.2 * Update component to 0.5.2 * Add 0.5.2 to changelog * Remove JSON
This commit is contained in:
parent
f1d90407a8
commit
f0f8e5d178
|
@ -4,6 +4,12 @@ The version of the AWS SageMaker Components is determined by the docker image ta
|
|||
Repository: https://hub.docker.com/repository/docker/amazon/aws-sagemaker-kfp-components
|
||||
|
||||
---------------------------------------------
|
||||
**Change log for version 0.5.2**
|
||||
- Modified outputs to use newer `outputPath` syntax
|
||||
|
||||
> Pull requests : [#4119](https://github.com/kubeflow/pipelines/pull/4119)
|
||||
|
||||
|
||||
**Change log for version 0.5.1**
|
||||
- Update region support for GroudTruth component
|
||||
- Make `label_category_config` an optional parameter in Ground Truth component
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
** Amazon SageMaker Components for Kubeflow Pipelines; version 0.5.1 --
|
||||
** Amazon SageMaker Components for Kubeflow Pipelines; version 0.5.2 --
|
||||
https://github.com/kubeflow/pipelines/tree/master/components/aws/sagemaker
|
||||
Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
** boto3; version 1.12.33 -- https://github.com/boto/boto3/
|
||||
|
|
|
@ -98,7 +98,7 @@ outputs:
|
|||
- {name: output_location, description: 'S3 URI of the transform job results.'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.2
|
||||
command: ['python3']
|
||||
args: [
|
||||
batch_transform.py,
|
||||
|
@ -126,5 +126,5 @@ implementation:
|
|||
--instance_count, {inputValue: instance_count},
|
||||
--resource_encryption_key, {inputValue: resource_encryption_key},
|
||||
--tags, {inputValue: tags},
|
||||
--output_location_file, {outputPath: output_location}
|
||||
--output_location_output_path, {outputPath: output_location}
|
||||
]
|
||||
|
|
|
@ -13,16 +13,9 @@
|
|||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib2 import Path
|
||||
|
||||
from common import _utils
|
||||
|
||||
try:
|
||||
unicode
|
||||
except NameError:
|
||||
unicode = str
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser(description='SageMaker Batch Transformation Job')
|
||||
_utils.add_default_client_arguments(parser)
|
||||
|
@ -49,7 +42,7 @@ def create_parser():
|
|||
parser.add_argument('--instance_count', type=int, required=False, help='The number of ML compute instances to use in the transform job.')
|
||||
parser.add_argument('--resource_encryption_key', type=str, required=False, help='The AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s).', default='')
|
||||
parser.add_argument('--tags', type=_utils.yaml_or_json_str, required=False, help='An array of key-value pairs, to categorize AWS resources.', default={})
|
||||
parser.add_argument('--output_location_file', type=str, required=True, help='File path where the program will write the Amazon S3 URI of the transform job results.')
|
||||
parser.add_argument('--output_location_output_path', type=str, default='/tmp/output-location', help='Local output path for the file containing the Amazon S3 URI of the transform job results.')
|
||||
|
||||
return parser
|
||||
|
||||
|
@ -71,9 +64,7 @@ def main(argv=None):
|
|||
cw_client = _utils.get_cloudwatch_client(args.region)
|
||||
_utils.print_logs_for_job(cw_client, '/aws/sagemaker/TransformJobs', batch_job_name)
|
||||
|
||||
Path(args.output_location_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(args.output_location_file, 'w') as f:
|
||||
f.write(unicode(args.output_location))
|
||||
_utils.write_output(args.output_location_output_path, args.output_location)
|
||||
|
||||
logging.info('Batch Transformation creation completed.')
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@ import random
|
|||
import json
|
||||
import yaml
|
||||
import re
|
||||
import json
|
||||
from pathlib2 import Path
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
|
@ -65,7 +67,7 @@ def nullable_string_argument(value):
|
|||
|
||||
def add_default_client_arguments(parser):
|
||||
parser.add_argument('--region', type=str, required=True, help='The region where the training job launches.')
|
||||
parser.add_argument('--endpoint_url', type=nullable_string_argument, required=False, help='The URL to use when communicating with the Sagemaker service.')
|
||||
parser.add_argument('--endpoint_url', type=nullable_string_argument, required=False, help='The URL to use when communicating with the SageMaker service.')
|
||||
|
||||
|
||||
def get_component_version():
|
||||
|
@ -207,7 +209,7 @@ def create_training_job_request(args):
|
|||
|
||||
|
||||
def create_training_job(client, args):
|
||||
"""Create a Sagemaker training job."""
|
||||
"""Create a SageMaker training job."""
|
||||
request = create_training_job_request(args)
|
||||
try:
|
||||
client.create_training_job(**request)
|
||||
|
@ -614,7 +616,7 @@ def create_hyperparameter_tuning_job_request(args):
|
|||
|
||||
|
||||
def create_hyperparameter_tuning_job(client, args):
|
||||
"""Create a Sagemaker HPO job"""
|
||||
"""Create a SageMaker HPO job"""
|
||||
request = create_hyperparameter_tuning_job_request(args)
|
||||
try:
|
||||
client.create_hyper_parameter_tuning_job(**request)
|
||||
|
@ -1027,3 +1029,17 @@ def str_to_bool(str):
|
|||
# This distutils function returns an integer representation of the boolean
|
||||
# rather than a True/False value. This simply hard casts it.
|
||||
return bool(strtobool(str))
|
||||
|
||||
def write_output(output_path, output_value, json_encode=False):
|
||||
"""Write an output value to the associated path, dumping as a JSON object
|
||||
if specified.
|
||||
Arguments:
|
||||
- output_path: The file path of the output.
|
||||
- output_value: The output value to write to the file.
|
||||
- json_encode: True if the value should be encoded as a JSON object.
|
||||
"""
|
||||
|
||||
write_value = json.dumps(output_value) if json_encode else output_value
|
||||
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(output_path).write_text(write_value)
|
|
@ -1,4 +1,4 @@
|
|||
name: 'Sagemaker - Deploy Model'
|
||||
name: 'SageMaker - Deploy Model'
|
||||
description: |
|
||||
Deploy Machine Learning Model Endpoint in SageMaker
|
||||
inputs:
|
||||
|
@ -104,7 +104,7 @@ outputs:
|
|||
- {name: endpoint_name, description: 'Endpoint name'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.2
|
||||
command: ['python3']
|
||||
args: [
|
||||
deploy.py,
|
||||
|
@ -132,7 +132,6 @@ implementation:
|
|||
--resource_encryption_key, {inputValue: resource_encryption_key},
|
||||
--endpoint_config_tags, {inputValue: endpoint_config_tags},
|
||||
--endpoint_name, {inputValue: endpoint_name},
|
||||
--endpoint_tags, {inputValue: endpoint_tags}
|
||||
--endpoint_tags, {inputValue: endpoint_tags},
|
||||
--endpoint_name_output_path, {outputPath: endpoint_name}
|
||||
]
|
||||
fileOutputs:
|
||||
endpoint_name: /tmp/endpoint_name.txt
|
||||
|
|
|
@ -43,6 +43,7 @@ def create_parser():
|
|||
parser.add_argument('--endpoint_config_tags', type=_utils.yaml_or_json_str, required=False, help='An array of key-value pairs, to categorize AWS resources.', default={})
|
||||
parser.add_argument('--endpoint_name', type=str, required=False, help='The name of the endpoint.', default='')
|
||||
parser.add_argument('--endpoint_tags', type=_utils.yaml_or_json_str, required=False, help='An array of key-value pairs, to categorize AWS resources.', default={})
|
||||
parser.add_argument('--endpoint_name_output_path', type=str, default='/tmp/endpoint-name', help='Local output path for the file containing the name of the created endpoint.')
|
||||
|
||||
return parser
|
||||
|
||||
|
@ -57,8 +58,7 @@ def main(argv=None):
|
|||
logging.info('Endpoint creation request submitted. Waiting for completion...')
|
||||
_utils.wait_for_endpoint_creation(client, endpoint_name)
|
||||
|
||||
with open('/tmp/endpoint_name.txt', 'w') as f:
|
||||
f.write(endpoint_name)
|
||||
_utils.write_output(args.endpoint_name_output_path, endpoint_name)
|
||||
|
||||
logging.info('Endpoint creation completed.')
|
||||
|
||||
|
|
|
@ -119,7 +119,7 @@ outputs:
|
|||
- {name: active_learning_model_arn, description: 'The ARN for the most recent Amazon SageMaker model trained as part of automated data labeling.'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.2
|
||||
command: ['python3']
|
||||
args: [
|
||||
ground_truth.py,
|
||||
|
@ -153,8 +153,7 @@ implementation:
|
|||
--task_availibility, {inputValue: task_availibility},
|
||||
--max_concurrent_tasks, {inputValue: max_concurrent_tasks},
|
||||
--workforce_task_price, {inputValue: workforce_task_price},
|
||||
--tags, {inputValue: tags}
|
||||
]
|
||||
fileOutputs:
|
||||
output_manifest_location: /tmp/output_manifest_location.txt
|
||||
active_learning_model_arn: /tmp/active_learning_model_arn.txt
|
||||
--tags, {inputValue: tags},
|
||||
--output_manifest_location_output_path, {outputPath: output_manifest_location},
|
||||
--active_learning_model_arn_output_path, {outputPath: active_learning_model_arn}
|
||||
]
|
|
@ -49,6 +49,8 @@ def create_parser():
|
|||
parser.add_argument('--max_concurrent_tasks', type=int, required=False, help='The maximum number of data objects that can be labeled by human workers at the same time.', default=0)
|
||||
parser.add_argument('--workforce_task_price', type=float, required=False, help='The price that you pay for each task performed by a public worker in USD. Specify to the tenth fractions of a cent. Format as "0.000".', default=0.000)
|
||||
parser.add_argument('--tags', type=_utils.yaml_or_json_str, required=False, help='An array of key-value pairs, to categorize AWS resources.', default={})
|
||||
parser.add_argument('--output_manifest_location_output_path', type=str, default='/tmp/manifest-location', help='Local output path for the file containing the Amazon S3 bucket location of the manifest file for labeled data.')
|
||||
parser.add_argument('--active_learning_model_arn_output_path', type=str, default='/tmp/active-model-arn', help='Local output path for the file containing the ARN for the most recent Amazon SageMaker model trained as part of automated data labeling.')
|
||||
|
||||
return parser
|
||||
|
||||
|
@ -66,10 +68,8 @@ def main(argv=None):
|
|||
|
||||
logging.info('Ground Truth Labeling Job completed.')
|
||||
|
||||
with open('/tmp/output_manifest_location.txt', 'w') as f:
|
||||
f.write(output_manifest)
|
||||
with open('/tmp/active_learning_model_arn.txt', 'w') as f:
|
||||
f.write(active_learning_model_arn)
|
||||
_utils.write_output(args.output_manifest_location_output_path, output_manifest)
|
||||
_utils.write_output(args.active_learning_model_arn_output_path, active_learning_model_arn)
|
||||
|
||||
|
||||
if __name__== "__main__":
|
||||
|
|
|
@ -150,7 +150,7 @@ outputs:
|
|||
description: 'The registry path of the Docker image that contains the training algorithm'
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.2
|
||||
command: ['python3']
|
||||
args: [
|
||||
hyperparameter_tuning.py,
|
||||
|
@ -189,11 +189,10 @@ implementation:
|
|||
--checkpoint_config, {inputValue: checkpoint_config},
|
||||
--warm_start_type, {inputValue: warm_start_type},
|
||||
--parent_hpo_jobs, {inputValue: parent_hpo_jobs},
|
||||
--tags, {inputValue: tags}
|
||||
]
|
||||
fileOutputs:
|
||||
hpo_job_name: /tmp/hpo_job_name.txt
|
||||
model_artifact_url: /tmp/model_artifact_url.txt
|
||||
best_job_name: /tmp/best_job_name.txt
|
||||
best_hyperparameters: /tmp/best_hyperparameters.txt
|
||||
training_image: /tmp/training_image.txt
|
||||
--tags, {inputValue: tags},
|
||||
--hpo_job_name_output_path, {outputPath: hpo_job_name},
|
||||
--model_artifact_url_output_path, {outputPath: model_artifact_url},
|
||||
--best_job_name_output_path, {outputPath: best_job_name},
|
||||
--best_hyperparameters_output_path, {outputPath: best_hyperparameters},
|
||||
--training_image_output_path, {outputPath: training_image}
|
||||
]
|
|
@ -13,7 +13,6 @@
|
|||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
|
||||
from common import _utils
|
||||
|
||||
|
@ -60,6 +59,14 @@ def create_parser():
|
|||
|
||||
parser.add_argument('--tags', type=_utils.yaml_or_json_str, required=False, help='An array of key-value pairs, to categorize AWS resources.', default={})
|
||||
|
||||
### Start outputs
|
||||
parser.add_argument('--hpo_job_name_output_path', type=str, default='/tmp/hpo-job-name', help='Local output path for the file containing the name of the hyper parameter tuning job')
|
||||
parser.add_argument('--model_artifact_url_output_path', type=str, default='/tmp/artifact-url', help='Local output path for the file containing the model artifacts url')
|
||||
parser.add_argument('--best_job_name_output_path', type=str, default='/tmp/best-job-name', help='Local output path for the file containing the name of the best training job in the hyper parameter tuning job')
|
||||
parser.add_argument('--best_hyperparameters_output_path', type=str, default='/tmp/best-hyperparams', help='Local output path for the file containing the final tuned hyperparameters')
|
||||
parser.add_argument('--training_image_output_path', type=str, default='/tmp/training-image', help='Local output path for the file containing the registry path of the Docker image that contains the training algorithm')
|
||||
### End outputs
|
||||
|
||||
return parser
|
||||
|
||||
def main(argv=None):
|
||||
|
@ -78,17 +85,11 @@ def main(argv=None):
|
|||
|
||||
logging.info('HyperParameter Tuning Job completed.')
|
||||
|
||||
with open('/tmp/hpo_job_name.txt', 'w') as f:
|
||||
f.write(hpo_job_name)
|
||||
with open('/tmp/best_job_name.txt', 'w') as f:
|
||||
f.write(best_job)
|
||||
with open('/tmp/best_hyperparameters.txt', 'w') as f:
|
||||
f.write(json.dumps(best_hyperparameters))
|
||||
with open('/tmp/model_artifact_url.txt', 'w') as f:
|
||||
f.write(model_artifact_url)
|
||||
with open('/tmp/training_image.txt', 'w') as f:
|
||||
f.write(image)
|
||||
|
||||
_utils.write_output(args.hpo_job_name_output_path, hpo_job_name)
|
||||
_utils.write_output(args.model_artifact_url_output_path, model_artifact_url)
|
||||
_utils.write_output(args.best_job_name_output_path, best_job)
|
||||
_utils.write_output(args.best_hyperparameters_output_path, best_hyperparameters, json_encode=True)
|
||||
_utils.write_output(args.training_image_output_path, image)
|
||||
|
||||
if __name__== "__main__":
|
||||
main(sys.argv[1:])
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
name: 'Sagemaker - Create Model'
|
||||
name: 'SageMaker - Create Model'
|
||||
description: |
|
||||
Create Models in SageMaker
|
||||
inputs:
|
||||
|
@ -56,10 +56,10 @@ inputs:
|
|||
default: '{}'
|
||||
type: JsonObject
|
||||
outputs:
|
||||
- {name: model_name, description: 'The model name Sagemaker created'}
|
||||
- {name: model_name, description: 'The model name SageMaker created'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.2
|
||||
command: ['python3']
|
||||
args: [
|
||||
create_model.py,
|
||||
|
@ -76,7 +76,6 @@ implementation:
|
|||
--vpc_security_group_ids, {inputValue: vpc_security_group_ids},
|
||||
--vpc_subnets, {inputValue: vpc_subnets},
|
||||
--network_isolation, {inputValue: network_isolation},
|
||||
--tags, {inputValue: tags}
|
||||
]
|
||||
fileOutputs:
|
||||
model_name: /tmp/model_name.txt
|
||||
--tags, {inputValue: tags},
|
||||
--model_name_output_path, {outputPath: model_name}
|
||||
]
|
|
@ -32,6 +32,7 @@ def create_parser():
|
|||
parser.add_argument('--vpc_subnets', type=str, required=False, help='The ID of the subnets in the VPC to which you want to connect your hpo job.', default='')
|
||||
parser.add_argument('--network_isolation', type=_utils.str_to_bool, required=False, help='Isolates the training container.', default=True)
|
||||
parser.add_argument('--tags', type=_utils.yaml_or_json_str, required=False, help='An array of key-value pairs, to categorize AWS resources.', default={})
|
||||
parser.add_argument('--model_name_output_path', type=str, default='/tmp/model-name', help='Local output path for the file containing the name of the model SageMaker created.')
|
||||
|
||||
return parser
|
||||
|
||||
|
@ -46,8 +47,8 @@ def main(argv=None):
|
|||
_utils.create_model(client, vars(args))
|
||||
|
||||
logging.info('Model creation completed.')
|
||||
with open('/tmp/model_name.txt', 'w') as f:
|
||||
f.write(args.model_name)
|
||||
|
||||
_utils.write_output(args.model_name_output_path, args.model_name)
|
||||
|
||||
|
||||
if __name__== "__main__":
|
||||
|
|
|
@ -89,7 +89,7 @@ outputs:
|
|||
- {name: output_artifacts, description: 'A dictionary containing the output S3 artifacts'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.2
|
||||
command: ['python3']
|
||||
args: [
|
||||
process.py,
|
||||
|
@ -113,8 +113,7 @@ implementation:
|
|||
--vpc_subnets, {inputValue: vpc_subnets},
|
||||
--network_isolation, {inputValue: network_isolation},
|
||||
--traffic_encryption, {inputValue: traffic_encryption},
|
||||
--tags, {inputValue: tags}
|
||||
]
|
||||
fileOutputs:
|
||||
job_name: /tmp/job_name.txt
|
||||
output_artifacts: /tmp/output_artifacts.txt
|
||||
--tags, {inputValue: tags},
|
||||
--job_name_output_path, {outputPath: job_name},
|
||||
--output_artifacts_output_path, {outputPath: output_artifacts}
|
||||
]
|
|
@ -13,7 +13,6 @@
|
|||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
|
||||
from common import _utils
|
||||
|
||||
|
@ -40,6 +39,8 @@ def create_parser():
|
|||
parser.add_argument('--network_isolation', type=_utils.str_to_bool, required=False, help='Isolates the processing container.', default=True)
|
||||
parser.add_argument('--traffic_encryption', type=_utils.str_to_bool, required=False, help='Encrypts all communications between ML compute instances in distributed training.', default=False)
|
||||
parser.add_argument('--tags', type=_utils.yaml_or_json_str, required=False, help='An array of key-value pairs, to categorize AWS resources.', default={})
|
||||
parser.add_argument('--job_name_output_path', type=str, default='/tmp/job-name', help='Local output path for the file containing the name of the processing job.')
|
||||
parser.add_argument('--output_artifacts_output_path', type=str, default='/tmp/output-artifacts', help='Local output path for the file containing the dictionary describing the output S3 artifacts.')
|
||||
|
||||
return parser
|
||||
|
||||
|
@ -64,11 +65,8 @@ def main(argv=None):
|
|||
|
||||
outputs = _utils.get_processing_job_outputs(client, job_name)
|
||||
|
||||
with open('/tmp/job_name.txt', 'w') as f:
|
||||
f.write(job_name)
|
||||
|
||||
with open('/tmp/output_artifacts.txt', 'w') as f:
|
||||
f.write(json.dumps(outputs))
|
||||
_utils.write_output(args.job_name_output_path, job_name)
|
||||
_utils.write_output(args.output_artifacts_output_path, outputs, json_encode=True)
|
||||
|
||||
logging.info('Job completed.')
|
||||
|
||||
|
|
|
@ -22,5 +22,4 @@
|
|||
1. Navigate to the root of this github directory.
|
||||
1. Run `docker build . -f components/aws/sagemaker/tests/integration_tests/Dockerfile -t amazon/integration_test`
|
||||
1. Run the image, injecting your environment variable files:
|
||||
1. Navigate to the `components/aws` directory.
|
||||
1. Run `docker run --env-file components/aws/sagemaker/tests/integration_tests/.env amazon/integration_test`
|
|
@ -70,7 +70,7 @@ def test_transform_job(
|
|||
|
||||
# Verify output location from pipeline matches job output and that the transformed file exists
|
||||
output_location = utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-batch-transformation"]["output_location"], "data",
|
||||
output_files["sagemaker-batch-transformation"]["output_location"]
|
||||
)
|
||||
print(f"output location: {output_location}")
|
||||
assert output_location == response["TransformOutput"]["S3OutputPath"]
|
||||
|
|
|
@ -87,7 +87,7 @@ def test_create_endpoint(
|
|||
)
|
||||
|
||||
output_endpoint_name = utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-deploy-model"]["endpoint_name"], "endpoint_name.txt"
|
||||
output_files["sagemaker-deploy-model"]["endpoint_name"]
|
||||
)
|
||||
print(f"endpoint name: {output_endpoint_name}")
|
||||
|
||||
|
|
|
@ -57,8 +57,7 @@ def test_hyperparameter_tuning(
|
|||
|
||||
# Verify HPO job was successful on SageMaker
|
||||
hpo_job_name = utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-hyperparameter-tuning"]["hpo_job_name"],
|
||||
"hpo_job_name.txt",
|
||||
output_files["sagemaker-hyperparameter-tuning"]["hpo_job_name"]
|
||||
)
|
||||
print(f"HPO job name: {hpo_job_name}")
|
||||
hpo_response = sagemaker_utils.describe_hpo_job(sagemaker_client, hpo_job_name)
|
||||
|
@ -68,8 +67,7 @@ def test_hyperparameter_tuning(
|
|||
|
||||
# Verify training image output is an ECR image
|
||||
training_image = utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-hyperparameter-tuning"]["training_image"],
|
||||
"training_image.txt",
|
||||
output_files["sagemaker-hyperparameter-tuning"]["training_image"]
|
||||
)
|
||||
print(f"Training image used: {training_image}")
|
||||
if "ExpectedTrainingImage" in test_params.keys():
|
||||
|
@ -79,8 +77,7 @@ def test_hyperparameter_tuning(
|
|||
|
||||
# Verify Training job was part of HPO job, returned as best and was successful
|
||||
best_training_job_name = utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-hyperparameter-tuning"]["best_job_name"],
|
||||
"best_job_name.txt",
|
||||
output_files["sagemaker-hyperparameter-tuning"]["best_job_name"]
|
||||
)
|
||||
print(f"best training job name: {best_training_job_name}")
|
||||
train_response = sagemaker_utils.describe_training_job(
|
||||
|
@ -95,8 +92,7 @@ def test_hyperparameter_tuning(
|
|||
|
||||
# Verify model artifacts output was generated from this run
|
||||
model_artifact_url = utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-hyperparameter-tuning"]["model_artifact_url"],
|
||||
"model_artifact_url.txt",
|
||||
output_files["sagemaker-hyperparameter-tuning"]["model_artifact_url"]
|
||||
)
|
||||
print(f"model_artifact_url: {model_artifact_url}")
|
||||
assert model_artifact_url == train_response["ModelArtifacts"]["S3ModelArtifacts"]
|
||||
|
@ -105,8 +101,7 @@ def test_hyperparameter_tuning(
|
|||
# Verify hyper_parameters output is not empty
|
||||
hyper_parameters = json.loads(
|
||||
utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-hyperparameter-tuning"]["best_hyperparameters"],
|
||||
"best_hyperparameters.txt",
|
||||
output_files["sagemaker-hyperparameter-tuning"]["best_hyperparameters"]
|
||||
)
|
||||
)
|
||||
print(f"HPO best hyperparameters: {json.dumps(hyper_parameters, indent = 2)}")
|
||||
|
|
|
@ -48,7 +48,7 @@ def test_createmodel(kfp_client, experiment_id, sagemaker_client, test_file_dir)
|
|||
)
|
||||
|
||||
output_model_name = utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-create-model"]["model_name"], "model_name.txt"
|
||||
output_files["sagemaker-create-model"]["model_name"]
|
||||
)
|
||||
print(f"model_name: {output_model_name}")
|
||||
assert output_model_name == input_model_name
|
||||
|
|
|
@ -64,7 +64,7 @@ def test_processingjob(
|
|||
|
||||
# Verify Processing job was successful on SageMaker
|
||||
processing_job_name = utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-processing-job"]["job_name"], "job_name.txt"
|
||||
output_files["sagemaker-processing-job"]["job_name"]
|
||||
)
|
||||
print(f"processing job name: {processing_job_name}")
|
||||
process_response = sagemaker_utils.describe_processing_job(
|
||||
|
@ -77,7 +77,6 @@ def test_processingjob(
|
|||
processing_outputs = json.loads(
|
||||
utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-processing-job"]["output_artifacts"],
|
||||
"output_artifacts.txt",
|
||||
)
|
||||
)
|
||||
print(f"processing job outputs: {json.dumps(processing_outputs, indent = 2)}")
|
||||
|
|
|
@ -51,7 +51,7 @@ def test_trainingjob(
|
|||
|
||||
# Verify Training job was successful on SageMaker
|
||||
training_job_name = utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-training-job"]["job_name"], "job_name.txt"
|
||||
output_files["sagemaker-training-job"]["job_name"]
|
||||
)
|
||||
print(f"training job name: {training_job_name}")
|
||||
train_response = sagemaker_utils.describe_training_job(
|
||||
|
@ -61,8 +61,7 @@ def test_trainingjob(
|
|||
|
||||
# Verify model artifacts output was generated from this run
|
||||
model_artifact_url = utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-training-job"]["model_artifact_url"],
|
||||
"model_artifact_url.txt",
|
||||
output_files["sagemaker-training-job"]["model_artifact_url"]
|
||||
)
|
||||
print(f"model_artifact_url: {model_artifact_url}")
|
||||
assert model_artifact_url == train_response["ModelArtifacts"]["S3ModelArtifacts"]
|
||||
|
@ -70,7 +69,7 @@ def test_trainingjob(
|
|||
|
||||
# Verify training image output is an ECR image
|
||||
training_image = utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-training-job"]["training_image"], "training_image.txt",
|
||||
output_files["sagemaker-training-job"]["training_image"]
|
||||
)
|
||||
print(f"Training image used: {training_image}")
|
||||
if "ExpectedTrainingImage" in test_params.keys():
|
||||
|
|
|
@ -17,7 +17,7 @@ def create_workteamjob(
|
|||
)
|
||||
)
|
||||
|
||||
# Get the account, region specific user_pool and client_id for the Sagemaker Workforce.
|
||||
# Get the account, region specific user_pool and client_id for the SageMaker Workforce.
|
||||
(
|
||||
test_params["Arguments"]["user_pool"],
|
||||
test_params["Arguments"]["client_id"],
|
||||
|
@ -70,8 +70,7 @@ def test_workteamjob(
|
|||
|
||||
# Verify WorkTeam arn artifact was created in Minio and matches the one in SageMaker
|
||||
workteam_arn = utils.read_from_file_in_tar(
|
||||
output_files["sagemaker-private-workforce"]["workteam_arn"],
|
||||
"workteam_arn.txt",
|
||||
output_files["sagemaker-private-workforce"]["workteam_arn"]
|
||||
)
|
||||
assert response["Workteam"]["WorkteamArn"] == workteam_arn
|
||||
|
||||
|
|
|
@ -58,7 +58,15 @@ def run_command(cmd, *popenargs, **kwargs):
|
|||
pytest.fail(f"Command failed. Error code: {e.returncode}, Log: {e.output}")
|
||||
|
||||
|
||||
def read_from_file_in_tar(file_path, file_name, decode=True):
|
||||
def read_from_file_in_tar(file_path, file_name="data", decode=True):
|
||||
"""Opens a local tarball and reads the contents of the file as specified.
|
||||
Arguments:
|
||||
- file_path: The local path of the tarball file.
|
||||
- file_name: The name of the file inside the tarball to be read. (Default `"data"`)
|
||||
- decode: Ensures the contents of the file is decoded to type `str`. (Default `True`)
|
||||
|
||||
See: https://github.com/kubeflow/pipelines/blob/2e14fe732b3f878a710b16d1a63beece6c19330a/sdk/python/kfp/components/_components.py#L182
|
||||
"""
|
||||
with tarfile.open(file_path).extractfile(file_name) as f:
|
||||
if decode:
|
||||
return f.read().decode()
|
||||
|
|
|
@ -35,7 +35,7 @@ def list_workteams(client):
|
|||
|
||||
|
||||
def get_cognito_member_definitions(client):
|
||||
# This is one way to get the user_pool and client_id for the Sagemaker Workforce.
|
||||
# This is one way to get the user_pool and client_id for the SageMaker Workforce.
|
||||
# An alternative would be to take these values as user input via params or a config file.
|
||||
# The current mechanism expects that there exists atleast one private workteam in the region.
|
||||
default_workteam = list_workteams(client)["Workteams"][0]["MemberDefinitions"][0][
|
||||
|
|
|
@ -16,7 +16,7 @@ required_args = [
|
|||
'--input_location', 's3://fake-bucket/data',
|
||||
'--output_location', 's3://fake-bucket/output',
|
||||
'--instance_type', 'ml.c5.18xlarge',
|
||||
'--output_location_file', 'tmp/output.txt'
|
||||
'--output_location_output_path', '/tmp/output'
|
||||
]
|
||||
|
||||
class BatchTransformTestCase(unittest.TestCase):
|
||||
|
@ -38,8 +38,7 @@ class BatchTransformTestCase(unittest.TestCase):
|
|||
# Set some static returns
|
||||
batch_transform._utils.create_transform_job.return_value = 'test-batch-job'
|
||||
|
||||
with patch('builtins.open', mock_open()) as file_open:
|
||||
batch_transform.main(required_args)
|
||||
batch_transform.main(required_args)
|
||||
|
||||
# Check if correct requests were created and triggered
|
||||
batch_transform._utils.create_transform_job.assert_called()
|
||||
|
@ -48,12 +47,8 @@ class BatchTransformTestCase(unittest.TestCase):
|
|||
|
||||
|
||||
# Check the file outputs
|
||||
file_open.assert_has_calls([
|
||||
call('tmp/output.txt', 'w')
|
||||
])
|
||||
|
||||
file_open().write.assert_has_calls([
|
||||
call('s3://fake-bucket/output')
|
||||
batch_transform._utils.write_output.assert_has_calls([
|
||||
call('/tmp/output', 's3://fake-bucket/output')
|
||||
])
|
||||
|
||||
|
||||
|
|
|
@ -9,7 +9,8 @@ from common import _utils
|
|||
|
||||
required_args = [
|
||||
'--region', 'us-west-2',
|
||||
'--model_name_1', 'model-test'
|
||||
'--model_name_1', 'model-test',
|
||||
'--endpoint_name_output_path', '/tmp/output'
|
||||
]
|
||||
|
||||
class DeployTestCase(unittest.TestCase):
|
||||
|
@ -29,20 +30,15 @@ class DeployTestCase(unittest.TestCase):
|
|||
# Set some static returns
|
||||
deploy._utils.deploy_model.return_value = 'test-endpoint-name'
|
||||
|
||||
with patch('builtins.open', mock_open()) as file_open:
|
||||
deploy.main(required_args)
|
||||
deploy.main(required_args)
|
||||
|
||||
# Check if correct requests were created and triggered
|
||||
deploy._utils.deploy_model.assert_called()
|
||||
deploy._utils.wait_for_endpoint_creation.assert_called()
|
||||
|
||||
# Check the file outputs
|
||||
file_open.assert_has_calls([
|
||||
call('/tmp/endpoint_name.txt', 'w')
|
||||
])
|
||||
|
||||
file_open().write.assert_has_calls([
|
||||
call('test-endpoint-name')
|
||||
deploy._utils.write_output.assert_has_calls([
|
||||
call('/tmp/output', 'test-endpoint-name')
|
||||
])
|
||||
|
||||
def test_deploy_model(self):
|
||||
|
|
|
@ -20,6 +20,8 @@ required_args = [
|
|||
'--description', 'fake job',
|
||||
'--num_workers_per_object', '1',
|
||||
'--time_limit', '180',
|
||||
'--output_manifest_location_output_path', '/tmp/manifest-output',
|
||||
'--active_learning_model_arn_output_path', '/tmp/model-output'
|
||||
]
|
||||
|
||||
class GroundTruthTestCase(unittest.TestCase):
|
||||
|
@ -39,8 +41,7 @@ class GroundTruthTestCase(unittest.TestCase):
|
|||
# Set some static returns
|
||||
ground_truth._utils.get_labeling_job_outputs.return_value = ('s3://fake-bucket/output', 'arn:aws:sagemaker:us-east-1:999999999999:labeling-job')
|
||||
|
||||
with patch('builtins.open', mock_open()) as file_open:
|
||||
ground_truth.main(required_args)
|
||||
ground_truth.main(required_args)
|
||||
|
||||
# Check if correct requests were created and triggered
|
||||
ground_truth._utils.create_labeling_job.assert_called()
|
||||
|
@ -48,15 +49,10 @@ class GroundTruthTestCase(unittest.TestCase):
|
|||
ground_truth._utils.get_labeling_job_outputs.assert_called()
|
||||
|
||||
# Check the file outputs
|
||||
file_open.assert_has_calls([
|
||||
call('/tmp/output_manifest_location.txt', 'w'),
|
||||
call('/tmp/active_learning_model_arn.txt', 'w')
|
||||
], any_order=True)
|
||||
|
||||
file_open().write.assert_has_calls([
|
||||
call('s3://fake-bucket/output'),
|
||||
call('arn:aws:sagemaker:us-east-1:999999999999:labeling-job')
|
||||
], any_order=False)
|
||||
ground_truth._utils.write_output.assert_has_calls([
|
||||
call('/tmp/manifest-output', 's3://fake-bucket/output'),
|
||||
call('/tmp/model-output', 'arn:aws:sagemaker:us-east-1:999999999999:labeling-job')
|
||||
])
|
||||
|
||||
def test_ground_truth(self):
|
||||
mock_client = MagicMock()
|
||||
|
|
|
@ -16,7 +16,12 @@ required_args = [
|
|||
'--channels', '[{"ChannelName": "train", "DataSource": {"S3DataSource":{"S3Uri": "s3://fake-bucket/data","S3DataType":"S3Prefix","S3DataDistributionType": "FullyReplicated"}},"ContentType":"","CompressionType": "None","RecordWrapperType":"None","InputMode": "File"}]',
|
||||
'--output_location', 'test-output-location',
|
||||
'--max_num_jobs', '5',
|
||||
'--max_parallel_jobs', '2'
|
||||
'--max_parallel_jobs', '2',
|
||||
'--hpo_job_name_output_path', '/tmp/hpo_job_name_output_path',
|
||||
'--model_artifact_url_output_path', '/tmp/model_artifact_url_output_path',
|
||||
'--best_job_name_output_path', '/tmp/best_job_name_output_path',
|
||||
'--best_hyperparameters_output_path', '/tmp/best_hyperparameters_output_path',
|
||||
'--training_image_output_path', '/tmp/training_image_output_path'
|
||||
]
|
||||
|
||||
class HyperparameterTestCase(unittest.TestCase):
|
||||
|
@ -62,33 +67,24 @@ class HyperparameterTestCase(unittest.TestCase):
|
|||
|
||||
# Set some static returns
|
||||
hpo._utils.create_hyperparameter_tuning_job.return_value = 'job-name'
|
||||
hpo._utils.get_best_training_job_and_hyperparameters.return_value = 'best_job', 'best_hyperparameters'
|
||||
hpo._utils.get_best_training_job_and_hyperparameters.return_value = 'best_job', {"key_1": "best_hp_1"}
|
||||
hpo._utils.get_image_from_job.return_value = 'training-image'
|
||||
hpo._utils.get_model_artifacts_from_job.return_value = 'model-artifacts'
|
||||
|
||||
with patch('builtins.open', mock_open()) as file_open:
|
||||
hpo.main(required_args)
|
||||
hpo.main(required_args)
|
||||
|
||||
# Check if correct requests were created and triggered
|
||||
hpo._utils.create_hyperparameter_tuning_job.assert_called()
|
||||
hpo._utils.wait_for_hyperparameter_training_job.assert_called()
|
||||
|
||||
# Check the file outputs
|
||||
file_open.assert_has_calls([
|
||||
call('/tmp/hpo_job_name.txt', 'w'),
|
||||
call('/tmp/best_job_name.txt', 'w'),
|
||||
call('/tmp/best_hyperparameters.txt', 'w'),
|
||||
call('/tmp/model_artifact_url.txt', 'w'),
|
||||
call('/tmp/training_image.txt', 'w')
|
||||
], any_order=True)
|
||||
|
||||
file_open().write.assert_has_calls([
|
||||
call('job-name'),
|
||||
call('best_job'),
|
||||
call('"best_hyperparameters"'),
|
||||
call('model-artifacts'),
|
||||
call('training-image'),
|
||||
], any_order=False)
|
||||
hpo._utils.write_output.assert_has_calls([
|
||||
call('/tmp/hpo_job_name_output_path', 'job-name'),
|
||||
call('/tmp/model_artifact_url_output_path', 'model-artifacts'),
|
||||
call('/tmp/best_job_name_output_path', 'best_job'),
|
||||
call('/tmp/best_hyperparameters_output_path', {"key_1": "best_hp_1"}, json_encode=True),
|
||||
call('/tmp/training_image_output_path', 'training-image')
|
||||
])
|
||||
|
||||
def test_create_hyperparameter_tuning_job(self):
|
||||
mock_client = MagicMock()
|
||||
|
|
|
@ -12,7 +12,8 @@ required_args = [
|
|||
'--model_name', 'model_test',
|
||||
'--role', 'arn:aws:iam::123456789012:user/Development/product_1234/*',
|
||||
'--image', 'test-image',
|
||||
'--model_artifact_url', 's3://fake-bucket/model_artifact'
|
||||
'--model_artifact_url', 's3://fake-bucket/model_artifact',
|
||||
'--model_name_output_path', '/tmp/output'
|
||||
]
|
||||
|
||||
class ModelTestCase(unittest.TestCase):
|
||||
|
@ -32,19 +33,14 @@ class ModelTestCase(unittest.TestCase):
|
|||
# Set some static returns
|
||||
create_model._utils.create_model.return_value = 'model_test'
|
||||
|
||||
with patch('builtins.open', mock_open()) as file_open:
|
||||
create_model.main(required_args)
|
||||
create_model.main(required_args)
|
||||
|
||||
# Check if correct requests were created and triggered
|
||||
create_model._utils.create_model.assert_called()
|
||||
|
||||
# Check the file outputs
|
||||
file_open.assert_has_calls([
|
||||
call('/tmp/model_name.txt', 'w')
|
||||
])
|
||||
|
||||
file_open().write.assert_has_calls([
|
||||
call('model_test')
|
||||
create_model._utils.write_output.assert_has_calls([
|
||||
call('/tmp/output', 'model_test')
|
||||
])
|
||||
|
||||
def test_create_model(self):
|
||||
|
|
|
@ -30,7 +30,9 @@ required_args = [
|
|||
'LocalPath': "/opt/ml/processing/output/train",
|
||||
'S3UploadMode': "Continuous"
|
||||
}
|
||||
}])
|
||||
}]),
|
||||
'--job_name_output_path', '/tmp/job_name_output_path',
|
||||
'--output_artifacts_output_path', '/tmp/output_artifacts_output_path'
|
||||
]
|
||||
|
||||
class ProcessTestCase(unittest.TestCase):
|
||||
|
@ -51,8 +53,7 @@ class ProcessTestCase(unittest.TestCase):
|
|||
process._utils.create_processing_job.return_value = 'job-name'
|
||||
process._utils.get_processing_job_outputs.return_value = mock_outputs = {'val1': 's3://1', 'val2': 's3://2'}
|
||||
|
||||
with patch('builtins.open', mock_open()) as file_open:
|
||||
process.main(required_args)
|
||||
process.main(required_args)
|
||||
|
||||
# Check if correct requests were created and triggered
|
||||
process._utils.create_processing_job.assert_called()
|
||||
|
@ -60,15 +61,10 @@ class ProcessTestCase(unittest.TestCase):
|
|||
process._utils.print_logs_for_job.assert_called()
|
||||
|
||||
# Check the file outputs
|
||||
file_open.assert_has_calls([
|
||||
call('/tmp/job_name.txt', 'w'),
|
||||
call('/tmp/output_artifacts.txt', 'w')
|
||||
], any_order=True)
|
||||
|
||||
file_open().write.assert_has_calls([
|
||||
call('job-name'),
|
||||
call(json.dumps(mock_outputs))
|
||||
], any_order=False) # Must be in the same order as called
|
||||
process._utils.write_output.assert_has_calls([
|
||||
call('/tmp/job_name_output_path', 'job-name'),
|
||||
call('/tmp/output_artifacts_output_path', mock_outputs, json_encode=True)
|
||||
])
|
||||
|
||||
def test_create_processing_job(self):
|
||||
mock_client = MagicMock()
|
||||
|
|
|
@ -16,7 +16,10 @@ required_args = [
|
|||
'--instance_count', '1',
|
||||
'--volume_size', '50',
|
||||
'--max_run_time', '3600',
|
||||
'--model_artifact_path', 'test-path'
|
||||
'--model_artifact_path', 'test-path',
|
||||
'--model_artifact_url_output_path', '/tmp/model_artifact_url_output_path',
|
||||
'--job_name_output_path', '/tmp/job_name_output_path',
|
||||
'--training_image_output_path', '/tmp/training_image_output_path',
|
||||
]
|
||||
|
||||
class TrainTestCase(unittest.TestCase):
|
||||
|
@ -38,8 +41,7 @@ class TrainTestCase(unittest.TestCase):
|
|||
train._utils.get_image_from_job.return_value = 'training-image'
|
||||
train._utils.get_model_artifacts_from_job.return_value = 'model-artifacts'
|
||||
|
||||
with patch('builtins.open', mock_open()) as file_open:
|
||||
train.main(required_args)
|
||||
train.main(required_args)
|
||||
|
||||
# Check if correct requests were created and triggered
|
||||
train._utils.create_training_job.assert_called()
|
||||
|
@ -47,17 +49,11 @@ class TrainTestCase(unittest.TestCase):
|
|||
train._utils.print_logs_for_job.assert_called()
|
||||
|
||||
# Check the file outputs
|
||||
file_open.assert_has_calls([
|
||||
call('/tmp/model_artifact_url.txt', 'w'),
|
||||
call('/tmp/job_name.txt', 'w'),
|
||||
call('/tmp/training_image.txt', 'w')
|
||||
], any_order=True)
|
||||
|
||||
file_open().write.assert_has_calls([
|
||||
call('model-artifacts'),
|
||||
call('job-name'),
|
||||
call('training-image'),
|
||||
], any_order=False) # Must be in the same order as called
|
||||
train._utils.write_output.assert_has_calls([
|
||||
call('/tmp/model_artifact_url_output_path', 'model-artifacts'),
|
||||
call('/tmp/job_name_output_path', 'job-name'),
|
||||
call('/tmp/training_image_output_path', 'training-image')
|
||||
])
|
||||
|
||||
def test_create_training_job(self):
|
||||
mock_client = MagicMock()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import unittest
|
||||
import json
|
||||
|
||||
from unittest.mock import patch, call, Mock, MagicMock, mock_open
|
||||
from botocore.exceptions import ClientError
|
||||
|
@ -42,4 +43,26 @@ class UtilsTestCase(unittest.TestCase):
|
|||
|
||||
with patch('logging.Logger.error') as errorLog:
|
||||
_utils.print_logs_for_job(mock_cw_client, '/aws/sagemaker/FakeJobs', 'fake_job_name')
|
||||
errorLog.assert_called()
|
||||
errorLog.assert_called()
|
||||
|
||||
def test_write_output_string(self):
|
||||
with patch("common._utils.Path", MagicMock()) as mock_path:
|
||||
_utils.write_output("/tmp/output-path", "output-value")
|
||||
|
||||
mock_path.assert_called_with("/tmp/output-path")
|
||||
mock_path("/tmp/output-path").parent.mkdir.assert_called()
|
||||
mock_path("/tmp/output-path").write_text.assert_called_with("output-value")
|
||||
|
||||
def test_write_output_json(self):
|
||||
# Ensure working versions of each type of JSON input
|
||||
test_cases = [{"key1": "value1"}, ["val1", "val2"], "string-val"]
|
||||
|
||||
for case in test_cases:
|
||||
with patch("common._utils.Path", MagicMock()) as mock_path:
|
||||
_utils.write_output("/tmp/test-output", case, json_encode=True)
|
||||
|
||||
mock_path.assert_called_with("/tmp/test-output")
|
||||
mock_path("/tmp/test-output").parent.mkdir.assert_called()
|
||||
mock_path("/tmp/test-output").write_text.assert_called_with(
|
||||
json.dumps(case)
|
||||
)
|
||||
|
|
|
@ -10,7 +10,8 @@ from common import _utils
|
|||
required_args = [
|
||||
'--region', 'us-west-2',
|
||||
'--team_name', 'test-team',
|
||||
'--description', 'fake team'
|
||||
'--description', 'fake team',
|
||||
'--workteam_arn_output_path', '/tmp/output'
|
||||
]
|
||||
|
||||
class WorkTeamTestCase(unittest.TestCase):
|
||||
|
@ -30,19 +31,14 @@ class WorkTeamTestCase(unittest.TestCase):
|
|||
# Set some static returns
|
||||
workteam._utils.create_workteam.return_value = 'arn:aws:sagemaker:us-east-1:999999999999:work-team'
|
||||
|
||||
with patch('builtins.open', mock_open()) as file_open:
|
||||
workteam.main(required_args)
|
||||
workteam.main(required_args)
|
||||
|
||||
# Check if correct requests were created and triggered
|
||||
workteam._utils.create_workteam.assert_called()
|
||||
|
||||
# Check the file outputs
|
||||
file_open.assert_has_calls([
|
||||
call('/tmp/workteam_arn.txt', 'w')
|
||||
])
|
||||
|
||||
file_open().write.assert_has_calls([
|
||||
call('arn:aws:sagemaker:us-east-1:999999999999:work-team')
|
||||
workteam._utils.write_output.assert_has_calls([
|
||||
call('/tmp/output', 'arn:aws:sagemaker:us-east-1:999999999999:work-team')
|
||||
])
|
||||
|
||||
def test_workteam(self):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
name: 'Sagemaker - Training Job'
|
||||
name: 'SageMaker - Training Job'
|
||||
description: |
|
||||
Train Machine Learning and Deep Learning Models using SageMaker
|
||||
inputs:
|
||||
|
@ -99,12 +99,12 @@ inputs:
|
|||
default: '{}'
|
||||
type: JsonObject
|
||||
outputs:
|
||||
- {name: model_artifact_url, description: 'Model artifacts url'}
|
||||
- {name: model_artifact_url, description: 'Model artifacts URL'}
|
||||
- {name: job_name, description: 'Training job name'}
|
||||
- {name: training_image, description: 'The registry path of the Docker image that contains the training algorithm'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.2
|
||||
command: ['python3']
|
||||
args: [
|
||||
train.py,
|
||||
|
@ -132,9 +132,8 @@ implementation:
|
|||
--spot_instance, {inputValue: spot_instance},
|
||||
--max_wait_time, {inputValue: max_wait_time},
|
||||
--checkpoint_config, {inputValue: checkpoint_config},
|
||||
--tags, {inputValue: tags}
|
||||
]
|
||||
fileOutputs:
|
||||
model_artifact_url: /tmp/model_artifact_url.txt
|
||||
job_name: /tmp/job_name.txt
|
||||
training_image: /tmp/training_image.txt
|
||||
--tags, {inputValue: tags},
|
||||
--model_artifact_url_output_path, {outputPath: model_artifact_url},
|
||||
--job_name_output_path, {outputPath: job_name},
|
||||
--training_image_output_path, {outputPath: training_image}
|
||||
]
|
|
@ -48,6 +48,12 @@ def create_parser():
|
|||
|
||||
parser.add_argument('--tags', type=_utils.yaml_or_json_str, required=False, help='An array of key-value pairs, to categorize AWS resources.', default={})
|
||||
|
||||
### Start outputs
|
||||
parser.add_argument('--model_artifact_url_output_path', type=str, default='/tmp/model-artifact-url', help='Local output path for the file containing the model artifacts URL.')
|
||||
parser.add_argument('--job_name_output_path', type=str, default='/tmp/job-name', help='Local output path for the file containing the training job name.')
|
||||
parser.add_argument('--training_image_output_path', type=str, default='/tmp/training-image', help='Local output path for the file containing the registry path of the Docker image that contains the training algorithm.')
|
||||
### End outputs
|
||||
|
||||
return parser
|
||||
|
||||
def main(argv=None):
|
||||
|
@ -72,12 +78,9 @@ def main(argv=None):
|
|||
model_artifact_url = _utils.get_model_artifacts_from_job(client, job_name)
|
||||
logging.info('Get model artifacts %s from training job %s.', model_artifact_url, job_name)
|
||||
|
||||
with open('/tmp/model_artifact_url.txt', 'w') as f:
|
||||
f.write(model_artifact_url)
|
||||
with open('/tmp/job_name.txt', 'w') as f:
|
||||
f.write(job_name)
|
||||
with open('/tmp/training_image.txt', 'w') as f:
|
||||
f.write(image)
|
||||
_utils.write_output(args.model_artifact_url_output_path, model_artifact_url)
|
||||
_utils.write_output(args.job_name_output_path, job_name)
|
||||
_utils.write_output(args.training_image_output_path, image)
|
||||
|
||||
logging.info('Job completed.')
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ outputs:
|
|||
- {name: workteam_arn, description: 'The ARN of the workteam.'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.1
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.2
|
||||
command: ['python3']
|
||||
args: [
|
||||
workteam.py,
|
||||
|
@ -49,6 +49,5 @@ implementation:
|
|||
--client_id, {inputValue: client_id},
|
||||
--sns_topic, {inputValue: sns_topic},
|
||||
--tags, {inputValue: tags},
|
||||
]
|
||||
fileOutputs:
|
||||
workteam_arn: /tmp/workteam_arn.txt
|
||||
--workteam_arn_output_path, {outputPath: workteam_arn}
|
||||
]
|
|
@ -27,6 +27,7 @@ def create_parser():
|
|||
parser.add_argument('--client_id', type=str, required=False, help='An identifier for an application client. You must create the app client ID using Amazon Cognito.', default='')
|
||||
parser.add_argument('--sns_topic', type=str, required=False, help='The ARN for the SNS topic to which notifications should be published.', default='')
|
||||
parser.add_argument('--tags', type=_utils.yaml_or_json_str, required=False, help='An array of key-value pairs, to categorize AWS resources.', default={})
|
||||
parser.add_argument('--workteam_arn_output_path', type=str, default='/tmp/workteam-arn', help='Local output path for the file containing the ARN of the workteam.')
|
||||
|
||||
return parser
|
||||
|
||||
|
@ -41,8 +42,7 @@ def main(argv=None):
|
|||
|
||||
logging.info('Workteam created.')
|
||||
|
||||
with open('/tmp/workteam_arn.txt', 'w') as f:
|
||||
f.write(workteam_arn)
|
||||
_utils.write_output(args.workteam_arn_output_path, workteam_arn)
|
||||
|
||||
|
||||
if __name__== "__main__":
|
||||
|
|
Loading…
Reference in New Issue