873 lines
42 KiB
Python
873 lines
42 KiB
Python
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import argparse
|
|
from time import gmtime, strftime
|
|
from distutils.util import strtobool
|
|
import time
|
|
import string
|
|
import random
|
|
import json
|
|
import yaml
|
|
import re
|
|
|
|
import boto3
|
|
import botocore
|
|
from botocore.exceptions import ClientError
|
|
from sagemaker.amazon.amazon_estimator import get_image_uri
|
|
|
|
import logging
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
|
|
# Mappings are extracted from the first table in https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html
|
|
built_in_algos = {
|
|
'blazingtext': 'blazingtext',
|
|
'deepar forecasting': 'forecasting-deepar',
|
|
'factorization machines': 'factorization-machines',
|
|
'image classification': 'image-classification',
|
|
'ip insights': 'ipinsights',
|
|
'k-means': 'kmeans',
|
|
'k-nearest neighbors': 'knn',
|
|
'k-nn': 'knn',
|
|
'lda': 'lda',
|
|
'linear learner': 'linear-learner',
|
|
'neural topic model': 'ntm',
|
|
'object2vec': 'object2vec',
|
|
'object detection': 'object-detection',
|
|
'pca': 'pca',
|
|
'random cut forest': 'randomcutforest',
|
|
'semantic segmentation': 'semantic-segmentation',
|
|
'sequence to sequence': 'seq2seq',
|
|
'seq2seq modeling': 'seq2seq',
|
|
'xgboost': 'xgboost'
|
|
}
|
|
|
|
# Get current directory to open templates
|
|
__cwd__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
|
|
|
|
|
def nullable_string_argument(value):
|
|
value = value.strip()
|
|
if not value:
|
|
return None
|
|
return value
|
|
|
|
|
|
def add_default_client_arguments(parser):
|
|
parser.add_argument('--region', type=str, required=True, help='The region where the training job launches.')
|
|
parser.add_argument('--endpoint_url', type=nullable_string_argument, required=False, help='The URL to use when communicating with the Sagemaker service.')
|
|
|
|
|
|
def get_component_version():
|
|
"""Get component version from the first line of License file"""
|
|
component_version = 'NULL'
|
|
|
|
with open('THIRD-PARTY-LICENSES.txt', 'r') as license_file:
|
|
version_match = re.search('Amazon SageMaker Components for Kubeflow Pipelines; version (([0-9]+[.])+[0-9]+)',
|
|
license_file.readline())
|
|
if version_match is not None:
|
|
component_version = version_match.group(1)
|
|
|
|
return component_version
|
|
|
|
|
|
def get_sagemaker_client(region, endpoint_url=None):
|
|
"""Builds a client to the AWS SageMaker API."""
|
|
session_config = botocore.config.Config(
|
|
user_agent='sagemaker-on-kubeflow-pipelines-v{}'.format(get_component_version())
|
|
)
|
|
client = boto3.client('sagemaker', region_name=region, endpoint_url=endpoint_url, config=session_config)
|
|
return client
|
|
|
|
|
|
def create_training_job_request(args):
|
|
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job
|
|
with open(os.path.join(__cwd__, 'train.template.yaml'), 'r') as f:
|
|
request = yaml.safe_load(f)
|
|
|
|
job_name = args['job_name'] if args['job_name'] else 'TrainingJob-' + strftime("%Y%m%d%H%M%S", gmtime()) + '-' + id_generator()
|
|
|
|
request['TrainingJobName'] = job_name
|
|
request['RoleArn'] = args['role']
|
|
request['HyperParameters'] = args['hyperparameters']
|
|
request['AlgorithmSpecification']['TrainingInputMode'] = args['training_input_mode']
|
|
|
|
### Update training image (for BYOC and built-in algorithms) or algorithm resource name
|
|
if not args['image'] and not args['algorithm_name']:
|
|
logging.error('Please specify training image or algorithm name.')
|
|
raise Exception('Could not create job request')
|
|
if args['image'] and args['algorithm_name']:
|
|
logging.error('Both image and algorithm name inputted, only one should be specified. Proceeding with image.')
|
|
|
|
if args['image']:
|
|
request['AlgorithmSpecification']['TrainingImage'] = args['image']
|
|
request['AlgorithmSpecification'].pop('AlgorithmName')
|
|
else:
|
|
# TODO: Adjust this implementation to account for custom algorithm resources names that are the same as built-in algorithm names
|
|
algo_name = args['algorithm_name'].lower().strip()
|
|
if algo_name in built_in_algos.keys():
|
|
request['AlgorithmSpecification']['TrainingImage'] = get_image_uri(args['region'], built_in_algos[algo_name])
|
|
request['AlgorithmSpecification'].pop('AlgorithmName')
|
|
logging.warning('Algorithm name is found as an Amazon built-in algorithm. Using built-in algorithm.')
|
|
# Just to give the user more leeway for built-in algorithm name inputs
|
|
elif algo_name in built_in_algos.values():
|
|
request['AlgorithmSpecification']['TrainingImage'] = get_image_uri(args['region'], algo_name)
|
|
request['AlgorithmSpecification'].pop('AlgorithmName')
|
|
logging.warning('Algorithm name is found as an Amazon built-in algorithm. Using built-in algorithm.')
|
|
else:
|
|
request['AlgorithmSpecification']['AlgorithmName'] = args['algorithm_name']
|
|
request['AlgorithmSpecification'].pop('TrainingImage')
|
|
|
|
### Update metric definitions
|
|
if args['metric_definitions']:
|
|
for key, val in args['metric_definitions'].items():
|
|
request['AlgorithmSpecification']['MetricDefinitions'].append({'Name': key, 'Regex': val})
|
|
else:
|
|
request['AlgorithmSpecification'].pop('MetricDefinitions')
|
|
|
|
### Update or pop VPC configs
|
|
if args['vpc_security_group_ids'] and args['vpc_subnets']:
|
|
request['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']]
|
|
request['VpcConfig']['Subnets'] = [args['vpc_subnets']]
|
|
else:
|
|
request.pop('VpcConfig')
|
|
|
|
### Update input channels, must have at least one specified
|
|
if len(args['channels']) > 0:
|
|
request['InputDataConfig'] = args['channels']
|
|
else:
|
|
logging.error("Must specify at least one input channel.")
|
|
raise Exception('Could not create job request')
|
|
|
|
request['OutputDataConfig']['S3OutputPath'] = args['model_artifact_path']
|
|
request['OutputDataConfig']['KmsKeyId'] = args['output_encryption_key']
|
|
request['ResourceConfig']['InstanceType'] = args['instance_type']
|
|
request['ResourceConfig']['VolumeKmsKeyId'] = args['resource_encryption_key']
|
|
request['EnableNetworkIsolation'] = args['network_isolation']
|
|
request['EnableInterContainerTrafficEncryption'] = args['traffic_encryption']
|
|
|
|
### Update InstanceCount, VolumeSizeInGB, and MaxRuntimeInSeconds if input is non-empty and > 0, otherwise use default values
|
|
if args['instance_count']:
|
|
request['ResourceConfig']['InstanceCount'] = args['instance_count']
|
|
|
|
if args['volume_size']:
|
|
request['ResourceConfig']['VolumeSizeInGB'] = args['volume_size']
|
|
|
|
if args['max_run_time']:
|
|
request['StoppingCondition']['MaxRuntimeInSeconds'] = args['max_run_time']
|
|
|
|
enable_spot_instance_support(request, args)
|
|
|
|
### Update tags
|
|
for key, val in args['tags'].items():
|
|
request['Tags'].append({'Key': key, 'Value': val})
|
|
|
|
return request
|
|
|
|
|
|
def create_training_job(client, args):
|
|
"""Create a Sagemaker training job."""
|
|
request = create_training_job_request(args)
|
|
try:
|
|
client.create_training_job(**request)
|
|
training_job_name = request['TrainingJobName']
|
|
logging.info("Created Training Job with name: " + training_job_name)
|
|
logging.info("Training job in SageMaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/jobs/{}"
|
|
.format(args['region'], args['region'], training_job_name))
|
|
logging.info("CloudWatch logs: https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={};streamFilter=typeLogStreamPrefix"
|
|
.format(args['region'], args['region'], training_job_name))
|
|
return training_job_name
|
|
except ClientError as e:
|
|
raise Exception(e.response['Error']['Message'])
|
|
|
|
|
|
def wait_for_training_job(client, training_job_name):
|
|
while(True):
|
|
response = client.describe_training_job(TrainingJobName=training_job_name)
|
|
status = response['TrainingJobStatus']
|
|
if status == 'Completed':
|
|
logging.info("Training job ended with status: " + status)
|
|
break
|
|
if status == 'Failed':
|
|
message = response['FailureReason']
|
|
logging.info('Training failed with the following error: {}'.format(message))
|
|
raise Exception('Training job failed')
|
|
logging.info("Training job is still in status: " + status)
|
|
time.sleep(30)
|
|
|
|
|
|
def get_model_artifacts_from_job(client, job_name):
|
|
info = client.describe_training_job(TrainingJobName=job_name)
|
|
model_artifact_url = info['ModelArtifacts']['S3ModelArtifacts']
|
|
return model_artifact_url
|
|
|
|
|
|
def get_image_from_job(client, job_name):
|
|
info = client.describe_training_job(TrainingJobName=job_name)
|
|
try:
|
|
image = info['AlgorithmSpecification']['TrainingImage']
|
|
except:
|
|
algorithm_name = info['AlgorithmSpecification']['AlgorithmName']
|
|
image = client.describe_algorithm(AlgorithmName=algorithm_name)['TrainingSpecification']['TrainingImage']
|
|
|
|
return image
|
|
|
|
|
|
def create_model(client, args):
|
|
request = create_model_request(args)
|
|
try:
|
|
create_model_response = client.create_model(**request)
|
|
logging.info("Model Config Arn: " + create_model_response['ModelArn'])
|
|
return create_model_response['ModelArn']
|
|
except ClientError as e:
|
|
raise Exception(e.response['Error']['Message'])
|
|
|
|
|
|
def create_model_request(args):
|
|
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model
|
|
with open(os.path.join(__cwd__, 'model.template.yaml'), 'r') as f:
|
|
request = yaml.safe_load(f)
|
|
|
|
request['ModelName'] = args['model_name']
|
|
request['PrimaryContainer']['Environment'] = args['environment']
|
|
|
|
if args['secondary_containers']:
|
|
request['Containers'] = args['secondary_containers']
|
|
request.pop('PrimaryContainer')
|
|
else:
|
|
request.pop('Containers')
|
|
### Update primary container and handle input errors
|
|
if args['container_host_name']:
|
|
request['PrimaryContainer']['ContainerHostname'] = args['container_host_name']
|
|
else:
|
|
request['PrimaryContainer'].pop('ContainerHostname')
|
|
|
|
if (args['image'] or args['model_artifact_url']) and args['model_package']:
|
|
logging.error("Please specify an image AND model artifact url, OR a model package name.")
|
|
raise Exception("Could not make create model request.")
|
|
elif args['model_package']:
|
|
request['PrimaryContainer']['ModelPackageName'] = args['model_package']
|
|
request['PrimaryContainer'].pop('Image')
|
|
request['PrimaryContainer'].pop('ModelDataUrl')
|
|
else:
|
|
if args['image'] and args['model_artifact_url']:
|
|
request['PrimaryContainer']['Image'] = args['image']
|
|
request['PrimaryContainer']['ModelDataUrl'] = args['model_artifact_url']
|
|
request['PrimaryContainer'].pop('ModelPackageName')
|
|
else:
|
|
logging.error("Please specify an image AND model artifact url.")
|
|
raise Exception("Could not make create model request.")
|
|
|
|
request['ExecutionRoleArn'] = args['role']
|
|
request['EnableNetworkIsolation'] = args['network_isolation']
|
|
|
|
### Update or pop VPC configs
|
|
if args['vpc_security_group_ids'] and args['vpc_subnets']:
|
|
request['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']]
|
|
request['VpcConfig']['Subnets'] = [args['vpc_subnets']]
|
|
else:
|
|
request.pop('VpcConfig')
|
|
|
|
### Update tags
|
|
for key, val in args['tags'].items():
|
|
request['Tags'].append({'Key': key, 'Value': val})
|
|
|
|
return request
|
|
|
|
def deploy_model(client, args):
|
|
endpoint_config_name = create_endpoint_config(client, args)
|
|
endpoint_name = create_endpoint(client, args['region'], args['endpoint_name'], endpoint_config_name, args['endpoint_tags'])
|
|
return endpoint_name
|
|
|
|
def create_endpoint_config_request(args):
|
|
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint_config
|
|
with open(os.path.join(__cwd__, 'endpoint_config.template.yaml'), 'r') as f:
|
|
request = yaml.safe_load(f)
|
|
|
|
endpoint_config_name = args['endpoint_config_name'] if args['endpoint_config_name'] else 'EndpointConfig' + args['model_name_1'][args['model_name_1'].index('-'):]
|
|
request['EndpointConfigName'] = endpoint_config_name
|
|
|
|
if args['resource_encryption_key']:
|
|
request['KmsKeyId'] = args['resource_encryption_key']
|
|
else:
|
|
request.pop('KmsKeyId')
|
|
|
|
if not args['model_name_1']:
|
|
logging.error("Must specify at least one model (model name) to host.")
|
|
raise Exception("Could not create endpoint config.")
|
|
|
|
for i in range(len(request['ProductionVariants']), 0, -1):
|
|
if args['model_name_' + str(i)]:
|
|
request['ProductionVariants'][i-1]['ModelName'] = args['model_name_' + str(i)]
|
|
if args['variant_name_' + str(i)]:
|
|
request['ProductionVariants'][i-1]['VariantName'] = args['variant_name_' + str(i)]
|
|
if args['initial_instance_count_' + str(i)]:
|
|
request['ProductionVariants'][i-1]['InitialInstanceCount'] = args['initial_instance_count_' + str(i)]
|
|
if args['instance_type_' + str(i)]:
|
|
request['ProductionVariants'][i-1]['InstanceType'] = args['instance_type_' + str(i)]
|
|
if args['initial_variant_weight_' + str(i)]:
|
|
request['ProductionVariants'][i-1]['InitialVariantWeight'] = args['initial_variant_weight_' + str(i)]
|
|
if args['accelerator_type_' + str(i)]:
|
|
request['ProductionVariants'][i-1]['AcceleratorType'] = args['accelerator_type_' + str(i)]
|
|
else:
|
|
request['ProductionVariants'][i-1].pop('AcceleratorType')
|
|
else:
|
|
request['ProductionVariants'].pop(i-1)
|
|
|
|
### Update tags
|
|
for key, val in args['endpoint_config_tags'].items():
|
|
request['Tags'].append({'Key': key, 'Value': val})
|
|
|
|
return request
|
|
|
|
def create_endpoint_config(client, args):
|
|
request = create_endpoint_config_request(args)
|
|
try:
|
|
create_endpoint_config_response = client.create_endpoint_config(**request)
|
|
logging.info("Endpoint configuration in SageMaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/endpointConfig/{}"
|
|
.format(args['region'], args['region'], request['EndpointConfigName']))
|
|
logging.info("Endpoint Config Arn: " + create_endpoint_config_response['EndpointConfigArn'])
|
|
return request['EndpointConfigName']
|
|
except ClientError as e:
|
|
raise Exception(e.response['Error']['Message'])
|
|
|
|
|
|
def create_endpoint(client, region, endpoint_name, endpoint_config_name, endpoint_tags):
|
|
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint
|
|
endpoint_name = endpoint_name if endpoint_name else 'Endpoint' + endpoint_config_name[endpoint_config_name.index('-'):]
|
|
|
|
### Update tags
|
|
tags=[]
|
|
for key, val in endpoint_tags.items():
|
|
tags.append({'Key': key, 'Value': val})
|
|
|
|
try:
|
|
create_endpoint_response = client.create_endpoint(
|
|
EndpointName=endpoint_name,
|
|
EndpointConfigName=endpoint_config_name,
|
|
Tags=tags)
|
|
logging.info("Created endpoint with name: " + endpoint_name)
|
|
logging.info("Endpoint in SageMaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/endpoints/{}"
|
|
.format(region, region, endpoint_name))
|
|
logging.info("CloudWatch logs: https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/Endpoints/{};streamFilter=typeLogStreamPrefix"
|
|
.format(region, region, endpoint_name))
|
|
return endpoint_name
|
|
except ClientError as e:
|
|
raise Exception(e.response['Error']['Message'])
|
|
|
|
|
|
def wait_for_endpoint_creation(client, endpoint_name):
|
|
status = client.describe_endpoint(EndpointName=endpoint_name)['EndpointStatus']
|
|
logging.info("Status: " + status)
|
|
|
|
try:
|
|
client.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name)
|
|
finally:
|
|
resp = client.describe_endpoint(EndpointName=endpoint_name)
|
|
status = resp['EndpointStatus']
|
|
logging.info("Endpoint Arn: " + resp['EndpointArn'])
|
|
logging.info("Create endpoint ended with status: " + status)
|
|
|
|
if status != 'InService':
|
|
message = client.describe_endpoint(EndpointName=endpoint_name)['FailureReason']
|
|
logging.info('Create endpoint failed with the following error: {}'.format(message))
|
|
raise Exception('Endpoint creation did not succeed')
|
|
|
|
|
|
def create_transform_job_request(args):
|
|
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_transform_job
|
|
with open(os.path.join(__cwd__, 'transform.template.yaml'), 'r') as f:
|
|
request = yaml.safe_load(f)
|
|
|
|
job_name = args['job_name'] if args['job_name'] else 'BatchTransform' + args['model_name'][args['model_name'].index('-'):]
|
|
|
|
request['TransformJobName'] = job_name
|
|
request['ModelName'] = args['model_name']
|
|
|
|
if args['max_concurrent']:
|
|
request['MaxConcurrentTransforms'] = args['max_concurrent']
|
|
|
|
if args['max_payload'] or args['max_payload'] == 0:
|
|
request['MaxPayloadInMB'] = args['max_payload']
|
|
|
|
if args['batch_strategy']:
|
|
request['BatchStrategy'] = args['batch_strategy']
|
|
else:
|
|
request.pop('BatchStrategy')
|
|
|
|
request['Environment'] = args['environment']
|
|
|
|
if args['data_type']:
|
|
request['TransformInput']['DataSource']['S3DataSource']['S3DataType'] = args['data_type']
|
|
|
|
request['TransformInput']['DataSource']['S3DataSource']['S3Uri'] = args['input_location']
|
|
request['TransformInput']['ContentType'] = args['content_type']
|
|
|
|
if args['compression_type']:
|
|
request['TransformInput']['CompressionType'] = args['compression_type']
|
|
|
|
if args['split_type']:
|
|
request['TransformInput']['SplitType'] = args['split_type']
|
|
|
|
request['TransformOutput']['S3OutputPath'] = args['output_location']
|
|
request['TransformOutput']['Accept'] = args['accept']
|
|
request['TransformOutput']['KmsKeyId'] = args['output_encryption_key']
|
|
|
|
if args['assemble_with']:
|
|
request['TransformOutput']['AssembleWith'] = args['assemble_with']
|
|
else:
|
|
request['TransformOutput'].pop('AssembleWith')
|
|
|
|
request['TransformResources']['InstanceType'] = args['instance_type']
|
|
request['TransformResources']['InstanceCount'] = args['instance_count']
|
|
request['TransformResources']['VolumeKmsKeyId'] = args['resource_encryption_key']
|
|
request['DataProcessing']['InputFilter'] = args['input_filter']
|
|
request['DataProcessing']['OutputFilter'] = args['output_filter']
|
|
|
|
if args['join_source']:
|
|
request['DataProcessing']['JoinSource'] = args['join_source']
|
|
|
|
### Update tags
|
|
if not args['tags'] is None:
|
|
for key, val in args['tags'].items():
|
|
request['Tags'].append({'Key': key, 'Value': val})
|
|
|
|
return request
|
|
|
|
|
|
def create_transform_job(client, args):
|
|
request = create_transform_job_request(args)
|
|
try:
|
|
client.create_transform_job(**request)
|
|
batch_job_name = request['TransformJobName']
|
|
logging.info("Created Transform Job with name: " + batch_job_name)
|
|
logging.info("Transform job in SageMaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/transform-jobs/{}"
|
|
.format(args['region'], args['region'], batch_job_name))
|
|
logging.info("CloudWatch logs: https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TransformJobs;prefix={};streamFilter=typeLogStreamPrefix"
|
|
.format(args['region'], args['region'], batch_job_name))
|
|
return batch_job_name
|
|
except ClientError as e:
|
|
raise Exception(e.response['Error']['Message'])
|
|
|
|
|
|
def wait_for_transform_job(client, batch_job_name):
|
|
### Wait until the job finishes
|
|
while(True):
|
|
response = client.describe_transform_job(TransformJobName=batch_job_name)
|
|
status = response['TransformJobStatus']
|
|
if status == 'Completed':
|
|
logging.info("Transform job ended with status: " + status)
|
|
break
|
|
if status == 'Failed':
|
|
message = response['FailureReason']
|
|
logging.info('Transform failed with the following error: {}'.format(message))
|
|
raise Exception('Transform job failed')
|
|
logging.info("Transform job is still in status: " + status)
|
|
time.sleep(30)
|
|
|
|
|
|
def create_hyperparameter_tuning_job_request(args):
|
|
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job
|
|
with open(os.path.join(__cwd__, 'hpo.template.yaml'), 'r') as f:
|
|
request = yaml.safe_load(f)
|
|
|
|
### Create a hyperparameter tuning job
|
|
request['HyperParameterTuningJobName'] = args['job_name'] if args['job_name'] else "HPOJob-" + strftime("%Y%m%d%H%M%S", gmtime()) + '-' + id_generator()
|
|
|
|
request['HyperParameterTuningJobConfig']['Strategy'] = args['strategy']
|
|
request['HyperParameterTuningJobConfig']['HyperParameterTuningJobObjective']['Type'] = args['metric_type']
|
|
request['HyperParameterTuningJobConfig']['HyperParameterTuningJobObjective']['MetricName'] = args['metric_name']
|
|
request['HyperParameterTuningJobConfig']['ResourceLimits']['MaxNumberOfTrainingJobs'] = args['max_num_jobs']
|
|
request['HyperParameterTuningJobConfig']['ResourceLimits']['MaxParallelTrainingJobs'] = args['max_parallel_jobs']
|
|
request['HyperParameterTuningJobConfig']['ParameterRanges']['IntegerParameterRanges'] = args['integer_parameters']
|
|
request['HyperParameterTuningJobConfig']['ParameterRanges']['ContinuousParameterRanges'] = args['continuous_parameters']
|
|
request['HyperParameterTuningJobConfig']['ParameterRanges']['CategoricalParameterRanges'] = args['categorical_parameters']
|
|
request['HyperParameterTuningJobConfig']['TrainingJobEarlyStoppingType'] = args['early_stopping_type']
|
|
|
|
request['TrainingJobDefinition']['StaticHyperParameters'] = args['static_parameters']
|
|
request['TrainingJobDefinition']['AlgorithmSpecification']['TrainingInputMode'] = args['training_input_mode']
|
|
|
|
### Update training image (for BYOC) or algorithm resource name
|
|
if not args['image'] and not args['algorithm_name']:
|
|
logging.error('Please specify training image or algorithm name.')
|
|
raise Exception('Could not create job request')
|
|
if args['image'] and args['algorithm_name']:
|
|
logging.error('Both image and algorithm name inputted, only one should be specified. Proceeding with image.')
|
|
|
|
if args['image']:
|
|
request['TrainingJobDefinition']['AlgorithmSpecification']['TrainingImage'] = args['image']
|
|
request['TrainingJobDefinition']['AlgorithmSpecification'].pop('AlgorithmName')
|
|
else:
|
|
# TODO: Adjust this implementation to account for custom algorithm resources names that are the same as built-in algorithm names
|
|
algo_name = args['algorithm_name'].lower().strip()
|
|
if algo_name in built_in_algos.keys():
|
|
request['TrainingJobDefinition']['AlgorithmSpecification']['TrainingImage'] = get_image_uri(args['region'], built_in_algos[algo_name])
|
|
request['TrainingJobDefinition']['AlgorithmSpecification'].pop('AlgorithmName')
|
|
logging.warning('Algorithm name is found as an Amazon built-in algorithm. Using built-in algorithm.')
|
|
# To give the user more leeway for built-in algorithm name inputs
|
|
elif algo_name in built_in_algos.values():
|
|
request['TrainingJobDefinition']['AlgorithmSpecification']['TrainingImage'] = get_image_uri(args['region'], algo_name)
|
|
request['TrainingJobDefinition']['AlgorithmSpecification'].pop('AlgorithmName')
|
|
logging.warning('Algorithm name is found as an Amazon built-in algorithm. Using built-in algorithm.')
|
|
else:
|
|
request['TrainingJobDefinition']['AlgorithmSpecification']['AlgorithmName'] = args['algorithm_name']
|
|
request['TrainingJobDefinition']['AlgorithmSpecification'].pop('TrainingImage')
|
|
|
|
### Update metric definitions
|
|
if args['metric_definitions']:
|
|
for key, val in args['metric_definitions'].items():
|
|
request['TrainingJobDefinition']['AlgorithmSpecification']['MetricDefinitions'].append({'Name': key, 'Regex': val})
|
|
else:
|
|
request['TrainingJobDefinition']['AlgorithmSpecification'].pop('MetricDefinitions')
|
|
|
|
### Update or pop VPC configs
|
|
if args['vpc_security_group_ids'] and args['vpc_subnets']:
|
|
request['TrainingJobDefinition']['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']]
|
|
request['TrainingJobDefinition']['VpcConfig']['Subnets'] = [args['vpc_subnets']]
|
|
else:
|
|
request['TrainingJobDefinition'].pop('VpcConfig')
|
|
|
|
### Update input channels, must have at least one specified
|
|
if len(args['channels']) > 0:
|
|
request['TrainingJobDefinition']['InputDataConfig'] = args['channels']
|
|
else:
|
|
logging.error("Must specify at least one input channel.")
|
|
raise Exception('Could not make job request')
|
|
|
|
request['TrainingJobDefinition']['OutputDataConfig']['S3OutputPath'] = args['output_location']
|
|
request['TrainingJobDefinition']['OutputDataConfig']['KmsKeyId'] = args['output_encryption_key']
|
|
request['TrainingJobDefinition']['ResourceConfig']['InstanceType'] = args['instance_type']
|
|
request['TrainingJobDefinition']['ResourceConfig']['VolumeKmsKeyId'] = args['resource_encryption_key']
|
|
request['TrainingJobDefinition']['EnableNetworkIsolation'] = args['network_isolation']
|
|
request['TrainingJobDefinition']['EnableInterContainerTrafficEncryption'] = args['traffic_encryption']
|
|
request['TrainingJobDefinition']['RoleArn'] = args['role']
|
|
|
|
### Update InstanceCount, VolumeSizeInGB, and MaxRuntimeInSeconds if input is non-empty and > 0, otherwise use default values
|
|
if args['instance_count']:
|
|
request['TrainingJobDefinition']['ResourceConfig']['InstanceCount'] = args['instance_count']
|
|
|
|
if args['volume_size']:
|
|
request['TrainingJobDefinition']['ResourceConfig']['VolumeSizeInGB'] = args['volume_size']
|
|
|
|
if args['max_run_time']:
|
|
request['TrainingJobDefinition']['StoppingCondition']['MaxRuntimeInSeconds'] = args['max_run_time']
|
|
|
|
### Update or pop warm start configs
|
|
if args['warm_start_type'] and args['parent_hpo_jobs']:
|
|
request['WarmStartConfig']['WarmStartType'] = args['warm_start_type']
|
|
parent_jobs = [n.strip() for n in args['parent_hpo_jobs'].split(',')]
|
|
for i in range(len(parent_jobs)):
|
|
request['WarmStartConfig']['ParentHyperParameterTuningJobs'].append({'HyperParameterTuningJobName': parent_jobs[i]})
|
|
else:
|
|
if args['warm_start_type'] or args['parent_hpo_jobs']:
|
|
if not args['warm_start_type']:
|
|
logging.error('Must specify warm start type as either "IdenticalDataAndAlgorithm" or "TransferLearning".')
|
|
if not args['parent_hpo_jobs']:
|
|
logging.error("Must specify at least one parent hyperparameter tuning job")
|
|
raise Exception('Could not make job request')
|
|
request.pop('WarmStartConfig')
|
|
|
|
enable_spot_instance_support(request['TrainingJobDefinition'], args)
|
|
|
|
### Update tags
|
|
for key, val in args['tags'].items():
|
|
request['Tags'].append({'Key': key, 'Value': val})
|
|
|
|
return request
|
|
|
|
|
|
def create_hyperparameter_tuning_job(client, args):
|
|
"""Create a Sagemaker HPO job"""
|
|
request = create_hyperparameter_tuning_job_request(args)
|
|
try:
|
|
job_arn = client.create_hyper_parameter_tuning_job(**request)
|
|
hpo_job_name = request['HyperParameterTuningJobName']
|
|
logging.info("Created Hyperparameter Training Job with name: " + hpo_job_name)
|
|
logging.info("HPO job in SageMaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/hyper-tuning-jobs/{}"
|
|
.format(args['region'], args['region'], hpo_job_name))
|
|
logging.info("CloudWatch logs: https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={};streamFilter=typeLogStreamPrefix"
|
|
.format(args['region'], args['region'], hpo_job_name))
|
|
return hpo_job_name
|
|
except ClientError as e:
|
|
raise Exception(e.response['Error']['Message'])
|
|
|
|
|
|
def wait_for_hyperparameter_training_job(client, hpo_job_name):
|
|
### Wait until the job finishes
|
|
while(True):
|
|
response = client.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=hpo_job_name)
|
|
status = response['HyperParameterTuningJobStatus']
|
|
if status == 'Completed':
|
|
logging.info("Hyperparameter tuning job ended with status: " + status)
|
|
break
|
|
if status == 'Failed':
|
|
message = response['FailureReason']
|
|
logging.error('Hyperparameter tuning failed with the following error: {}'.format(message))
|
|
raise Exception('Hyperparameter tuning job failed')
|
|
logging.info("Hyperparameter tuning job is still in status: " + status)
|
|
time.sleep(30)
|
|
|
|
|
|
def get_best_training_job_and_hyperparameters(client, hpo_job_name):
|
|
### Get and return best training job and its hyperparameters, without the objective metric
|
|
info = client.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=hpo_job_name)
|
|
best_job = info['BestTrainingJob']['TrainingJobName']
|
|
training_info = client.describe_training_job(TrainingJobName=best_job)
|
|
train_hyperparameters = training_info['HyperParameters']
|
|
train_hyperparameters.pop('_tuning_objective_metric')
|
|
return best_job, train_hyperparameters
|
|
|
|
|
|
def create_workteam(client, args):
|
|
try:
|
|
request = create_workteam_request(args)
|
|
response = client.create_workteam(**request)
|
|
portal = client.describe_workteam(WorkteamName=args['team_name'])['Workteam']['SubDomain']
|
|
logging.info("Labeling portal: " + portal)
|
|
return response['WorkteamArn']
|
|
except ClientError as e:
|
|
raise Exception(e.response['Error']['Message'])
|
|
|
|
def create_workteam_request(args):
|
|
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_workteam
|
|
"""Create a workteam"""
|
|
with open(os.path.join(__cwd__, 'workteam.template.yaml'), 'r') as f:
|
|
request = yaml.safe_load(f)
|
|
|
|
request['WorkteamName'] = args['team_name']
|
|
request['Description'] = args['description']
|
|
|
|
if args['sns_topic']:
|
|
request['NotificationConfiguration']['NotificationTopicArn'] = args['sns_topic']
|
|
else:
|
|
request.pop('NotificationConfiguration')
|
|
|
|
for group in [n.strip() for n in args['user_groups'].split(',')]:
|
|
request['MemberDefinitions'].append({'CognitoMemberDefinition': {'UserPool': args['user_pool'], 'UserGroup': group, 'ClientId': args['client_id']}})
|
|
|
|
for key, val in args['tags'].items():
|
|
request['Tags'].append({'Key': key, 'Value': val})
|
|
|
|
return request
|
|
|
|
|
|
def create_labeling_job_request(args):
|
|
### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_labeling_job
|
|
with open(os.path.join(__cwd__, 'gt.template.yaml'), 'r') as f:
|
|
request = yaml.safe_load(f)
|
|
|
|
# Mapping are extracted from ARNs listed in https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_labeling_job
|
|
algorithm_arn_map = {'us-west-2': '081040173940',
|
|
'us-east-1': '432418664414',
|
|
'us-east-2': '266458841044',
|
|
'eu-west-1': '568282634449',
|
|
'ap-northeast-1': '477331159723',
|
|
'ap-southeast-1': '454466003867'}
|
|
|
|
task_map = {'bounding box': 'BoundingBox',
|
|
'image classification': 'ImageMultiClass',
|
|
'semantic segmentation': 'SemanticSegmentation',
|
|
'text classification': 'TextMultiClass'}
|
|
|
|
auto_labeling_map = {'bounding box': 'object-detection',
|
|
'image classification': 'image-classification',
|
|
'text classification': 'text-classification'}
|
|
|
|
task = args['task_type'].lower()
|
|
|
|
request['LabelingJobName'] = args['job_name'] if args['job_name'] else "LabelingJob-" + strftime("%Y%m%d%H%M%S", gmtime()) + '-' + id_generator()
|
|
|
|
if args['label_attribute_name']:
|
|
name_check = args['label_attribute_name'].split('-')[-1]
|
|
if task == 'semantic segmentation' and name_check == 'ref' or task != 'semantic segmentation' and name_check != 'metadata' and name_check != 'ref':
|
|
request['LabelAttributeName'] = args['label_attribute_name']
|
|
else:
|
|
logging.error('Invalid label attribute name. If task type is semantic segmentation, name must end in "-ref". Else, name must not end in "-ref" or "-metadata".')
|
|
else:
|
|
request['LabelAttributeName'] = args['job_name']
|
|
|
|
request['InputConfig']['DataSource']['S3DataSource']['ManifestS3Uri'] = args['manifest_location']
|
|
request['OutputConfig']['S3OutputPath'] = args['output_location']
|
|
request['OutputConfig']['KmsKeyId'] = args['output_encryption_key']
|
|
request['RoleArn'] = args['role']
|
|
request['LabelCategoryConfigS3Uri'] = args['label_category_config']
|
|
|
|
### Update or pop stopping conditions
|
|
if not args['max_human_labeled_objects'] and not args['max_percent_objects']:
|
|
request.pop('StoppingConditions')
|
|
else:
|
|
if args['max_human_labeled_objects']:
|
|
request['StoppingConditions']['MaxHumanLabeledObjectCount'] = args['max_human_labeled_objects']
|
|
else:
|
|
request['StoppingConditions'].pop('MaxHumanLabeledObjectCount')
|
|
if args['max_percent_objects']:
|
|
request['StoppingConditions']['MaxPercentageOfInputDatasetLabeled'] = args['max_percent_objects']
|
|
else:
|
|
request['StoppingConditions'].pop('MaxPercentageOfInputDatasetLabeled')
|
|
|
|
### Update or pop automatic labeling configs
|
|
if args['enable_auto_labeling']:
|
|
if task == 'image classification' or task == 'bounding box' or task == 'text classification':
|
|
labeling_algorithm_arn = 'arn:aws:sagemaker:{}:027400017018:labeling-job-algorithm-specification/image-classification'.format(args['region'], auto_labeling_map[task])
|
|
request['LabelingJobAlgorithmsConfig']['LabelingJobAlgorithmSpecificationArn'] = labeling_algorithm_arn
|
|
if args['initial_model_arn']:
|
|
request['LabelingJobAlgorithmsConfig']['InitialActiveLearningModelArn'] = args['initial_model_arn']
|
|
else:
|
|
request['LabelingJobAlgorithmsConfig'].pop('InitialActiveLearningModelArn')
|
|
request['LabelingJobAlgorithmsConfig']['LabelingJobResourceConfig']['VolumeKmsKeyId'] = args['resource_encryption_key']
|
|
else:
|
|
logging.error("Automated data labeling not available for semantic segmentation or custom algorithms. Proceeding without automated data labeling.")
|
|
else:
|
|
request.pop('LabelingJobAlgorithmsConfig')
|
|
|
|
### Update pre-human and annotation consolidation task lambda functions
|
|
if task == 'image classification' or task == 'bounding box' or task == 'text classification' or task == 'semantic segmentation':
|
|
prehuman_arn = 'arn:aws:lambda:{}:{}:function:PRE-{}'.format(args['region'], algorithm_arn_map[args['region']], task_map[task])
|
|
acs_arn = 'arn:aws:lambda:{}:{}:function:ACS-{}'.format(args['region'], algorithm_arn_map[args['region']], task_map[task])
|
|
request['HumanTaskConfig']['PreHumanTaskLambdaArn'] = prehuman_arn
|
|
request['HumanTaskConfig']['AnnotationConsolidationConfig']['AnnotationConsolidationLambdaArn'] = acs_arn
|
|
elif task == 'custom' or task == '':
|
|
if args['pre_human_task_function'] and args['post_human_task_function']:
|
|
request['HumanTaskConfig']['PreHumanTaskLambdaArn'] = args['pre_human_task_function']
|
|
request['HumanTaskConfig']['AnnotationConsolidationConfig']['AnnotationConsolidationLambdaArn'] = args['post_human_task_function']
|
|
else:
|
|
logging.error("Must specify pre-human task lambda arn and annotation consolidation post-human task lambda arn.")
|
|
else:
|
|
logging.error("Task type must be Bounding Box, Image Classification, Semantic Segmentation, Text Classification, or Custom.")
|
|
|
|
request['HumanTaskConfig']['UiConfig']['UiTemplateS3Uri'] = args['ui_template']
|
|
request['HumanTaskConfig']['TaskTitle'] = args['title']
|
|
request['HumanTaskConfig']['TaskDescription'] = args['description']
|
|
request['HumanTaskConfig']['NumberOfHumanWorkersPerDataObject'] = args['num_workers_per_object']
|
|
request['HumanTaskConfig']['TaskTimeLimitInSeconds'] = args['time_limit']
|
|
|
|
if args['task_availibility']:
|
|
request['HumanTaskConfig']['TaskAvailabilityLifetimeInSeconds'] = args['task_availibility']
|
|
else:
|
|
request['HumanTaskConfig'].pop('TaskAvailabilityLifetimeInSeconds')
|
|
|
|
if args['max_concurrent_tasks']:
|
|
request['HumanTaskConfig']['MaxConcurrentTaskCount'] = args['max_concurrent_tasks']
|
|
else:
|
|
request['HumanTaskConfig'].pop('MaxConcurrentTaskCount')
|
|
|
|
if args['task_keywords']:
|
|
for word in [n.strip() for n in args['task_keywords'].split(',')]:
|
|
request['HumanTaskConfig']['TaskKeywords'].append(word)
|
|
else:
|
|
request['HumanTaskConfig'].pop('TaskKeywords')
|
|
|
|
### Update worker configurations
|
|
if args['worker_type'].lower() == 'public':
|
|
if args['no_adult_content']:
|
|
request['InputConfig']['DataAttributes']['ContentClassifiers'].append('FreeOfAdultContent')
|
|
if args['no_ppi']:
|
|
request['InputConfig']['DataAttributes']['ContentClassifiers'].append('FreeOfPersonallyIdentifiableInformation')
|
|
|
|
request['HumanTaskConfig']['WorkteamArn'] = 'arn:aws:sagemaker:{}:394669845002:workteam/public-crowd/default'.format(args['region'])
|
|
|
|
dollars = int(args['workforce_task_price'])
|
|
cents = int(100 * (args['workforce_task_price'] - dollars))
|
|
tenth_of_cents = int((args['workforce_task_price'] * 1000) - (dollars * 1000) - (cents * 10))
|
|
request['HumanTaskConfig']['PublicWorkforceTaskPrice']['AmountInUsd']['Dollars'] = dollars
|
|
request['HumanTaskConfig']['PublicWorkforceTaskPrice']['AmountInUsd']['Cents'] = cents
|
|
request['HumanTaskConfig']['PublicWorkforceTaskPrice']['AmountInUsd']['TenthFractionsOfACent'] = tenth_of_cents
|
|
else:
|
|
request['InputConfig'].pop('DataAttributes')
|
|
request['HumanTaskConfig']['WorkteamArn'] = args['workteam_arn']
|
|
request['HumanTaskConfig'].pop('PublicWorkforceTaskPrice')
|
|
|
|
for key, val in args['tags'].items():
|
|
request['Tags'].append({'Key': key, 'Value': val})
|
|
|
|
return request
|
|
|
|
|
|
def create_labeling_job(client, args):
|
|
"""Create a SageMaker Ground Truth job"""
|
|
request = create_labeling_job_request(args)
|
|
try:
|
|
client.create_labeling_job(**request)
|
|
gt_job_name = request['LabelingJobName']
|
|
logging.info("Created Ground Truth Labeling Job with name: " + gt_job_name)
|
|
logging.info("Ground Truth job in SageMaker: https://{}.console.aws.amazon.com/sagemaker/groundtruth?region={}#/labeling-jobs/details/{}"
|
|
.format(args['region'], args['region'], gt_job_name))
|
|
return gt_job_name
|
|
except ClientError as e:
|
|
raise Exception(e.response['Error']['Message'])
|
|
|
|
|
|
def wait_for_labeling_job(client, labeling_job_name):
|
|
### Wait until the job finishes
|
|
status = 'InProgress'
|
|
while(status == 'InProgress'):
|
|
response = client.describe_labeling_job(LabelingJobName=labeling_job_name)
|
|
status = response['LabelingJobStatus']
|
|
if status == 'Failed':
|
|
message = response['FailureReason']
|
|
logging.info('Labeling failed with the following error: {}'.format(message))
|
|
raise Exception('Labeling job failed')
|
|
logging.info("Labeling job is still in status: " + status)
|
|
time.sleep(30)
|
|
|
|
if status == 'Completed':
|
|
logging.info("Labeling job ended with status: " + status)
|
|
else:
|
|
raise Exception('Labeling job stopped')
|
|
|
|
|
|
def get_labeling_job_outputs(client, labeling_job_name, auto_labeling):
|
|
### Get and return labeling job outputs
|
|
info = client.describe_labeling_job(LabelingJobName=labeling_job_name)
|
|
output_manifest = info['LabelingJobOutput']['OutputDatasetS3Uri']
|
|
if auto_labeling:
|
|
active_learning_model_arn = info['LabelingJobOutput']['FinalActiveLearningModelArn']
|
|
else:
|
|
active_learning_model_arn = ' '
|
|
return output_manifest, active_learning_model_arn
|
|
|
|
def enable_spot_instance_support(training_job_config, args):
|
|
if args['spot_instance']:
|
|
training_job_config['EnableManagedSpotTraining'] = args['spot_instance']
|
|
if args['max_wait_time'] >= training_job_config['StoppingCondition']['MaxRuntimeInSeconds']:
|
|
training_job_config['StoppingCondition']['MaxWaitTimeInSeconds'] = args['max_wait_time']
|
|
else:
|
|
logging.error("Max wait time must be greater than or equal to max run time.")
|
|
raise Exception('Could not create job request.')
|
|
|
|
if args['checkpoint_config'] and 'S3Uri' in args['checkpoint_config']:
|
|
training_job_config['CheckpointConfig'] = args['checkpoint_config']
|
|
else:
|
|
logging.error("EnableManagedSpotTraining requires checkpoint config with an S3 uri.")
|
|
raise Exception('Could not create job request.')
|
|
else:
|
|
# Remove any artifacts that require spot instance support
|
|
del training_job_config['StoppingCondition']['MaxWaitTimeInSeconds']
|
|
del training_job_config['CheckpointConfig']
|
|
|
|
|
|
def id_generator(size=4, chars=string.ascii_uppercase + string.digits):
|
|
return ''.join(random.choice(chars) for _ in range(size))
|
|
|
|
def yaml_or_json_str(str):
|
|
if str == "" or str == None:
|
|
return None
|
|
try:
|
|
return json.loads(str)
|
|
except:
|
|
return yaml.safe_load(str)
|
|
|
|
def str_to_bool(str):
|
|
# This distutils function returns an integer representation of the boolean
|
|
# rather than a True/False value. This simply hard casts it.
|
|
return bool(strtobool(str)) |