Dataflow SDK to support launch beam python code or template (#833)

* initial files for dataflow commands

* dataflow commands

* add launch python command

* Support launch_python command

* Use display API and default outputs to /tmp/output folder
This commit is contained in:
hongye-sun 2019-02-26 16:18:43 -08:00 committed by Kubernetes Prow Robot
parent 7775692adf
commit ea26a574bf
17 changed files with 811 additions and 2 deletions

View File

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import launcher, core
from . import launcher, core, google

View File

@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import ml_engine, dataflow

View File

@ -0,0 +1,16 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._launch_template import launch_template
from ._launch_python import launch_python

View File

@ -0,0 +1,58 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import googleapiclient.discovery as discovery
from googleapiclient import errors
class DataflowClient:
def __init__(self):
self._df = discovery.build('dataflow', 'v1b3')
def launch_template(self, project_id, gcs_path, location,
validate_only, launch_parameters):
return self._df.projects().templates().launch(
projectId = project_id,
gcsPath = gcs_path,
location = location,
validateOnly = validate_only,
body = launch_parameters
).execute()
def get_job(self, project_id, job_id, location=None, view=None):
return self._df.projects().jobs().get(
projectId = project_id,
jobId = job_id,
location = location,
view = view
).execute()
def cancel_job(self, project_id, job_id, location):
return self._df.projects().jobs().update(
projectId = project_id,
jobId = job_id,
location = location,
body = {
'requestedState': 'JOB_STATE_CANCELLED'
}
).execute()
def list_aggregated_jobs(self, project_id, filter=None,
view=None, page_size=None, page_token=None, location=None):
return self._df.projects().jobs().aggregated(
projectId = project_id,
filter = filter,
view = view,
pageSize = page_size,
pageToken = page_token,
location = location).execute()

View File

@ -0,0 +1,121 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
import json
import os
import tempfile
from kfp_component.core import display
from .. import common as gcp_common
from ..storage import download_blob, parse_blob_path, is_gcs_path
_JOB_SUCCESSFUL_STATES = ['JOB_STATE_DONE', 'JOB_STATE_UPDATED', 'JOB_STATE_DRAINED']
_JOB_FAILED_STATES = ['JOB_STATE_STOPPED', 'JOB_STATE_FAILED', 'JOB_STATE_CANCELLED']
_JOB_TERMINATED_STATES = _JOB_SUCCESSFUL_STATES + _JOB_FAILED_STATES
def generate_job_name(job_name, context_id):
"""Generates a stable job name in the job context.
If user provided ``job_name`` has value, the function will use it
as a prefix and appends first 8 characters of ``context_id`` to
make the name unique across contexts. If the ``job_name`` is empty,
it will use ``job-{context_id}`` as the job name.
"""
if job_name:
return '{}-{}'.format(
gcp_common.normalize_name(job_name),
context_id[:8])
return 'job-{}'.format(context_id)
def get_job_by_name(df_client, project_id, job_name, location=None):
"""Gets a job by its name.
The function lists all jobs under a project or a region location.
Compares their names with the ``job_name`` and return the job
once it finds a match. If none of the jobs matches, it returns
``None``.
"""
page_token = None
while True:
response = df_client.list_aggregated_jobs(project_id,
page_size=50, page_token=page_token, location=location)
for job in response.get('jobs', []):
name = job.get('name', None)
if job_name == name:
return job
page_token = response.get('nextPageToken', None)
if not page_token:
return None
def wait_for_job_done(df_client, project_id, job_id, location=None, wait_interval=30):
while True:
job = df_client.get_job(project_id, job_id, location=location)
state = job.get('currentState', None)
if is_job_done(state):
return job
elif is_job_terminated(state):
# Terminated with error state
raise RuntimeError('Job {} failed with error state: {}.'.format(
job_id,
state
))
else:
logging.info('Job {} is in pending state {}.'
' Waiting for {} seconds for next poll.'.format(
job_id,
state,
wait_interval
))
time.sleep(wait_interval)
def wait_and_dump_job(df_client, project_id, location, job,
wait_interval):
display_job_link(project_id, job)
job_id = job.get('id')
job = wait_for_job_done(df_client, project_id, job_id,
location, wait_interval)
dump_job(job)
return job
def is_job_terminated(job_state):
return job_state in _JOB_TERMINATED_STATES
def is_job_done(job_state):
return job_state in _JOB_SUCCESSFUL_STATES
def display_job_link(project_id, job):
location = job.get('location')
job_id = job.get('id')
display.display(display.Link(
href = 'https://console.cloud.google.com/dataflow/'
'jobsDetail/locations/{}/jobs/{}?project={}'.format(
location, job_id, project_id),
text = 'Job Details'
))
def dump_job(job):
gcp_common.dump_file('/tmp/output/job.json', json.dumps(job))
def stage_file(local_or_gcs_path):
if not is_gcs_path(local_or_gcs_path):
return local_or_gcs_path
_, blob_path = parse_blob_path(local_or_gcs_path)
file_name = os.path.basename(blob_path)
local_file_path = os.path.join(tempfile.mkdtemp(), file_name)
download_blob(local_or_gcs_path, local_file_path)
return local_file_path

View File

@ -0,0 +1,107 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import re
import logging
from kfp_component.core import KfpExecutionContext
from ._client import DataflowClient
from .. import common as gcp_common
from ._common_ops import (generate_job_name, get_job_by_name,
wait_and_dump_job, stage_file)
from ._process import Process
def launch_python(python_file_path, project_id, requirements_file_path=None,
location=None, job_name_prefix=None, args=[], wait_interval=30):
"""Launch a self-executing beam python file.
Args:
python_file_path (str): The gcs or local path to the python file to run.
project_id (str): The ID of the parent project.
requirements_file_path (str): Optional, the gcs or local path to the pip
requirements file.
location (str): The regional endpoint to which to direct the
request.
job_name_prefix (str): Optional. The prefix of the genrated job
name. If not provided, the method will generated a random name.
args (list): The list of args to pass to the python file.
wait_interval (int): The wait seconds between polling.
Returns:
The completed job.
"""
df_client = DataflowClient()
job_id = None
def cancel():
if job_id:
df_client.cancel_job(
project_id,
job_id,
location
)
with KfpExecutionContext(on_cancel=cancel) as ctx:
job_name = generate_job_name(
job_name_prefix,
ctx.context_id())
# We will always generate unique name for the job. We expect
# job with same name was created in previous tries from the same
# pipeline run.
job = get_job_by_name(df_client, project_id, job_name,
location)
if job:
return wait_and_dump_job(df_client, project_id, location, job,
wait_interval)
_install_requirements(requirements_file_path)
python_file_path = stage_file(python_file_path)
cmd = _prepare_cmd(project_id, location, job_name, python_file_path,
args)
sub_process = Process(cmd)
for line in sub_process.read_lines():
job_id = _extract_job_id(line)
if job_id:
logging.info('Found job id {}'.format(job_id))
break
sub_process.wait_and_check()
if not job_id:
logging.warning('No dataflow job was found when '
'running the python file.')
return None
job = df_client.get_job(project_id, job_id,
location=location)
return wait_and_dump_job(df_client, project_id, location, job,
wait_interval)
def _prepare_cmd(project_id, location, job_name, python_file_path, args):
dataflow_args = [
'--runner', 'dataflow',
'--project', project_id,
'--job-name', job_name]
if location:
dataflow_args += ['--location', location]
return (['python2', '-u', python_file_path] +
dataflow_args + args)
def _extract_job_id(line):
job_id_pattern = re.compile(
br'.*console.cloud.google.com/dataflow.*/jobs/([a-z|0-9|A-Z|\-|\_]+).*')
matched_job_id = job_id_pattern.search(line or '')
if matched_job_id:
return matched_job_id.group(1).decode()
return None
def _install_requirements(requirements_file_path):
if not requirements_file_path:
return
requirements_file_path = stage_file(requirements_file_path)
subprocess.run(['pip2', 'install', '-r', requirements_file_path])

View File

@ -0,0 +1,74 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import re
import time
from kfp_component.core import KfpExecutionContext
from ._client import DataflowClient
from .. import common as gcp_common
from ._common_ops import (generate_job_name, get_job_by_name,
wait_and_dump_job)
def launch_template(project_id, gcs_path, launch_parameters,
location=None, job_name_prefix=None, validate_only=None,
wait_interval=30):
"""Launchs a dataflow job from template.
Args:
project_id (str): Required. The ID of the Cloud Platform project
that the job belongs to.
gcs_path (str): Required. A Cloud Storage path to the template
from which to create the job. Must be valid Cloud
Storage URL, beginning with 'gs://'.
launch_parameters (dict): Parameters to provide to the template
being launched. Schema defined in
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/LaunchTemplateParameters.
`jobName` will be replaced by generated name.
location (str): The regional endpoint to which to direct the
request.
job_name_prefix (str): Optional. The prefix of the genrated job
name. If not provided, the method will generated a random name.
validate_only (boolean): If true, the request is validated but
not actually executed. Defaults to false.
wait_interval (int): The wait seconds between polling.
Returns:
The completed job.
"""
df_client = DataflowClient()
job_id = None
def cancel():
if job_id:
df_client.cancel_job(
project_id,
job_id,
location
)
with KfpExecutionContext(on_cancel=cancel) as ctx:
job_name = generate_job_name(
job_name_prefix,
ctx.context_id())
print(job_name)
job = get_job_by_name(df_client, project_id, job_name,
location)
if not job:
launch_parameters['jobName'] = job_name
response = df_client.launch_template(project_id, gcs_path,
location, validate_only, launch_parameters)
job = response.get('job')
return wait_and_dump_job(df_client, project_id, location, job,
wait_interval)

View File

@ -0,0 +1,40 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import logging
class Process:
def __init__(self, cmd):
self._cmd = cmd
self.process = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
close_fds=True,
shell=False)
def read_lines(self):
# stdout will end with empty bytes when process exits.
for line in iter(self.process.stdout.readline, b''):
logging.info('subprocess: {}'.format(line))
yield line
def wait_and_check(self):
for _ in self.read_lines():
pass
self.process.stdout.close()
return_code = self.process.wait()
logging.info('Subprocess exit with code {}.'.format(
return_code))
if return_code:
raise subprocess.CalledProcessError(return_code, self._cmd)

View File

@ -0,0 +1,16 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._download_blob import download_blob
from ._common_ops import parse_blob_path, is_gcs_path

View File

@ -0,0 +1,41 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
def is_gcs_path(path):
"""Check if the path is a gcs path"""
return path.startswith('gs://')
def parse_blob_path(path):
"""Parse a gcs path into bucket name and blob name
Args:
path (str): the path to parse.
Returns:
(bucket name in the path, blob name in the path)
Raises:
ValueError if the path is not a valid gcs blob path.
Example:
`bucket_name, blob_name = parse_blob_path('gs://foo/bar')`
`bucket_name` is `foo` and `blob_name` is `bar`
"""
match = re.match('gs://([^/]+)/(.+)$', path)
if match:
return match.group(1), match.group(2)
raise ValueError('Path {} is invalid blob path.'.format(
path))

View File

@ -0,0 +1,42 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from google.cloud import storage
from ._common_ops import parse_blob_path
def download_blob(source_blob_path, destination_file_path):
"""Downloads a blob from the bucket.
Args:
source_blob_path (str): the source blob path to download from.
destination_file_path (str): the local file path to download to.
"""
bucket_name, blob_name = parse_blob_path(source_blob_path)
storage_client = storage.Client()
bucket = storage_client.get_bucket(bucket_name)
blob = bucket.blob(blob_name)
dirname = os.path.dirname(destination_file_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
with open(destination_file_path, 'wb+') as f:
blob.download_to_file(f)
logging.info('Blob {} downloaded to {}.'.format(
source_blob_path,
destination_file_path))

View File

@ -0,0 +1,13 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,82 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mock
import unittest
import os
from kfp_component.google.dataflow import launch_python
MODULE = 'kfp_component.google.dataflow._launch_python'
@mock.patch('kfp_component.google.dataflow._common_ops.display')
@mock.patch(MODULE + '.stage_file')
@mock.patch(MODULE + '.KfpExecutionContext')
@mock.patch(MODULE + '.DataflowClient')
@mock.patch(MODULE + '.Process')
@mock.patch(MODULE + '.subprocess')
class LaunchPythonTest(unittest.TestCase):
def test_launch_python_succeed(self, mock_subprocess, mock_process,
mock_client, mock_context, mock_stage_file, mock_display):
mock_context().__enter__().context_id.return_value = 'ctx-1'
mock_client().list_aggregated_jobs.return_value = {
'jobs': []
}
mock_process().read_lines.return_value = [
b'https://console.cloud.google.com/dataflow/locations/us-central1/jobs/job-1?project=project-1'
]
expected_job = {
'currentState': 'JOB_STATE_DONE'
}
mock_client().get_job.return_value = expected_job
result = launch_python('/tmp/test.py', 'project-1')
self.assertEqual(expected_job, result)
def test_launch_python_retry_succeed(self, mock_subprocess, mock_process,
mock_client, mock_context, mock_stage_file, mock_display):
mock_context().__enter__().context_id.return_value = 'ctx-1'
mock_client().list_aggregated_jobs.return_value = {
'jobs': [{
'id': 'job-1',
'name': 'test_job-ctx-1'
}]
}
expected_job = {
'currentState': 'JOB_STATE_DONE'
}
mock_client().get_job.return_value = expected_job
result = launch_python('/tmp/test.py', 'project-1', job_name_prefix='test-job')
self.assertEqual(expected_job, result)
mock_process.assert_not_called()
def test_launch_python_no_job_created(self, mock_subprocess, mock_process,
mock_client, mock_context, mock_stage_file, mock_display):
mock_context().__enter__().context_id.return_value = 'ctx-1'
mock_client().list_aggregated_jobs.return_value = {
'jobs': []
}
mock_process().read_lines.return_value = [
b'no job id',
b'no job id'
]
result = launch_python('/tmp/test.py', 'project-1')
self.assertEqual(None, result)

View File

@ -0,0 +1,106 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mock
import unittest
import os
from kfp_component.google.dataflow import launch_template
MODULE = 'kfp_component.google.dataflow._launch_template'
@mock.patch('kfp_component.google.dataflow._common_ops.display')
@mock.patch(MODULE + '.KfpExecutionContext')
@mock.patch(MODULE + '.DataflowClient')
class LaunchTemplateTest(unittest.TestCase):
def test_launch_template_succeed(self, mock_client, mock_context, mock_display):
mock_context().__enter__().context_id.return_value = 'context-1'
mock_client().list_aggregated_jobs.return_value = {
'jobs': []
}
mock_client().launch_template.return_value = {
'job': { 'id': 'job-1' }
}
expected_job = {
'currentState': 'JOB_STATE_DONE'
}
mock_client().get_job.return_value = expected_job
result = launch_template('project-1', 'gs://foo/bar', {
"parameters": {
"foo": "bar"
},
"environment": {
"zone": "us-central1"
}
})
self.assertEqual(expected_job, result)
mock_client().launch_template.assert_called_once()
def test_launch_template_retry_succeed(self,
mock_client, mock_context, mock_display):
mock_context().__enter__().context_id.return_value = 'ctx-1'
# The job with same name already exists.
mock_client().list_aggregated_jobs.return_value = {
'jobs': [{
'id': 'job-1',
'name': 'test_job-ctx-1'
}]
}
pending_job = {
'currentState': 'JOB_STATE_PENDING'
}
expected_job = {
'currentState': 'JOB_STATE_DONE'
}
mock_client().get_job.side_effect = [pending_job, expected_job]
result = launch_template('project-1', 'gs://foo/bar', {
"parameters": {
"foo": "bar"
},
"environment": {
"zone": "us-central1"
}
}, job_name_prefix='test-job', wait_interval=0)
self.assertEqual(expected_job, result)
mock_client().launch_template.assert_not_called()
def test_launch_template_fail(self, mock_client, mock_context, mock_display):
mock_context().__enter__().context_id.return_value = 'context-1'
mock_client().list_aggregated_jobs.return_value = {
'jobs': []
}
mock_client().launch_template.return_value = {
'job': { 'id': 'job-1' }
}
failed_job = {
'currentState': 'JOB_STATE_FAILED'
}
mock_client().get_job.return_value = failed_job
self.assertRaises(RuntimeError,
lambda: launch_template('project-1', 'gs://foo/bar', {
"parameters": {
"foo": "bar"
},
"environment": {
"zone": "us-central1"
}
}))

View File

@ -0,0 +1,13 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,40 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from kfp_component.google.storage import is_gcs_path, parse_blob_path
class CommonOpsTest(unittest.TestCase):
def test_is_gcs_path(self):
self.assertTrue(is_gcs_path('gs://foo'))
self.assertTrue(is_gcs_path('gs://foo/bar'))
self.assertFalse(is_gcs_path('gs:/foo/bar'))
self.assertFalse(is_gcs_path('foo/bar'))
def test_parse_blob_path_valid(self):
bucket_name, blob_name = parse_blob_path('gs://foo/bar/baz/')
self.assertEqual('foo', bucket_name)
self.assertEqual('bar/baz/', blob_name)
def test_parse_blob_path_invalid(self):
# No blob name
self.assertRaises(ValueError, lambda: parse_blob_path('gs://foo'))
self.assertRaises(ValueError, lambda: parse_blob_path('gs://foo/'))
# Invalid GCS path
self.assertRaises(ValueError, lambda: parse_blob_path('foo'))
self.assertRaises(ValueError, lambda: parse_blob_path('gs:///foo'))

View File

@ -0,0 +1,38 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mock
import unittest
import os
from kfp_component.google.storage import download_blob
DOWNLOAD_BLOB_MODULE = 'kfp_component.google.storage._download_blob'
@mock.patch(DOWNLOAD_BLOB_MODULE + '.os')
@mock.patch(DOWNLOAD_BLOB_MODULE + '.open')
@mock.patch(DOWNLOAD_BLOB_MODULE + '.storage.Client')
class DownloadBlobTest(unittest.TestCase):
def test_download_blob_succeed(self, mock_storage_client,
mock_open, mock_os):
mock_os.path.dirname.return_value = '/foo'
mock_os.path.exists.return_value = False
download_blob('gs://foo/bar.py',
'/foo/bar.py')
mock_blob = mock_storage_client().get_bucket().blob()
mock_blob.download_to_file.assert_called_once()
mock_os.makedirs.assert_called_once()