Dataflow SDK to support launch beam python code or template (#833)
* initial files for dataflow commands * dataflow commands * add launch python command * Support launch_python command * Use display API and default outputs to /tmp/output folder
This commit is contained in:
parent
7775692adf
commit
ea26a574bf
|
|
@ -12,4 +12,4 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import launcher, core
|
||||
from . import launcher, core, google
|
||||
|
|
|
|||
|
|
@ -11,3 +11,5 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import ml_engine, dataflow
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._launch_template import launch_template
|
||||
from ._launch_python import launch_python
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import googleapiclient.discovery as discovery
|
||||
from googleapiclient import errors
|
||||
|
||||
class DataflowClient:
|
||||
def __init__(self):
|
||||
self._df = discovery.build('dataflow', 'v1b3')
|
||||
|
||||
def launch_template(self, project_id, gcs_path, location,
|
||||
validate_only, launch_parameters):
|
||||
return self._df.projects().templates().launch(
|
||||
projectId = project_id,
|
||||
gcsPath = gcs_path,
|
||||
location = location,
|
||||
validateOnly = validate_only,
|
||||
body = launch_parameters
|
||||
).execute()
|
||||
|
||||
def get_job(self, project_id, job_id, location=None, view=None):
|
||||
return self._df.projects().jobs().get(
|
||||
projectId = project_id,
|
||||
jobId = job_id,
|
||||
location = location,
|
||||
view = view
|
||||
).execute()
|
||||
|
||||
def cancel_job(self, project_id, job_id, location):
|
||||
return self._df.projects().jobs().update(
|
||||
projectId = project_id,
|
||||
jobId = job_id,
|
||||
location = location,
|
||||
body = {
|
||||
'requestedState': 'JOB_STATE_CANCELLED'
|
||||
}
|
||||
).execute()
|
||||
|
||||
def list_aggregated_jobs(self, project_id, filter=None,
|
||||
view=None, page_size=None, page_token=None, location=None):
|
||||
return self._df.projects().jobs().aggregated(
|
||||
projectId = project_id,
|
||||
filter = filter,
|
||||
view = view,
|
||||
pageSize = page_size,
|
||||
pageToken = page_token,
|
||||
location = location).execute()
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from kfp_component.core import display
|
||||
from .. import common as gcp_common
|
||||
from ..storage import download_blob, parse_blob_path, is_gcs_path
|
||||
|
||||
_JOB_SUCCESSFUL_STATES = ['JOB_STATE_DONE', 'JOB_STATE_UPDATED', 'JOB_STATE_DRAINED']
|
||||
_JOB_FAILED_STATES = ['JOB_STATE_STOPPED', 'JOB_STATE_FAILED', 'JOB_STATE_CANCELLED']
|
||||
_JOB_TERMINATED_STATES = _JOB_SUCCESSFUL_STATES + _JOB_FAILED_STATES
|
||||
|
||||
def generate_job_name(job_name, context_id):
|
||||
"""Generates a stable job name in the job context.
|
||||
|
||||
If user provided ``job_name`` has value, the function will use it
|
||||
as a prefix and appends first 8 characters of ``context_id`` to
|
||||
make the name unique across contexts. If the ``job_name`` is empty,
|
||||
it will use ``job-{context_id}`` as the job name.
|
||||
"""
|
||||
if job_name:
|
||||
return '{}-{}'.format(
|
||||
gcp_common.normalize_name(job_name),
|
||||
context_id[:8])
|
||||
|
||||
return 'job-{}'.format(context_id)
|
||||
|
||||
def get_job_by_name(df_client, project_id, job_name, location=None):
|
||||
"""Gets a job by its name.
|
||||
|
||||
The function lists all jobs under a project or a region location.
|
||||
Compares their names with the ``job_name`` and return the job
|
||||
once it finds a match. If none of the jobs matches, it returns
|
||||
``None``.
|
||||
"""
|
||||
page_token = None
|
||||
while True:
|
||||
response = df_client.list_aggregated_jobs(project_id,
|
||||
page_size=50, page_token=page_token, location=location)
|
||||
for job in response.get('jobs', []):
|
||||
name = job.get('name', None)
|
||||
if job_name == name:
|
||||
return job
|
||||
page_token = response.get('nextPageToken', None)
|
||||
if not page_token:
|
||||
return None
|
||||
|
||||
def wait_for_job_done(df_client, project_id, job_id, location=None, wait_interval=30):
|
||||
while True:
|
||||
job = df_client.get_job(project_id, job_id, location=location)
|
||||
state = job.get('currentState', None)
|
||||
if is_job_done(state):
|
||||
return job
|
||||
elif is_job_terminated(state):
|
||||
# Terminated with error state
|
||||
raise RuntimeError('Job {} failed with error state: {}.'.format(
|
||||
job_id,
|
||||
state
|
||||
))
|
||||
else:
|
||||
logging.info('Job {} is in pending state {}.'
|
||||
' Waiting for {} seconds for next poll.'.format(
|
||||
job_id,
|
||||
state,
|
||||
wait_interval
|
||||
))
|
||||
time.sleep(wait_interval)
|
||||
|
||||
def wait_and_dump_job(df_client, project_id, location, job,
|
||||
wait_interval):
|
||||
display_job_link(project_id, job)
|
||||
job_id = job.get('id')
|
||||
job = wait_for_job_done(df_client, project_id, job_id,
|
||||
location, wait_interval)
|
||||
dump_job(job)
|
||||
return job
|
||||
|
||||
def is_job_terminated(job_state):
|
||||
return job_state in _JOB_TERMINATED_STATES
|
||||
|
||||
def is_job_done(job_state):
|
||||
return job_state in _JOB_SUCCESSFUL_STATES
|
||||
|
||||
def display_job_link(project_id, job):
|
||||
location = job.get('location')
|
||||
job_id = job.get('id')
|
||||
display.display(display.Link(
|
||||
href = 'https://console.cloud.google.com/dataflow/'
|
||||
'jobsDetail/locations/{}/jobs/{}?project={}'.format(
|
||||
location, job_id, project_id),
|
||||
text = 'Job Details'
|
||||
))
|
||||
|
||||
def dump_job(job):
|
||||
gcp_common.dump_file('/tmp/output/job.json', json.dumps(job))
|
||||
|
||||
def stage_file(local_or_gcs_path):
|
||||
if not is_gcs_path(local_or_gcs_path):
|
||||
return local_or_gcs_path
|
||||
_, blob_path = parse_blob_path(local_or_gcs_path)
|
||||
file_name = os.path.basename(blob_path)
|
||||
local_file_path = os.path.join(tempfile.mkdtemp(), file_name)
|
||||
download_blob(local_or_gcs_path, local_file_path)
|
||||
return local_file_path
|
||||
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import subprocess
|
||||
import re
|
||||
import logging
|
||||
|
||||
from kfp_component.core import KfpExecutionContext
|
||||
from ._client import DataflowClient
|
||||
from .. import common as gcp_common
|
||||
from ._common_ops import (generate_job_name, get_job_by_name,
|
||||
wait_and_dump_job, stage_file)
|
||||
from ._process import Process
|
||||
|
||||
def launch_python(python_file_path, project_id, requirements_file_path=None,
|
||||
location=None, job_name_prefix=None, args=[], wait_interval=30):
|
||||
"""Launch a self-executing beam python file.
|
||||
|
||||
Args:
|
||||
python_file_path (str): The gcs or local path to the python file to run.
|
||||
project_id (str): The ID of the parent project.
|
||||
requirements_file_path (str): Optional, the gcs or local path to the pip
|
||||
requirements file.
|
||||
location (str): The regional endpoint to which to direct the
|
||||
request.
|
||||
job_name_prefix (str): Optional. The prefix of the genrated job
|
||||
name. If not provided, the method will generated a random name.
|
||||
args (list): The list of args to pass to the python file.
|
||||
wait_interval (int): The wait seconds between polling.
|
||||
Returns:
|
||||
The completed job.
|
||||
"""
|
||||
df_client = DataflowClient()
|
||||
job_id = None
|
||||
def cancel():
|
||||
if job_id:
|
||||
df_client.cancel_job(
|
||||
project_id,
|
||||
job_id,
|
||||
location
|
||||
)
|
||||
with KfpExecutionContext(on_cancel=cancel) as ctx:
|
||||
job_name = generate_job_name(
|
||||
job_name_prefix,
|
||||
ctx.context_id())
|
||||
# We will always generate unique name for the job. We expect
|
||||
# job with same name was created in previous tries from the same
|
||||
# pipeline run.
|
||||
job = get_job_by_name(df_client, project_id, job_name,
|
||||
location)
|
||||
if job:
|
||||
return wait_and_dump_job(df_client, project_id, location, job,
|
||||
wait_interval)
|
||||
|
||||
_install_requirements(requirements_file_path)
|
||||
python_file_path = stage_file(python_file_path)
|
||||
cmd = _prepare_cmd(project_id, location, job_name, python_file_path,
|
||||
args)
|
||||
sub_process = Process(cmd)
|
||||
for line in sub_process.read_lines():
|
||||
job_id = _extract_job_id(line)
|
||||
if job_id:
|
||||
logging.info('Found job id {}'.format(job_id))
|
||||
break
|
||||
sub_process.wait_and_check()
|
||||
if not job_id:
|
||||
logging.warning('No dataflow job was found when '
|
||||
'running the python file.')
|
||||
return None
|
||||
job = df_client.get_job(project_id, job_id,
|
||||
location=location)
|
||||
return wait_and_dump_job(df_client, project_id, location, job,
|
||||
wait_interval)
|
||||
|
||||
def _prepare_cmd(project_id, location, job_name, python_file_path, args):
|
||||
dataflow_args = [
|
||||
'--runner', 'dataflow',
|
||||
'--project', project_id,
|
||||
'--job-name', job_name]
|
||||
if location:
|
||||
dataflow_args += ['--location', location]
|
||||
return (['python2', '-u', python_file_path] +
|
||||
dataflow_args + args)
|
||||
|
||||
def _extract_job_id(line):
|
||||
job_id_pattern = re.compile(
|
||||
br'.*console.cloud.google.com/dataflow.*/jobs/([a-z|0-9|A-Z|\-|\_]+).*')
|
||||
matched_job_id = job_id_pattern.search(line or '')
|
||||
if matched_job_id:
|
||||
return matched_job_id.group(1).decode()
|
||||
return None
|
||||
|
||||
def _install_requirements(requirements_file_path):
|
||||
if not requirements_file_path:
|
||||
return
|
||||
requirements_file_path = stage_file(requirements_file_path)
|
||||
subprocess.run(['pip2', 'install', '-r', requirements_file_path])
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
|
||||
from kfp_component.core import KfpExecutionContext
|
||||
from ._client import DataflowClient
|
||||
from .. import common as gcp_common
|
||||
from ._common_ops import (generate_job_name, get_job_by_name,
|
||||
wait_and_dump_job)
|
||||
|
||||
def launch_template(project_id, gcs_path, launch_parameters,
|
||||
location=None, job_name_prefix=None, validate_only=None,
|
||||
wait_interval=30):
|
||||
"""Launchs a dataflow job from template.
|
||||
|
||||
Args:
|
||||
project_id (str): Required. The ID of the Cloud Platform project
|
||||
that the job belongs to.
|
||||
gcs_path (str): Required. A Cloud Storage path to the template
|
||||
from which to create the job. Must be valid Cloud
|
||||
Storage URL, beginning with 'gs://'.
|
||||
launch_parameters (dict): Parameters to provide to the template
|
||||
being launched. Schema defined in
|
||||
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/LaunchTemplateParameters.
|
||||
`jobName` will be replaced by generated name.
|
||||
location (str): The regional endpoint to which to direct the
|
||||
request.
|
||||
job_name_prefix (str): Optional. The prefix of the genrated job
|
||||
name. If not provided, the method will generated a random name.
|
||||
validate_only (boolean): If true, the request is validated but
|
||||
not actually executed. Defaults to false.
|
||||
wait_interval (int): The wait seconds between polling.
|
||||
|
||||
Returns:
|
||||
The completed job.
|
||||
"""
|
||||
df_client = DataflowClient()
|
||||
job_id = None
|
||||
def cancel():
|
||||
if job_id:
|
||||
df_client.cancel_job(
|
||||
project_id,
|
||||
job_id,
|
||||
location
|
||||
)
|
||||
with KfpExecutionContext(on_cancel=cancel) as ctx:
|
||||
job_name = generate_job_name(
|
||||
job_name_prefix,
|
||||
ctx.context_id())
|
||||
print(job_name)
|
||||
job = get_job_by_name(df_client, project_id, job_name,
|
||||
location)
|
||||
if not job:
|
||||
launch_parameters['jobName'] = job_name
|
||||
response = df_client.launch_template(project_id, gcs_path,
|
||||
location, validate_only, launch_parameters)
|
||||
job = response.get('job')
|
||||
return wait_and_dump_job(df_client, project_id, location, job,
|
||||
wait_interval)
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import subprocess
|
||||
import logging
|
||||
|
||||
class Process:
|
||||
def __init__(self, cmd):
|
||||
self._cmd = cmd
|
||||
self.process = subprocess.Popen(cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
close_fds=True,
|
||||
shell=False)
|
||||
|
||||
def read_lines(self):
|
||||
# stdout will end with empty bytes when process exits.
|
||||
for line in iter(self.process.stdout.readline, b''):
|
||||
logging.info('subprocess: {}'.format(line))
|
||||
yield line
|
||||
|
||||
def wait_and_check(self):
|
||||
for _ in self.read_lines():
|
||||
pass
|
||||
self.process.stdout.close()
|
||||
return_code = self.process.wait()
|
||||
logging.info('Subprocess exit with code {}.'.format(
|
||||
return_code))
|
||||
if return_code:
|
||||
raise subprocess.CalledProcessError(return_code, self._cmd)
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._download_blob import download_blob
|
||||
from ._common_ops import parse_blob_path, is_gcs_path
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
|
||||
def is_gcs_path(path):
|
||||
"""Check if the path is a gcs path"""
|
||||
return path.startswith('gs://')
|
||||
|
||||
def parse_blob_path(path):
|
||||
"""Parse a gcs path into bucket name and blob name
|
||||
|
||||
Args:
|
||||
path (str): the path to parse.
|
||||
|
||||
Returns:
|
||||
(bucket name in the path, blob name in the path)
|
||||
|
||||
Raises:
|
||||
ValueError if the path is not a valid gcs blob path.
|
||||
|
||||
Example:
|
||||
|
||||
`bucket_name, blob_name = parse_blob_path('gs://foo/bar')`
|
||||
`bucket_name` is `foo` and `blob_name` is `bar`
|
||||
"""
|
||||
match = re.match('gs://([^/]+)/(.+)$', path)
|
||||
if match:
|
||||
return match.group(1), match.group(2)
|
||||
raise ValueError('Path {} is invalid blob path.'.format(
|
||||
path))
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from google.cloud import storage
|
||||
from ._common_ops import parse_blob_path
|
||||
|
||||
def download_blob(source_blob_path, destination_file_path):
|
||||
"""Downloads a blob from the bucket.
|
||||
|
||||
Args:
|
||||
source_blob_path (str): the source blob path to download from.
|
||||
destination_file_path (str): the local file path to download to.
|
||||
"""
|
||||
bucket_name, blob_name = parse_blob_path(source_blob_path)
|
||||
storage_client = storage.Client()
|
||||
bucket = storage_client.get_bucket(bucket_name)
|
||||
blob = bucket.blob(blob_name)
|
||||
|
||||
dirname = os.path.dirname(destination_file_path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
with open(destination_file_path, 'wb+') as f:
|
||||
blob.download_to_file(f)
|
||||
|
||||
logging.info('Blob {} downloaded to {}.'.format(
|
||||
source_blob_path,
|
||||
destination_file_path))
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import mock
|
||||
import unittest
|
||||
import os
|
||||
|
||||
from kfp_component.google.dataflow import launch_python
|
||||
|
||||
MODULE = 'kfp_component.google.dataflow._launch_python'
|
||||
|
||||
@mock.patch('kfp_component.google.dataflow._common_ops.display')
|
||||
@mock.patch(MODULE + '.stage_file')
|
||||
@mock.patch(MODULE + '.KfpExecutionContext')
|
||||
@mock.patch(MODULE + '.DataflowClient')
|
||||
@mock.patch(MODULE + '.Process')
|
||||
@mock.patch(MODULE + '.subprocess')
|
||||
class LaunchPythonTest(unittest.TestCase):
|
||||
|
||||
def test_launch_python_succeed(self, mock_subprocess, mock_process,
|
||||
mock_client, mock_context, mock_stage_file, mock_display):
|
||||
mock_context().__enter__().context_id.return_value = 'ctx-1'
|
||||
mock_client().list_aggregated_jobs.return_value = {
|
||||
'jobs': []
|
||||
}
|
||||
mock_process().read_lines.return_value = [
|
||||
b'https://console.cloud.google.com/dataflow/locations/us-central1/jobs/job-1?project=project-1'
|
||||
]
|
||||
expected_job = {
|
||||
'currentState': 'JOB_STATE_DONE'
|
||||
}
|
||||
mock_client().get_job.return_value = expected_job
|
||||
|
||||
result = launch_python('/tmp/test.py', 'project-1')
|
||||
|
||||
self.assertEqual(expected_job, result)
|
||||
|
||||
def test_launch_python_retry_succeed(self, mock_subprocess, mock_process,
|
||||
mock_client, mock_context, mock_stage_file, mock_display):
|
||||
mock_context().__enter__().context_id.return_value = 'ctx-1'
|
||||
mock_client().list_aggregated_jobs.return_value = {
|
||||
'jobs': [{
|
||||
'id': 'job-1',
|
||||
'name': 'test_job-ctx-1'
|
||||
}]
|
||||
}
|
||||
expected_job = {
|
||||
'currentState': 'JOB_STATE_DONE'
|
||||
}
|
||||
mock_client().get_job.return_value = expected_job
|
||||
|
||||
result = launch_python('/tmp/test.py', 'project-1', job_name_prefix='test-job')
|
||||
|
||||
self.assertEqual(expected_job, result)
|
||||
mock_process.assert_not_called()
|
||||
|
||||
def test_launch_python_no_job_created(self, mock_subprocess, mock_process,
|
||||
mock_client, mock_context, mock_stage_file, mock_display):
|
||||
mock_context().__enter__().context_id.return_value = 'ctx-1'
|
||||
mock_client().list_aggregated_jobs.return_value = {
|
||||
'jobs': []
|
||||
}
|
||||
mock_process().read_lines.return_value = [
|
||||
b'no job id',
|
||||
b'no job id'
|
||||
]
|
||||
|
||||
result = launch_python('/tmp/test.py', 'project-1')
|
||||
|
||||
self.assertEqual(None, result)
|
||||
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import mock
|
||||
import unittest
|
||||
import os
|
||||
|
||||
from kfp_component.google.dataflow import launch_template
|
||||
|
||||
MODULE = 'kfp_component.google.dataflow._launch_template'
|
||||
|
||||
@mock.patch('kfp_component.google.dataflow._common_ops.display')
|
||||
@mock.patch(MODULE + '.KfpExecutionContext')
|
||||
@mock.patch(MODULE + '.DataflowClient')
|
||||
class LaunchTemplateTest(unittest.TestCase):
|
||||
|
||||
def test_launch_template_succeed(self, mock_client, mock_context, mock_display):
|
||||
mock_context().__enter__().context_id.return_value = 'context-1'
|
||||
mock_client().list_aggregated_jobs.return_value = {
|
||||
'jobs': []
|
||||
}
|
||||
mock_client().launch_template.return_value = {
|
||||
'job': { 'id': 'job-1' }
|
||||
}
|
||||
expected_job = {
|
||||
'currentState': 'JOB_STATE_DONE'
|
||||
}
|
||||
mock_client().get_job.return_value = expected_job
|
||||
|
||||
result = launch_template('project-1', 'gs://foo/bar', {
|
||||
"parameters": {
|
||||
"foo": "bar"
|
||||
},
|
||||
"environment": {
|
||||
"zone": "us-central1"
|
||||
}
|
||||
})
|
||||
|
||||
self.assertEqual(expected_job, result)
|
||||
mock_client().launch_template.assert_called_once()
|
||||
|
||||
def test_launch_template_retry_succeed(self,
|
||||
mock_client, mock_context, mock_display):
|
||||
mock_context().__enter__().context_id.return_value = 'ctx-1'
|
||||
# The job with same name already exists.
|
||||
mock_client().list_aggregated_jobs.return_value = {
|
||||
'jobs': [{
|
||||
'id': 'job-1',
|
||||
'name': 'test_job-ctx-1'
|
||||
}]
|
||||
}
|
||||
pending_job = {
|
||||
'currentState': 'JOB_STATE_PENDING'
|
||||
}
|
||||
expected_job = {
|
||||
'currentState': 'JOB_STATE_DONE'
|
||||
}
|
||||
mock_client().get_job.side_effect = [pending_job, expected_job]
|
||||
|
||||
result = launch_template('project-1', 'gs://foo/bar', {
|
||||
"parameters": {
|
||||
"foo": "bar"
|
||||
},
|
||||
"environment": {
|
||||
"zone": "us-central1"
|
||||
}
|
||||
}, job_name_prefix='test-job', wait_interval=0)
|
||||
|
||||
self.assertEqual(expected_job, result)
|
||||
mock_client().launch_template.assert_not_called()
|
||||
|
||||
def test_launch_template_fail(self, mock_client, mock_context, mock_display):
|
||||
mock_context().__enter__().context_id.return_value = 'context-1'
|
||||
mock_client().list_aggregated_jobs.return_value = {
|
||||
'jobs': []
|
||||
}
|
||||
mock_client().launch_template.return_value = {
|
||||
'job': { 'id': 'job-1' }
|
||||
}
|
||||
failed_job = {
|
||||
'currentState': 'JOB_STATE_FAILED'
|
||||
}
|
||||
mock_client().get_job.return_value = failed_job
|
||||
|
||||
self.assertRaises(RuntimeError,
|
||||
lambda: launch_template('project-1', 'gs://foo/bar', {
|
||||
"parameters": {
|
||||
"foo": "bar"
|
||||
},
|
||||
"environment": {
|
||||
"zone": "us-central1"
|
||||
}
|
||||
}))
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from kfp_component.google.storage import is_gcs_path, parse_blob_path
|
||||
|
||||
class CommonOpsTest(unittest.TestCase):
|
||||
|
||||
def test_is_gcs_path(self):
|
||||
self.assertTrue(is_gcs_path('gs://foo'))
|
||||
self.assertTrue(is_gcs_path('gs://foo/bar'))
|
||||
self.assertFalse(is_gcs_path('gs:/foo/bar'))
|
||||
self.assertFalse(is_gcs_path('foo/bar'))
|
||||
|
||||
def test_parse_blob_path_valid(self):
|
||||
bucket_name, blob_name = parse_blob_path('gs://foo/bar/baz/')
|
||||
|
||||
self.assertEqual('foo', bucket_name)
|
||||
self.assertEqual('bar/baz/', blob_name)
|
||||
|
||||
def test_parse_blob_path_invalid(self):
|
||||
# No blob name
|
||||
self.assertRaises(ValueError, lambda: parse_blob_path('gs://foo'))
|
||||
self.assertRaises(ValueError, lambda: parse_blob_path('gs://foo/'))
|
||||
|
||||
# Invalid GCS path
|
||||
self.assertRaises(ValueError, lambda: parse_blob_path('foo'))
|
||||
self.assertRaises(ValueError, lambda: parse_blob_path('gs:///foo'))
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import mock
|
||||
import unittest
|
||||
import os
|
||||
|
||||
from kfp_component.google.storage import download_blob
|
||||
|
||||
DOWNLOAD_BLOB_MODULE = 'kfp_component.google.storage._download_blob'
|
||||
|
||||
@mock.patch(DOWNLOAD_BLOB_MODULE + '.os')
|
||||
@mock.patch(DOWNLOAD_BLOB_MODULE + '.open')
|
||||
@mock.patch(DOWNLOAD_BLOB_MODULE + '.storage.Client')
|
||||
class DownloadBlobTest(unittest.TestCase):
|
||||
|
||||
def test_download_blob_succeed(self, mock_storage_client,
|
||||
mock_open, mock_os):
|
||||
mock_os.path.dirname.return_value = '/foo'
|
||||
mock_os.path.exists.return_value = False
|
||||
|
||||
download_blob('gs://foo/bar.py',
|
||||
'/foo/bar.py')
|
||||
|
||||
mock_blob = mock_storage_client().get_bucket().blob()
|
||||
mock_blob.download_to_file.assert_called_once()
|
||||
mock_os.makedirs.assert_called_once()
|
||||
Loading…
Reference in New Issue