feat(components): Support uploading for model versions for ModelUploadOp
PiperOrigin-RevId: 490354372
This commit is contained in:
parent
9970f3d0ab
commit
94bdce8a32
|
|
@ -31,6 +31,11 @@ def _parse_args(args):
|
|||
# executor_input is only needed for components that emit output artifacts.
|
||||
required=True,
|
||||
default=argparse.SUPPRESS)
|
||||
parser.add_argument(
|
||||
'--parent_model_name',
|
||||
dest='parent_model_name',
|
||||
type=str,
|
||||
default=None)
|
||||
parsed_args, _ = parser.parse_known_args(args)
|
||||
return vars(parsed_args)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from google_cloud_pipeline_components.container.v1.gcp_launcher import lro_remote_runner
|
||||
from google_cloud_pipeline_components.container.v1.gcp_launcher.utils import artifact_util
|
||||
|
|
@ -49,6 +50,7 @@ def upload_model(
|
|||
payload,
|
||||
gcp_resources,
|
||||
executor_input,
|
||||
parent_model_name: Optional[str] = None,
|
||||
):
|
||||
"""Upload model and poll the LongRunningOperator till it reaches a final state."""
|
||||
api_endpoint = location + '-aiplatform.googleapis.com'
|
||||
|
|
@ -62,6 +64,8 @@ def upload_model(
|
|||
append_unmanaged_model_artifact_into_payload(
|
||||
executor_input, model_spec))
|
||||
}
|
||||
if parent_model_name:
|
||||
upload_model_request['parent_model'] = parent_model_name.rsplit('@', 1)[0]
|
||||
|
||||
# Add explanation_spec details back into the request if metadata is non-empty, as sklearn/xgboost input features can be empty.
|
||||
if (('explanation_spec' in model_spec) and
|
||||
|
|
@ -76,6 +80,8 @@ def upload_model(
|
|||
upload_model_url, json.dumps(upload_model_request), gcp_resources)
|
||||
upload_model_lro = remote_runner.poll_lro(lro=upload_model_lro)
|
||||
model_resource_name = upload_model_lro['response']['model']
|
||||
if 'model_version_id' in upload_model_lro['response']:
|
||||
model_resource_name += f'@{upload_model_lro["response"]["model_version_id"]}'
|
||||
|
||||
vertex_model = VertexModel('model', vertex_uri_prefix + model_resource_name,
|
||||
model_resource_name)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ description: |
|
|||
characters long and can be consist of any UTF-8 characters.
|
||||
description (Optional[str]):
|
||||
The description of the model.
|
||||
parent_model (Optional[google.VertexModel]):
|
||||
An artifact of a model which to upload a new version to.
|
||||
Only specify this field when uploading a new version.
|
||||
unmanaged_container_model (Optional[google.UnmanagedContainerModel]):
|
||||
Optional. The unmanaged container model to be uploaded.
|
||||
|
||||
|
|
@ -72,6 +75,7 @@ inputs:
|
|||
- {name: location, type: String, default: "us-central1"}
|
||||
- {name: display_name, type: String}
|
||||
- {name: description, type: String, optional: true, default: ''}
|
||||
- {name: parent_model, type: google.VertexModel, optional: true}
|
||||
- {name: unmanaged_container_model, type: google.UnmanagedContainerModel, optional: true}
|
||||
- {name: explanation_metadata, type: JsonObject, optional: true, default: '{}'}
|
||||
- {name: explanation_parameters, type: JsonObject, optional: true, default: '{}'}
|
||||
|
|
@ -103,4 +107,8 @@ implementation:
|
|||
--location, {inputValue: location},
|
||||
--gcp_resources, {outputPath: gcp_resources},
|
||||
--executor_input, "{{$}}",
|
||||
{if: {
|
||||
cond: {isPresent: parent_model},
|
||||
then: [concat: ["--parent_model_name ", "{{$.inputs.artifacts['parent_model'].metadata['resourceName']}}",]]
|
||||
}}
|
||||
]
|
||||
|
|
|
|||
|
|
@ -44,4 +44,19 @@ class LauncherUploadModelUtilsTests(unittest.TestCase):
|
|||
location='us_central1',
|
||||
payload='test_payload',
|
||||
gcp_resources=self._gcp_resources,
|
||||
executor_input='executor_input')
|
||||
executor_input='executor_input',
|
||||
parent_model_name=None)
|
||||
|
||||
@mock.patch.object(
|
||||
remote_runner, 'upload_model', autospec=True)
|
||||
def test_launcher_on_upload_model_parent_model(self, mock_upload_model):
|
||||
self._input_args.extend(('--parent_model_name', 'test_parent_model_name'))
|
||||
launcher.main(self._input_args)
|
||||
mock_upload_model.assert_called_once_with(
|
||||
type='UploadModel',
|
||||
project='test_project',
|
||||
location='us_central1',
|
||||
payload='test_payload',
|
||||
gcp_resources=self._gcp_resources,
|
||||
executor_input='executor_input',
|
||||
parent_model_name='test_parent_model_name')
|
||||
|
|
|
|||
|
|
@ -274,3 +274,55 @@ class ModelUploadRemoteRunnerUtilsTests(unittest.TestCase):
|
|||
'Content-type': 'application/json',
|
||||
'Authorization': 'Bearer fake_token',
|
||||
})
|
||||
|
||||
@mock.patch.object(google.auth, 'default', autospec=True)
|
||||
@mock.patch.object(google.auth.transport.requests, 'Request', autospec=True)
|
||||
@mock.patch.object(requests, 'post', autospec=True)
|
||||
def test_model_upload_with_parent_model_remote_runner_succeeded(self, mock_post_requests, _,
|
||||
mock_auth):
|
||||
creds = mock.Mock()
|
||||
creds.token = 'fake_token'
|
||||
mock_auth.return_value = [creds, 'project']
|
||||
upload_model_lro = mock.Mock()
|
||||
upload_model_lro.json.return_value = {
|
||||
'name': self._lro_name,
|
||||
'done': True,
|
||||
'response': {
|
||||
'model': self._model_name,
|
||||
'model_version_id': '2'
|
||||
}
|
||||
}
|
||||
mock_post_requests.return_value = upload_model_lro
|
||||
|
||||
upload_model_remote_runner.upload_model(self._type, self._project,
|
||||
self._location, self._payload,
|
||||
self._gcp_resources_path,
|
||||
self._executor_input,
|
||||
self._model_name)
|
||||
mock_post_requests.assert_called_once_with(
|
||||
url=f'{self._uri_prefix}projects/test_project/locations/test_region/models:upload',
|
||||
data='{"model": {"display_name": "model1"}, "parent_model": "%s"}' %
|
||||
(self._model_name),
|
||||
headers={
|
||||
'Content-type': 'application/json',
|
||||
'Authorization': 'Bearer fake_token',
|
||||
'User-Agent': 'google-cloud-pipeline-components'
|
||||
})
|
||||
|
||||
with open(self._output_file_path) as f:
|
||||
executor_output = json.load(f, strict=False)
|
||||
self.assertEqual(
|
||||
executor_output,
|
||||
json.loads(
|
||||
'{"artifacts": {"model": {"artifacts": [{"metadata": {"resourceName": "projects/test_project/locations/test_region/models/123@2"}, "name": "foobar", "type": {"schemaTitle": "google.VertexModel"}, "uri": "https://test_region-aiplatform.googleapis.com/v1/projects/test_project/locations/test_region/models/123@2"}]}}}'
|
||||
))
|
||||
|
||||
with open(self._gcp_resources_path) as f:
|
||||
serialized_gcp_resources = f.read()
|
||||
# Instantiate GCPResources Proto
|
||||
lro_resources = json_format.Parse(serialized_gcp_resources,
|
||||
GcpResources())
|
||||
|
||||
self.assertEqual(len(lro_resources.resources), 1)
|
||||
self.assertEqual(lro_resources.resources[0].resource_uri,
|
||||
self._uri_prefix + self._lro_name)
|
||||
|
|
|
|||
Loading…
Reference in New Issue