pipelines/samples/contrib/aws-samples/sagemaker_debugger_demo/debugger-training-pipeline.py

111 lines
3.0 KiB
Python

#!/usr/bin/env python3
import kfp
import json
import os
import copy
from kfp import components
from kfp import dsl
cur_file_dir = os.path.dirname(__file__)
components_dir = os.path.join(cur_file_dir, "../../../../components/aws/sagemaker/")
sagemaker_train_op = components.load_component_from_file(
components_dir + "/train/component.yaml"
)
def training_input(input_name, s3_uri, content_type):
return {
"ChannelName": input_name,
"DataSource": {"S3DataSource": {"S3Uri": s3_uri, "S3DataType": "S3Prefix"}},
"ContentType": content_type,
}
def training_debug_hook(s3_uri, collection_dict):
return {
"S3OutputPath": s3_uri,
"CollectionConfigurations": format_collection_config(collection_dict),
}
def format_collection_config(collection_dict):
output = []
for key, val in collection_dict.items():
output.append({"CollectionName": key, "CollectionParameters": val})
return output
def training_debug_rules(rule_name, parameters):
return {
"RuleConfigurationName": rule_name,
"RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest",
"RuleParameters": parameters,
}
collections = {
"feature_importance": {"save_interval": "5"},
"losses": {"save_interval": "10"},
"average_shap": {"save_interval": "5"},
"metrics": {"save_interval": "3"},
}
bad_hyperparameters = {
"max_depth": "5",
"eta": "0",
"gamma": "4",
"min_child_weight": "6",
"silent": "0",
"subsample": "0.7",
"num_round": "50",
}
@dsl.pipeline(
name="XGBoost Training Pipeline with bad hyperparameters",
description="SageMaker training job test with debugger",
)
def training(role_arn="", bucket_name="my-bucket"):
train_channels = [
training_input(
"train",
f"s3://{bucket_name}/mnist_kmeans_example/input/valid_data.csv",
"text/csv",
)
]
train_debug_rules = [
training_debug_rules(
"LossNotDecreasing",
{"rule_to_invoke": "LossNotDecreasing", "tensor_regex": ".*"},
),
training_debug_rules(
"Overtraining",
{
"rule_to_invoke": "Overtraining",
"patience_train": "10",
"patience_validation": "20",
},
),
]
training = sagemaker_train_op(
region="us-east-1",
image="683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:0.90-2-cpu-py3",
hyperparameters=bad_hyperparameters,
channels=train_channels,
instance_type="ml.m5.2xlarge",
model_artifact_path=f"s3://{bucket_name}/mnist_kmeans_example/output/model",
debug_hook_config=training_debug_hook(
f"s3://{bucket_name}/mnist_kmeans_example/hook_config", collections
),
debug_rule_config=train_debug_rules,
role=role_arn,
)
if __name__ == "__main__":
kfp.compiler.Compiler().compile(training, __file__ + ".zip")