263 lines
8.3 KiB
Python
263 lines
8.3 KiB
Python
import pytest
|
|
import boto3
|
|
import kfp
|
|
import os
|
|
from utils import sagemaker_utils
|
|
import utils
|
|
|
|
from datetime import datetime
|
|
from filelock import FileLock
|
|
from sagemaker import image_uris
|
|
from botocore.config import Config
|
|
|
|
|
|
def pytest_addoption(parser):
|
|
parser.addoption(
|
|
"--region",
|
|
default="us-west-2",
|
|
required=False,
|
|
help="AWS region where test will run",
|
|
)
|
|
parser.addoption(
|
|
"--sagemaker-role-arn",
|
|
required=True,
|
|
help="SageMaker execution IAM role ARN",
|
|
)
|
|
parser.addoption(
|
|
"--robomaker-role-arn",
|
|
required=True,
|
|
help="RoboMaker execution IAM role ARN",
|
|
)
|
|
parser.addoption(
|
|
"--assume-role-arn",
|
|
required=True,
|
|
help="The ARN of a role which the assume role tests will assume to access SageMaker.",
|
|
)
|
|
parser.addoption(
|
|
"--s3-data-bucket",
|
|
required=True,
|
|
help="Regional S3 bucket name in which test data is hosted",
|
|
)
|
|
parser.addoption(
|
|
"--minio-service-port",
|
|
default="9000",
|
|
required=False,
|
|
help="Localhost port to which minio service is mapped to",
|
|
)
|
|
parser.addoption(
|
|
"--kfp-namespace",
|
|
default="kubeflow",
|
|
required=False,
|
|
help="Cluster namespace where kubeflow pipelines is installed",
|
|
)
|
|
parser.addoption(
|
|
"--fsx-subnet",
|
|
required=False,
|
|
help="The subnet in which FSx is installed",
|
|
default="",
|
|
)
|
|
parser.addoption(
|
|
"--fsx-security-group",
|
|
required=False,
|
|
help="The security group SageMaker should use when running the FSx test",
|
|
default="",
|
|
)
|
|
parser.addoption(
|
|
"--fsx-id",
|
|
required=False,
|
|
help="The file system ID of the FSx instance",
|
|
default="",
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def region(request):
|
|
os.environ["AWS_REGION"] = request.config.getoption("--region")
|
|
return request.config.getoption("--region")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def assume_role_arn(request):
|
|
os.environ["ASSUME_ROLE_ARN"] = request.config.getoption("--assume-role-arn")
|
|
return request.config.getoption("--assume-role-arn")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def sagemaker_role_arn(request):
|
|
os.environ["SAGEMAKER_ROLE_ARN"] = request.config.getoption("--sagemaker-role-arn")
|
|
return request.config.getoption("--sagemaker-role-arn")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def robomaker_role_arn(request):
|
|
os.environ["ROBOMAKER_ROLE_ARN"] = request.config.getoption("--robomaker-role-arn")
|
|
return request.config.getoption("--robomaker-role-arn")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def s3_data_bucket(request):
|
|
os.environ["S3_DATA_BUCKET"] = request.config.getoption("--s3-data-bucket")
|
|
return request.config.getoption("--s3-data-bucket")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def minio_service_port(request):
|
|
os.environ["MINIO_SERVICE_PORT"] = request.config.getoption("--minio-service-port")
|
|
return request.config.getoption("--minio-service-port")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def kfp_namespace(request):
|
|
os.environ["NAMESPACE"] = request.config.getoption("--kfp-namespace")
|
|
return request.config.getoption("--kfp-namespace")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def fsx_subnet(request):
|
|
os.environ["FSX_SUBNET"] = request.config.getoption("--fsx-subnet")
|
|
return request.config.getoption("--fsx-subnet")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def fsx_security_group(request):
|
|
os.environ["FSX_SECURITY_GROUP"] = request.config.getoption("--fsx-security-group")
|
|
return request.config.getoption("--fsx-security-group")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def fsx_id(request):
|
|
os.environ["FSX_ID"] = request.config.getoption("--fsx-id")
|
|
return request.config.getoption("--fsx-id")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def boto3_session(region):
|
|
return boto3.Session(region_name=region)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def sagemaker_client(boto3_session):
|
|
return boto3_session.client(service_name="sagemaker")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def robomaker_client(boto3_session):
|
|
return boto3_session.client(service_name="robomaker")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def s3_client(boto3_session):
|
|
return boto3_session.client(service_name="s3")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def kfp_client():
|
|
kfp_installed_namespace = utils.get_kfp_namespace()
|
|
return kfp.Client(namespace=kfp_installed_namespace)
|
|
|
|
|
|
def get_experiment_id(kfp_client):
|
|
exp_name = datetime.now().strftime("%Y-%m-%d-%H-%M")
|
|
try:
|
|
experiment = kfp_client.get_experiment(experiment_name=exp_name)
|
|
except ValueError:
|
|
experiment = kfp_client.create_experiment(name=exp_name)
|
|
return experiment.id
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def experiment_id(kfp_client, tmp_path_factory, worker_id):
|
|
if worker_id == "master":
|
|
return get_experiment_id(kfp_client)
|
|
|
|
# Locking taking as an example from
|
|
# https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once
|
|
# get the temp directory shared by all workers
|
|
root_tmp_dir = tmp_path_factory.getbasetemp().parent
|
|
|
|
fn = root_tmp_dir / "experiment_id"
|
|
with FileLock(str(fn) + ".lock"):
|
|
if fn.is_file():
|
|
data = fn.read_text()
|
|
else:
|
|
data = get_experiment_id(kfp_client)
|
|
fn.write_text(data)
|
|
return data
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
# Deploy endpoint for testing model monitoring
|
|
def deploy_endpoint(sagemaker_client, s3_data_bucket, sagemaker_role_arn, region):
|
|
model_name = "model-monitor-" + utils.generate_random_string(5) + "-model"
|
|
endpoint_name = "model-monitor-" + utils.generate_random_string(5) + "-endpoint-v2"
|
|
endpoint_config_name = (
|
|
"model-monitor-" + utils.generate_random_string(5) + "-endpoint-config"
|
|
)
|
|
|
|
# create sagemaker model
|
|
# TODO upgrade to a newer xgboost version
|
|
image_uri = image_uris.retrieve("xgboost", region, "0.90-1")
|
|
create_model_api_response = sagemaker_client.create_model(
|
|
ModelName=model_name,
|
|
PrimaryContainer={
|
|
"Image": image_uri,
|
|
"ModelDataUrl": f"s3://{s3_data_bucket}/model-monitor/xgb-churn-prediction-model.tar.gz",
|
|
"Environment": {},
|
|
},
|
|
ExecutionRoleArn=sagemaker_role_arn,
|
|
)
|
|
|
|
# create sagemaker endpoint config
|
|
create_endpoint_config_api_response = sagemaker_client.create_endpoint_config(
|
|
EndpointConfigName=endpoint_config_name,
|
|
ProductionVariants=[
|
|
{
|
|
"VariantName": "variant-1",
|
|
"ModelName": model_name,
|
|
"InitialInstanceCount": 1,
|
|
"InstanceType": "ml.m5.large",
|
|
},
|
|
],
|
|
DataCaptureConfig={
|
|
"EnableCapture": True,
|
|
"CaptureOptions": [{"CaptureMode": "Input"}, {"CaptureMode": "Output"}],
|
|
"InitialSamplingPercentage": 100,
|
|
"DestinationS3Uri": f"s3://{s3_data_bucket}/model-monitor/datacapture",
|
|
},
|
|
)
|
|
|
|
# create sagemaker endpoint
|
|
create_endpoint_api_response = sagemaker_client.create_endpoint(
|
|
EndpointName=endpoint_name,
|
|
EndpointConfigName=endpoint_config_name,
|
|
)
|
|
|
|
try:
|
|
sagemaker_client.get_waiter("endpoint_in_service").wait(
|
|
EndpointName=endpoint_name
|
|
)
|
|
finally:
|
|
resp = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
|
|
endpoint_status = resp["EndpointStatus"]
|
|
endpoint_arn = resp["EndpointArn"]
|
|
print(f"Deployed endpoint {endpoint_arn}, ended with status {endpoint_status}")
|
|
|
|
if endpoint_status != "InService":
|
|
message = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)[
|
|
"FailureReason"
|
|
]
|
|
print(
|
|
"Endpoint deployment failed with the following error: {}".format(
|
|
message
|
|
)
|
|
)
|
|
raise Exception("Endpoint deployment failed")
|
|
|
|
yield endpoint_name
|
|
|
|
# delete model and endpoint config
|
|
print("deleting endpoint.................")
|
|
sagemaker_utils.delete_endpoint(sagemaker_client, endpoint_name)
|
|
sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
|
|
sagemaker_client.delete_model(ModelName=model_name)
|