# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from time import gmtime, strftime import time import string import random import json import yaml import boto3 from botocore.exceptions import ClientError from sagemaker.amazon.amazon_estimator import get_image_uri import logging logging.getLogger().setLevel(logging.INFO) # Mappings are extracted from the first table in https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html built_in_algos = { 'blazingtext': 'blazingtext', 'deepar forecasting': 'forecasting-deepar', 'factorization machines': 'factorization-machines', 'image classification': 'image-classification', 'ip insights': 'ipinsights', 'k-means': 'kmeans', 'k-nearest neighbors': 'knn', 'k-nn': 'knn', 'lda': 'lda', 'linear learner': 'linear-learner', 'neural topic model': 'ntm', 'object2vec': 'object2vec', 'object detection': 'object-detection', 'pca': 'pca', 'random cut forest': 'randomcutforest', 'semantic segmentation': 'semantic-segmentation', 'sequence to sequence': 'seq2seq', 'seq2seq modeling': 'seq2seq', 'xgboost': 'xgboost' } def get_client(region=None): """Builds a client to the AWS SageMaker API.""" client = boto3.client('sagemaker', region_name=region) return client def create_training_job_request(args): ### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job with open('/app/common/train.template.yaml', 'r') as f: request = yaml.safe_load(f) job_name = args['job_name'] if args['job_name'] else 'TrainingJob-' + strftime("%Y%m%d%H%M%S", gmtime()) + '-' + id_generator() request['TrainingJobName'] = job_name request['RoleArn'] = args['role'] request['HyperParameters'] = args['hyperparameters'] request['AlgorithmSpecification']['TrainingInputMode'] = args['training_input_mode'] ### Update training image (for BYOC and built-in algorithms) or algorithm resource name if not args['image'] and not args['algorithm_name']: logging.error('Please specify training image or algorithm name.') raise Exception('Could not create job request') if args['image'] and args['algorithm_name']: logging.error('Both image and algorithm name inputted, only one should be specified. Proceeding with image.') if args['image']: request['AlgorithmSpecification']['TrainingImage'] = args['image'] request['AlgorithmSpecification'].pop('AlgorithmName') else: # TODO: Adjust this implementation to account for custom algorithm resources names that are the same as built-in algorithm names algo_name = args['algorithm_name'].lower().strip() if algo_name in built_in_algos.keys(): request['AlgorithmSpecification']['TrainingImage'] = get_image_uri(args['region'], built_in_algos[algo_name]) request['AlgorithmSpecification'].pop('AlgorithmName') logging.warning('Algorithm name is found as an Amazon built-in algorithm. Using built-in algorithm.') # Just to give the user more leeway for built-in algorithm name inputs elif algo_name in built_in_algos.values(): request['AlgorithmSpecification']['TrainingImage'] = get_image_uri(args['region'], algo_name) request['AlgorithmSpecification'].pop('AlgorithmName') logging.warning('Algorithm name is found as an Amazon built-in algorithm. Using built-in algorithm.') else: request['AlgorithmSpecification']['AlgorithmName'] = args['algorithm_name'] request['AlgorithmSpecification'].pop('TrainingImage') ### Update metric definitions if args['metric_definitions']: for key, val in args['metric_definitions'].items(): request['AlgorithmSpecification']['MetricDefinitions'].append({'Name': key, 'Regex': val}) else: request['AlgorithmSpecification'].pop('MetricDefinitions') ### Update or pop VPC configs if args['vpc_security_group_ids'] and args['vpc_subnets']: request['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']] request['VpcConfig']['Subnets'] = [args['vpc_subnets']] else: request.pop('VpcConfig') ### Update input channels, must have at least one specified if len(args['channels']) > 0: request['InputDataConfig'] = args['channels'] # Max number of input channels/data locations is 20, but currently only 8 data location parameters are exposed separately. # Source: Input data configuration description in the SageMaker create training job form for i in range(1, len(args['channels']) + 1): if args['data_location_' + str(i)]: request['InputDataConfig'][i-1]['DataSource']['S3DataSource']['S3Uri'] = args['data_location_' + str(i)] else: logging.error("Must specify at least one input channel.") raise Exception('Could not make job request') request['OutputDataConfig']['S3OutputPath'] = args['model_artifact_path'] request['OutputDataConfig']['KmsKeyId'] = args['output_encryption_key'] request['ResourceConfig']['InstanceType'] = args['instance_type'] request['ResourceConfig']['VolumeKmsKeyId'] = args['resource_encryption_key'] request['EnableNetworkIsolation'] = args['network_isolation'] request['EnableInterContainerTrafficEncryption'] = args['traffic_encryption'] ### Update InstanceCount, VolumeSizeInGB, and MaxRuntimeInSeconds if input is non-empty and > 0, otherwise use default values if args['instance_count']: request['ResourceConfig']['InstanceCount'] = args['instance_count'] if args['volume_size']: request['ResourceConfig']['VolumeSizeInGB'] = args['volume_size'] if args['max_run_time']: request['StoppingCondition']['MaxRuntimeInSeconds'] = args['max_run_time'] ### Update tags for key, val in args['tags'].items(): request['Tags'].append({'Key': key, 'Value': val}) return request def create_training_job(client, args): """Create a Sagemaker training job.""" request = create_training_job_request(args) try: client.create_training_job(**request) training_job_name = request['TrainingJobName'] logging.info("Created Training Job with name: " + training_job_name) logging.info("Training job in SageMaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/jobs/{}" .format(args['region'], args['region'], training_job_name)) logging.info("CloudWatch logs: https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={};streamFilter=typeLogStreamPrefix" .format(args['region'], args['region'], training_job_name)) return training_job_name except ClientError as e: raise Exception(e.response['Error']['Message']) def wait_for_training_job(client, training_job_name): while(True): response = client.describe_training_job(TrainingJobName=training_job_name) status = response['TrainingJobStatus'] if status == 'Completed': logging.info("Training job ended with status: " + status) break if status == 'Failed': message = response['FailureReason'] logging.info('Training failed with the following error: {}'.format(message)) raise Exception('Training job failed') logging.info("Training job is still in status: " + status) time.sleep(30) def get_model_artifacts_from_job(client, job_name): info = client.describe_training_job(TrainingJobName=job_name) model_artifact_url = info['ModelArtifacts']['S3ModelArtifacts'] return model_artifact_url def get_image_from_job(client, job_name): info = client.describe_training_job(TrainingJobName=job_name) try: image = info['AlgorithmSpecification']['TrainingImage'] except: algorithm_name = info['AlgorithmSpecification']['AlgorithmName'] image = client.describe_algorithm(AlgorithmName=algorithm_name)['TrainingSpecification']['TrainingImage'] return image def create_model(client, args): ### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model with open('/app/common/model.template.yaml', 'r') as f: request = yaml.safe_load(f) request['ModelName'] = args['model_name'] request['PrimaryContainer']['Environment'] = args['environment'] if args['secondary_containers']: request['Containers'] = args['secondary_containers'] request.pop('PrimaryContainer') else: request.pop('Containers') ### Update primary container and handle input errors if args['container_host_name']: request['PrimaryContainer']['ContainerHostname'] = args['container_host_name'] else: request['PrimaryContainer'].pop('ContainerHostname') if (args['image'] or args['model_artifact_url']) and args['model_package']: logging.error("Please specify an image AND model artifact url, OR a model package name.") raise Exception("Could not make create model request.") elif args['model_package']: request['PrimaryContainer']['ModelPackageName'] = args['model_package'] request['PrimaryContainer'].pop('Image') request['PrimaryContainer'].pop('ModelDataUrl') else: if args['image'] and args['model_artifact_url']: request['PrimaryContainer']['Image'] = args['image'] request['PrimaryContainer']['ModelDataUrl'] = args['model_artifact_url'] request['PrimaryContainer'].pop('ModelPackageName') else: logging.error("Please specify an image AND model artifact url.") raise Exception("Could not make create model request.") request['ExecutionRoleArn'] = args['role'] request['EnableNetworkIsolation'] = args['network_isolation'] ### Update or pop VPC configs if args['vpc_security_group_ids'] and args['vpc_subnets']: request['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']] request['VpcConfig']['Subnets'] = [args['vpc_subnets']] else: request.pop('VpcConfig') ### Update tags for key, val in args['tags'].items(): request['Tags'].append({'Key': key, 'Value': val}) create_model_response = client.create_model(**request) logging.info("Model Config Arn: " + create_model_response['ModelArn']) return create_model_response['ModelArn'] def deploy_model(client, args): endpoint_config_name = create_endpoint_config(client, args) endpoint_name = create_endpoint(client, args['region'], args['endpoint_name'], endpoint_config_name, args['endpoint_tags']) return endpoint_name def create_endpoint_config(client, args): ### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint_config with open('/app/common/endpoint_config.template.yaml', 'r') as f: request = yaml.safe_load(f) endpoint_config_name = args['endpoint_config_name'] if args['endpoint_config_name'] else 'EndpointConfig' + args['model_name_1'][args['model_name_1'].index('-'):] request['EndpointConfigName'] = endpoint_config_name if args['resource_encryption_key']: request['KmsKeyId'] = args['resource_encryption_key'] else: request.pop('KmsKeyId') if not args['model_name_1']: logging.error("Must specify at least one model (model name) to host.") raise Exception("Could not create endpoint config.") for i in range(len(request['ProductionVariants']), 0, -1): if args['model_name_' + str(i)]: request['ProductionVariants'][i-1]['ModelName'] = args['model_name_' + str(i)] if args['variant_name_' + str(i)]: request['ProductionVariants'][i-1]['VariantName'] = args['variant_name_' + str(i)] if args['initial_instance_count_' + str(i)]: request['ProductionVariants'][i-1]['InitialInstanceCount'] = args['initial_instance_count_' + str(i)] if args['instance_type_' + str(i)]: request['ProductionVariants'][i-1]['InstanceType'] = args['instance_type_' + str(i)] if args['initial_variant_weight_' + str(i)]: request['ProductionVariants'][i-1]['InitialVariantWeight'] = args['initial_variant_weight_' + str(i)] if args['accelerator_type_' + str(i)]: request['ProductionVariants'][i-1]['AcceleratorType'] = args['accelerator_type_' + str(i)] else: request['ProductionVariants'][i-1].pop('AcceleratorType') else: request['ProductionVariants'].pop(i-1) ### Update tags for key, val in args['endpoint_config_tags'].items(): request['Tags'].append({'Key': key, 'Value': val}) try: create_endpoint_config_response = client.create_endpoint_config(**request) logging.info("Endpoint configuration in SageMaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/endpointConfig/{}" .format(args['region'], args['region'], endpoint_config_name)) logging.info("Endpoint Config Arn: " + create_endpoint_config_response['EndpointConfigArn']) return endpoint_config_name except ClientError as e: raise Exception(e.response['Error']['Message']) def create_endpoint(client, region, endpoint_name, endpoint_config_name, endpoint_tags): ### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint endpoint_name = endpoint_name if endpoint_name else 'Endpoint' + endpoint_config_name[endpoint_config_name.index('-'):] ### Update tags tags=[] for key, val in endpoint_tags.items(): tags.append({'Key': key, 'Value': val}) try: create_endpoint_response = client.create_endpoint( EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name, Tags=tags) logging.info("Created endpoint with name: " + endpoint_name) logging.info("Endpoint in SageMaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/endpoints/{}" .format(region, region, endpoint_name)) logging.info("CloudWatch logs: https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/Endpoints/{};streamFilter=typeLogStreamPrefix" .format(region, region, endpoint_name)) return endpoint_name except ClientError as e: raise Exception(e.response['Error']['Message']) def wait_for_endpoint_creation(client, endpoint_name): status = client.describe_endpoint(EndpointName=endpoint_name)['EndpointStatus'] logging.info("Status: " + status) try: client.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name) finally: resp = client.describe_endpoint(EndpointName=endpoint_name) status = resp['EndpointStatus'] logging.info("Endpoint Arn: " + resp['EndpointArn']) logging.info("Create endpoint ended with status: " + status) if status != 'InService': message = client.describe_endpoint(EndpointName=endpoint_name)['FailureReason'] logging.info('Create endpoint failed with the following error: {}'.format(message)) raise Exception('Endpoint creation did not succeed') def create_transform_job_request(args): ### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_transform_job with open('/app/common/transform.template.yaml', 'r') as f: request = yaml.safe_load(f) job_name = args['job_name'] if args['job_name'] else 'BatchTransform' + args['model_name'][args['model_name'].index('-'):] request['TransformJobName'] = job_name request['ModelName'] = args['model_name'] if args['max_concurrent']: request['MaxConcurrentTransforms'] = args['max_concurrent'] if args['max_payload'] or args['max_payload'] == 0: request['MaxPayloadInMB'] = args['max_payload'] if args['batch_strategy']: request['BatchStrategy'] = args['batch_strategy'] else: request.pop('BatchStrategy') request['Environment'] = args['environment'] if args['data_type']: request['TransformInput']['DataSource']['S3DataSource']['S3DataType'] = args['data_type'] request['TransformInput']['DataSource']['S3DataSource']['S3Uri'] = args['input_location'] request['TransformInput']['ContentType'] = args['content_type'] if args['compression_type']: request['TransformInput']['CompressionType'] = args['compression_type'] if args['split_type']: request['TransformInput']['SplitType'] = args['split_type'] request['TransformOutput']['S3OutputPath'] = args['output_location'] request['TransformOutput']['Accept'] = args['accept'] request['TransformOutput']['KmsKeyId'] = args['output_encryption_key'] if args['assemble_with']: request['TransformOutput']['AssembleWith'] = args['assemble_with'] else: request['TransformOutput'].pop('AssembleWith') request['TransformResources']['InstanceType'] = args['instance_type'] request['TransformResources']['InstanceCount'] = args['instance_count'] request['TransformResources']['VolumeKmsKeyId'] = args['resource_encryption_key'] request['DataProcessing']['InputFilter'] = args['input_filter'] request['DataProcessing']['OutputFilter'] = args['output_filter'] if args['join_source']: request['DataProcessing']['JoinSource'] = args['join_source'] ### Update tags if not args['tags'] is None: for key, val in args['tags'].items(): request['Tags'].append({'Key': key, 'Value': val}) return request def create_transform_job(client, args): request = create_transform_job_request(args) try: client.create_transform_job(**request) batch_job_name = request['TransformJobName'] logging.info("Created Transform Job with name: " + batch_job_name) logging.info("Transform job in SageMaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/transform-jobs/{}" .format(args['region'], args['region'], batch_job_name)) logging.info("CloudWatch logs: https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TransformJobs;prefix={};streamFilter=typeLogStreamPrefix" .format(args['region'], args['region'], batch_job_name)) return batch_job_name except ClientError as e: raise Exception(e.response['Error']['Message']) def wait_for_transform_job(client, batch_job_name): ### Wait until the job finishes while(True): response = client.describe_transform_job(TransformJobName=batch_job_name) status = response['TransformJobStatus'] if status == 'Completed': logging.info("Transform job ended with status: " + status) break if status == 'Failed': message = response['FailureReason'] logging.info('Transform failed with the following error: {}'.format(message)) raise Exception('Transform job failed') logging.info("Transform job is still in status: " + status) time.sleep(30) def create_hyperparameter_tuning_job_request(args): ### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job with open('/app/common/hpo.template.yaml', 'r') as f: request = yaml.safe_load(f) ### Create a hyperparameter tuning job request['HyperParameterTuningJobName'] = args['job_name'] if args['job_name'] else "HPOJob-" + strftime("%Y%m%d%H%M%S", gmtime()) + '-' + id_generator() request['HyperParameterTuningJobConfig']['Strategy'] = args['strategy'] request['HyperParameterTuningJobConfig']['HyperParameterTuningJobObjective']['Type'] = args['metric_type'] request['HyperParameterTuningJobConfig']['HyperParameterTuningJobObjective']['MetricName'] = args['metric_name'] request['HyperParameterTuningJobConfig']['ResourceLimits']['MaxNumberOfTrainingJobs'] = args['max_num_jobs'] request['HyperParameterTuningJobConfig']['ResourceLimits']['MaxParallelTrainingJobs'] = args['max_parallel_jobs'] request['HyperParameterTuningJobConfig']['ParameterRanges']['IntegerParameterRanges'] = args['integer_parameters'] request['HyperParameterTuningJobConfig']['ParameterRanges']['ContinuousParameterRanges'] = args['continuous_parameters'] request['HyperParameterTuningJobConfig']['ParameterRanges']['CategoricalParameterRanges'] = args['categorical_parameters'] request['HyperParameterTuningJobConfig']['TrainingJobEarlyStoppingType'] = args['early_stopping_type'] request['TrainingJobDefinition']['StaticHyperParameters'] = args['static_parameters'] request['TrainingJobDefinition']['AlgorithmSpecification']['TrainingInputMode'] = args['training_input_mode'] ### Update training image (for BYOC) or algorithm resource name if not args['image'] and not args['algorithm_name']: logging.error('Please specify training image or algorithm name.') raise Exception('Could not create job request') if args['image'] and args['algorithm_name']: logging.error('Both image and algorithm name inputted, only one should be specified. Proceeding with image.') if args['image']: request['TrainingJobDefinition']['AlgorithmSpecification']['TrainingImage'] = args['image'] request['TrainingJobDefinition']['AlgorithmSpecification'].pop('AlgorithmName') else: # TODO: Adjust this implementation to account for custom algorithm resources names that are the same as built-in algorithm names algo_name = args['algorithm_name'].lower().strip() if algo_name in built_in_algos.keys(): request['TrainingJobDefinition']['AlgorithmSpecification']['TrainingImage'] = get_image_uri(args['region'], built_in_algos[algo_name]) request['TrainingJobDefinition']['AlgorithmSpecification'].pop('AlgorithmName') logging.warning('Algorithm name is found as an Amazon built-in algorithm. Using built-in algorithm.') # To give the user more leeway for built-in algorithm name inputs elif algo_name in built_in_algos.values(): request['TrainingJobDefinition']['AlgorithmSpecification']['TrainingImage'] = get_image_uri(args['region'], algo_name) request['TrainingJobDefinition']['AlgorithmSpecification'].pop('AlgorithmName') logging.warning('Algorithm name is found as an Amazon built-in algorithm. Using built-in algorithm.') else: request['TrainingJobDefinition']['AlgorithmSpecification']['AlgorithmName'] = args['algorithm_name'] request['TrainingJobDefinition']['AlgorithmSpecification'].pop('TrainingImage') ### Update metric definitions if args['metric_definitions']: for key, val in args['metric_definitions'].items(): request['TrainingJobDefinition']['AlgorithmSpecification']['MetricDefinitions'].append({'Name': key, 'Regex': val}) else: request['TrainingJobDefinition']['AlgorithmSpecification'].pop('MetricDefinitions') ### Update or pop VPC configs if args['vpc_security_group_ids'] and args['vpc_subnets']: request['TrainingJobDefinition']['VpcConfig']['SecurityGroupIds'] = [args['vpc_security_group_ids']] request['TrainingJobDefinition']['VpcConfig']['Subnets'] = [args['vpc_subnets']] else: request['TrainingJobDefinition'].pop('VpcConfig') ### Update input channels, must have at least one specified if len(args['channels']) > 0: request['TrainingJobDefinition']['InputDataConfig'] = args['channels'] # Max number of input channels/data locations is 20, but currently only 8 data location parameters are exposed separately. # Source: Input data configuration description in the SageMaker create hyperparameter tuning job form for i in range(1, len(args['channels']) + 1): if args['data_location_' + str(i)]: request['TrainingJobDefinition']['InputDataConfig'][i-1]['DataSource']['S3DataSource']['S3Uri'] = args['data_location_' + str(i)] else: logging.error("Must specify at least one input channel.") raise Exception('Could not make job request') request['TrainingJobDefinition']['OutputDataConfig']['S3OutputPath'] = args['output_location'] request['TrainingJobDefinition']['OutputDataConfig']['KmsKeyId'] = args['output_encryption_key'] request['TrainingJobDefinition']['ResourceConfig']['InstanceType'] = args['instance_type'] request['TrainingJobDefinition']['ResourceConfig']['VolumeKmsKeyId'] = args['resource_encryption_key'] request['TrainingJobDefinition']['EnableNetworkIsolation'] = args['network_isolation'] request['TrainingJobDefinition']['EnableInterContainerTrafficEncryption'] = args['traffic_encryption'] request['TrainingJobDefinition']['RoleArn'] = args['role'] ### Update InstanceCount, VolumeSizeInGB, and MaxRuntimeInSeconds if input is non-empty and > 0, otherwise use default values if args['instance_count']: request['TrainingJobDefinition']['ResourceConfig']['InstanceCount'] = args['instance_count'] if args['volume_size']: request['TrainingJobDefinition']['ResourceConfig']['VolumeSizeInGB'] = args['volume_size'] if args['max_run_time']: request['TrainingJobDefinition']['StoppingCondition']['MaxRuntimeInSeconds'] = args['max_run_time'] ### Update or pop warm start configs if args['warm_start_type'] and args['parent_hpo_jobs']: request['WarmStartConfig']['WarmStartType'] = args['warm_start_type'] parent_jobs = [n.strip() for n in args['parent_hpo_jobs'].split(',')] for i in range(len(parent_jobs)): request['WarmStartConfig']['ParentHyperParameterTuningJobs'].append({'HyperParameterTuningJobName': parent_jobs[i]}) else: if args['warm_start_type'] or args['parent_hpo_jobs']: if not args['warm_start_type']: logging.error('Must specify warm start type as either "IdenticalDataAndAlgorithm" or "TransferLearning".') if not args['parent_hpo_jobs']: logging.error("Must specify at least one parent hyperparameter tuning job") raise Exception('Could not make job request') request.pop('WarmStartConfig') ### Update tags for key, val in args['tags'].items(): request['Tags'].append({'Key': key, 'Value': val}) return request def create_hyperparameter_tuning_job(client, args): """Create a Sagemaker HPO job""" request = create_hyperparameter_tuning_job_request(args) try: job_arn = client.create_hyper_parameter_tuning_job(**request) hpo_job_name = request['HyperParameterTuningJobName'] logging.info("Created Hyperparameter Training Job with name: " + hpo_job_name) logging.info("HPO job in SageMaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/hyper-tuning-jobs/{}" .format(args['region'], args['region'], hpo_job_name)) logging.info("CloudWatch logs: https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={};streamFilter=typeLogStreamPrefix" .format(args['region'], args['region'], hpo_job_name)) return hpo_job_name except ClientError as e: raise Exception(e.response['Error']['Message']) def wait_for_hyperparameter_training_job(client, hpo_job_name): ### Wait until the job finishes while(True): response = client.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=hpo_job_name) status = response['HyperParameterTuningJobStatus'] if status == 'Completed': logging.info("Hyperparameter tuning job ended with status: " + status) break if status == 'Failed': message = response['FailureReason'] logging.error('Hyperparameter tuning failed with the following error: {}'.format(message)) raise Exception('Hyperparameter tuning job failed') logging.info("Hyperparameter tuning job is still in status: " + status) time.sleep(30) def get_best_training_job_and_hyperparameters(client, hpo_job_name): ### Get and return best training job and its hyperparameters, without the objective metric info = client.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=hpo_job_name) best_job = info['BestTrainingJob']['TrainingJobName'] training_info = client.describe_training_job(TrainingJobName=best_job) train_hyperparameters = training_info['HyperParameters'] train_hyperparameters.pop('_tuning_objective_metric') return best_job, train_hyperparameters def create_workteam(client, args): ### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_workteam """Create a workteam""" with open('/app/common/workteam.template.yaml', 'r') as f: request = yaml.safe_load(f) request['WorkteamName'] = args['team_name'] request['Description'] = args['description'] if args['sns_topic']: request['NotificationConfiguration']['NotificationTopicArn'] = args['sns_topic'] else: request.pop('NotificationConfiguration') for group in [n.strip() for n in args['user_groups'].split(',')]: request['MemberDefinitions'].append({'CognitoMemberDefinition': {'UserPool': args['user_pool'], 'UserGroup': group, 'ClientId': args['client_id']}}) for key, val in args['tags'].items(): request['Tags'].append({'Key': key, 'Value': val}) try: response = client.create_workteam(**request) portal = client.describe_workteam(WorkteamName=args['team_name'])['Workteam']['SubDomain'] logging.info("Labeling portal: " + portal) return response['WorkteamArn'] except ClientError as e: raise Exception(e.response['Error']['Message']) def create_labeling_job_request(args): ### Documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_labeling_job with open('/app/common/gt.template.yaml', 'r') as f: request = yaml.safe_load(f) # Mapping are extracted from ARNs listed in https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_labeling_job algorithm_arn_map = {'us-west-2': '081040173940', 'us-east-1': '432418664414', 'us-east-2': '266458841044', 'eu-west-1': '568282634449', 'ap-northeast-1': '477331159723', 'ap-southeast-1': '454466003867'} task_map = {'bounding box': 'BoundingBox', 'image classification': 'ImageMultiClass', 'semantic segmentation': 'SemanticSegmentation', 'text classification': 'TextMultiClass'} auto_labeling_map = {'bounding box': 'object-detection', 'image classification': 'image-classification', 'text classification': 'text-classification'} task = args['task_type'].lower() request['LabelingJobName'] = args['job_name'] if args['job_name'] else "LabelingJob-" + strftime("%Y%m%d%H%M%S", gmtime()) + '-' + id_generator() if args['label_attribute_name']: name_check = args['label_attribute_name'].split('-')[-1] if task == 'semantic segmentation' and name_check == 'ref' or task != 'semantic segmentation' and name_check != 'metadata' and name_check != 'ref': request['LabelAttributeName'] = args['label_attribute_name'] else: logging.error('Invalid label attribute name. If task type is semantic segmentation, name must end in "-ref". Else, name must not end in "-ref" or "-metadata".') else: request['LabelAttributeName'] = args['job_name'] request['InputConfig']['DataSource']['S3DataSource']['ManifestS3Uri'] = args['manifest_location'] request['OutputConfig']['S3OutputPath'] = args['output_location'] request['OutputConfig']['KmsKeyId'] = args['output_encryption_key'] request['RoleArn'] = args['role'] request['LabelCategoryConfigS3Uri'] = args['label_category_config'] ### Update or pop stopping conditions if not args['max_human_labeled_objects'] and not args['max_percent_objects']: request.pop('StoppingConditions') else: if args['max_human_labeled_objects']: request['StoppingConditions']['MaxHumanLabeledObjectCount'] = args['max_human_labeled_objects'] else: request['StoppingConditions'].pop('MaxHumanLabeledObjectCount') if args['max_percent_objects']: request['StoppingConditions']['MaxPercentageOfInputDatasetLabeled'] = args['max_percent_objects'] else: request['StoppingConditions'].pop('MaxPercentageOfInputDatasetLabeled') ### Update or pop automatic labeling configs if args['enable_auto_labeling']: if task == 'image classification' or task == 'bounding box' or task == 'text classification': labeling_algorithm_arn = 'arn:aws:sagemaker:{}:027400017018:labeling-job-algorithm-specification/image-classification'.format(args['region'], auto_labeling_map[task]) request['LabelingJobAlgorithmsConfig']['LabelingJobAlgorithmSpecificationArn'] = labeling_algorithm_arn if args['initial_model_arn']: request['LabelingJobAlgorithmsConfig']['InitialActiveLearningModelArn'] = args['initial_model_arn'] else: request['LabelingJobAlgorithmsConfig'].pop('InitialActiveLearningModelArn') request['LabelingJobAlgorithmsConfig']['LabelingJobResourceConfig']['VolumeKmsKeyId'] = args['resource_encryption_key'] else: logging.error("Automated data labeling not available for semantic segmentation or custom algorithms. Proceeding without automated data labeling.") else: request.pop('LabelingJobAlgorithmsConfig') ### Update pre-human and annotation consolidation task lambda functions if task == 'image classification' or task == 'bounding box' or task == 'text classification' or task == 'semantic segmentation': prehuman_arn = 'arn:aws:lambda:{}:{}:function:PRE-{}'.format(args['region'], algorithm_arn_map[args['region']], task_map[task]) acs_arn = 'arn:aws:lambda:{}:{}:function:ACS-{}'.format(args['region'], algorithm_arn_map[args['region']], task_map[task]) request['HumanTaskConfig']['PreHumanTaskLambdaArn'] = prehuman_arn request['HumanTaskConfig']['AnnotationConsolidationConfig']['AnnotationConsolidationLambdaArn'] = acs_arn elif task == 'custom' or task == '': if args['pre_human_task_function'] and args['post_human_task_function']: request['HumanTaskConfig']['PreHumanTaskLambdaArn'] = args['pre_human_task_function'] request['HumanTaskConfig']['AnnotationConsolidationConfig']['AnnotationConsolidationLambdaArn'] = args['post_human_task_function'] else: logging.error("Must specify pre-human task lambda arn and annotation consolidation post-human task lambda arn.") else: logging.error("Task type must be Bounding Box, Image Classification, Semantic Segmentation, Text Classification, or Custom.") request['HumanTaskConfig']['UiConfig']['UiTemplateS3Uri'] = args['ui_template'] request['HumanTaskConfig']['TaskTitle'] = args['title'] request['HumanTaskConfig']['TaskDescription'] = args['description'] request['HumanTaskConfig']['NumberOfHumanWorkersPerDataObject'] = args['num_workers_per_object'] request['HumanTaskConfig']['TaskTimeLimitInSeconds'] = args['time_limit'] if args['task_availibility']: request['HumanTaskConfig']['TaskAvailabilityLifetimeInSeconds'] = args['task_availibility'] else: request['HumanTaskConfig'].pop('TaskAvailabilityLifetimeInSeconds') if args['max_concurrent_tasks']: request['HumanTaskConfig']['MaxConcurrentTaskCount'] = args['max_concurrent_tasks'] else: request['HumanTaskConfig'].pop('MaxConcurrentTaskCount') if args['task_keywords']: for word in [n.strip() for n in args['task_keywords'].split(',')]: request['HumanTaskConfig']['TaskKeywords'].append(word) else: request['HumanTaskConfig'].pop('TaskKeywords') ### Update worker configurations if args['worker_type'].lower() == 'public': if args['no_adult_content']: request['InputConfig']['DataAttributes']['ContentClassifiers'].append('FreeOfAdultContent') if args['no_ppi']: request['InputConfig']['DataAttributes']['ContentClassifiers'].append('FreeOfPersonallyIdentifiableInformation') request['HumanTaskConfig']['WorkteamArn'] = 'arn:aws:sagemaker:{}:394669845002:workteam/public-crowd/default'.format(args['region']) dollars = int(args['workforce_task_price']) cents = int(100 * (args['workforce_task_price'] - dollars)) tenth_of_cents = int((args['workforce_task_price'] * 1000) - (dollars * 1000) - (cents * 10)) request['HumanTaskConfig']['PublicWorkforceTaskPrice']['AmountInUsd']['Dollars'] = dollars request['HumanTaskConfig']['PublicWorkforceTaskPrice']['AmountInUsd']['Cents'] = cents request['HumanTaskConfig']['PublicWorkforceTaskPrice']['AmountInUsd']['TenthFractionsOfACent'] = tenth_of_cents else: request['InputConfig'].pop('DataAttributes') request['HumanTaskConfig']['WorkteamArn'] = args['workteam_arn'] request['HumanTaskConfig'].pop('PublicWorkforceTaskPrice') for key, val in args['tags'].items(): request['Tags'].append({'Key': key, 'Value': val}) return request def create_labeling_job(client, args): """Create a SageMaker Ground Truth job""" request = create_labeling_job_request(args) try: client.create_labeling_job(**request) gt_job_name = request['LabelingJobName'] logging.info("Created Ground Truth Labeling Job with name: " + gt_job_name) logging.info("Ground Truth job in SageMaker: https://{}.console.aws.amazon.com/sagemaker/groundtruth?region={}#/labeling-jobs/details/{}" .format(args['region'], args['region'], gt_job_name)) return gt_job_name except ClientError as e: raise Exception(e.response['Error']['Message']) def wait_for_labeling_job(client, labeling_job_name): ### Wait until the job finishes status = 'InProgress' while(status == 'InProgress'): response = client.describe_labeling_job(LabelingJobName=labeling_job_name) status = response['LabelingJobStatus'] if status == 'Failed': message = response['FailureReason'] logging.info('Labeling failed with the following error: {}'.format(message)) raise Exception('Labeling job failed') logging.info("Labeling job is still in status: " + status) time.sleep(30) if status == 'Completed': logging.info("Labeling job ended with status: " + status) else: raise Exception('Labeling job stopped') def get_labeling_job_outputs(client, labeling_job_name, auto_labeling): ### Get and return labeling job outputs info = client.describe_labeling_job(LabelingJobName=labeling_job_name) output_manifest = info['LabelingJobOutput']['OutputDatasetS3Uri'] if auto_labeling: active_learning_model_arn = info['LabelingJobOutput']['FinalActiveLearningModelArn'] else: active_learning_model_arn = '' return output_manifest, active_learning_model_arn def id_generator(size=4, chars=string.ascii_uppercase + string.digits): return ''.join(random.choice(chars) for _ in range(size)) def str_to_bool(s): if s.lower().strip() == 'true': return True elif s.lower().strip() == 'false': return False else: raise argparse.ArgumentTypeError('"True" or "False" expected.') def str_to_int(s): if s: return int(s) else: return 0 def str_to_float(s): if s: return float(s) else: return 0.0 def str_to_json_dict(s): if s != '': return json.loads(s) else: return {} def str_to_json_list(s): if s != '': return json.loads(s) else: return []