157 lines
5.9 KiB
Python
157 lines
5.9 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import kfp
|
|
import json
|
|
import copy
|
|
from kfp import components
|
|
from kfp import dsl
|
|
from kfp.aws import use_aws_secret
|
|
|
|
sagemaker_workteam_op = components.load_component_from_file(
|
|
"../../../../components/aws/sagemaker/workteam/component.yaml"
|
|
)
|
|
sagemaker_gt_op = components.load_component_from_file(
|
|
"../../../../components/aws/sagemaker/ground_truth/component.yaml"
|
|
)
|
|
sagemaker_train_op = components.load_component_from_file(
|
|
"../../../../components/aws/sagemaker/train/component.yaml"
|
|
)
|
|
|
|
channelObjList = []
|
|
|
|
channelObj = {
|
|
"ChannelName": "",
|
|
"DataSource": {
|
|
"S3DataSource": {
|
|
"S3Uri": "",
|
|
"S3DataType": "AugmentedManifestFile",
|
|
"S3DataDistributionType": "FullyReplicated",
|
|
"AttributeNames": ["source-ref", "category"],
|
|
}
|
|
},
|
|
"ContentType": "application/x-recordio",
|
|
"CompressionType": "None",
|
|
"RecordWrapperType": "RecordIO",
|
|
}
|
|
|
|
|
|
@dsl.pipeline(
|
|
name="Ground Truth image classification test pipeline",
|
|
description="SageMaker Ground Truth job test",
|
|
)
|
|
def ground_truth_test(
|
|
region="us-west-2",
|
|
team_name="ground-truth-demo-team",
|
|
team_description="Team for mini image classification labeling job",
|
|
user_pool="",
|
|
user_groups="",
|
|
client_id="",
|
|
ground_truth_train_job_name="mini-image-classification-demo-train",
|
|
ground_truth_validation_job_name="mini-image-classification-demo-validation",
|
|
ground_truth_label_attribute_name="category",
|
|
ground_truth_train_manifest_location="s3://your-bucket-name/mini-image-classification/ground-truth-demo/train.manifest",
|
|
ground_truth_validation_manifest_location="s3://your-bucket-name/mini-image-classification/ground-truth-demo/validation.manifest",
|
|
ground_truth_output_location="s3://your-bucket-name/mini-image-classification/ground-truth-demo/output",
|
|
ground_truth_task_type="image classification",
|
|
ground_truth_worker_type="private",
|
|
ground_truth_label_category_config="s3://your-bucket-name/mini-image-classification/ground-truth-demo/class_labels.json",
|
|
ground_truth_ui_template="s3://your-bucket-name/mini-image-classification/ground-truth-demo/instructions.template",
|
|
ground_truth_title="Mini image classification",
|
|
ground_truth_description="Test for Ground Truth KFP component",
|
|
ground_truth_num_workers_per_object=1,
|
|
ground_truth_time_limit=30,
|
|
ground_truth_task_availibility=3600,
|
|
ground_truth_max_concurrent_tasks=20,
|
|
training_algorithm_name="image classification",
|
|
training_input_mode="Pipe",
|
|
training_hyperparameters={
|
|
"num_classes": "2",
|
|
"num_training_samples": "14",
|
|
"mini_batch_size": "2",
|
|
},
|
|
training_output_location="s3://your-bucket-name/mini-image-classification/training-output",
|
|
training_instance_type="ml.m5.2xlarge",
|
|
training_instance_count=1,
|
|
training_volume_size=50,
|
|
training_max_run_time=3600,
|
|
role_arn="",
|
|
):
|
|
|
|
workteam = sagemaker_workteam_op(
|
|
region=region,
|
|
team_name=team_name,
|
|
description=team_description,
|
|
user_pool=user_pool,
|
|
user_groups=user_groups,
|
|
client_id=client_id,
|
|
)
|
|
|
|
ground_truth_train = sagemaker_gt_op(
|
|
region=region,
|
|
role=role_arn,
|
|
job_name=ground_truth_train_job_name,
|
|
label_attribute_name=ground_truth_label_attribute_name,
|
|
manifest_location=ground_truth_train_manifest_location,
|
|
output_location=ground_truth_output_location,
|
|
task_type=ground_truth_task_type,
|
|
worker_type=ground_truth_worker_type,
|
|
workteam_arn=workteam.output,
|
|
label_category_config=ground_truth_label_category_config,
|
|
ui_template=ground_truth_ui_template,
|
|
title=ground_truth_title,
|
|
description=ground_truth_description,
|
|
num_workers_per_object=ground_truth_num_workers_per_object,
|
|
time_limit=ground_truth_time_limit,
|
|
task_availibility=ground_truth_task_availibility,
|
|
max_concurrent_tasks=ground_truth_max_concurrent_tasks,
|
|
)
|
|
|
|
ground_truth_validation = sagemaker_gt_op(
|
|
region=region,
|
|
role=role_arn,
|
|
job_name=ground_truth_validation_job_name,
|
|
label_attribute_name=ground_truth_label_attribute_name,
|
|
manifest_location=ground_truth_validation_manifest_location,
|
|
output_location=ground_truth_output_location,
|
|
task_type=ground_truth_task_type,
|
|
worker_type=ground_truth_worker_type,
|
|
workteam_arn=workteam.output,
|
|
label_category_config=ground_truth_label_category_config,
|
|
ui_template=ground_truth_ui_template,
|
|
title=ground_truth_title,
|
|
description=ground_truth_description,
|
|
num_workers_per_object=ground_truth_num_workers_per_object,
|
|
time_limit=ground_truth_time_limit,
|
|
task_availibility=ground_truth_task_availibility,
|
|
max_concurrent_tasks=ground_truth_max_concurrent_tasks,
|
|
)
|
|
|
|
channelObj["ChannelName"] = "train"
|
|
channelObj["DataSource"]["S3DataSource"]["S3Uri"] = str(
|
|
ground_truth_train.outputs["output_manifest_location"]
|
|
)
|
|
channelObjList.append(copy.deepcopy(channelObj))
|
|
channelObj["ChannelName"] = "validation"
|
|
channelObj["DataSource"]["S3DataSource"]["S3Uri"] = str(
|
|
ground_truth_validation.outputs["output_manifest_location"]
|
|
)
|
|
channelObjList.append(copy.deepcopy(channelObj))
|
|
|
|
training = sagemaker_train_op(
|
|
region=region,
|
|
algorithm_name=training_algorithm_name,
|
|
training_input_mode=training_input_mode,
|
|
hyperparameters=training_hyperparameters,
|
|
channels=json.dumps(channelObjList),
|
|
instance_type=training_instance_type,
|
|
instance_count=training_instance_count,
|
|
volume_size=training_volume_size,
|
|
max_run_time=training_max_run_time,
|
|
model_artifact_path=training_output_location,
|
|
role=role_arn,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
kfp.compiler.Compiler().compile(ground_truth_test, __file__ + ".zip")
|