pipelines/components/aws/sagemaker/tests/integration_tests/utils/sagemaker_utils.py

141 lines
4.2 KiB
Python

import logging
import re
from datetime import datetime
from time import sleep
import os
import pickle
import gzip
import io
import numpy
import json
from utils import get_s3_data_bucket
def describe_training_job(client, training_job_name):
return client.describe_training_job(TrainingJobName=training_job_name)
def describe_model(client, model_name):
return client.describe_model(ModelName=model_name)
def describe_endpoint(client, endpoint_name):
return client.describe_endpoint(EndpointName=endpoint_name)
def list_endpoints(client, name_contains):
return client.list_endpoints(NameContains=name_contains)
def describe_endpoint_config(client, endpoint_config_name):
return client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
def delete_endpoint(client, endpoint_name):
client.delete_endpoint(EndpointName=endpoint_name)
waiter = client.get_waiter("endpoint_deleted")
waiter.wait(EndpointName=endpoint_name)
def describe_monitoring_schedule(client, monitoring_schedule_name):
return client.describe_monitoring_schedule(
MonitoringScheduleName=monitoring_schedule_name
)
def describe_data_quality_job_definition(client, job_definition_name):
return client.describe_data_quality_job_definition(
JobDefinitionName=job_definition_name
)
def describe_hpo_job(client, job_name):
return client.describe_hyper_parameter_tuning_job(
HyperParameterTuningJobName=job_name
)
def describe_transform_job(client, job_name):
return client.describe_transform_job(TransformJobName=job_name)
def describe_workteam(client, workteam_name):
return client.describe_workteam(WorkteamName=workteam_name)
def list_workteams(client):
return client.list_workteams()
def get_cognito_member_definitions(client):
# This is one way to get the user_pool and client_id for the SageMaker Workforce.
# An alternative would be to take these values as user input via params or a config file.
# The current mechanism expects that there exists atleast one private workteam in the region.
default_workteam = list_workteams(client)["Workteams"][0]["MemberDefinitions"][0][
"CognitoMemberDefinition"
]
return (
default_workteam["UserPool"],
default_workteam["ClientId"],
default_workteam["UserGroup"],
)
def list_labeling_jobs_for_workteam(client, workteam_arn):
return client.list_labeling_jobs_for_workteam(WorkteamArn=workteam_arn)
def describe_labeling_job(client, labeling_job_name):
return client.describe_labeling_job(LabelingJobName=labeling_job_name)
def get_workteam_arn(client, workteam_name):
response = describe_workteam(client, workteam_name)
return response["Workteam"]["WorkteamArn"]
def delete_workteam(client, workteam_name):
client.delete_workteam(WorkteamName=workteam_name)
def stop_labeling_job(client, labeling_job_name):
client.stop_labeling_job(LabelingJobName=labeling_job_name)
def describe_processing_job(client, processing_job_name):
return client.describe_processing_job(ProcessingJobName=processing_job_name)
def run_predict_mnist(boto3_session, endpoint_name, download_dir):
"""https://github.com/awslabs/amazon-sagemaker-
examples/blob/a8c20eeb72dc7d3e94aaaf28be5bf7d7cd5695cb.
/sagemaker-python-sdk/1P_kmeans_lowlevel/kmeans_mnist_lowlevel.ipynb
"""
# Download and load dataset
region = boto3_session.region_name
download_path = os.path.join(download_dir, "mnist.pkl.gz")
boto3_session.resource("s3", region_name=region).Bucket(
get_s3_data_bucket()
).download_file("algorithms/mnist.pkl.gz", download_path)
with gzip.open(download_path, "rb") as f:
train_set, valid_set, test_set = pickle.load(f, encoding="latin1")
# Function to create a csv from numpy array
def np2csv(arr):
csv = io.BytesIO()
numpy.savetxt(csv, arr, delimiter=",", fmt="%g")
return csv.getvalue().decode().rstrip()
# Run prediction on an image
runtime = boto3_session.client("sagemaker-runtime")
payload = np2csv(train_set[0][30:31])
response = runtime.invoke_endpoint(
EndpointName=endpoint_name,
ContentType="text/csv",
Body=payload,
)
return json.loads(response["Body"].read().decode())