pipelines/components/contrib/google-cloud/Optimizer/Create_study/component.yaml

161 lines
7.0 KiB
YAML

name: Create study in gcp ai platform optimizer
description: Creates a Google Cloud AI Plaform Optimizer study.
inputs:
- {name: study_id, type: String, description: Name of the study.}
- {name: parameter_specs, type: JsonArray, description: 'List of parameter specs.
See https://cloud.google.com/ai-platform/optimizer/docs/reference/rest/v1/projects.locations.studies#parameterspec'}
- {name: optimization_goal, type: String, description: Optimization goal when optimizing
a single metric. Can be MAXIMIZE (default) or MINIMIZE. Ignored if metric_specs
list is provided., default: MAXIMIZE, optional: true}
- {name: metric_specs, type: JsonArray, description: 'List of metric specs. See https://cloud.google.com/ai-platform/optimizer/docs/reference/rest/v1/projects.locations.studies#metricspec',
optional: true}
- {name: gcp_project_id, type: String, optional: true}
- {name: gcp_region, type: String, default: us-central1, optional: true}
outputs:
- {name: study_name, type: String}
metadata:
annotations:
author: Alexey Volkov <alexey.volkov@ark-kun.com>
canonical_location: 'https://raw.githubusercontent.com/Ark-kun/pipeline_components/master/components/google-cloud/Optimizer/Create_study/component.yaml'
implementation:
container:
image: python:3.8
command:
- sh
- -c
- (PIP_DISABLE_PIP_VERSION_CHECK=1 python3 -m pip install --quiet --no-warn-script-location
'google-api-python-client==1.12.3' 'google-cloud-storage==1.31.2' 'google-auth==1.21.3'
|| PIP_DISABLE_PIP_VERSION_CHECK=1 python3 -m pip install --quiet --no-warn-script-location
'google-api-python-client==1.12.3' 'google-cloud-storage==1.31.2' 'google-auth==1.21.3'
--user) && "$0" "$@"
- python3
- -u
- -c
- |
def create_study_in_gcp_ai_platform_optimizer(
study_id,
parameter_specs,
optimization_goal = 'MAXIMIZE',
metric_specs = None,
gcp_project_id = None,
gcp_region = "us-central1",
):
"""Creates a Google Cloud AI Plaform Optimizer study.
See https://cloud.google.com/ai-platform/optimizer/docs
Annotations:
author: Alexey Volkov <alexey.volkov@ark-kun.com>
Args:
study_id: Name of the study.
parameter_specs: List of parameter specs. See https://cloud.google.com/ai-platform/optimizer/docs/reference/rest/v1/projects.locations.studies#parameterspec
optimization_goal: Optimization goal when optimizing a single metric. Can be MAXIMIZE (default) or MINIMIZE. Ignored if metric_specs list is provided.
metric_specs: List of metric specs. See https://cloud.google.com/ai-platform/optimizer/docs/reference/rest/v1/projects.locations.studies#metricspec
"""
import logging
import google.auth
logging.getLogger().setLevel(logging.INFO)
# Validating and inferring the arguments
if not gcp_project_id:
_, gcp_project_id = google.auth.default()
# Building the API client.
# The main API does not work, so we need to build from the published discovery document.
def create_caip_optimizer_client(project_id):
from google.cloud import storage
from googleapiclient import discovery
_OPTIMIZER_API_DOCUMENT_BUCKET = 'caip-optimizer-public'
_OPTIMIZER_API_DOCUMENT_FILE = 'api/ml_public_google_rest_v1.json'
client = storage.Client(project_id)
bucket = client.get_bucket(_OPTIMIZER_API_DOCUMENT_BUCKET)
blob = bucket.get_blob(_OPTIMIZER_API_DOCUMENT_FILE)
discovery_document = blob.download_as_bytes()
return discovery.build_from_document(service=discovery_document)
ml_api = create_caip_optimizer_client(gcp_project_id)
if not metric_specs:
metric_specs=[{
'metric': 'metric',
'goal': optimization_goal,
}]
study_config = {
'algorithm': 'ALGORITHM_UNSPECIFIED', # Let the service choose the `default` algorithm.
'parameters': parameter_specs,
'metrics': metric_specs,
}
study = {'study_config': study_config}
create_study_request = ml_api.projects().locations().studies().create(
parent=f'projects/{gcp_project_id}/locations/{gcp_region}',
studyId=study_id,
body=study,
)
create_study_response = create_study_request.execute()
study_name = create_study_response['name']
return (study_name,)
def _serialize_str(str_value: str) -> str:
if not isinstance(str_value, str):
raise TypeError('Value "{}" has type "{}" instead of str.'.format(str(str_value), str(type(str_value))))
return str_value
import json
import argparse
_parser = argparse.ArgumentParser(prog='Create study in gcp ai platform optimizer', description='Creates a Google Cloud AI Plaform Optimizer study.')
_parser.add_argument("--study-id", dest="study_id", type=str, required=True, default=argparse.SUPPRESS)
_parser.add_argument("--parameter-specs", dest="parameter_specs", type=json.loads, required=True, default=argparse.SUPPRESS)
_parser.add_argument("--optimization-goal", dest="optimization_goal", type=str, required=False, default=argparse.SUPPRESS)
_parser.add_argument("--metric-specs", dest="metric_specs", type=json.loads, required=False, default=argparse.SUPPRESS)
_parser.add_argument("--gcp-project-id", dest="gcp_project_id", type=str, required=False, default=argparse.SUPPRESS)
_parser.add_argument("--gcp-region", dest="gcp_region", type=str, required=False, default=argparse.SUPPRESS)
_parser.add_argument("----output-paths", dest="_output_paths", type=str, nargs=1)
_parsed_args = vars(_parser.parse_args())
_output_files = _parsed_args.pop("_output_paths", [])
_outputs = create_study_in_gcp_ai_platform_optimizer(**_parsed_args)
_output_serializers = [
_serialize_str,
]
import os
for idx, output_file in enumerate(_output_files):
try:
os.makedirs(os.path.dirname(output_file))
except OSError:
pass
with open(output_file, 'w') as f:
f.write(_output_serializers[idx](_outputs[idx]))
args:
- --study-id
- {inputValue: study_id}
- --parameter-specs
- {inputValue: parameter_specs}
- if:
cond: {isPresent: optimization_goal}
then:
- --optimization-goal
- {inputValue: optimization_goal}
- if:
cond: {isPresent: metric_specs}
then:
- --metric-specs
- {inputValue: metric_specs}
- if:
cond: {isPresent: gcp_project_id}
then:
- --gcp-project-id
- {inputValue: gcp_project_id}
- if:
cond: {isPresent: gcp_region}
then:
- --gcp-region
- {inputValue: gcp_region}
- '----output-paths'
- {outputPath: study_name}