[Component] Add VPC Interface Endpoint Support for SageMaker (#2299)
* Added Private Link Components * Updated Component Dockerfile * Added endpoint_url to Samples
This commit is contained in:
parent
9a9bd904ac
commit
2fe8c0de61
|
|
@ -64,6 +64,9 @@ inputs:
|
|||
- name: resource_encryption_key
|
||||
description: 'The AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s).'
|
||||
default: ''
|
||||
- name: endpoint_url
|
||||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
- name: tags
|
||||
description: 'Key-value pairs to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -71,11 +74,12 @@ outputs:
|
|||
- {name: output_location, description: 'S3 URI of the transform job results.'}
|
||||
implementation:
|
||||
container:
|
||||
image: redbackthomson/aws-kubeflow-sagemaker:20190930
|
||||
image: redbackthomson/aws-kubeflow-sagemaker:20191003
|
||||
command: ['python']
|
||||
args: [
|
||||
batch_transform.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--job_name, {inputValue: job_name},
|
||||
--model_name, {inputValue: model_name},
|
||||
--max_concurrent, {inputValue: max_concurrent},
|
||||
|
|
|
|||
|
|
@ -22,9 +22,10 @@ except NameError:
|
|||
unicode = str
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser(description='SageMaker Batch Transformation Job')
|
||||
parser.add_argument('--region', type=str.strip, required=True, help='The region where the cluster launches.')
|
||||
_utils.add_default_client_arguments(parser)
|
||||
|
||||
parser.add_argument('--job_name', type=str.strip, required=False, help='The name of the transform job.', default='')
|
||||
parser.add_argument('--model_name', type=str.strip, required=True, help='The name of the model that you want to use for the transform job.')
|
||||
parser.add_argument('--max_concurrent', type=_utils.str_to_int, required=False, help='The maximum number of parallel requests that can be sent to each instance in a transform job.', default='0')
|
||||
|
|
@ -51,10 +52,14 @@ def main(argv=None):
|
|||
parser.add_argument('--tags', type=_utils.str_to_json_dict, required=False, help='An array of key-value pairs, to categorize AWS resources.', default='{}')
|
||||
parser.add_argument('--output_location_file', type=str.strip, required=True, help='File path where the program will write the Amazon S3 URI of the transform job results.')
|
||||
|
||||
return parser
|
||||
|
||||
def main(argv=None):
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_client(args.region)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
|
||||
logging.info('Submitting Batch Transformation request to SageMaker...')
|
||||
batch_job_name = _utils.create_transform_job(client, vars(args))
|
||||
logging.info('Batch Job request submitted. Waiting for completion...')
|
||||
|
|
|
|||
|
|
@ -52,11 +52,14 @@ built_in_algos = {
|
|||
# Get current directory to open templates
|
||||
__cwd__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||
|
||||
def get_client(region=None):
|
||||
"""Builds a client to the AWS SageMaker API."""
|
||||
client = boto3.client('sagemaker', region_name=region)
|
||||
return client
|
||||
def add_default_client_arguments(parser):
|
||||
parser.add_argument('--region', type=str.strip, required=True, help='The region where the training job launches.')
|
||||
parser.add_argument('--endpoint_url', type=str.strip, required=False, help='The URL to use when communicating with the Sagemaker service.')
|
||||
|
||||
def get_sagemaker_client(region, endpoint_url=None):
|
||||
"""Builds a client to the AWS SageMaker API."""
|
||||
client = boto3.client('sagemaker', region_name=region, endpoint_url=endpoint_url)
|
||||
return client
|
||||
|
||||
def create_training_job_request(args):
|
||||
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job
|
||||
|
|
|
|||
|
|
@ -63,6 +63,9 @@ inputs:
|
|||
- name: resource_encryption_key
|
||||
description: 'The AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance that hosts the endpoint.'
|
||||
default: ''
|
||||
- name: endpoint_url
|
||||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
- name: endpoint_config_tags
|
||||
description: 'Key-value pairs to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -76,11 +79,12 @@ outputs:
|
|||
- {name: endpoint_name, description: 'Endpoint name'}
|
||||
implementation:
|
||||
container:
|
||||
image: redbackthomson/aws-kubeflow-sagemaker:20190930
|
||||
image: redbackthomson/aws-kubeflow-sagemaker:20191003
|
||||
command: ['python']
|
||||
args: [
|
||||
deploy.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--endpoint_config_name, {inputValue: endpoint_config_name},
|
||||
--variant_name_1,{inputValue: variant_name_1},
|
||||
--model_name_1, {inputValue: model_name_1},
|
||||
|
|
|
|||
|
|
@ -15,9 +15,10 @@ import logging
|
|||
|
||||
from common import _utils
|
||||
|
||||
def main(argv=None):
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser(description='SageMaker Training Job')
|
||||
parser.add_argument('--region', type=str.strip, required=True, help='The region where the cluster launches.')
|
||||
_utils.add_default_client_arguments(parser)
|
||||
|
||||
parser.add_argument('--endpoint_config_name', type=str.strip, required=False, help='The name of the endpoint configuration.', default='')
|
||||
parser.add_argument('--variant_name_1', type=str.strip, required=False, help='The name of the production variant.', default='variant-name-1')
|
||||
parser.add_argument('--model_name_1', type=str.strip, required=True, help='The model name used for endpoint deployment.')
|
||||
|
|
@ -48,10 +49,15 @@ def main(argv=None):
|
|||
|
||||
parser.add_argument('--endpoint_name', type=str.strip, required=False, help='The name of the endpoint.', default='')
|
||||
parser.add_argument('--endpoint_tags', type=_utils.str_to_json_dict, required=False, help='An array of key-value pairs, to categorize AWS resources.', default='{}')
|
||||
|
||||
return parser
|
||||
|
||||
def main(argv=None):
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_client(args.region)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
|
||||
logging.info('Submitting Endpoint request to SageMaker...')
|
||||
endpoint_name = _utils.deploy_model(client, vars(args))
|
||||
logging.info('Endpoint creation request submitted. Waiting for completion...')
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ time_limit | The maximum run time in seconds per training job | No | Int | [30,
|
|||
task_availibility | The length of time that a task remains available for labeling by human workers | Yes | Int | Public workforce: [1, 43200], other: [1, 864000] | |
|
||||
max_concurrent_tasks | The maximum number of data objects that can be labeled by human workers at the same time | Yes | Int | [1, 1000] | |
|
||||
workforce_task_price | The price that you pay for each task performed by a public worker in USD; Specify to the tenth fractions of a cent; Format as "0.000" | Yes | Float | 0.000 |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint. | Yes | String | | |
|
||||
tags | Key-value pairs to categorize AWS resources | Yes | Dict | | {} |
|
||||
|
||||
## Outputs
|
||||
|
|
|
|||
|
|
@ -77,6 +77,9 @@ inputs:
|
|||
- name: workforce_task_price
|
||||
description: 'The price that you pay for each task performed by a public worker in USD. Specify to the tenth fractions of a cent. Format as "0.000".'
|
||||
default: '0.000'
|
||||
- name: endpoint_url
|
||||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
- name: tags
|
||||
description: 'Key-value pairs to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -85,11 +88,12 @@ outputs:
|
|||
- {name: active_learning_model_arn, description: 'The ARN for the most recent Amazon SageMaker model trained as part of automated data labeling.'}
|
||||
implementation:
|
||||
container:
|
||||
image: redbackthomson/aws-kubeflow-sagemaker:20190930
|
||||
image: redbackthomson/aws-kubeflow-sagemaker:20191003
|
||||
command: ['python']
|
||||
args: [
|
||||
ground_truth.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--role, {inputValue: role},
|
||||
--job_name, {inputValue: job_name},
|
||||
--label_attribute_name, {inputValue: label_attribute_name},
|
||||
|
|
|
|||
|
|
@ -15,9 +15,10 @@ import logging
|
|||
|
||||
from common import _utils
|
||||
|
||||
def main(argv=None):
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser(description='SageMaker Ground Truth Job')
|
||||
parser.add_argument('--region', type=str.strip, required=True, help='The region where the resources are.')
|
||||
_utils.add_default_client_arguments(parser)
|
||||
|
||||
parser.add_argument('--role', type=str.strip, required=True, help='The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf.')
|
||||
parser.add_argument('--job_name', type=str.strip, required=True, help='The name of the labeling job.')
|
||||
parser.add_argument('--label_attribute_name', type=str.strip, required=False, help='The attribute name to use for the label in the output manifest file. Default is the job name.', default='')
|
||||
|
|
@ -48,10 +49,14 @@ def main(argv=None):
|
|||
parser.add_argument('--workforce_task_price', type=_utils.str_to_float, required=False, help='The price that you pay for each task performed by a public worker in USD. Specify to the tenth fractions of a cent. Format as "0.000".', default=0.000)
|
||||
parser.add_argument('--tags', type=_utils.str_to_json_dict, required=False, help='An array of key-value pairs, to categorize AWS resources.', default='{}')
|
||||
|
||||
return parser
|
||||
|
||||
def main(argv=None):
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_client(args.region)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
|
||||
logging.info('Submitting Ground Truth Job request to SageMaker...')
|
||||
_utils.create_labeling_job(client, vars(args))
|
||||
logging.info('Ground Truth labeling job request submitted. Waiting for completion...')
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ max_wait_time | The maximum time in seconds you are willing to wait for a manage
|
|||
checkpoint_config | Dictionary of information about the output location for managed spot training checkpoint data | Yes | Yes | Dict | | {} |
|
||||
warm_start_type | Specifies the type of warm start used | Yes | No | String | IdenticalDataAndAlgorithm, TransferLearning | |
|
||||
parent_hpo_jobs | List of previously completed or stopped hyperparameter tuning jobs to be used as a starting point | Yes | Yes | String | Yes | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint. | Yes | Yes | String | | |
|
||||
tags | Key-value pairs to categorize AWS resources | Yes | Yes | Dict | | {} |
|
||||
|
||||
Notes:
|
||||
|
|
|
|||
|
|
@ -120,6 +120,9 @@ inputs:
|
|||
- name: parent_hpo_jobs
|
||||
description: 'List of previously completed or stopped hyperparameter tuning jobs to be used as a starting point.'
|
||||
default: ''
|
||||
- name: endpoint_url
|
||||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
- name: tags
|
||||
description: 'Key-value pairs, to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -136,11 +139,12 @@ outputs:
|
|||
description: 'The registry path of the Docker image that contains the training algorithm'
|
||||
implementation:
|
||||
container:
|
||||
image: redbackthomson/aws-kubeflow-sagemaker:20190930
|
||||
image: redbackthomson/aws-kubeflow-sagemaker:20191003
|
||||
command: ['python']
|
||||
args: [
|
||||
hyperparameter_tuning.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--job_name, {inputValue: job_name},
|
||||
--role, {inputValue: role},
|
||||
--image, {inputValue: image},
|
||||
|
|
|
|||
|
|
@ -18,7 +18,8 @@ from common import _utils
|
|||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser(description='SageMaker Hyperparameter Tuning Job')
|
||||
parser.add_argument('--region', type=str.strip, required=True, help='The region where the cluster launches.')
|
||||
_utils.add_default_client_arguments(parser)
|
||||
|
||||
parser.add_argument('--job_name', type=str.strip, required=False, help='The name of the tuning job. Must be unique within the same AWS account and AWS region.')
|
||||
parser.add_argument('--role', type=str.strip, required=True, help='The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf.')
|
||||
parser.add_argument('--image', type=str.strip, required=True, help='The registry path of the Docker image that contains the training algorithm.', default='')
|
||||
|
|
@ -75,7 +76,7 @@ def main(argv=None):
|
|||
args = parser.parse_args()
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_client(args.region)
|
||||
client = _utils.get_sagemaker_client(args.region)
|
||||
logging.info('Submitting HyperParameter Tuning Job request to SageMaker...')
|
||||
hpo_job_name = _utils.create_hyperparameter_tuning_job(client, vars(args))
|
||||
logging.info('HyperParameter Tuning Job request submitted. Waiting for completion...')
|
||||
|
|
|
|||
|
|
@ -35,6 +35,9 @@ inputs:
|
|||
- name: network_isolation
|
||||
description: 'Isolates the training container.'
|
||||
default: 'True'
|
||||
- name: endpoint_url
|
||||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
- name: tags
|
||||
description: 'Key-value pairs to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -42,11 +45,12 @@ outputs:
|
|||
- {name: model_name, description: 'The model name Sagemaker created'}
|
||||
implementation:
|
||||
container:
|
||||
image: redbackthomson/aws-kubeflow-sagemaker:20190930
|
||||
image: redbackthomson/aws-kubeflow-sagemaker:20191003
|
||||
command: ['python']
|
||||
args: [
|
||||
create_model.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--model_name, {inputValue: model_name},
|
||||
--role, {inputValue: role},
|
||||
--container_host_name, {inputValue: container_host_name},
|
||||
|
|
|
|||
|
|
@ -15,9 +15,10 @@ import logging
|
|||
|
||||
from common import _utils
|
||||
|
||||
def main(argv=None):
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser(description='SageMaker Training Job')
|
||||
parser.add_argument('--region', type=str.strip, required=True, help='The region where the cluster launches.')
|
||||
_utils.add_default_client_arguments(parser)
|
||||
|
||||
parser.add_argument('--model_name', type=str.strip, required=True, help='The name of the new model.')
|
||||
parser.add_argument('--role', type=str.strip, required=True, help='The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf.')
|
||||
parser.add_argument('--container_host_name', type=str.strip, required=False, help='When a ContainerDefinition is part of an inference pipeline, this value uniquely identifies the container for the purposes of logging and metrics.', default='')
|
||||
|
|
@ -30,10 +31,15 @@ def main(argv=None):
|
|||
parser.add_argument('--vpc_subnets', type=str.strip, required=False, help='The ID of the subnets in the VPC to which you want to connect your hpo job.', default='')
|
||||
parser.add_argument('--network_isolation', type=_utils.str_to_bool, required=False, help='Isolates the training container.', default=True)
|
||||
parser.add_argument('--tags', type=_utils.str_to_json_dict, required=False, help='An array of key-value pairs, to categorize AWS resources.', default='{}')
|
||||
|
||||
return parser
|
||||
|
||||
def main(argv=None):
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_client(args.region)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
|
||||
|
||||
logging.info('Submitting model creation request to SageMaker...')
|
||||
_utils.create_model(client, vars(args))
|
||||
|
|
|
|||
|
|
@ -91,6 +91,9 @@ inputs:
|
|||
- name: checkpoint_config
|
||||
description: 'Dictionary of information about the output location for managed spot training checkpoint data.'
|
||||
default: '{}'
|
||||
- name: endpoint_url
|
||||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
- name: tags
|
||||
description: 'Key-value pairs, to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -100,11 +103,12 @@ outputs:
|
|||
- {name: training_image, description: 'The registry path of the Docker image that contains the training algorithm'}
|
||||
implementation:
|
||||
container:
|
||||
image: redbackthomson/aws-kubeflow-sagemaker:20190930
|
||||
image: redbackthomson/aws-kubeflow-sagemaker:20191003
|
||||
command: ['python']
|
||||
args: [
|
||||
train.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--job_name, {inputValue: job_name},
|
||||
--role, {inputValue: role},
|
||||
--image, {inputValue: image},
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ from common import _utils
|
|||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser(description='SageMaker Training Job')
|
||||
parser.add_argument('--region', type=str.strip, required=True, help='The region where the training job launches.')
|
||||
_utils.add_default_client_arguments(parser)
|
||||
|
||||
parser.add_argument('--job_name', type=str.strip, required=False, help='The name of the training job.', default='')
|
||||
parser.add_argument('--role', type=str.strip, required=True, help='The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf.')
|
||||
parser.add_argument('--image', type=str.strip, required=True, help='The registry path of the Docker image that contains the training algorithm.', default='')
|
||||
|
|
@ -63,7 +64,7 @@ def main(argv=None):
|
|||
args = parser.parse_args()
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_client(args.region)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
|
||||
|
||||
logging.info('Submitting Training Job to SageMaker...')
|
||||
job_name = _utils.create_training_job(client, vars(args))
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ user_pool | An identifier for a user pool, which must be in the same region as t
|
|||
user_groups | An identifier for user groups separated by commas | No | String | | |
|
||||
client_id | An identifier for an application client, which you must create using Amazon Cognito | No | String | | |
|
||||
sns_topic | The AWS KMS key that Amazon SageMaker uses to encrypt the model artifacts | Yes | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint. | Yes | Yes | String | | |
|
||||
tags | Key-value pairs to categorize AWS resources | Yes | Dict | | {} |
|
||||
|
||||
Notes:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,9 @@ inputs:
|
|||
- name: sns_topic
|
||||
description: 'The ARN for the SNS topic to which notifications should be published.'
|
||||
default: ''
|
||||
- name: endpoint_url
|
||||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
- name: tags
|
||||
description: 'Key-value pairs to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -24,11 +27,12 @@ outputs:
|
|||
- {name: workteam_arn, description: 'The ARN of the workteam.'}
|
||||
implementation:
|
||||
container:
|
||||
image: redbackthomson/aws-kubeflow-sagemaker:20190930
|
||||
image: redbackthomson/aws-kubeflow-sagemaker:20191003
|
||||
command: ['python']
|
||||
args: [
|
||||
workteam.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--team_name, {inputValue: team_name},
|
||||
--description, {inputValue: description},
|
||||
--user_pool, {inputValue: user_pool},
|
||||
|
|
|
|||
|
|
@ -15,9 +15,10 @@ import logging
|
|||
|
||||
from common import _utils
|
||||
|
||||
def main(argv=None):
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser(description='SageMaker Hyperparameter Tuning Job')
|
||||
parser.add_argument('--region', type=str.strip, required=True, help='The region where the cluster launches.')
|
||||
_utils.add_default_client_arguments(parser)
|
||||
|
||||
parser.add_argument('--team_name', type=str.strip, required=True, help='The name of your work team.')
|
||||
parser.add_argument('--description', type=str.strip, required=True, help='A description of the work team.')
|
||||
parser.add_argument('--user_pool', type=str.strip, required=False, help='An identifier for a user pool. The user pool must be in the same region as the service that you are calling.', default='')
|
||||
|
|
@ -26,10 +27,14 @@ def main(argv=None):
|
|||
parser.add_argument('--sns_topic', type=str.strip, required=False, help='The ARN for the SNS topic to which notifications should be published.', default='')
|
||||
parser.add_argument('--tags', type=_utils.str_to_json_dict, required=False, help='An array of key-value pairs, to categorize AWS resources.', default='{}')
|
||||
|
||||
return parser
|
||||
|
||||
def main(argv=None):
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_client(args.region)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
|
||||
logging.info('Submitting a create workteam request to SageMaker...')
|
||||
workteam_arn = _utils.create_workteam(client, vars(args))
|
||||
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ def hpo_test(region='us-west-2',
|
|||
max_run_time='3600',
|
||||
vpc_security_group_ids='',
|
||||
vpc_subnets='',
|
||||
endpoint_url='',
|
||||
network_isolation='True',
|
||||
traffic_encryption='False',
|
||||
warm_start_type='',
|
||||
|
|
@ -74,6 +75,7 @@ def hpo_test(region='us-west-2',
|
|||
|
||||
training = sagemaker_hpo_op(
|
||||
region=region,
|
||||
endpoint_url=endpoint_url,
|
||||
job_name=hpo_job_name,
|
||||
image=image,
|
||||
training_input_mode=training_input_mode,
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ def mnist_classification(region='us-west-2',
|
|||
hpo_max_num_jobs='9',
|
||||
hpo_max_parallel_jobs='3',
|
||||
max_run_time='3600',
|
||||
endpoint_url='',
|
||||
network_isolation='True',
|
||||
traffic_encryption='False',
|
||||
train_channels='[{"ChannelName": "train", \
|
||||
|
|
@ -93,6 +94,7 @@ def mnist_classification(region='us-west-2',
|
|||
|
||||
hpo = sagemaker_hpo_op(
|
||||
region=region,
|
||||
endpoint_url=endpoint_url,
|
||||
image=image,
|
||||
training_input_mode=training_input_mode,
|
||||
strategy=hpo_strategy,
|
||||
|
|
@ -122,6 +124,7 @@ def mnist_classification(region='us-west-2',
|
|||
|
||||
training = sagemaker_train_op(
|
||||
region=region,
|
||||
endpoint_url=endpoint_url,
|
||||
image=image,
|
||||
training_input_mode=training_input_mode,
|
||||
hyperparameters=hpo.outputs['best_hyperparameters'],
|
||||
|
|
@ -142,6 +145,7 @@ def mnist_classification(region='us-west-2',
|
|||
|
||||
create_model = sagemaker_model_op(
|
||||
region=region,
|
||||
endpoint_url=endpoint_url,
|
||||
model_name=training.outputs['job_name'],
|
||||
image=training.outputs['training_image'],
|
||||
model_artifact_url=training.outputs['model_artifact_url'],
|
||||
|
|
@ -151,11 +155,13 @@ def mnist_classification(region='us-west-2',
|
|||
|
||||
prediction = sagemaker_deploy_op(
|
||||
region=region,
|
||||
endpoint_url=endpoint_url,
|
||||
model_name_1=create_model.output,
|
||||
).apply(use_aws_secret('aws-secret', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY'))
|
||||
|
||||
batch_transform = sagemaker_batch_transform_op(
|
||||
region=region,
|
||||
endpoint_url=endpoint_url,
|
||||
model_name=create_model.output,
|
||||
instance_type=batch_transform_instance_type,
|
||||
instance_count=instance_count,
|
||||
|
|
|
|||
Loading…
Reference in New Issue