1101 lines
37 KiB
Go
1101 lines
37 KiB
Go
package core_test
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/kubeflow/model-registry/internal/apiutils"
|
|
"github.com/kubeflow/model-registry/pkg/api"
|
|
"github.com/kubeflow/model-registry/pkg/openapi"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestUpsertInferenceService(t *testing.T) {
|
|
_service, cleanup := SetupModelRegistryService(t)
|
|
defer cleanup()
|
|
|
|
t.Run("successful create", func(t *testing.T) {
|
|
// Create prerequisites: registered model and serving environment
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "test-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "test-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
// Create inference service
|
|
input := &openapi.InferenceService{
|
|
Name: apiutils.Of("test-inference-service"),
|
|
Description: apiutils.Of("Test inference service description"),
|
|
ExternalId: apiutils.Of("inference-ext-123"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Runtime: apiutils.Of("tensorflow"),
|
|
}
|
|
|
|
result, err := _service.UpsertInferenceService(input)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.NotNil(t, result.Id)
|
|
assert.Equal(t, "test-inference-service", *result.Name)
|
|
assert.Equal(t, "inference-ext-123", *result.ExternalId)
|
|
assert.Equal(t, "Test inference service description", *result.Description)
|
|
assert.Equal(t, *createdEnv.Id, result.ServingEnvironmentId)
|
|
assert.Equal(t, *createdModel.Id, result.RegisteredModelId)
|
|
assert.Equal(t, "tensorflow", *result.Runtime)
|
|
assert.NotNil(t, result.CreateTimeSinceEpoch)
|
|
assert.NotNil(t, result.LastUpdateTimeSinceEpoch)
|
|
})
|
|
|
|
t.Run("successful update", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "update-test-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "update-test-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
// Create first
|
|
input := &openapi.InferenceService{
|
|
Name: apiutils.Of("update-test-inference-service"),
|
|
Description: apiutils.Of("Original description"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
}
|
|
|
|
created, err := _service.UpsertInferenceService(input)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, created.Id)
|
|
|
|
// Update
|
|
update := &openapi.InferenceService{
|
|
Id: created.Id,
|
|
Name: apiutils.Of("update-test-inference-service"), // Name should remain the same
|
|
Description: apiutils.Of("Updated description"),
|
|
ExternalId: apiutils.Of("updated-ext-456"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Runtime: apiutils.Of("pytorch"),
|
|
}
|
|
|
|
updated, err := _service.UpsertInferenceService(update)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, updated)
|
|
assert.Equal(t, *created.Id, *updated.Id)
|
|
assert.Equal(t, "update-test-inference-service", *updated.Name)
|
|
assert.Equal(t, "Updated description", *updated.Description)
|
|
assert.Equal(t, "updated-ext-456", *updated.ExternalId)
|
|
assert.Equal(t, "pytorch", *updated.Runtime)
|
|
})
|
|
|
|
t.Run("create with custom properties", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "custom-props-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "custom-props-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
customProps := map[string]openapi.MetadataValue{
|
|
"model_uri": {
|
|
MetadataStringValue: &openapi.MetadataStringValue{
|
|
StringValue: "s3://bucket/model",
|
|
},
|
|
},
|
|
"batch_size": {
|
|
MetadataIntValue: &openapi.MetadataIntValue{
|
|
IntValue: "32",
|
|
},
|
|
},
|
|
"enable_logging": {
|
|
MetadataBoolValue: &openapi.MetadataBoolValue{
|
|
BoolValue: true,
|
|
},
|
|
},
|
|
"confidence_threshold": {
|
|
MetadataDoubleValue: &openapi.MetadataDoubleValue{
|
|
DoubleValue: 0.85,
|
|
},
|
|
},
|
|
}
|
|
|
|
input := &openapi.InferenceService{
|
|
Name: apiutils.Of("custom-props-inference-service"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
CustomProperties: &customProps,
|
|
}
|
|
|
|
result, err := _service.UpsertInferenceService(input)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.Equal(t, "custom-props-inference-service", *result.Name)
|
|
assert.NotNil(t, result.CustomProperties)
|
|
|
|
resultProps := *result.CustomProperties
|
|
assert.Contains(t, resultProps, "model_uri")
|
|
assert.Contains(t, resultProps, "batch_size")
|
|
assert.Contains(t, resultProps, "enable_logging")
|
|
assert.Contains(t, resultProps, "confidence_threshold")
|
|
|
|
assert.Equal(t, "s3://bucket/model", resultProps["model_uri"].MetadataStringValue.StringValue)
|
|
assert.Equal(t, "32", resultProps["batch_size"].MetadataIntValue.IntValue)
|
|
assert.Equal(t, true, resultProps["enable_logging"].MetadataBoolValue.BoolValue)
|
|
assert.Equal(t, 0.85, resultProps["confidence_threshold"].MetadataDoubleValue.DoubleValue)
|
|
})
|
|
|
|
t.Run("minimal inference service", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "minimal-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "minimal-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
input := &openapi.InferenceService{
|
|
Name: apiutils.Of("minimal-inference-service"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
}
|
|
|
|
result, err := _service.UpsertInferenceService(input)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.Equal(t, "minimal-inference-service", *result.Name)
|
|
assert.NotNil(t, result.Id)
|
|
assert.Equal(t, *createdEnv.Id, result.ServingEnvironmentId)
|
|
assert.Equal(t, *createdModel.Id, result.RegisteredModelId)
|
|
})
|
|
|
|
t.Run("nil inference service error", func(t *testing.T) {
|
|
result, err := _service.UpsertInferenceService(nil)
|
|
|
|
assert.Error(t, err)
|
|
assert.Nil(t, result)
|
|
assert.Contains(t, err.Error(), "invalid inference service pointer")
|
|
})
|
|
|
|
t.Run("nil desired state preserved", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "nil-state-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "nil-state-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
// Create inference service with nil DesiredState and other optional fields
|
|
input := &openapi.InferenceService{
|
|
Name: apiutils.Of("nil-state-inference-service"),
|
|
Description: apiutils.Of("Test inference service with nil desired state"),
|
|
ExternalId: apiutils.Of("nil-state-ext-123"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Runtime: apiutils.Of("tensorflow"),
|
|
DesiredState: nil, // Explicitly set to nil
|
|
}
|
|
|
|
result, err := _service.UpsertInferenceService(input)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.NotNil(t, result.Id)
|
|
assert.Equal(t, "nil-state-inference-service", *result.Name)
|
|
assert.Equal(t, "nil-state-ext-123", *result.ExternalId)
|
|
assert.Equal(t, "Test inference service with nil desired state", *result.Description)
|
|
assert.Equal(t, *createdEnv.Id, result.ServingEnvironmentId)
|
|
assert.Equal(t, *createdModel.Id, result.RegisteredModelId)
|
|
assert.Equal(t, "tensorflow", *result.Runtime)
|
|
assert.Nil(t, result.DesiredState) // Verify desired state remains nil
|
|
assert.NotNil(t, result.CreateTimeSinceEpoch)
|
|
assert.NotNil(t, result.LastUpdateTimeSinceEpoch)
|
|
})
|
|
|
|
t.Run("unicode characters in name", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "unicode-test-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "unicode-test-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
// Test with unicode characters: Chinese, Russian, Japanese, and emoji
|
|
unicodeName := "推理服务-тест-推論サービス-🚀"
|
|
input := &openapi.InferenceService{
|
|
Name: apiutils.Of(unicodeName),
|
|
Description: apiutils.Of("Test inference service with unicode characters"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Runtime: apiutils.Of("tensorflow"),
|
|
}
|
|
|
|
result, err := _service.UpsertInferenceService(input)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.Equal(t, unicodeName, *result.Name)
|
|
assert.Equal(t, "Test inference service with unicode characters", *result.Description)
|
|
assert.NotNil(t, result.Id)
|
|
assert.NotNil(t, result.CreateTimeSinceEpoch)
|
|
assert.NotNil(t, result.LastUpdateTimeSinceEpoch)
|
|
|
|
// Verify we can retrieve it by ID
|
|
retrieved, err := _service.GetInferenceServiceById(*result.Id)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, unicodeName, *retrieved.Name)
|
|
})
|
|
|
|
t.Run("special characters in name", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "special-chars-test-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "special-chars-test-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
// Test with various special characters
|
|
specialName := "!@#$%^&*()_+-=[]{}|;':\",./<>?"
|
|
input := &openapi.InferenceService{
|
|
Name: apiutils.Of(specialName),
|
|
Description: apiutils.Of("Test inference service with special characters"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Runtime: apiutils.Of("pytorch"),
|
|
}
|
|
|
|
result, err := _service.UpsertInferenceService(input)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.Equal(t, specialName, *result.Name)
|
|
assert.Equal(t, "Test inference service with special characters", *result.Description)
|
|
assert.NotNil(t, result.Id)
|
|
|
|
// Verify we can retrieve it by ID
|
|
retrieved, err := _service.GetInferenceServiceById(*result.Id)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, specialName, *retrieved.Name)
|
|
})
|
|
|
|
t.Run("mixed unicode and special characters", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "mixed-chars-test-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "mixed-chars-test-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
// Test with mixed unicode and special characters
|
|
mixedName := "推理@#$%服务-тест!@#-推論()サービス-🚀[]"
|
|
input := &openapi.InferenceService{
|
|
Name: apiutils.Of(mixedName),
|
|
Description: apiutils.Of("Test inference service with mixed unicode and special characters"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Runtime: apiutils.Of("onnx"),
|
|
}
|
|
|
|
result, err := _service.UpsertInferenceService(input)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.Equal(t, mixedName, *result.Name)
|
|
assert.Equal(t, "Test inference service with mixed unicode and special characters", *result.Description)
|
|
assert.NotNil(t, result.Id)
|
|
|
|
// Verify we can retrieve it by ID
|
|
retrieved, err := _service.GetInferenceServiceById(*result.Id)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, mixedName, *retrieved.Name)
|
|
})
|
|
|
|
t.Run("pagination with 10+ inference services", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "paging-test-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "paging-test-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
// Create 15 inference services for pagination testing
|
|
var createdServices []string
|
|
for i := 0; i < 15; i++ {
|
|
serviceName := "paging-test-inference-service-" + fmt.Sprintf("%02d", i)
|
|
input := &openapi.InferenceService{
|
|
Name: apiutils.Of(serviceName),
|
|
Description: apiutils.Of("Pagination test inference service " + fmt.Sprintf("%02d", i)),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Runtime: apiutils.Of("tensorflow"),
|
|
}
|
|
|
|
result, err := _service.UpsertInferenceService(input)
|
|
require.NoError(t, err)
|
|
createdServices = append(createdServices, *result.Id)
|
|
}
|
|
|
|
// Test pagination with page size 5
|
|
pageSize := int32(5)
|
|
orderBy := "name"
|
|
sortOrder := "ASC"
|
|
listOptions := api.ListOptions{
|
|
PageSize: &pageSize,
|
|
OrderBy: &orderBy,
|
|
SortOrder: &sortOrder,
|
|
}
|
|
|
|
// Get first page
|
|
firstPage, err := _service.GetInferenceServices(listOptions, nil, nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, firstPage)
|
|
assert.LessOrEqual(t, len(firstPage.Items), 5, "First page should have at most 5 items")
|
|
assert.Equal(t, int32(5), firstPage.PageSize)
|
|
|
|
// Filter to only our test inference services in first page
|
|
var firstPageTestServices []openapi.InferenceService
|
|
firstPageIds := make(map[string]bool)
|
|
for _, item := range firstPage.Items {
|
|
// Only include our test services (those with the specific prefix)
|
|
if strings.HasPrefix(*item.Name, "paging-test-inference-service-") {
|
|
assert.False(t, firstPageIds[*item.Id], "Should not have duplicate IDs in first page")
|
|
firstPageIds[*item.Id] = true
|
|
firstPageTestServices = append(firstPageTestServices, item)
|
|
}
|
|
}
|
|
|
|
// Only proceed with second page test if we have a next page token and found test services
|
|
if firstPage.NextPageToken != "" && len(firstPageTestServices) > 0 {
|
|
// Get second page using next page token
|
|
listOptions.NextPageToken = &firstPage.NextPageToken
|
|
secondPage, err := _service.GetInferenceServices(listOptions, nil, nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, secondPage)
|
|
assert.LessOrEqual(t, len(secondPage.Items), 5, "Second page should have at most 5 items")
|
|
|
|
// Verify no duplicates between pages (only check our test services)
|
|
for _, item := range secondPage.Items {
|
|
if strings.HasPrefix(*item.Name, "paging-test-inference-service-") {
|
|
assert.False(t, firstPageIds[*item.Id], "Should not have duplicate IDs between pages")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Test with larger page size
|
|
largePage := int32(100)
|
|
listOptions = api.ListOptions{
|
|
PageSize: &largePage,
|
|
OrderBy: &orderBy,
|
|
SortOrder: &sortOrder,
|
|
}
|
|
|
|
allItems, err := _service.GetInferenceServices(listOptions, nil, nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, allItems)
|
|
assert.GreaterOrEqual(t, len(allItems.Items), 15, "Should have at least our 15 test inference services")
|
|
|
|
// Count our test services in the results
|
|
foundCount := 0
|
|
for _, item := range allItems.Items {
|
|
for _, createdId := range createdServices {
|
|
if *item.Id == createdId {
|
|
foundCount++
|
|
break
|
|
}
|
|
}
|
|
}
|
|
assert.Equal(t, 15, foundCount, "Should find all 15 created inference services")
|
|
|
|
// Test descending order
|
|
descOrder := "DESC"
|
|
listOptions = api.ListOptions{
|
|
PageSize: &pageSize,
|
|
OrderBy: &orderBy,
|
|
SortOrder: &descOrder,
|
|
}
|
|
|
|
descPage, err := _service.GetInferenceServices(listOptions, nil, nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, descPage)
|
|
assert.LessOrEqual(t, len(descPage.Items), 5, "Desc page should have at most 5 items")
|
|
|
|
// Verify ordering (names should be in descending order)
|
|
if len(descPage.Items) > 1 {
|
|
for i := 1; i < len(descPage.Items); i++ {
|
|
assert.GreaterOrEqual(t, *descPage.Items[i-1].Name, *descPage.Items[i].Name,
|
|
"Items should be in descending order by name")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestGetInferenceServiceById(t *testing.T) {
|
|
_service, cleanup := SetupModelRegistryService(t)
|
|
defer cleanup()
|
|
|
|
t.Run("successful get", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "get-test-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "get-test-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
// First create an inference service to retrieve
|
|
input := &openapi.InferenceService{
|
|
Name: apiutils.Of("get-test-inference-service"),
|
|
Description: apiutils.Of("Test description"),
|
|
ExternalId: apiutils.Of("get-ext-123"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Runtime: apiutils.Of("tensorflow"),
|
|
}
|
|
|
|
created, err := _service.UpsertInferenceService(input)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, created.Id)
|
|
|
|
// Get the inference service by ID
|
|
result, err := _service.GetInferenceServiceById(*created.Id)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.Equal(t, *created.Id, *result.Id)
|
|
assert.Equal(t, "get-test-inference-service", *result.Name)
|
|
assert.Equal(t, "get-ext-123", *result.ExternalId)
|
|
assert.Equal(t, "Test description", *result.Description)
|
|
assert.Equal(t, "tensorflow", *result.Runtime)
|
|
})
|
|
|
|
t.Run("invalid id", func(t *testing.T) {
|
|
result, err := _service.GetInferenceServiceById("invalid")
|
|
|
|
assert.Error(t, err)
|
|
assert.Nil(t, result)
|
|
assert.Contains(t, err.Error(), "invalid")
|
|
})
|
|
|
|
t.Run("non-existent id", func(t *testing.T) {
|
|
result, err := _service.GetInferenceServiceById("99999")
|
|
|
|
assert.Error(t, err)
|
|
assert.Nil(t, result)
|
|
assert.Contains(t, err.Error(), "no InferenceService found")
|
|
})
|
|
}
|
|
|
|
func TestGetInferenceServiceByParams(t *testing.T) {
|
|
_service, cleanup := SetupModelRegistryService(t)
|
|
defer cleanup()
|
|
|
|
t.Run("successful get by name and parent resource id", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "params-test-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "params-test-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
input := &openapi.InferenceService{
|
|
Name: apiutils.Of("params-test-inference-service"),
|
|
ExternalId: apiutils.Of("params-ext-123"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
}
|
|
created, err := _service.UpsertInferenceService(input)
|
|
require.NoError(t, err)
|
|
|
|
// Get by name and parent resource ID
|
|
serviceName := "params-test-inference-service"
|
|
result, err := _service.GetInferenceServiceByParams(&serviceName, createdEnv.Id, nil)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.Equal(t, *created.Id, *result.Id)
|
|
assert.Equal(t, "params-test-inference-service", *result.Name)
|
|
})
|
|
|
|
t.Run("successful get by external id", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "params-ext-test-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "params-ext-test-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
input := &openapi.InferenceService{
|
|
Name: apiutils.Of("params-ext-test-inference-service"),
|
|
ExternalId: apiutils.Of("params-unique-ext-456"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
}
|
|
created, err := _service.UpsertInferenceService(input)
|
|
require.NoError(t, err)
|
|
|
|
// Get by external ID
|
|
externalId := "params-unique-ext-456"
|
|
result, err := _service.GetInferenceServiceByParams(nil, nil, &externalId)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.Equal(t, *created.Id, *result.Id)
|
|
assert.Equal(t, "params-unique-ext-456", *result.ExternalId)
|
|
})
|
|
|
|
t.Run("invalid parameters", func(t *testing.T) {
|
|
result, err := _service.GetInferenceServiceByParams(nil, nil, nil)
|
|
|
|
assert.Error(t, err)
|
|
assert.Nil(t, result)
|
|
assert.Contains(t, err.Error(), "invalid parameters call")
|
|
})
|
|
|
|
t.Run("same inference service name across different serving environments", func(t *testing.T) {
|
|
// This test catches the bug where ParentResourceID was not being used to filter inference services
|
|
|
|
// Create first serving environment
|
|
servingEnv1 := &openapi.ServingEnvironment{
|
|
Name: "serving-env-with-shared-service-1",
|
|
}
|
|
createdEnv1, err := _service.UpsertServingEnvironment(servingEnv1)
|
|
require.NoError(t, err)
|
|
|
|
// Create second serving environment
|
|
servingEnv2 := &openapi.ServingEnvironment{
|
|
Name: "serving-env-with-shared-service-2",
|
|
}
|
|
createdEnv2, err := _service.UpsertServingEnvironment(servingEnv2)
|
|
require.NoError(t, err)
|
|
|
|
// Create registered model for the services
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "model-for-shared-services",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
// Create inference service "shared-service-name-test" for the first environment
|
|
service1 := &openapi.InferenceService{
|
|
Name: apiutils.Of("shared-service-name-test"),
|
|
ServingEnvironmentId: *createdEnv1.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Description: apiutils.Of("Service for environment 1"),
|
|
}
|
|
createdService1, err := _service.UpsertInferenceService(service1)
|
|
require.NoError(t, err)
|
|
|
|
// Create inference service "shared-service-name-test" for the second environment
|
|
service2 := &openapi.InferenceService{
|
|
Name: apiutils.Of("shared-service-name-test"),
|
|
ServingEnvironmentId: *createdEnv2.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Description: apiutils.Of("Service for environment 2"),
|
|
}
|
|
createdService2, err := _service.UpsertInferenceService(service2)
|
|
require.NoError(t, err)
|
|
|
|
// Query for service "shared-service-name-test" of the first environment
|
|
serviceName := "shared-service-name-test"
|
|
result1, err := _service.GetInferenceServiceByParams(&serviceName, createdEnv1.Id, nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result1)
|
|
assert.Equal(t, *createdService1.Id, *result1.Id)
|
|
assert.Equal(t, *createdEnv1.Id, result1.ServingEnvironmentId)
|
|
assert.Equal(t, "Service for environment 1", *result1.Description)
|
|
|
|
// Query for service "shared-service-name-test" of the second environment
|
|
result2, err := _service.GetInferenceServiceByParams(&serviceName, createdEnv2.Id, nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result2)
|
|
assert.Equal(t, *createdService2.Id, *result2.Id)
|
|
assert.Equal(t, *createdEnv2.Id, result2.ServingEnvironmentId)
|
|
assert.Equal(t, "Service for environment 2", *result2.Description)
|
|
|
|
// Ensure we got different services
|
|
assert.NotEqual(t, *result1.Id, *result2.Id)
|
|
})
|
|
|
|
t.Run("no inference service found", func(t *testing.T) {
|
|
serviceName := "nonexistent-inference-service"
|
|
parentResourceId := "999"
|
|
result, err := _service.GetInferenceServiceByParams(&serviceName, &parentResourceId, nil)
|
|
|
|
assert.Error(t, err)
|
|
assert.Nil(t, result)
|
|
assert.Contains(t, err.Error(), "no inference service found")
|
|
})
|
|
}
|
|
|
|
func TestGetInferenceServices(t *testing.T) {
|
|
_service, cleanup := SetupModelRegistryService(t)
|
|
defer cleanup()
|
|
|
|
t.Run("successful list", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "list-test-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "list-test-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
// Create multiple inference services for listing
|
|
testInferenceServices := []*openapi.InferenceService{
|
|
{
|
|
Name: apiutils.Of("list-inference-service-1"),
|
|
ExternalId: apiutils.Of("list-ext-1"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Runtime: apiutils.Of("tensorflow"),
|
|
},
|
|
{
|
|
Name: apiutils.Of("list-inference-service-2"),
|
|
ExternalId: apiutils.Of("list-ext-2"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Runtime: apiutils.Of("pytorch"),
|
|
},
|
|
{
|
|
Name: apiutils.Of("list-inference-service-3"),
|
|
ExternalId: apiutils.Of("list-ext-3"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Runtime: apiutils.Of("onnx"),
|
|
},
|
|
}
|
|
|
|
var createdIds []string
|
|
for _, infSvc := range testInferenceServices {
|
|
created, err := _service.UpsertInferenceService(infSvc)
|
|
require.NoError(t, err)
|
|
createdIds = append(createdIds, *created.Id)
|
|
}
|
|
|
|
// List inference services with basic pagination
|
|
pageSize := int32(10)
|
|
listOptions := api.ListOptions{
|
|
PageSize: &pageSize,
|
|
}
|
|
|
|
result, err := _service.GetInferenceServices(listOptions, nil, nil)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.GreaterOrEqual(t, len(result.Items), 3) // Should have at least our 3 test inference services
|
|
assert.Equal(t, int32(10), result.PageSize)
|
|
|
|
// Verify our inference services are in the result
|
|
foundServices := 0
|
|
for _, item := range result.Items {
|
|
for _, createdId := range createdIds {
|
|
if *item.Id == createdId {
|
|
foundServices++
|
|
break
|
|
}
|
|
}
|
|
}
|
|
assert.Equal(t, 3, foundServices, "All created inference services should be found in the list")
|
|
})
|
|
|
|
t.Run("list with serving environment filter", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "filter-test-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv1 := &openapi.ServingEnvironment{
|
|
Name: "filter-test-serving-env-1",
|
|
}
|
|
createdEnv1, err := _service.UpsertServingEnvironment(servingEnv1)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv2 := &openapi.ServingEnvironment{
|
|
Name: "filter-test-serving-env-2",
|
|
}
|
|
createdEnv2, err := _service.UpsertServingEnvironment(servingEnv2)
|
|
require.NoError(t, err)
|
|
|
|
// Create inference services in different serving environments
|
|
infSvc1 := &openapi.InferenceService{
|
|
Name: apiutils.Of("filter-inference-service-1"),
|
|
ServingEnvironmentId: *createdEnv1.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
}
|
|
created1, err := _service.UpsertInferenceService(infSvc1)
|
|
require.NoError(t, err)
|
|
|
|
infSvc2 := &openapi.InferenceService{
|
|
Name: apiutils.Of("filter-inference-service-2"),
|
|
ServingEnvironmentId: *createdEnv2.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
}
|
|
_, err = _service.UpsertInferenceService(infSvc2)
|
|
require.NoError(t, err)
|
|
|
|
// List inference services filtered by serving environment
|
|
pageSize := int32(10)
|
|
listOptions := api.ListOptions{
|
|
PageSize: &pageSize,
|
|
}
|
|
|
|
result, err := _service.GetInferenceServices(listOptions, createdEnv1.Id, nil)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.GreaterOrEqual(t, len(result.Items), 1) // Should have at least our 1 service in env1
|
|
|
|
// Verify that only services from the specified environment are returned
|
|
found := false
|
|
for _, item := range result.Items {
|
|
if *item.Id == *created1.Id {
|
|
found = true
|
|
assert.Equal(t, *createdEnv1.Id, item.ServingEnvironmentId)
|
|
}
|
|
}
|
|
assert.True(t, found, "Should find the inference service in the specified serving environment")
|
|
})
|
|
|
|
t.Run("list with runtime filter", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "runtime-filter-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "runtime-filter-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
// Create inference services with different runtimes
|
|
infSvcTensorflow := &openapi.InferenceService{
|
|
Name: apiutils.Of("runtime-tensorflow-service"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Runtime: apiutils.Of("tensorflow"),
|
|
}
|
|
createdTensorflow, err := _service.UpsertInferenceService(infSvcTensorflow)
|
|
require.NoError(t, err)
|
|
|
|
infSvcPytorch := &openapi.InferenceService{
|
|
Name: apiutils.Of("runtime-pytorch-service"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Runtime: apiutils.Of("pytorch"),
|
|
}
|
|
_, err = _service.UpsertInferenceService(infSvcPytorch)
|
|
require.NoError(t, err)
|
|
|
|
// List inference services filtered by runtime
|
|
pageSize := int32(10)
|
|
listOptions := api.ListOptions{
|
|
PageSize: &pageSize,
|
|
}
|
|
runtime := "tensorflow"
|
|
|
|
result, err := _service.GetInferenceServices(listOptions, nil, &runtime)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.GreaterOrEqual(t, len(result.Items), 1) // Should have at least our 1 tensorflow service
|
|
|
|
// Verify that only services with the specified runtime are returned
|
|
found := false
|
|
for _, item := range result.Items {
|
|
if *item.Id == *createdTensorflow.Id {
|
|
found = true
|
|
assert.Equal(t, "tensorflow", *item.Runtime)
|
|
}
|
|
}
|
|
assert.True(t, found, "Should find the inference service with the specified runtime")
|
|
})
|
|
|
|
t.Run("pagination and ordering", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "pagination-test-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "pagination-test-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
// Create several inference services for pagination testing
|
|
for i := 0; i < 5; i++ {
|
|
infSvc := &openapi.InferenceService{
|
|
Name: apiutils.Of("pagination-inference-service-" + string(rune('A'+i))),
|
|
ExternalId: apiutils.Of("pagination-ext-" + string(rune('A'+i))),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
}
|
|
_, err := _service.UpsertInferenceService(infSvc)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// Test with small page size and ordering
|
|
pageSize := int32(2)
|
|
orderBy := "name"
|
|
sortOrder := "asc"
|
|
listOptions := api.ListOptions{
|
|
PageSize: &pageSize,
|
|
OrderBy: &orderBy,
|
|
SortOrder: &sortOrder,
|
|
}
|
|
|
|
result, err := _service.GetInferenceServices(listOptions, nil, nil)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.GreaterOrEqual(t, len(result.Items), 2) // Should have at least 2 items
|
|
assert.Equal(t, int32(2), result.PageSize)
|
|
})
|
|
|
|
t.Run("invalid serving environment id", func(t *testing.T) {
|
|
invalidId := "invalid"
|
|
listOptions := api.ListOptions{}
|
|
|
|
result, err := _service.GetInferenceServices(listOptions, &invalidId, nil)
|
|
|
|
assert.Error(t, err)
|
|
assert.Nil(t, result)
|
|
assert.Contains(t, err.Error(), "invalid syntax: bad request")
|
|
})
|
|
}
|
|
|
|
func TestInferenceServiceRoundTrip(t *testing.T) {
|
|
_service, cleanup := SetupModelRegistryService(t)
|
|
defer cleanup()
|
|
|
|
t.Run("complete roundtrip", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "roundtrip-registered-model",
|
|
Description: apiutils.Of("Roundtrip test registered model"),
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "roundtrip-serving-env",
|
|
Description: apiutils.Of("Roundtrip test serving environment"),
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
// Create an inference service with all fields
|
|
original := &openapi.InferenceService{
|
|
Name: apiutils.Of("roundtrip-inference-service"),
|
|
Description: apiutils.Of("Roundtrip test description"),
|
|
ExternalId: apiutils.Of("roundtrip-ext-123"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
Runtime: apiutils.Of("tensorflow"),
|
|
}
|
|
|
|
// Create
|
|
created, err := _service.UpsertInferenceService(original)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, created.Id)
|
|
|
|
// Get by ID
|
|
retrieved, err := _service.GetInferenceServiceById(*created.Id)
|
|
require.NoError(t, err)
|
|
|
|
// Verify all fields match
|
|
assert.Equal(t, *created.Id, *retrieved.Id)
|
|
assert.Equal(t, *original.Name, *retrieved.Name)
|
|
assert.Equal(t, *original.Description, *retrieved.Description)
|
|
assert.Equal(t, *original.ExternalId, *retrieved.ExternalId)
|
|
assert.Equal(t, original.ServingEnvironmentId, retrieved.ServingEnvironmentId)
|
|
assert.Equal(t, original.RegisteredModelId, retrieved.RegisteredModelId)
|
|
assert.Equal(t, *original.Runtime, *retrieved.Runtime)
|
|
|
|
// Update
|
|
retrieved.Description = apiutils.Of("Updated description")
|
|
retrieved.Runtime = apiutils.Of("pytorch")
|
|
|
|
updated, err := _service.UpsertInferenceService(retrieved)
|
|
require.NoError(t, err)
|
|
|
|
// Verify update
|
|
assert.Equal(t, *created.Id, *updated.Id)
|
|
assert.Equal(t, "Updated description", *updated.Description)
|
|
assert.Equal(t, "pytorch", *updated.Runtime)
|
|
|
|
// Get again to verify persistence
|
|
final, err := _service.GetInferenceServiceById(*created.Id)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "Updated description", *final.Description)
|
|
assert.Equal(t, "pytorch", *final.Runtime)
|
|
})
|
|
|
|
t.Run("roundtrip with custom properties", func(t *testing.T) {
|
|
// Create prerequisites
|
|
registeredModel := &openapi.RegisteredModel{
|
|
Name: "roundtrip-custom-props-registered-model",
|
|
}
|
|
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
|
|
require.NoError(t, err)
|
|
|
|
servingEnv := &openapi.ServingEnvironment{
|
|
Name: "roundtrip-custom-props-serving-env",
|
|
}
|
|
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
|
|
require.NoError(t, err)
|
|
|
|
customProps := map[string]openapi.MetadataValue{
|
|
"deployment_type": {
|
|
MetadataStringValue: &openapi.MetadataStringValue{
|
|
StringValue: "canary",
|
|
},
|
|
},
|
|
"replicas": {
|
|
MetadataIntValue: &openapi.MetadataIntValue{
|
|
IntValue: "3",
|
|
},
|
|
},
|
|
}
|
|
|
|
original := &openapi.InferenceService{
|
|
Name: apiutils.Of("roundtrip-custom-props-inference-service"),
|
|
ServingEnvironmentId: *createdEnv.Id,
|
|
RegisteredModelId: *createdModel.Id,
|
|
CustomProperties: &customProps,
|
|
}
|
|
|
|
// Create
|
|
created, err := _service.UpsertInferenceService(original)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, created.Id)
|
|
|
|
// Get by ID
|
|
retrieved, err := _service.GetInferenceServiceById(*created.Id)
|
|
require.NoError(t, err)
|
|
|
|
// Verify custom properties
|
|
assert.NotNil(t, retrieved.CustomProperties)
|
|
retrievedProps := *retrieved.CustomProperties
|
|
assert.Contains(t, retrievedProps, "deployment_type")
|
|
assert.Contains(t, retrievedProps, "replicas")
|
|
assert.Equal(t, "canary", retrievedProps["deployment_type"].MetadataStringValue.StringValue)
|
|
assert.Equal(t, "3", retrievedProps["replicas"].MetadataIntValue.IntValue)
|
|
|
|
// Update custom properties
|
|
updatedProps := map[string]openapi.MetadataValue{
|
|
"deployment_type": {
|
|
MetadataStringValue: &openapi.MetadataStringValue{
|
|
StringValue: "blue-green",
|
|
},
|
|
},
|
|
"replicas": {
|
|
MetadataIntValue: &openapi.MetadataIntValue{
|
|
IntValue: "5",
|
|
},
|
|
},
|
|
"new_prop": {
|
|
MetadataStringValue: &openapi.MetadataStringValue{
|
|
StringValue: "new_value",
|
|
},
|
|
},
|
|
}
|
|
retrieved.CustomProperties = &updatedProps
|
|
|
|
updated, err := _service.UpsertInferenceService(retrieved)
|
|
require.NoError(t, err)
|
|
|
|
// Verify updated custom properties
|
|
assert.NotNil(t, updated.CustomProperties)
|
|
finalProps := *updated.CustomProperties
|
|
assert.Equal(t, "blue-green", finalProps["deployment_type"].MetadataStringValue.StringValue)
|
|
assert.Equal(t, "5", finalProps["replicas"].MetadataIntValue.IntValue)
|
|
assert.Equal(t, "new_value", finalProps["new_prop"].MetadataStringValue.StringValue)
|
|
})
|
|
}
|