pipelines/sdk/python/kfp/_client.py

212 lines
7.1 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
import time
import logging
import json
import os
import tarfile
import yaml
from datetime import datetime
class Client(object):
""" API Client for KubeFlow Pipeline.
"""
def __init__(self, host='ml-pipeline.kubeflow.svc.cluster.local:8888'):
"""Create a new instance of kfp client.
Args:
host: the API host. If running inside the cluster as a Pod, default value should work.
"""
try:
import kfp_experiment
except ImportError:
raise Exception('This module requires installation of kfp_experiment')
try:
import kfp_run
except ImportError:
raise Exception('This module requires installation of kfp_run')
config = kfp_run.configuration.Configuration()
config.host = host
api_client = kfp_run.api_client.ApiClient(config)
self._run_api = kfp_run.api.run_service_api.RunServiceApi(api_client)
config = kfp_experiment.configuration.Configuration()
config.host = host
api_client = kfp_experiment.api_client.ApiClient(config)
self._experiment_api = \
kfp_experiment.api.experiment_service_api.ExperimentServiceApi(api_client)
def _is_ipython(self):
"""Returns whether we are running in notebook."""
try:
import IPython
except ImportError:
return False
return True
def create_experiment(self, name):
"""Create a new experiment.
Args:
name: the name of the experiment.
Returns:
An Experiment object. Most important field is id.
"""
import kfp_experiment
exp = kfp_experiment.models.ApiExperiment(name=name)
response = self._experiment_api.create_experiment(body=exp)
if self._is_ipython():
import IPython
html = \
('Experiment link <a href="/pipeline/#/experiments/details/%s" target="_blank" >here</a>'
% response.id)
IPython.display.display(IPython.display.HTML(html))
return response
def list_experiments(self, page_token='', page_size=10, sort_by=''):
"""List experiments.
Args:
page_token: token for starting of the page.
page_size: size of the page.
sort_by: can be '[field_name]', '[field_name] des'. For example, 'name des'.
Returns:
A response object including a list of experiments and next page token.
"""
response = self._experiment_api.list_experiment(
page_token=page_token, page_size=page_size, sort_by=sort_by)
return response
def get_experiment(self, experiment_id):
"""Get details of an experiment
Args:
id of the experiment.
Returns:
A response object including details of a experiment.
Throws:
Exception if experiment is not found.
"""
return self._experiment_api.get_experiment(id=experiment_id)
def _extract_pipeline_yaml(self, tar_file):
with tarfile.open(tar_file, "r:gz") as tar:
all_yaml_files = [m for m in tar if m.isfile() and
(os.path.splitext(m.name)[-1] == '.yaml' or os.path.splitext(m.name)[-1] == '.yml')]
if len(all_yaml_files) == 0:
raise ValueError('Invalid package. Missing pipeline yaml file in the package.')
if len(all_yaml_files) > 1:
raise ValueError('Invalid package. Multiple yaml files in the package.')
with tar.extractfile(all_yaml_files[0]) as f:
return yaml.load(f)
def run_pipeline(self, experiment_id, job_name, pipeline_package_path, params={}):
"""Run a specified pipeline.
Args:
experiment_id: The string id of an experiment.
job_name: name of the job.
pipeline_package_path: local path of the pipeline package(tar.gz file).
params: a dictionary with key (string) as param name and value (string) as as param value.
Returns:
A run object. Most important field is id.
"""
import kfp_run
pipeline_obj = self._extract_pipeline_yaml(pipeline_package_path)
pipeline_json_string = json.dumps(pipeline_obj)
api_params = [kfp_run.ApiParameter(name=k, value=str(v)) for k,v in six.iteritems(params)]
key = kfp_run.models.ApiResourceKey(id=experiment_id,
type=kfp_run.models.ApiResourceType.EXPERIMENT)
reference = kfp_run.models.ApiResourceReference(key, kfp_run.models.ApiRelationship.OWNER)
spec = kfp_run.models.ApiPipelineSpec(
workflow_manifest=pipeline_json_string, parameters=api_params)
run_body = kfp_run.models.ApiRun(
pipeline_spec=spec, resource_references=[reference], name=job_name)
response = self._run_api.create_run(body=run_body)
if self._is_ipython():
import IPython
html = ('Job link <a href="/pipeline/#/runs/details/%s" target="_blank" >here</a>'
% response.run.id)
IPython.display.display(IPython.display.HTML(html))
return response.run
def list_runs(self, page_token='', page_size=10, sort_by=''):
"""List runs.
Args:
page_token: token for starting of the page.
page_size: size of the page.
sort_by: one of 'field_name', 'field_name des'. For example, 'name des'.
Returns:
A response object including a list of experiments and next page token.
"""
response = self._run_api.list_runs(page_token=page_token, page_size=page_size, sort_by=sort_by)
return response
def get_run(self, run_id):
"""Get run details.
Args:
id of the run.
Returns:
A response object including details of a run.
Throws:
Exception if run is not found.
"""
return self._run_api.get_run(run_id=run_id)
def wait_for_run_completion(self, run_id, timeout):
"""Wait for a run to complete.
Args:
run_id: run id, returned from run_pipeline.
timeout: timeout in seconds.
Returns:
A run detail object: Most important fields are run and pipeline_runtime
"""
status = 'Running:'
start_time = datetime.now()
while status is None or status.lower() not in ['succeeded', 'failed', 'skipped', 'error']:
get_run_response = self._run_api.get_run(run_id=run_id)
status = get_run_response.run.status
elapsed_time = (datetime.now() - start_time).seconds
logging.info('Waiting for the job to complete...')
if elapsed_time > timeout:
raise TimeoutError('Run timeout')
time.sleep(5)
return get_run_response
def _get_workflow_json(self, run_id):
"""Get the workflow json.
Args:
run_id: run id, returned from run_pipeline.
Returns:
workflow: json workflow
"""
get_run_response = self._run_api.get_run(run_id=run_id)
workflow = get_run_response.pipeline_runtime.workflow_manifest
workflow_json = json.loads(workflow)
return workflow_json