pipelines/components/aws/sagemaker/tests/integration_tests/conftest.py

263 lines
8.3 KiB
Python

import pytest
import boto3
import kfp
import os
from utils import sagemaker_utils
import utils
from datetime import datetime
from filelock import FileLock
from sagemaker import image_uris
from botocore.config import Config
def pytest_addoption(parser):
parser.addoption(
"--region",
default="us-west-2",
required=False,
help="AWS region where test will run",
)
parser.addoption(
"--sagemaker-role-arn",
required=True,
help="SageMaker execution IAM role ARN",
)
parser.addoption(
"--robomaker-role-arn",
required=True,
help="RoboMaker execution IAM role ARN",
)
parser.addoption(
"--assume-role-arn",
required=True,
help="The ARN of a role which the assume role tests will assume to access SageMaker.",
)
parser.addoption(
"--s3-data-bucket",
required=True,
help="Regional S3 bucket name in which test data is hosted",
)
parser.addoption(
"--minio-service-port",
default="9000",
required=False,
help="Localhost port to which minio service is mapped to",
)
parser.addoption(
"--kfp-namespace",
default="kubeflow",
required=False,
help="Cluster namespace where kubeflow pipelines is installed",
)
parser.addoption(
"--fsx-subnet",
required=False,
help="The subnet in which FSx is installed",
default="",
)
parser.addoption(
"--fsx-security-group",
required=False,
help="The security group SageMaker should use when running the FSx test",
default="",
)
parser.addoption(
"--fsx-id",
required=False,
help="The file system ID of the FSx instance",
default="",
)
@pytest.fixture(scope="session", autouse=True)
def region(request):
os.environ["AWS_REGION"] = request.config.getoption("--region")
return request.config.getoption("--region")
@pytest.fixture(scope="session", autouse=True)
def assume_role_arn(request):
os.environ["ASSUME_ROLE_ARN"] = request.config.getoption("--assume-role-arn")
return request.config.getoption("--assume-role-arn")
@pytest.fixture(scope="session", autouse=True)
def sagemaker_role_arn(request):
os.environ["SAGEMAKER_ROLE_ARN"] = request.config.getoption("--sagemaker-role-arn")
return request.config.getoption("--sagemaker-role-arn")
@pytest.fixture(scope="session", autouse=True)
def robomaker_role_arn(request):
os.environ["ROBOMAKER_ROLE_ARN"] = request.config.getoption("--robomaker-role-arn")
return request.config.getoption("--robomaker-role-arn")
@pytest.fixture(scope="session", autouse=True)
def s3_data_bucket(request):
os.environ["S3_DATA_BUCKET"] = request.config.getoption("--s3-data-bucket")
return request.config.getoption("--s3-data-bucket")
@pytest.fixture(scope="session", autouse=True)
def minio_service_port(request):
os.environ["MINIO_SERVICE_PORT"] = request.config.getoption("--minio-service-port")
return request.config.getoption("--minio-service-port")
@pytest.fixture(scope="session", autouse=True)
def kfp_namespace(request):
os.environ["NAMESPACE"] = request.config.getoption("--kfp-namespace")
return request.config.getoption("--kfp-namespace")
@pytest.fixture(scope="session", autouse=True)
def fsx_subnet(request):
os.environ["FSX_SUBNET"] = request.config.getoption("--fsx-subnet")
return request.config.getoption("--fsx-subnet")
@pytest.fixture(scope="session", autouse=True)
def fsx_security_group(request):
os.environ["FSX_SECURITY_GROUP"] = request.config.getoption("--fsx-security-group")
return request.config.getoption("--fsx-security-group")
@pytest.fixture(scope="session", autouse=True)
def fsx_id(request):
os.environ["FSX_ID"] = request.config.getoption("--fsx-id")
return request.config.getoption("--fsx-id")
@pytest.fixture(scope="session")
def boto3_session(region):
return boto3.Session(region_name=region)
@pytest.fixture(scope="session")
def sagemaker_client(boto3_session):
return boto3_session.client(service_name="sagemaker")
@pytest.fixture(scope="session")
def robomaker_client(boto3_session):
return boto3_session.client(service_name="robomaker")
@pytest.fixture(scope="session")
def s3_client(boto3_session):
return boto3_session.client(service_name="s3")
@pytest.fixture(scope="session")
def kfp_client():
kfp_installed_namespace = utils.get_kfp_namespace()
return kfp.Client(namespace=kfp_installed_namespace)
def get_experiment_id(kfp_client):
exp_name = datetime.now().strftime("%Y-%m-%d-%H-%M")
try:
experiment = kfp_client.get_experiment(experiment_name=exp_name)
except ValueError:
experiment = kfp_client.create_experiment(name=exp_name)
return experiment.id
@pytest.fixture(scope="session")
def experiment_id(kfp_client, tmp_path_factory, worker_id):
if worker_id == "master":
return get_experiment_id(kfp_client)
# Locking taking as an example from
# https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once
# get the temp directory shared by all workers
root_tmp_dir = tmp_path_factory.getbasetemp().parent
fn = root_tmp_dir / "experiment_id"
with FileLock(str(fn) + ".lock"):
if fn.is_file():
data = fn.read_text()
else:
data = get_experiment_id(kfp_client)
fn.write_text(data)
return data
@pytest.fixture(scope="session")
# Deploy endpoint for testing model monitoring
def deploy_endpoint(sagemaker_client, s3_data_bucket, sagemaker_role_arn, region):
model_name = "model-monitor-" + utils.generate_random_string(5) + "-model"
endpoint_name = "model-monitor-" + utils.generate_random_string(5) + "-endpoint-v2"
endpoint_config_name = (
"model-monitor-" + utils.generate_random_string(5) + "-endpoint-config"
)
# create sagemaker model
# TODO upgrade to a newer xgboost version
image_uri = image_uris.retrieve("xgboost", region, "0.90-1")
create_model_api_response = sagemaker_client.create_model(
ModelName=model_name,
PrimaryContainer={
"Image": image_uri,
"ModelDataUrl": f"s3://{s3_data_bucket}/model-monitor/xgb-churn-prediction-model.tar.gz",
"Environment": {},
},
ExecutionRoleArn=sagemaker_role_arn,
)
# create sagemaker endpoint config
create_endpoint_config_api_response = sagemaker_client.create_endpoint_config(
EndpointConfigName=endpoint_config_name,
ProductionVariants=[
{
"VariantName": "variant-1",
"ModelName": model_name,
"InitialInstanceCount": 1,
"InstanceType": "ml.m5.large",
},
],
DataCaptureConfig={
"EnableCapture": True,
"CaptureOptions": [{"CaptureMode": "Input"}, {"CaptureMode": "Output"}],
"InitialSamplingPercentage": 100,
"DestinationS3Uri": f"s3://{s3_data_bucket}/model-monitor/datacapture",
},
)
# create sagemaker endpoint
create_endpoint_api_response = sagemaker_client.create_endpoint(
EndpointName=endpoint_name,
EndpointConfigName=endpoint_config_name,
)
try:
sagemaker_client.get_waiter("endpoint_in_service").wait(
EndpointName=endpoint_name
)
finally:
resp = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
endpoint_status = resp["EndpointStatus"]
endpoint_arn = resp["EndpointArn"]
print(f"Deployed endpoint {endpoint_arn}, ended with status {endpoint_status}")
if endpoint_status != "InService":
message = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)[
"FailureReason"
]
print(
"Endpoint deployment failed with the following error: {}".format(
message
)
)
raise Exception("Endpoint deployment failed")
yield endpoint_name
# delete model and endpoint config
print("deleting endpoint.................")
sagemaker_utils.delete_endpoint(sagemaker_client, endpoint_name)
sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sagemaker_client.delete_model(ModelName=model_name)