model-registry/internal/db/service/inference_service_test.go

546 lines
17 KiB
Go

package service_test
import (
"testing"
"time"
"github.com/kubeflow/model-registry/internal/apiutils"
"github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/service"
"github.com/kubeflow/model-registry/internal/testutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInferenceServiceRepository(t *testing.T) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
// Get the actual InferenceService type ID from the database
typeID := getInferenceServiceTypeID(t, sharedDB)
repo := service.NewInferenceServiceRepository(sharedDB, typeID)
// Also get other type IDs for creating parent and related entities
servingEnvironmentTypeID := getServingEnvironmentTypeID(t, sharedDB)
servingEnvironmentRepo := service.NewServingEnvironmentRepository(sharedDB, servingEnvironmentTypeID)
registeredModelTypeID := getRegisteredModelTypeID(t, sharedDB)
registeredModelRepo := service.NewRegisteredModelRepository(sharedDB, registeredModelTypeID)
modelVersionTypeID := getModelVersionTypeID(t, sharedDB)
modelVersionRepo := service.NewModelVersionRepository(sharedDB, modelVersionTypeID)
t.Run("TestSave", func(t *testing.T) {
// First create a parent serving environment
parentServingEnv := &models.ServingEnvironmentImpl{
TypeID: apiutils.Of(int32(servingEnvironmentTypeID)),
Attributes: &models.ServingEnvironmentAttributes{
Name: apiutils.Of("parent-serving-env-for-inference"),
},
}
savedServingEnv, err := servingEnvironmentRepo.Save(parentServingEnv)
require.NoError(t, err)
// Create a registered model
registeredModel := &models.RegisteredModelImpl{
TypeID: apiutils.Of(int32(registeredModelTypeID)),
Attributes: &models.RegisteredModelAttributes{
Name: apiutils.Of("test-registered-model"),
},
}
savedRegisteredModel, err := registeredModelRepo.Save(registeredModel)
require.NoError(t, err)
// Test creating a new inference service
inferenceService := &models.InferenceServiceImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.InferenceServiceAttributes{
Name: apiutils.Of("test-inference-service"),
ExternalID: apiutils.Of("inference-ext-123"),
},
Properties: &[]models.Properties{
{
Name: "description",
StringValue: apiutils.Of("Test inference service description"),
},
{
Name: "serving_environment_id",
IntValue: savedServingEnv.GetID(),
},
{
Name: "registered_model_id",
IntValue: savedRegisteredModel.GetID(),
},
{
Name: "runtime",
StringValue: apiutils.Of("tensorflow"),
},
},
CustomProperties: &[]models.Properties{
{
Name: "custom-inference-prop",
StringValue: apiutils.Of("custom-inference-value"),
},
},
}
saved, err := repo.Save(inferenceService)
require.NoError(t, err)
require.NotNil(t, saved)
require.NotNil(t, saved.GetID())
assert.Equal(t, "test-inference-service", *saved.GetAttributes().Name)
assert.Equal(t, "inference-ext-123", *saved.GetAttributes().ExternalID)
// Test updating the same inference service
inferenceService.ID = saved.GetID()
inferenceService.GetAttributes().Name = apiutils.Of("updated-inference-service")
// Preserve CreateTimeSinceEpoch from the saved entity (simulating what OpenAPI converter would do)
inferenceService.GetAttributes().CreateTimeSinceEpoch = saved.GetAttributes().CreateTimeSinceEpoch
updated, err := repo.Save(inferenceService)
require.NoError(t, err)
require.NotNil(t, updated)
assert.Equal(t, *saved.GetID(), *updated.GetID())
assert.Equal(t, "updated-inference-service", *updated.GetAttributes().Name)
})
t.Run("TestGetByID", func(t *testing.T) {
// First create a parent serving environment
parentServingEnv := &models.ServingEnvironmentImpl{
TypeID: apiutils.Of(int32(servingEnvironmentTypeID)),
Attributes: &models.ServingEnvironmentAttributes{
Name: apiutils.Of("parent-serving-env-for-getbyid"),
},
}
savedServingEnv, err := servingEnvironmentRepo.Save(parentServingEnv)
require.NoError(t, err)
// Create a registered model
registeredModel := &models.RegisteredModelImpl{
TypeID: apiutils.Of(int32(registeredModelTypeID)),
Attributes: &models.RegisteredModelAttributes{
Name: apiutils.Of("test-registered-model-getbyid"),
},
}
savedRegisteredModel, err := registeredModelRepo.Save(registeredModel)
require.NoError(t, err)
// First create an inference service to retrieve
inferenceService := &models.InferenceServiceImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.InferenceServiceAttributes{
Name: apiutils.Of("get-test-inference-service"),
ExternalID: apiutils.Of("get-inference-ext-123"),
},
Properties: &[]models.Properties{
{
Name: "serving_environment_id",
IntValue: savedServingEnv.GetID(),
},
{
Name: "registered_model_id",
IntValue: savedRegisteredModel.GetID(),
},
},
}
saved, err := repo.Save(inferenceService)
require.NoError(t, err)
require.NotNil(t, saved.GetID())
// Test retrieving by ID
retrieved, err := repo.GetByID(*saved.GetID())
require.NoError(t, err)
require.NotNil(t, retrieved)
assert.Equal(t, *saved.GetID(), *retrieved.GetID())
assert.Equal(t, "get-test-inference-service", *retrieved.GetAttributes().Name)
assert.Equal(t, "get-inference-ext-123", *retrieved.GetAttributes().ExternalID)
// Test retrieving non-existent ID
_, err = repo.GetByID(99999)
assert.Error(t, err)
})
t.Run("TestList", func(t *testing.T) {
// Create a parent serving environment for the inference services
parentServingEnv := &models.ServingEnvironmentImpl{
TypeID: apiutils.Of(int32(servingEnvironmentTypeID)),
Attributes: &models.ServingEnvironmentAttributes{
Name: apiutils.Of("parent-serving-env-for-list"),
},
}
savedServingEnv, err := servingEnvironmentRepo.Save(parentServingEnv)
require.NoError(t, err)
// Create a registered model
registeredModel := &models.RegisteredModelImpl{
TypeID: apiutils.Of(int32(registeredModelTypeID)),
Attributes: &models.RegisteredModelAttributes{
Name: apiutils.Of("test-registered-model-list"),
},
}
savedRegisteredModel, err := registeredModelRepo.Save(registeredModel)
require.NoError(t, err)
// Create multiple inference services for listing
testInferenceServices := []*models.InferenceServiceImpl{
{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.InferenceServiceAttributes{
Name: apiutils.Of("list-inference-service-1"),
ExternalID: apiutils.Of("list-inference-ext-1"),
},
Properties: &[]models.Properties{
{
Name: "serving_environment_id",
IntValue: savedServingEnv.GetID(),
},
{
Name: "registered_model_id",
IntValue: savedRegisteredModel.GetID(),
},
{
Name: "runtime",
StringValue: apiutils.Of("tensorflow"),
},
},
},
{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.InferenceServiceAttributes{
Name: apiutils.Of("list-inference-service-2"),
ExternalID: apiutils.Of("list-inference-ext-2"),
},
Properties: &[]models.Properties{
{
Name: "serving_environment_id",
IntValue: savedServingEnv.GetID(),
},
{
Name: "registered_model_id",
IntValue: savedRegisteredModel.GetID(),
},
{
Name: "runtime",
StringValue: apiutils.Of("pytorch"),
},
},
},
{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.InferenceServiceAttributes{
Name: apiutils.Of("list-inference-service-3"),
ExternalID: apiutils.Of("list-inference-ext-3"),
},
Properties: &[]models.Properties{
{
Name: "serving_environment_id",
IntValue: savedServingEnv.GetID(),
},
{
Name: "registered_model_id",
IntValue: savedRegisteredModel.GetID(),
},
{
Name: "runtime",
StringValue: apiutils.Of("tensorflow"),
},
},
},
}
for _, infSvc := range testInferenceServices {
_, err := repo.Save(infSvc)
require.NoError(t, err)
}
// Test listing all inference services with basic pagination
pageSize := int32(10)
listOptions := models.InferenceServiceListOptions{}
listOptions.PageSize = &pageSize
result, err := repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
assert.GreaterOrEqual(t, len(result.Items), 3) // At least our 3 test inference services
// Test listing by name
listOptions = models.InferenceServiceListOptions{
Name: apiutils.Of("list-inference-service-1"),
}
listOptions.PageSize = &pageSize
result, err = repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
if len(result.Items) > 0 {
assert.Equal(t, 1, len(result.Items))
assert.Equal(t, "list-inference-service-1", *result.Items[0].GetAttributes().Name)
}
// Test listing by external ID
listOptions = models.InferenceServiceListOptions{
ExternalID: apiutils.Of("list-inference-ext-2"),
}
listOptions.PageSize = &pageSize
result, err = repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
if len(result.Items) > 0 {
assert.Equal(t, 1, len(result.Items))
assert.Equal(t, "list-inference-ext-2", *result.Items[0].GetAttributes().ExternalID)
}
// Test listing by parent resource ID (serving environment)
listOptions = models.InferenceServiceListOptions{
ParentResourceID: savedServingEnv.GetID(),
}
listOptions.PageSize = &pageSize
result, err = repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
assert.GreaterOrEqual(t, len(result.Items), 3) // Should find our 3 test inference services
// Test listing by runtime
listOptions = models.InferenceServiceListOptions{
Runtime: apiutils.Of("tensorflow"),
}
listOptions.PageSize = &pageSize
result, err = repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
assert.GreaterOrEqual(t, len(result.Items), 2) // Should find 2 tensorflow services
// Test ordering by ID (deterministic)
listOptions = models.InferenceServiceListOptions{
Pagination: models.Pagination{
OrderBy: apiutils.Of("ID"),
},
}
listOptions.PageSize = &pageSize
result, err = repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
// Verify we get results back and they are ordered by ID
assert.GreaterOrEqual(t, len(result.Items), 1)
if len(result.Items) > 1 {
// Verify ascending ID order
firstID := *result.Items[0].GetID()
secondID := *result.Items[1].GetID()
assert.Less(t, firstID, secondID, "Results should be ordered by ID ascending")
}
})
t.Run("TestListOrdering", func(t *testing.T) {
// First create a parent serving environment
parentServingEnv := &models.ServingEnvironmentImpl{
TypeID: apiutils.Of(int32(servingEnvironmentTypeID)),
Attributes: &models.ServingEnvironmentAttributes{
Name: apiutils.Of("parent-serving-env-for-ordering"),
},
}
savedServingEnv, err := servingEnvironmentRepo.Save(parentServingEnv)
require.NoError(t, err)
// Create a registered model
registeredModel := &models.RegisteredModelImpl{
TypeID: apiutils.Of(int32(registeredModelTypeID)),
Attributes: &models.RegisteredModelAttributes{
Name: apiutils.Of("test-registered-model-ordering"),
},
}
savedRegisteredModel, err := registeredModelRepo.Save(registeredModel)
require.NoError(t, err)
// Create inference services sequentially with time delays to ensure deterministic ordering
infSvc1 := &models.InferenceServiceImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.InferenceServiceAttributes{
Name: apiutils.Of("time-test-inference-service-1"),
},
Properties: &[]models.Properties{
{
Name: "serving_environment_id",
IntValue: savedServingEnv.GetID(),
},
{
Name: "registered_model_id",
IntValue: savedRegisteredModel.GetID(),
},
},
}
saved1, err := repo.Save(infSvc1)
require.NoError(t, err)
// Small delay to ensure different timestamps
time.Sleep(10 * time.Millisecond)
infSvc2 := &models.InferenceServiceImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.InferenceServiceAttributes{
Name: apiutils.Of("time-test-inference-service-2"),
},
Properties: &[]models.Properties{
{
Name: "serving_environment_id",
IntValue: savedServingEnv.GetID(),
},
{
Name: "registered_model_id",
IntValue: savedRegisteredModel.GetID(),
},
},
}
saved2, err := repo.Save(infSvc2)
require.NoError(t, err)
// Test ordering by CREATE_TIME
pageSize := int32(10)
listOptions := models.InferenceServiceListOptions{
Pagination: models.Pagination{
OrderBy: apiutils.Of("CREATE_TIME"),
},
}
listOptions.PageSize = &pageSize
result, err := repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
// Find our test inference services in the results
var foundInfSvc1, foundInfSvc2 models.InferenceService
var index1, index2 = -1, -1
for i, item := range result.Items {
if *item.GetID() == *saved1.GetID() {
foundInfSvc1 = item
index1 = i
}
if *item.GetID() == *saved2.GetID() {
foundInfSvc2 = item
index2 = i
}
}
// Verify both inference services were found and infSvc1 comes before infSvc2 (ascending order)
require.NotEqual(t, -1, index1, "Inference Service 1 should be found in results")
require.NotEqual(t, -1, index2, "Inference Service 2 should be found in results")
assert.Less(t, index1, index2, "Inference Service 1 should come before Inference Service 2 when ordered by CREATE_TIME")
assert.Less(t, *foundInfSvc1.GetAttributes().CreateTimeSinceEpoch, *foundInfSvc2.GetAttributes().CreateTimeSinceEpoch, "Inference Service 1 should have earlier create time")
})
t.Run("TestSaveWithModelVersion", func(t *testing.T) {
// First create a parent serving environment
parentServingEnv := &models.ServingEnvironmentImpl{
TypeID: apiutils.Of(int32(servingEnvironmentTypeID)),
Attributes: &models.ServingEnvironmentAttributes{
Name: apiutils.Of("parent-serving-env-for-model-version"),
},
}
savedServingEnv, err := servingEnvironmentRepo.Save(parentServingEnv)
require.NoError(t, err)
// Create a registered model
registeredModel := &models.RegisteredModelImpl{
TypeID: apiutils.Of(int32(registeredModelTypeID)),
Attributes: &models.RegisteredModelAttributes{
Name: apiutils.Of("test-registered-model-with-version"),
},
}
savedRegisteredModel, err := registeredModelRepo.Save(registeredModel)
require.NoError(t, err)
// Create a model version
modelVersion := &models.ModelVersionImpl{
TypeID: apiutils.Of(int32(modelVersionTypeID)),
Attributes: &models.ModelVersionAttributes{
Name: apiutils.Of("test-model-version"),
},
Properties: &[]models.Properties{
{
Name: "registered_model_id",
IntValue: savedRegisteredModel.GetID(),
},
},
}
savedModelVersion, err := modelVersionRepo.Save(modelVersion)
require.NoError(t, err)
// Create inference service with both registered model and model version
inferenceService := &models.InferenceServiceImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.InferenceServiceAttributes{
Name: apiutils.Of("inference-service-with-model-version"),
},
Properties: &[]models.Properties{
{
Name: "description",
StringValue: apiutils.Of("Inference service with model version"),
},
{
Name: "serving_environment_id",
IntValue: savedServingEnv.GetID(),
},
{
Name: "registered_model_id",
IntValue: savedRegisteredModel.GetID(),
},
{
Name: "model_version_id",
IntValue: savedModelVersion.GetID(),
},
{
Name: "runtime",
StringValue: apiutils.Of("onnx"),
},
},
CustomProperties: &[]models.Properties{
{
Name: "team",
StringValue: apiutils.Of("ml-team"),
},
{
Name: "priority",
IntValue: apiutils.Of(int32(5)),
},
},
}
saved, err := repo.Save(inferenceService)
require.NoError(t, err)
require.NotNil(t, saved)
// Verify properties were saved
retrieved, err := repo.GetByID(*saved.GetID())
require.NoError(t, err)
assert.NotNil(t, retrieved.GetProperties())
assert.Len(t, *retrieved.GetProperties(), 5) // description, serving_environment_id, registered_model_id, model_version_id, runtime
assert.NotNil(t, retrieved.GetCustomProperties())
assert.Len(t, *retrieved.GetCustomProperties(), 2)
// Verify the specific properties exist
properties := *retrieved.GetProperties()
var foundModelVersionID, foundRegisteredModelID, foundServingEnvID bool
for _, prop := range properties {
if prop.Name == "model_version_id" && prop.IntValue != nil && *prop.IntValue == *savedModelVersion.GetID() {
foundModelVersionID = true
}
if prop.Name == "registered_model_id" && prop.IntValue != nil && *prop.IntValue == *savedRegisteredModel.GetID() {
foundRegisteredModelID = true
}
if prop.Name == "serving_environment_id" && prop.IntValue != nil && *prop.IntValue == *savedServingEnv.GetID() {
foundServingEnvID = true
}
}
assert.True(t, foundModelVersionID, "Should find model_version_id property")
assert.True(t, foundRegisteredModelID, "Should find registered_model_id property")
assert.True(t, foundServingEnvID, "Should find serving_environment_id property")
})
}