feat(sdk): Implement Registry client (#7597)
* Implement registry client * Update registry client code * Add test skeleton * Add some tests * Update code * add tests * update tests * update tests * Rename Client -> RegistryClient * Update wrt comments * add type annotations * fix renaming in __init__.py * remove unused imports * extract host variable in test * format using yapf * remove locals and use arg keywords * remove json conversion * fix header * write bytes when downloading file * fix create_tag; fix tests * fix request_body for update_tag and create_tag using json.dumps * simply return json for delete_tag * rename files * format files * update return types and format double quotes * add comments and format files * add todos * update credentials and change open to use context * format using yapf * move request into context * Update comments * Update release notes * Update release notes
This commit is contained in:
parent
6e905c459e
commit
25e4c58820
|
@ -1,6 +1,7 @@
|
|||
# Current Version (Still in Development)
|
||||
|
||||
## Major Features and Improvements
|
||||
* feat(sdk): Implement Registry Client [\#7597](https://github.com/kubeflow/pipelines/pull/7597)
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright 2022 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from kfp.registry.registry_client import ApiAuth
|
||||
from kfp.registry.registry_client import RegistryClient
|
|
@ -0,0 +1,512 @@
|
|||
# Copyright 2022 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Class for KFP Registry Client."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import google.auth
|
||||
import requests
|
||||
from google.auth import credentials
|
||||
|
||||
_KNOWN_HOSTS_REGEX = {
|
||||
'kfp_pkg_dev':
|
||||
r'(^https\:\/\/(?P<location>[\w\-]+)\-kfp\.pkg\.dev\/(?P<project_id>.*)\/(?P<repo_id>.*))',
|
||||
}
|
||||
|
||||
_DEFAULT_JSON_HEADER = {
|
||||
'Content-type': 'application/json',
|
||||
}
|
||||
|
||||
|
||||
class _SafeDict(dict):
|
||||
"""Class for safely handling missing keys in .format_map."""
|
||||
|
||||
def __missing__(self, key: str) -> str:
|
||||
"""Handle missing keys by adding them back.
|
||||
|
||||
Args:
|
||||
key: The key itself.
|
||||
|
||||
Returns:
|
||||
The key with curly braces.
|
||||
"""
|
||||
return '{' + key + '}'
|
||||
|
||||
|
||||
class ApiAuth(requests.auth.AuthBase):
|
||||
"""Class for authentication using API token."""
|
||||
|
||||
def __init__(self, token: str) -> None:
|
||||
self._token = token
|
||||
|
||||
def __call__(self,
|
||||
request: requests.PreparedRequest) -> requests.PreparedRequest:
|
||||
request.headers['authorization'] = 'Bearer ' + self._token
|
||||
return request
|
||||
|
||||
|
||||
class RegistryClient:
|
||||
"""Registry Client class for communicating with registry hosts."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
auth: Optional[Union[requests.auth.AuthBase,
|
||||
credentials.Credentials]] = None
|
||||
) -> None:
|
||||
"""Initializes the RegistryClient.
|
||||
|
||||
Args:
|
||||
host: The address of the registry host.
|
||||
auth: Authentication using python requests or google.auth credentials.
|
||||
"""
|
||||
self._host = host.rstrip('/')
|
||||
self._known_host_key = ''
|
||||
for key in _KNOWN_HOSTS_REGEX.keys():
|
||||
if re.match(_KNOWN_HOSTS_REGEX[key], self._host):
|
||||
self._known_host_key = key
|
||||
break
|
||||
self._config = self.load_config()
|
||||
if auth:
|
||||
self._auth = auth
|
||||
elif self._is_ar_host():
|
||||
self._auth, _ = google.auth.default()
|
||||
|
||||
def _request(self,
|
||||
request_url: str,
|
||||
request_body: str = '',
|
||||
http_request: str = 'get',
|
||||
extra_headers: dict = None) -> requests.Response:
|
||||
"""Calls the HTTP request.
|
||||
|
||||
Args:
|
||||
request_url: The address of the API endpoint to send the request to.
|
||||
request_body: Body of the request.
|
||||
http_request: Type of HTTP request (post, get, delete etc, defaults to get).
|
||||
extra_headers: Any extra headers required.
|
||||
|
||||
Returns:
|
||||
Response from the request.
|
||||
"""
|
||||
self._refresh_creds()
|
||||
auth = self._get_auth()
|
||||
http_request_fn = getattr(requests, http_request)
|
||||
|
||||
response = http_request_fn(
|
||||
url=request_url,
|
||||
data=request_body,
|
||||
headers=extra_headers,
|
||||
auth=auth)
|
||||
response.raise_for_status()
|
||||
|
||||
return response
|
||||
|
||||
def _is_ar_host(self) -> bool:
|
||||
"""Checks if the host is on Artifact Registry.
|
||||
|
||||
Returns:
|
||||
Whether the host is on Artifact Registry.
|
||||
"""
|
||||
# TODO: Handle multiple known hosts.
|
||||
return self._known_host_key == 'kfp_pkg_dev'
|
||||
|
||||
def _is_known_host(self) -> bool:
|
||||
"""Checks if the host is a known host.
|
||||
|
||||
Returns:
|
||||
Whether the host is a known host.
|
||||
"""
|
||||
return bool(self._known_host_key)
|
||||
|
||||
def load_config(self) -> dict:
|
||||
"""Loads the config.
|
||||
|
||||
Returns:
|
||||
The loaded config if it's a known host, otherwise None.
|
||||
"""
|
||||
# TODO: Move config file code to its own file.
|
||||
config = {}
|
||||
if self._is_ar_host():
|
||||
repo_resource_format = ''
|
||||
matched = re.match(_KNOWN_HOSTS_REGEX[self._known_host_key],
|
||||
self._host)
|
||||
if matched:
|
||||
repo_resource_format = ('projects/'
|
||||
'{project_id}/locations/{location}/'
|
||||
'repositories/{repo_id}'.format_map(
|
||||
_SafeDict(matched.groupdict())))
|
||||
else:
|
||||
raise ValueError(f'Invalid host URL: {self._host}.')
|
||||
registry_endpoint = 'https://artifactregistry.googleapis.com/v1'
|
||||
api_endpoint = f'{registry_endpoint}/{repo_resource_format}'
|
||||
package_endpoint = f'{api_endpoint}/packages'
|
||||
package_name_endpoint = f'{package_endpoint}/{{package_name}}'
|
||||
tags_endpoint = f'{package_name_endpoint}/tags'
|
||||
versions_endpoint = f'{package_name_endpoint}/versions'
|
||||
config = {
|
||||
'host':
|
||||
self._host,
|
||||
'upload_url':
|
||||
self._host,
|
||||
'download_version_url':
|
||||
f'{self._host}/{{package_name}}/{{version}}',
|
||||
'download_tag_url':
|
||||
f'{self._host}/{{package_name}}/{{tag}}',
|
||||
'get_package_url':
|
||||
f'{package_name_endpoint}',
|
||||
'list_packages_url':
|
||||
package_endpoint,
|
||||
'delete_package_url':
|
||||
f'{package_name_endpoint}',
|
||||
'get_tag_url':
|
||||
f'{tags_endpoint}/{{tag}}',
|
||||
'list_tags_url':
|
||||
f'{tags_endpoint}',
|
||||
'delete_tag_url':
|
||||
f'{tags_endpoint}/{{tag}}',
|
||||
'create_tag_url':
|
||||
f'{tags_endpoint}?tagId={{tag}}',
|
||||
'update_tag_url':
|
||||
f'{tags_endpoint}/{{tag}}?updateMask=version',
|
||||
'get_version_url':
|
||||
f'{versions_endpoint}/{{version}}',
|
||||
'list_versions_url':
|
||||
f'{versions_endpoint}',
|
||||
'delete_version_url':
|
||||
f'{versions_endpoint}/{{version}}',
|
||||
'package_format':
|
||||
f'{repo_resource_format}/packages/{{package_name}}',
|
||||
'tag_format':
|
||||
f'{repo_resource_format}/packages/{{package_name}}/tags/{{tag}}',
|
||||
'version_format':
|
||||
f'{repo_resource_format}/packages/{{package_name}}/versions/{{version}}',
|
||||
}
|
||||
else:
|
||||
logging.info(f'load_config not implemented for host: {self._host}')
|
||||
return config
|
||||
|
||||
def _get_auth(self) -> requests.auth.AuthBase:
|
||||
"""Helper function to convert google credentials to AuthBase class if
|
||||
needed.
|
||||
|
||||
Returns:
|
||||
An instance of the AuthBase class
|
||||
"""
|
||||
auth = self._auth
|
||||
if isinstance(auth, credentials.Credentials):
|
||||
auth = ApiAuth(auth.token)
|
||||
return auth
|
||||
|
||||
def _refresh_creds(self) -> None:
|
||||
"""Helper function to refresh google credentials if needed."""
|
||||
if self._is_ar_host() and isinstance(
|
||||
self._auth, credentials.Credentials) and not self._auth.valid:
|
||||
self._auth.refresh(google.auth.transport.requests.Request())
|
||||
|
||||
def upload_pipeline(self, file_name: str, tags: Optional[Union[str,
|
||||
List[str]]],
|
||||
extra_headers: Optional[dict]) -> Tuple[str, str]:
|
||||
"""Uploads the pipeline.
|
||||
|
||||
Args:
|
||||
file_name: The name of the file to be uploaded.
|
||||
tags: Tags to be attached to the uploaded pipeline.
|
||||
extra_headers: Any extra headers required.
|
||||
|
||||
Returns:
|
||||
A tuple representing the package name and the version
|
||||
"""
|
||||
url = self._config['upload_url']
|
||||
self._refresh_creds()
|
||||
auth = self._get_auth()
|
||||
request_body = {}
|
||||
if tags:
|
||||
if isinstance(tags, str):
|
||||
request_body = {'tags': tags}
|
||||
elif isinstance(tags, List):
|
||||
request_body = {'tags': ','.join(tags)}
|
||||
|
||||
with open(file_name, 'rb') as f:
|
||||
files = {'content': f}
|
||||
response = requests.post(
|
||||
url=url,
|
||||
data=request_body,
|
||||
headers=extra_headers,
|
||||
files=files,
|
||||
auth=auth)
|
||||
response.raise_for_status()
|
||||
[package_name, version] = response.text.split('/')
|
||||
|
||||
return (package_name, version)
|
||||
|
||||
def _get_download_url(self,
|
||||
package_name: str,
|
||||
version: Optional[str] = None,
|
||||
tag: Optional[str] = None) -> str:
|
||||
"""Gets the download url based on version or tag (either one must be
|
||||
specified).
|
||||
|
||||
Args:
|
||||
package_name: Name of the package.
|
||||
version: Version of the package.
|
||||
tag: Tag attached to the package.
|
||||
|
||||
Returns:
|
||||
A url for downloading the package.
|
||||
"""
|
||||
if (not version) and (not tag):
|
||||
raise ValueError('Either version or tag must be specified.')
|
||||
if version:
|
||||
url = self._config['download_version_url'].format(
|
||||
package_name=package_name, version=version)
|
||||
if tag:
|
||||
if version:
|
||||
logging.info(
|
||||
'Both version and tag are specified, using version only.')
|
||||
else:
|
||||
url = self._config['download_tag_url'].format(
|
||||
package_name=package_name, tag=tag)
|
||||
return url
|
||||
|
||||
def download_pipeline(self,
|
||||
package_name: str,
|
||||
version: Optional[str] = None,
|
||||
tag: Optional[str] = None,
|
||||
file_name: str = None) -> str:
|
||||
"""Downloads a pipeline - either version or tag must be specified.
|
||||
|
||||
Args:
|
||||
package_name: Name of the package.
|
||||
version: Version of the package.
|
||||
tag: Tag attached to the package.
|
||||
file_name: File name to be saved as. If not specified, the
|
||||
file name will be based on the package name and version/tag.
|
||||
|
||||
Returns:
|
||||
The file name of the downloaded pipeline.
|
||||
"""
|
||||
url = self._get_download_url(package_name, version, tag)
|
||||
response = self._request(request_url=url)
|
||||
|
||||
if not file_name:
|
||||
file_name = package_name + '_'
|
||||
if version:
|
||||
file_name += version[len('sha256:'):]
|
||||
elif tag:
|
||||
file_name += tag
|
||||
file_name += '.yaml'
|
||||
|
||||
with open(file_name, 'wb') as f:
|
||||
f.write(response.content)
|
||||
|
||||
return file_name
|
||||
|
||||
def get_package(self, package_name: str) -> Dict[str, Any]:
|
||||
"""Gets package metadata.
|
||||
|
||||
Args:
|
||||
package_name: Name of the package.
|
||||
|
||||
Returns:
|
||||
The package metadata.
|
||||
"""
|
||||
url = self._config['get_package_url'].format(package_name=package_name)
|
||||
response = self._request(request_url=url)
|
||||
|
||||
return response.json()
|
||||
|
||||
def list_packages(self) -> List[dict]:
|
||||
"""Lists packages.
|
||||
|
||||
Returns:
|
||||
List of packages in the repository.
|
||||
"""
|
||||
url = self._config['list_packages_url']
|
||||
response = self._request(request_url=url)
|
||||
response_json = response.json()
|
||||
|
||||
return response_json['packages']
|
||||
|
||||
def delete_package(self, package_name: str) -> bool:
|
||||
"""Deletes a package.
|
||||
|
||||
Args:
|
||||
package_name: Name of the package.
|
||||
|
||||
Returns:
|
||||
Whether the package was deleted successfully.
|
||||
"""
|
||||
url = self._config['delete_package_url'].format(
|
||||
package_name=package_name)
|
||||
response = self._request(request_url=url, http_request='delete')
|
||||
response_json = response.json()
|
||||
|
||||
return response_json['done']
|
||||
|
||||
def get_version(self, package_name: str, version: str) -> Dict[str, Any]:
|
||||
"""Gets package version metadata.
|
||||
|
||||
Args:
|
||||
package_name: Name of the package.
|
||||
version: Version of the package.
|
||||
|
||||
Returns:
|
||||
The version metadata.
|
||||
"""
|
||||
url = self._config['get_version_url'].format(
|
||||
package_name=package_name, version=version)
|
||||
response = self._request(request_url=url)
|
||||
|
||||
return response.json()
|
||||
|
||||
def list_versions(self, package_name: str) -> List[dict]:
|
||||
"""Lists package versions.
|
||||
|
||||
Args:
|
||||
package_name: Name of the package.
|
||||
|
||||
Returns:
|
||||
List of package versions.
|
||||
"""
|
||||
url = self._config['list_versions_url'].format(
|
||||
package_name=package_name)
|
||||
response = self._request(request_url=url)
|
||||
response_json = response.json()
|
||||
|
||||
return response_json['versions']
|
||||
|
||||
def delete_version(self, package_name: str, version: str) -> bool:
|
||||
"""Deletes package version.
|
||||
|
||||
Args:
|
||||
package_name: Name of the package.
|
||||
version: Version of the package.
|
||||
|
||||
Returns:
|
||||
Whether the version was deleted successfully.
|
||||
"""
|
||||
url = self._config['delete_version_url'].format(
|
||||
package_name=package_name, version=version)
|
||||
response = self._request(request_url=url, http_request='delete')
|
||||
response_json = response.json()
|
||||
|
||||
return response_json['done']
|
||||
|
||||
def create_tag(self, package_name: str, version: str,
|
||||
tag: str) -> Dict[str, Any]:
|
||||
"""Creates a tag on a package version.
|
||||
|
||||
Args:
|
||||
package_name: Name of the package.
|
||||
version: Version of the package.
|
||||
tag: Tag to be attached to the package version.
|
||||
|
||||
Returns:
|
||||
The metadata for the created tag.
|
||||
"""
|
||||
url = self._config['create_tag_url'].format(
|
||||
package_name=package_name, tag=tag)
|
||||
new_tag = {
|
||||
'name':
|
||||
'',
|
||||
'version':
|
||||
self._config['version_format'].format(
|
||||
package_name=package_name, version=version)
|
||||
}
|
||||
response = self._request(
|
||||
request_url=url,
|
||||
request_body=json.dumps(new_tag),
|
||||
http_request='post',
|
||||
extra_headers=_DEFAULT_JSON_HEADER)
|
||||
|
||||
return response.json()
|
||||
|
||||
def get_tag(self, package_name: str, tag: str) -> Dict[str, Any]:
|
||||
"""Gets tag metadata.
|
||||
|
||||
Args:
|
||||
package_name: Name of the package.
|
||||
tag: Tag attached to the package version.
|
||||
|
||||
Returns:
|
||||
The metadata for the tag.
|
||||
"""
|
||||
url = self._config['get_tag_url'].format(
|
||||
package_name=package_name, tag=tag)
|
||||
response = self._request(request_url=url)
|
||||
|
||||
return response.json()
|
||||
|
||||
def update_tag(self, package_name: str, version: str,
|
||||
tag: str) -> Dict[str, Any]:
|
||||
"""Updates a tag to another package version.
|
||||
|
||||
Args:
|
||||
package_name: Name of the package.
|
||||
version: Version of the package.
|
||||
tag: Tag to be attached to the new package version.
|
||||
|
||||
Returns:
|
||||
The metadata for the updated tag.
|
||||
"""
|
||||
url = self._config['update_tag_url'].format(
|
||||
package_name=package_name, tag=tag)
|
||||
new_tag = {
|
||||
'name':
|
||||
'',
|
||||
'version':
|
||||
self._config['version_format'].format(
|
||||
package_name=package_name, version=version)
|
||||
}
|
||||
response = self._request(
|
||||
request_url=url,
|
||||
request_body=json.dumps(new_tag),
|
||||
http_request='patch',
|
||||
extra_headers=_DEFAULT_JSON_HEADER)
|
||||
|
||||
return response.json()
|
||||
|
||||
def list_tags(self, package_name: str) -> List[dict]:
|
||||
"""Lists package tags.
|
||||
|
||||
Args:
|
||||
package_name: Name of the package.
|
||||
|
||||
Returns:
|
||||
List of tags.
|
||||
"""
|
||||
url = self._config['list_tags_url'].format(package_name=package_name)
|
||||
response = self._request(request_url=url)
|
||||
response_json = response.json()
|
||||
|
||||
return response_json['tags']
|
||||
|
||||
def delete_tag(self, package_name: str, tag: str) -> Dict[str, Any]:
|
||||
"""Deletes package tag.
|
||||
|
||||
Args:
|
||||
package_name: Name of the package.
|
||||
tag: Tag to be deleted.
|
||||
|
||||
Returns:
|
||||
Response from the delete request.
|
||||
"""
|
||||
url = self._config['delete_tag_url'].format(
|
||||
package_name=package_name, tag=tag)
|
||||
response = self._request(request_url=url, http_request='delete')
|
||||
|
||||
return response.json()
|
|
@ -0,0 +1,351 @@
|
|||
# Copyright 2022 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for KFP Registry RegistryClient."""
|
||||
|
||||
import builtins
|
||||
import json
|
||||
|
||||
import mock
|
||||
from absl.testing import parameterized
|
||||
from kfp.registry import ApiAuth
|
||||
from kfp.registry import RegistryClient
|
||||
|
||||
_DEFAULT_HOST = 'https://us-central1-kfp.pkg.dev/proj/repo'
|
||||
|
||||
|
||||
class RegistryClientTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
'host': 'https://us-central1-kfp.pkg.dev/proj/repo',
|
||||
'expected': True,
|
||||
},
|
||||
{
|
||||
'host': 'https://hub.docker.com/r/google/cloud-sdk',
|
||||
'expected': False,
|
||||
},
|
||||
)
|
||||
def test_is_ar_host(self, host, expected):
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
self.assertEqual(client._is_ar_host(), expected)
|
||||
|
||||
def test_load_config(self):
|
||||
host = _DEFAULT_HOST
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
expected_config = {
|
||||
'host':
|
||||
host,
|
||||
'upload_url':
|
||||
host,
|
||||
'download_version_url':
|
||||
f'{host}/{{package_name}}/{{version}}',
|
||||
'download_tag_url':
|
||||
f'{host}/{{package_name}}/{{tag}}',
|
||||
'get_package_url':
|
||||
('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/{package_name}'),
|
||||
'list_packages_url':
|
||||
('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages'),
|
||||
'delete_package_url':
|
||||
('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/{package_name}'),
|
||||
'get_tag_url':
|
||||
('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/{package_name}/tags/{tag}'),
|
||||
'list_tags_url':
|
||||
('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/{package_name}/tags'),
|
||||
'delete_tag_url':
|
||||
('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/{package_name}/tags/{tag}'),
|
||||
'create_tag_url':
|
||||
('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/{package_name}/tags?tagId={tag}'),
|
||||
'update_tag_url':
|
||||
('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/{package_name}/tags/{tag}?updateMask=version'),
|
||||
'get_version_url':
|
||||
('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/{package_name}/versions/{version}'),
|
||||
'list_versions_url':
|
||||
('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/{package_name}/versions'),
|
||||
'delete_version_url':
|
||||
('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/{package_name}/versions/{version}'),
|
||||
'package_format':
|
||||
('projects/proj/locations/us-central1/repositories'
|
||||
'/repo/packages/{package_name}'),
|
||||
'tag_format': ('projects/proj/locations/us-central1/repositories'
|
||||
'/repo/packages/{package_name}/tags/{tag}'),
|
||||
'version_format':
|
||||
('projects/proj/locations/us-central1/repositories'
|
||||
'/repo/packages/{package_name}/versions/{version}')
|
||||
}
|
||||
self.assertEqual(expected_config, client._config)
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
'version':
|
||||
'sha256:abcde12345',
|
||||
'tag':
|
||||
None,
|
||||
'file_name':
|
||||
None,
|
||||
'expected_url':
|
||||
'https://us-central1-kfp.pkg.dev/proj/repo/pack/sha256:abcde12345',
|
||||
'expected_file_name':
|
||||
'pack_abcde12345.yaml'
|
||||
},
|
||||
{
|
||||
'version':
|
||||
None,
|
||||
'tag':
|
||||
'tag1',
|
||||
'file_name':
|
||||
None,
|
||||
'expected_url':
|
||||
'https://us-central1-kfp.pkg.dev/proj/repo/pack/tag1',
|
||||
'expected_file_name':
|
||||
'pack_tag1.yaml'
|
||||
},
|
||||
{
|
||||
'version':
|
||||
None,
|
||||
'tag':
|
||||
'tag1',
|
||||
'file_name':
|
||||
'pipeline.yaml',
|
||||
'expected_url':
|
||||
'https://us-central1-kfp.pkg.dev/proj/repo/pack/tag1',
|
||||
'expected_file_name':
|
||||
'pipeline.yaml'
|
||||
},
|
||||
)
|
||||
@mock.patch('requests.get', autospec=True)
|
||||
@mock.patch('builtins.open', new_callable=mock.mock_open())
|
||||
def test_download_pipeline(self, mock_open, mock_get, version, tag,
|
||||
file_name, expected_url, expected_file_name):
|
||||
mock_open.reset_mock()
|
||||
host = _DEFAULT_HOST
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
client.download_pipeline(
|
||||
package_name='pack', version=version, tag=tag, file_name=file_name)
|
||||
mock_get.assert_called_once_with(
|
||||
url=expected_url, data='', headers=None, auth=mock.ANY)
|
||||
mock_open.assert_called_once_with(expected_file_name, 'wb')
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
'tags': 'tag1',
|
||||
'expected_tags': 'tag1'
|
||||
},
|
||||
{
|
||||
'tags': ['tag1', 'tag2'],
|
||||
'expected_tags': 'tag1,tag2'
|
||||
},
|
||||
)
|
||||
@mock.patch('requests.post', autospec=True)
|
||||
@mock.patch('builtins.open', new_callable=mock.mock_open())
|
||||
def test_upload_pipeline(self, mock_open, mock_post, tags, expected_tags):
|
||||
mock_open.reset_mock()
|
||||
mock_post.return_value.text = 'package_name/sha256:abcde12345'
|
||||
mock_open.return_value.__enter__.return_value = 'file_content'
|
||||
mock_open.return_value.__exit__.return_value = False
|
||||
host = _DEFAULT_HOST
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
package_name, version = client.upload_pipeline(
|
||||
file_name='pipeline.yaml',
|
||||
tags=tags,
|
||||
extra_headers={'description': 'nothing'})
|
||||
mock_post.assert_called_once_with(
|
||||
url=host,
|
||||
data={'tags': expected_tags},
|
||||
headers={'description': 'nothing'},
|
||||
files={'content': 'file_content'},
|
||||
auth=mock.ANY)
|
||||
mock_open.assert_called_once_with('pipeline.yaml', 'rb')
|
||||
self.assertEqual(package_name, 'package_name')
|
||||
self.assertEqual(version, 'sha256:abcde12345')
|
||||
|
||||
@mock.patch('requests.get', autospec=True)
|
||||
def test_get_package(self, mock_get):
|
||||
host = _DEFAULT_HOST
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
client.get_package('pack')
|
||||
mock_get.assert_called_once_with(
|
||||
url=('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/pack'),
|
||||
data='',
|
||||
headers=None,
|
||||
auth=mock.ANY)
|
||||
|
||||
@mock.patch('requests.get', autospec=True)
|
||||
def test_list_packages(self, mock_get):
|
||||
host = _DEFAULT_HOST
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
client.list_packages()
|
||||
mock_get.assert_called_once_with(
|
||||
url=('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages'),
|
||||
data='',
|
||||
headers=None,
|
||||
auth=mock.ANY)
|
||||
|
||||
@mock.patch('requests.delete', autospec=True)
|
||||
def test_delete_package(self, mock_delete):
|
||||
host = _DEFAULT_HOST
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
client.delete_package('pack')
|
||||
mock_delete.assert_called_once_with(
|
||||
url=('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/pack'),
|
||||
data='',
|
||||
headers=None,
|
||||
auth=mock.ANY)
|
||||
|
||||
@mock.patch('requests.get', autospec=True)
|
||||
def test_get_version(self, mock_get):
|
||||
host = _DEFAULT_HOST
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
client.get_version('pack', 'v1')
|
||||
mock_get.assert_called_once_with(
|
||||
url=('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/pack/versions/v1'),
|
||||
data='',
|
||||
headers=None,
|
||||
auth=mock.ANY)
|
||||
|
||||
@mock.patch('requests.get', autospec=True)
|
||||
def test_list_versions(self, mock_get):
|
||||
host = _DEFAULT_HOST
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
client.list_versions('pack')
|
||||
mock_get.assert_called_once_with(
|
||||
url=('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/pack/versions'),
|
||||
data='',
|
||||
headers=None,
|
||||
auth=mock.ANY)
|
||||
|
||||
@mock.patch('requests.delete', autospec=True)
|
||||
def test_delete_version(self, mock_delete):
|
||||
host = _DEFAULT_HOST
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
client.delete_version('pack', 'v1')
|
||||
mock_delete.assert_called_once_with(
|
||||
url=('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/pack/versions/v1'),
|
||||
data='',
|
||||
headers=None,
|
||||
auth=mock.ANY)
|
||||
|
||||
@mock.patch('requests.get', autospec=True)
|
||||
def test_get_tag(self, mock_get):
|
||||
host = _DEFAULT_HOST
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
client.get_tag('pack', 'tag1')
|
||||
mock_get.assert_called_once_with(
|
||||
url=('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/pack/tags/tag1'),
|
||||
data='',
|
||||
headers=None,
|
||||
auth=mock.ANY)
|
||||
|
||||
@mock.patch('requests.get', autospec=True)
|
||||
def test_list_tags(self, mock_get):
|
||||
host = _DEFAULT_HOST
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
client.list_tags('pack')
|
||||
mock_get.assert_called_once_with(
|
||||
url=('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/pack/tags'),
|
||||
data='',
|
||||
headers=None,
|
||||
auth=mock.ANY)
|
||||
|
||||
@mock.patch('requests.delete', autospec=True)
|
||||
def test_delete_tag(self, mock_delete):
|
||||
host = _DEFAULT_HOST
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
client.delete_tag('pack', 'tag1')
|
||||
mock_delete.assert_called_once_with(
|
||||
url=('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/pack/tags/tag1'),
|
||||
data='',
|
||||
headers=None,
|
||||
auth=mock.ANY)
|
||||
|
||||
@mock.patch('requests.post', autospec=True)
|
||||
def test_create_tag(self, mock_post):
|
||||
host = _DEFAULT_HOST
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
client.create_tag('pack', 'sha256:abcde12345', 'tag1')
|
||||
expected_data = json.dumps({
|
||||
'name':
|
||||
'',
|
||||
'version': ('projects/proj/locations/us-central1/repositories'
|
||||
'/repo/packages/pack/versions/sha256:abcde12345')
|
||||
})
|
||||
mock_post.assert_called_once_with(
|
||||
url=('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/pack/tags?tagId=tag1'),
|
||||
data=expected_data,
|
||||
headers={
|
||||
'Content-type': 'application/json',
|
||||
},
|
||||
auth=mock.ANY)
|
||||
|
||||
@mock.patch('requests.patch', autospec=True)
|
||||
def test_update_tag(self, mock_patch):
|
||||
host = _DEFAULT_HOST
|
||||
client = RegistryClient(host=host, auth=ApiAuth(''))
|
||||
client.update_tag('pack', 'sha256:abcde12345', 'tag1')
|
||||
expected_data = json.dumps({
|
||||
'name':
|
||||
'',
|
||||
'version': ('projects/proj/locations/us-central1/repositories'
|
||||
'/repo/packages/pack/versions/sha256:abcde12345')
|
||||
})
|
||||
mock_patch.assert_called_once_with(
|
||||
url=('https://artifactregistry.googleapis.com/v1/projects/'
|
||||
'proj/locations/us-central1/repositories'
|
||||
'/repo/packages/pack/tags/tag1?updateMask=version'),
|
||||
data=expected_data,
|
||||
headers={
|
||||
'Content-type': 'application/json',
|
||||
},
|
||||
auth=mock.ANY)
|
Loading…
Reference in New Issue