254 lines
9.2 KiB
Python
254 lines
9.2 KiB
Python
from dataclasses import asdict
|
|
import pytest
|
|
from unittest.mock import Mock, patch
|
|
from model_registry.utils import S3Params, OCIParams
|
|
from job.upload import _get_upload_params, perform_upload
|
|
from job.models import (
|
|
AsyncUploadConfig,
|
|
S3StorageConfig,
|
|
OCIStorageConfig,
|
|
ModelConfig,
|
|
StorageConfig,
|
|
RegistryConfig,
|
|
UpdateArtifactIntent
|
|
)
|
|
|
|
class TestGetUploadParams:
|
|
"""Test cases for _get_upload_params function"""
|
|
|
|
def test_get_upload_params_s3_config(self):
|
|
"""Test _get_upload_params with S3 configuration returns S3Params"""
|
|
config = AsyncUploadConfig(
|
|
source=S3StorageConfig(
|
|
bucket="source-bucket",
|
|
key="source-key",
|
|
access_key_id="source-key-id",
|
|
secret_access_key="source-secret",
|
|
region="us-west-1"
|
|
),
|
|
destination=S3StorageConfig(
|
|
bucket="test-bucket",
|
|
key="test-key",
|
|
endpoint="https://s3.amazonaws.com",
|
|
access_key_id="test-access-key",
|
|
secret_access_key="test-secret-key",
|
|
region="us-east-1"
|
|
),
|
|
model=ModelConfig(
|
|
intent=UpdateArtifactIntent(artifact_id="test-artifact")
|
|
),
|
|
storage=StorageConfig(path="/tmp/test"),
|
|
registry=RegistryConfig(server_address="test-server")
|
|
)
|
|
|
|
result = _get_upload_params(config)
|
|
|
|
assert isinstance(result, S3Params)
|
|
assert result.bucket_name == "test-bucket"
|
|
assert result.s3_prefix == "test-key"
|
|
assert result.endpoint_url == "https://s3.amazonaws.com"
|
|
assert result.access_key_id == "test-access-key"
|
|
assert result.secret_access_key == "test-secret-key"
|
|
assert result.region == "us-east-1"
|
|
|
|
def test_get_upload_params_oci_config(self):
|
|
"""Test _get_upload_params with OCI configuration returns OCIParams"""
|
|
config = AsyncUploadConfig(
|
|
source=S3StorageConfig(
|
|
bucket="source-bucket",
|
|
key="source-key",
|
|
access_key_id="source-key-id",
|
|
secret_access_key="source-secret",
|
|
region="us-west-1"
|
|
),
|
|
destination=OCIStorageConfig(
|
|
uri="quay.io/example/test:latest",
|
|
registry="quay.io",
|
|
username="test-user",
|
|
password="test-password",
|
|
base_image="foo-bar:latest",
|
|
enable_tls_verify=False,
|
|
credentials_path="/tmp/test-creds"
|
|
),
|
|
model=ModelConfig(
|
|
intent=UpdateArtifactIntent(artifact_id="123")
|
|
),
|
|
storage=StorageConfig(path="/tmp/test-model"),
|
|
registry=RegistryConfig(server_address="test-server")
|
|
)
|
|
|
|
result = _get_upload_params(config)
|
|
|
|
assert isinstance(result, OCIParams)
|
|
assert result.base_image == "foo-bar:latest"
|
|
assert result.oci_ref == "quay.io/example/test:latest"
|
|
assert result.dest_dir == "/tmp/test-model"
|
|
assert result.oci_username == "test-user"
|
|
assert result.oci_password == "test-password"
|
|
|
|
def test_get_upload_params_unsupported_type(self):
|
|
"""Test _get_upload_params with unsupported destination type raises ValueError"""
|
|
# Create a mock config with an unsupported destination type
|
|
config = Mock(spec=AsyncUploadConfig)
|
|
config.destination = Mock()
|
|
config.destination.__class__.__name__ = "UnsupportedStorageConfig"
|
|
|
|
with pytest.raises(ValueError, match="Unsupported destination type"):
|
|
_get_upload_params(config)
|
|
|
|
def test_get_upload_params_oci_with_none_values(self):
|
|
"""Test _get_upload_params with OCI config containing None values"""
|
|
config = AsyncUploadConfig(
|
|
source=S3StorageConfig(
|
|
bucket="source-bucket",
|
|
key="source-key",
|
|
access_key_id="source-key-id",
|
|
secret_access_key="source-secret",
|
|
region="us-west-1"
|
|
),
|
|
destination=OCIStorageConfig(
|
|
uri="quay.io/example/test:latest",
|
|
registry="quay.io",
|
|
username=None,
|
|
password=None,
|
|
base_image="foo-bar:latest",
|
|
enable_tls_verify=False,
|
|
credentials_path=None
|
|
),
|
|
model=ModelConfig(
|
|
intent=UpdateArtifactIntent(artifact_id="123")
|
|
),
|
|
storage=StorageConfig(path="/tmp/test-model"),
|
|
registry=RegistryConfig(server_address="test-server")
|
|
)
|
|
|
|
result = _get_upload_params(config)
|
|
|
|
assert isinstance(result, OCIParams)
|
|
assert result.base_image == "foo-bar:latest"
|
|
assert result.oci_ref == "quay.io/example/test:latest"
|
|
assert result.dest_dir == "/tmp/test-model"
|
|
assert result.oci_username is None
|
|
assert result.oci_password is None
|
|
|
|
|
|
class TestPerformUpload:
|
|
"""Test cases for perform_upload function"""
|
|
|
|
@patch("job.upload.save_to_oci_registry")
|
|
def test_perform_upload_oci(
|
|
self, mock_save_to_oci_registry
|
|
):
|
|
"""Test perform_upload with OCI destination"""
|
|
|
|
mock_save_to_oci_registry.return_value = 'quay.io/example/oci/abc:def'
|
|
|
|
config = AsyncUploadConfig(
|
|
source=S3StorageConfig(
|
|
bucket="source-bucket",
|
|
key="source-key",
|
|
access_key_id="source-key-id",
|
|
secret_access_key="source-secret",
|
|
region="us-west-1"
|
|
),
|
|
destination=OCIStorageConfig(
|
|
uri="quay.io/example/oci",
|
|
registry="quay.io",
|
|
username="oci_user",
|
|
password="oci_pass",
|
|
base_image="foo-bar:latest",
|
|
enable_tls_verify=False,
|
|
credentials_path="/tmp/test-creds"
|
|
),
|
|
model=ModelConfig(
|
|
intent=UpdateArtifactIntent(artifact_id="123")
|
|
),
|
|
storage=StorageConfig(path="/tmp/test-model"),
|
|
registry=RegistryConfig(server_address="test-server")
|
|
)
|
|
|
|
# Act
|
|
result_uri = perform_upload(config)
|
|
|
|
# - And the returned URI is forwarded
|
|
assert result_uri == "quay.io/example/oci/abc:def"
|
|
|
|
@patch("job.upload._get_upload_params")
|
|
def test_perform_upload_propagates_exceptions_from_get_upload_params(
|
|
self, mock_get_upload_params
|
|
):
|
|
"""Test perform_upload propagates exceptions from _get_upload_params"""
|
|
# Setup
|
|
mock_client = Mock()
|
|
mock_get_upload_params.side_effect = ValueError("Invalid config")
|
|
|
|
config = AsyncUploadConfig(
|
|
source=S3StorageConfig(
|
|
bucket="source-bucket",
|
|
key="source-key",
|
|
access_key_id="source-key-id",
|
|
secret_access_key="source-secret",
|
|
region="us-west-1"
|
|
),
|
|
destination=S3StorageConfig(
|
|
bucket="test-bucket",
|
|
key="test-key",
|
|
access_key_id="test-access-key",
|
|
secret_access_key="test-secret-key",
|
|
region="us-east-1"
|
|
),
|
|
model=ModelConfig(
|
|
intent=UpdateArtifactIntent(artifact_id="test-artifact")
|
|
),
|
|
storage=StorageConfig(path="/tmp/test-model"),
|
|
registry=RegistryConfig(server_address="test-server")
|
|
)
|
|
|
|
# Execute and verify exception is propagated
|
|
with pytest.raises(ValueError, match="Invalid config"):
|
|
perform_upload(config)
|
|
|
|
# Verify client method was not called
|
|
mock_client.upload_artifact_and_register_model.assert_not_called()
|
|
|
|
@patch("job.upload.save_to_oci_registry")
|
|
def test_perform_upload_propagates_exceptions_from_client(
|
|
self, mock_save_to_oci_registry
|
|
):
|
|
"""Test perform_upload propagates exceptions from client method"""
|
|
# Setup
|
|
mock_save_to_oci_registry.side_effect = Exception(
|
|
"Upload failed"
|
|
)
|
|
|
|
config = AsyncUploadConfig(
|
|
source=S3StorageConfig(
|
|
bucket="source-bucket",
|
|
key="source-key",
|
|
access_key_id="source-key-id",
|
|
secret_access_key="source-secret",
|
|
region="us-west-1"
|
|
),
|
|
destination=OCIStorageConfig(
|
|
uri="quay.io/example/oci",
|
|
registry="quay.io",
|
|
username="oci_user",
|
|
password="oci_pass",
|
|
base_image="foo-bar:latest",
|
|
enable_tls_verify=False,
|
|
credentials_path="/tmp/test-creds"
|
|
),
|
|
model=ModelConfig(
|
|
intent=UpdateArtifactIntent(artifact_id="123")
|
|
),
|
|
storage=StorageConfig(path="/tmp/test-model"),
|
|
registry=RegistryConfig(server_address="test-server")
|
|
)
|
|
|
|
# Execute and verify exception is propagated
|
|
with pytest.raises(Exception, match="Upload failed"):
|
|
perform_upload(config)
|
|
|
|
# Verify client method was called
|
|
mock_save_to_oci_registry.assert_called_once()
|