pipelines/components/aws/sagemaker/common/_utils.py

226 lines
7.5 KiB
Python

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
import subprocess
from time import gmtime, strftime
import time
import string
import random
import json
from urlparse import urlparse
import boto3
from botocore.exceptions import ClientError
from sagemaker import get_execution_role
def get_client(region=None):
"""Builds a client to the AWS SageMaker API."""
client = boto3.client('sagemaker', region_name=region)
return client
def get_sagemaker_role():
return get_execution_role()
def create_training_job(client, image, instance_type, instance_count, volume_size, data_location, output_location, role):
"""Create a Sagemaker training job."""
job_name = 'TrainingJob-' + strftime("%Y%m%d%H%M%S", gmtime()) + '-' + id_generator()
create_training_params = \
{
"AlgorithmSpecification": {
"TrainingImage": image,
"TrainingInputMode": "File"
},
"RoleArn": role,
"OutputDataConfig": {
"S3OutputPath": output_location
},
"ResourceConfig": {
"InstanceCount": instance_count,
"InstanceType": instance_type,
"VolumeSizeInGB": volume_size
},
"TrainingJobName": job_name,
"HyperParameters": {
"k": "10",
"feature_dim": "784",
"mini_batch_size": "500"
},
"StoppingCondition": {
"MaxRuntimeInSeconds": 60 * 60
},
"InputDataConfig": [
{
"ChannelName": "train",
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3Uri": data_location,
"S3DataDistributionType": "FullyReplicated"
}
},
"CompressionType": "None",
"RecordWrapperType": "None"
}
]
}
client.create_training_job(**create_training_params)
return job_name
def deploy_model(client, model_name):
endpoint_config_name = create_endpoint_config(client, model_name)
endpoint_name = create_endpoint(client, endpoint_config_name)
return endpoint_name
def get_model_artifacts_from_job(client, job_name):
info = client.describe_training_job(TrainingJobName=job_name)
model_artifact_url = info['ModelArtifacts']['S3ModelArtifacts']
return model_artifact_url
def create_model(client, model_artifact_url, model_name, image, role):
primary_container = {
'Image': image,
'ModelDataUrl': model_artifact_url
}
create_model_response = client.create_model(
ModelName = model_name,
ExecutionRoleArn = role,
PrimaryContainer = primary_container)
print("Model Config Arn: " + create_model_response['ModelArn'])
return create_model_response['ModelArn']
def create_endpoint_config(client, model_name):
endpoint_config_name = 'EndpointConfig' + model_name[model_name.index('-'):]
print(endpoint_config_name)
create_endpoint_config_response = client.create_endpoint_config(
EndpointConfigName = endpoint_config_name,
ProductionVariants=[{
'InstanceType':'ml.m4.xlarge',
'InitialInstanceCount':1,
'ModelName': model_name,
'VariantName':'AllTraffic'}])
print("Endpoint Config Arn: " + create_endpoint_config_response['EndpointConfigArn'])
return endpoint_config_name
def create_endpoint(client, endpoint_config_name):
endpoint_name = 'Endpoint' + endpoint_config_name[endpoint_config_name.index('-'):]
print(endpoint_name)
create_endpoint_response = client.create_endpoint(
EndpointName=endpoint_name,
EndpointConfigName=endpoint_config_name)
print(create_endpoint_response['EndpointArn'])
resp = client.describe_endpoint(EndpointName=endpoint_name)
status = resp['EndpointStatus']
print("Status: " + status)
try:
client.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name)
finally:
resp = client.describe_endpoint(EndpointName=endpoint_name)
status = resp['EndpointStatus']
print("Arn: " + resp['EndpointArn'])
print("Create endpoint ended with status: " + status)
return endpoint_name
if status != 'InService':
message = client.describe_endpoint(EndpointName=endpoint_name)['FailureReason']
print('Create endpoint failed with the following error: {}'.format(message))
raise Exception('Endpoint creation did not succeed')
def wait_for_training_job(client, job_name):
status = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']
print(status)
try:
client.get_waiter('training_job_completed_or_stopped').wait(TrainingJobName=job_name)
finally:
status = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']
print("Training job ended with status: " + status)
if status == 'Failed':
message = client.describe_training_job(TrainingJobName=job_name)['FailureReason']
print('Training failed with the following error: {}'.format(message))
raise Exception('Training job failed')
def create_transform_job(client, model_name, input_location, output_location):
batch_job_name = 'BatchTransform' + model_name[model_name.index('-'):]
### Create a transform job
request = \
{
"TransformJobName": batch_job_name,
"ModelName": model_name,
"MaxConcurrentTransforms": 4,
"MaxPayloadInMB": 6,
"BatchStrategy": "MultiRecord",
"TransformOutput": {
"S3OutputPath": output_location
},
"TransformInput": {
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3Uri": input_location
}
},
"ContentType": "text/csv",
"SplitType": "Line",
"CompressionType": "None"
},
"TransformResources": {
"InstanceType": "ml.m4.xlarge",
"InstanceCount": 1
}
}
client.create_transform_job(**request)
print("Created Transform job with name: ", batch_job_name)
return batch_job_name
def wait_for_transform_job(client, batch_job_name):
### Wait until the job finishes
while(True):
response = client.describe_transform_job(TransformJobName=batch_job_name)
status = response['TransformJobStatus']
if status == 'Completed':
print("Transform job ended with status: " + status)
break
if status == 'Failed':
message = response['FailureReason']
print('Transform failed with the following error: {}'.format(message))
raise Exception('Transform job failed')
print("Transform job is still in status: " + status)
time.sleep(30)
def print_tranformation_job_result(output_location):
### Fetch the transform output
bucket = urlparse(output_location).netloc
output_key = "{}/valid_data.csv.out".format(urlparse(output_location).path.lstrip('/'))
s3_client = boto3.client('s3')
s3_client.download_file(bucket, output_key, 'valid-result')
with open('valid-result') as f:
results = f.readlines()
print("Sample transform result: {}".format(results[0]))
def id_generator(size=4, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))