model-registry/internal/core/model_version_test.go

1313 lines
43 KiB
Go

package core_test
import (
"fmt"
"strings"
"testing"
"github.com/kubeflow/model-registry/internal/apiutils"
"github.com/kubeflow/model-registry/pkg/api"
"github.com/kubeflow/model-registry/pkg/openapi"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUpsertModelVersion(t *testing.T) {
_service, cleanup := SetupModelRegistryService(t)
defer cleanup()
t.Run("successful create", func(t *testing.T) {
// First create a registered model
registeredModel := &openapi.RegisteredModel{
Name: "test-registered-model",
Description: apiutils.Of("Test registered model for version"),
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
require.NotNil(t, createdModel.Id)
// Create a model version
inputVersion := &openapi.ModelVersion{
Name: "test-version",
Description: apiutils.Of("Test version description"),
Author: apiutils.Of("test-author"),
ExternalId: apiutils.Of("version-ext-123"),
State: apiutils.Of(openapi.MODELVERSIONSTATE_LIVE),
RegisteredModelId: *createdModel.Id,
}
result, err := _service.UpsertModelVersion(inputVersion, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, result)
assert.NotNil(t, result.Id)
assert.Equal(t, "test-version", result.Name)
assert.Equal(t, "version-ext-123", *result.ExternalId)
assert.Equal(t, "Test version description", *result.Description)
assert.Equal(t, "test-author", *result.Author)
assert.Equal(t, openapi.MODELVERSIONSTATE_LIVE, *result.State)
assert.Equal(t, *createdModel.Id, result.RegisteredModelId)
assert.NotNil(t, result.CreateTimeSinceEpoch)
assert.NotNil(t, result.LastUpdateTimeSinceEpoch)
})
t.Run("successful update", func(t *testing.T) {
// Create a registered model
registeredModel := &openapi.RegisteredModel{
Name: "update-test-registered-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Create a model version
inputVersion := &openapi.ModelVersion{
Name: "update-test-version",
Description: apiutils.Of("Original description"),
RegisteredModelId: *createdModel.Id,
}
created, err := _service.UpsertModelVersion(inputVersion, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, created.Id)
// Update the version
updateVersion := &openapi.ModelVersion{
Id: created.Id,
Name: "update-test-version", // Name should remain the same
Description: apiutils.Of("Updated description"),
Author: apiutils.Of("new-author"),
State: apiutils.Of(openapi.MODELVERSIONSTATE_ARCHIVED),
RegisteredModelId: *createdModel.Id,
}
updated, err := _service.UpsertModelVersion(updateVersion, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, updated)
assert.Equal(t, *created.Id, *updated.Id)
assert.Equal(t, "update-test-version", updated.Name)
assert.Equal(t, "Updated description", *updated.Description)
assert.Equal(t, "new-author", *updated.Author)
assert.Equal(t, openapi.MODELVERSIONSTATE_ARCHIVED, *updated.State)
})
t.Run("create with custom properties", func(t *testing.T) {
// Create a registered model
registeredModel := &openapi.RegisteredModel{
Name: "custom-props-registered-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
customProps := map[string]openapi.MetadataValue{
"accuracy": {
MetadataDoubleValue: &openapi.MetadataDoubleValue{
DoubleValue: 0.95,
},
},
"framework": {
MetadataStringValue: &openapi.MetadataStringValue{
StringValue: "tensorflow",
},
},
"epochs": {
MetadataIntValue: &openapi.MetadataIntValue{
IntValue: "100",
},
},
"is_production": {
MetadataBoolValue: &openapi.MetadataBoolValue{
BoolValue: true,
},
},
}
inputVersion := &openapi.ModelVersion{
Name: "custom-props-version",
RegisteredModelId: *createdModel.Id,
CustomProperties: &customProps,
}
result, err := _service.UpsertModelVersion(inputVersion, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, "custom-props-version", result.Name)
assert.NotNil(t, result.CustomProperties)
resultProps := *result.CustomProperties
assert.Contains(t, resultProps, "accuracy")
assert.Contains(t, resultProps, "framework")
assert.Contains(t, resultProps, "epochs")
assert.Contains(t, resultProps, "is_production")
assert.Equal(t, 0.95, resultProps["accuracy"].MetadataDoubleValue.DoubleValue)
assert.Equal(t, "tensorflow", resultProps["framework"].MetadataStringValue.StringValue)
assert.Equal(t, "100", resultProps["epochs"].MetadataIntValue.IntValue)
assert.Equal(t, true, resultProps["is_production"].MetadataBoolValue.BoolValue)
})
t.Run("minimal version", func(t *testing.T) {
// Create a registered model
registeredModel := &openapi.RegisteredModel{
Name: "minimal-registered-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
inputVersion := &openapi.ModelVersion{
Name: "minimal-version",
RegisteredModelId: *createdModel.Id,
}
result, err := _service.UpsertModelVersion(inputVersion, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, "minimal-version", result.Name)
assert.NotNil(t, result.Id)
assert.Equal(t, *createdModel.Id, result.RegisteredModelId)
})
t.Run("nil model version error", func(t *testing.T) {
result, err := _service.UpsertModelVersion(nil, apiutils.Of("1"))
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "invalid model version pointer")
})
t.Run("nil fields preserved", func(t *testing.T) {
// Create a registered model first
registeredModel := &openapi.RegisteredModel{
Name: "nil-fields-registered-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Create model version with nil optional fields
inputVersion := &openapi.ModelVersion{
Name: "nil-fields-version",
Description: nil, // Explicitly set to nil
Author: nil, // Explicitly set to nil
ExternalId: nil, // Explicitly set to nil
State: nil, // Explicitly set to nil
RegisteredModelId: *createdModel.Id,
}
result, err := _service.UpsertModelVersion(inputVersion, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, result)
assert.NotNil(t, result.Id)
assert.Equal(t, "nil-fields-version", result.Name)
assert.Equal(t, *createdModel.Id, result.RegisteredModelId)
assert.Nil(t, result.Description) // Verify description remains nil
assert.Nil(t, result.Author) // Verify author remains nil
assert.Nil(t, result.ExternalId) // Verify external ID remains nil
assert.Nil(t, result.State) // Verify state remains nil
assert.NotNil(t, result.CreateTimeSinceEpoch)
assert.NotNil(t, result.LastUpdateTimeSinceEpoch)
})
t.Run("unicode characters in name", func(t *testing.T) {
// Create a registered model first
registeredModel := &openapi.RegisteredModel{
Name: "unicode-test-registered-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Test with unicode characters: Chinese, Russian, Japanese, and emoji
unicodeName := "模型版本-тест-モデルバージョン-🚀"
inputVersion := &openapi.ModelVersion{
Name: unicodeName,
Description: apiutils.Of("Test model version with unicode characters"),
Author: apiutils.Of("测试作者-тестовый автор-テスト作者"),
RegisteredModelId: *createdModel.Id,
}
result, err := _service.UpsertModelVersion(inputVersion, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, unicodeName, result.Name)
assert.Equal(t, "Test model version with unicode characters", *result.Description)
assert.Equal(t, "测试作者-тестовый автор-テスト作者", *result.Author)
assert.NotNil(t, result.Id)
assert.NotNil(t, result.CreateTimeSinceEpoch)
assert.NotNil(t, result.LastUpdateTimeSinceEpoch)
// Verify we can retrieve it by ID
retrieved, err := _service.GetModelVersionById(*result.Id)
require.NoError(t, err)
assert.Equal(t, unicodeName, retrieved.Name)
assert.Equal(t, "测试作者-тестовый автор-テスト作者", *retrieved.Author)
})
t.Run("special characters in name", func(t *testing.T) {
// Create a registered model first
registeredModel := &openapi.RegisteredModel{
Name: "special-chars-test-registered-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Test with various special characters
specialName := "!@#$%^&*()_+-=[]{}|;':\",./<>?"
inputVersion := &openapi.ModelVersion{
Name: specialName,
Description: apiutils.Of("Test model version with special characters"),
Author: apiutils.Of("author@#$%^&*()"),
RegisteredModelId: *createdModel.Id,
}
result, err := _service.UpsertModelVersion(inputVersion, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, specialName, result.Name)
assert.Equal(t, "Test model version with special characters", *result.Description)
assert.Equal(t, "author@#$%^&*()", *result.Author)
assert.NotNil(t, result.Id)
// Verify we can retrieve it by ID
retrieved, err := _service.GetModelVersionById(*result.Id)
require.NoError(t, err)
assert.Equal(t, specialName, retrieved.Name)
assert.Equal(t, "author@#$%^&*()", *retrieved.Author)
})
t.Run("mixed unicode and special characters", func(t *testing.T) {
// Create a registered model first
registeredModel := &openapi.RegisteredModel{
Name: "mixed-chars-test-registered-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Test with mixed unicode and special characters
mixedName := "模型@#$%版本-тест!@#-モデル()バージョン-🚀[]"
inputVersion := &openapi.ModelVersion{
Name: mixedName,
Description: apiutils.Of("Test model version with mixed unicode and special characters"),
Author: apiutils.Of("作者@#$%-автор!@#-作者()🚀"),
RegisteredModelId: *createdModel.Id,
}
result, err := _service.UpsertModelVersion(inputVersion, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, mixedName, result.Name)
assert.Equal(t, "Test model version with mixed unicode and special characters", *result.Description)
assert.Equal(t, "作者@#$%-автор!@#-作者()🚀", *result.Author)
assert.NotNil(t, result.Id)
// Verify we can retrieve it by ID
retrieved, err := _service.GetModelVersionById(*result.Id)
require.NoError(t, err)
assert.Equal(t, mixedName, retrieved.Name)
assert.Equal(t, "作者@#$%-автор!@#-作者()🚀", *retrieved.Author)
})
t.Run("pagination with 10+ model versions", func(t *testing.T) {
// Create a registered model first
registeredModel := &openapi.RegisteredModel{
Name: "paging-test-registered-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Create 15 model versions for pagination testing
var createdVersions []string
for i := 0; i < 15; i++ {
versionName := "paging-test-model-version-" + fmt.Sprintf("%02d", i)
inputVersion := &openapi.ModelVersion{
Name: versionName,
Description: apiutils.Of("Pagination test model version " + fmt.Sprintf("%02d", i)),
Author: apiutils.Of("test-author-" + fmt.Sprintf("%02d", i)),
RegisteredModelId: *createdModel.Id,
}
result, err := _service.UpsertModelVersion(inputVersion, createdModel.Id)
require.NoError(t, err)
createdVersions = append(createdVersions, *result.Id)
}
// Test pagination with page size 5
pageSize := int32(5)
orderBy := "name"
sortOrder := "ASC"
listOptions := api.ListOptions{
PageSize: &pageSize,
OrderBy: &orderBy,
SortOrder: &sortOrder,
}
// Get first page
firstPage, err := _service.GetModelVersions(listOptions, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, firstPage)
assert.LessOrEqual(t, len(firstPage.Items), 5, "First page should have at most 5 items")
assert.Equal(t, int32(5), firstPage.PageSize)
// Filter to only our test model versions in first page
var firstPageTestVersions []openapi.ModelVersion
firstPageIds := make(map[string]bool)
for _, item := range firstPage.Items {
// Only include our test versions (those with the specific prefix)
if strings.HasPrefix(item.Name, "paging-test-model-version-") {
assert.False(t, firstPageIds[*item.Id], "Should not have duplicate IDs in first page")
firstPageIds[*item.Id] = true
firstPageTestVersions = append(firstPageTestVersions, item)
}
}
// Only proceed with second page test if we have a next page token and found test versions
if firstPage.NextPageToken != "" && len(firstPageTestVersions) > 0 {
// Get second page using next page token
listOptions.NextPageToken = &firstPage.NextPageToken
secondPage, err := _service.GetModelVersions(listOptions, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, secondPage)
assert.LessOrEqual(t, len(secondPage.Items), 5, "Second page should have at most 5 items")
// Verify no duplicates between pages (only check our test versions)
for _, item := range secondPage.Items {
if strings.HasPrefix(item.Name, "paging-test-model-version-") {
assert.False(t, firstPageIds[*item.Id], "Should not have duplicate IDs between pages")
}
}
}
// Test with larger page size
largePage := int32(100)
listOptions = api.ListOptions{
PageSize: &largePage,
OrderBy: &orderBy,
SortOrder: &sortOrder,
}
allItems, err := _service.GetModelVersions(listOptions, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, allItems)
assert.GreaterOrEqual(t, len(allItems.Items), 15, "Should have at least our 15 test model versions")
// Count our test versions in the results
foundCount := 0
for _, item := range allItems.Items {
for _, createdId := range createdVersions {
if *item.Id == createdId {
foundCount++
break
}
}
}
assert.Equal(t, 15, foundCount, "Should find all 15 created model versions")
// Test descending order
descOrder := "DESC"
listOptions = api.ListOptions{
PageSize: &pageSize,
OrderBy: &orderBy,
SortOrder: &descOrder,
}
descPage, err := _service.GetModelVersions(listOptions, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, descPage)
assert.LessOrEqual(t, len(descPage.Items), 5, "Desc page should have at most 5 items")
// Verify ordering (names should be in descending order)
if len(descPage.Items) > 1 {
for i := 1; i < len(descPage.Items); i++ {
assert.GreaterOrEqual(t, descPage.Items[i-1].Name, descPage.Items[i].Name,
"Items should be in descending order by name")
}
}
})
}
func TestGetModelVersionById(t *testing.T) {
_service, cleanup := SetupModelRegistryService(t)
defer cleanup()
t.Run("successful get", func(t *testing.T) {
// Create a registered model
registeredModel := &openapi.RegisteredModel{
Name: "get-test-registered-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Create a model version
inputVersion := &openapi.ModelVersion{
Name: "get-test-version",
Description: apiutils.Of("Test description"),
ExternalId: apiutils.Of("get-ext-123"),
State: apiutils.Of(openapi.MODELVERSIONSTATE_LIVE),
RegisteredModelId: *createdModel.Id,
}
created, err := _service.UpsertModelVersion(inputVersion, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, created.Id)
// Get the version by ID
result, err := _service.GetModelVersionById(*created.Id)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, *created.Id, *result.Id)
assert.Equal(t, "get-test-version", result.Name)
assert.Equal(t, "get-ext-123", *result.ExternalId)
assert.Equal(t, "Test description", *result.Description)
assert.Equal(t, openapi.MODELVERSIONSTATE_LIVE, *result.State)
assert.Equal(t, *createdModel.Id, result.RegisteredModelId)
})
t.Run("invalid id", func(t *testing.T) {
result, err := _service.GetModelVersionById("invalid")
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "invalid")
})
t.Run("non-existent id", func(t *testing.T) {
result, err := _service.GetModelVersionById("99999")
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "no model version found")
})
}
func TestGetModelVersionByInferenceService(t *testing.T) {
_service, cleanup := SetupModelRegistryService(t)
defer cleanup()
t.Run("successful get with specific model version", func(t *testing.T) {
// Create a registered model
registeredModel := &openapi.RegisteredModel{
Name: "inference-test-registered-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Create a model version
modelVersion := &openapi.ModelVersion{
Name: "inference-test-version",
RegisteredModelId: *createdModel.Id,
}
createdVersion, err := _service.UpsertModelVersion(modelVersion, createdModel.Id)
require.NoError(t, err)
// Create a serving environment
servingEnv := &openapi.ServingEnvironment{
Name: "inference-test-env",
}
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
require.NoError(t, err)
// Create an inference service with specific model version
inferenceService := &openapi.InferenceService{
Name: apiutils.Of("test-inference-service"),
ServingEnvironmentId: *createdEnv.Id,
RegisteredModelId: *createdModel.Id,
ModelVersionId: createdVersion.Id,
}
createdInference, err := _service.UpsertInferenceService(inferenceService)
require.NoError(t, err)
// Get model version by inference service
result, err := _service.GetModelVersionByInferenceService(*createdInference.Id)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, *createdVersion.Id, *result.Id)
assert.Equal(t, "inference-test-version", result.Name)
})
t.Run("successful get with latest version", func(t *testing.T) {
// Create a registered model
registeredModel := &openapi.RegisteredModel{
Name: "latest-version-test-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Create multiple model versions
version1 := &openapi.ModelVersion{
Name: "version-1",
RegisteredModelId: *createdModel.Id,
}
_, err = _service.UpsertModelVersion(version1, createdModel.Id)
require.NoError(t, err)
version2 := &openapi.ModelVersion{
Name: "version-2",
RegisteredModelId: *createdModel.Id,
}
createdVersion2, err := _service.UpsertModelVersion(version2, createdModel.Id)
require.NoError(t, err)
// Create a serving environment
servingEnv := &openapi.ServingEnvironment{
Name: "latest-version-test-env",
}
createdEnv, err := _service.UpsertServingEnvironment(servingEnv)
require.NoError(t, err)
// Create an inference service without specific model version (should get latest)
inferenceService := &openapi.InferenceService{
Name: apiutils.Of("latest-version-inference-service"),
ServingEnvironmentId: *createdEnv.Id,
RegisteredModelId: *createdModel.Id,
// ModelVersionId is nil, should get latest
}
createdInference, err := _service.UpsertInferenceService(inferenceService)
require.NoError(t, err)
// Get model version by inference service (should return latest)
result, err := _service.GetModelVersionByInferenceService(*createdInference.Id)
require.NoError(t, err)
require.NotNil(t, result)
// Should be the latest version (version2)
assert.Equal(t, *createdVersion2.Id, *result.Id)
assert.Equal(t, "version-2", result.Name)
})
t.Run("invalid inference service id", func(t *testing.T) {
result, err := _service.GetModelVersionByInferenceService("invalid")
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "invalid syntax: bad request")
})
t.Run("non-existent inference service", func(t *testing.T) {
result, err := _service.GetModelVersionByInferenceService("99999")
assert.Error(t, err)
assert.Nil(t, result)
})
}
func TestGetModelVersionByParams(t *testing.T) {
_service, cleanup := SetupModelRegistryService(t)
defer cleanup()
t.Run("successful get by name and registered model id", func(t *testing.T) {
// Create a registered model
registeredModel := &openapi.RegisteredModel{
Name: "params-test-registered-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Create a model version
inputVersion := &openapi.ModelVersion{
Name: "params-test-version",
ExternalId: apiutils.Of("params-ext-123"),
RegisteredModelId: *createdModel.Id,
}
created, err := _service.UpsertModelVersion(inputVersion, createdModel.Id)
require.NoError(t, err)
// Get by name and registered model ID
versionName := "params-test-version"
result, err := _service.GetModelVersionByParams(&versionName, createdModel.Id, nil)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, *created.Id, *result.Id)
assert.Equal(t, "params-test-version", result.Name)
assert.Equal(t, *createdModel.Id, result.RegisteredModelId)
})
t.Run("successful get by external id", func(t *testing.T) {
// Create a registered model
registeredModel := &openapi.RegisteredModel{
Name: "params-ext-test-registered-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Create a model version
inputVersion := &openapi.ModelVersion{
Name: "params-ext-test-version",
ExternalId: apiutils.Of("params-unique-ext-456"),
RegisteredModelId: *createdModel.Id,
}
created, err := _service.UpsertModelVersion(inputVersion, createdModel.Id)
require.NoError(t, err)
// Get by external ID
externalId := "params-unique-ext-456"
result, err := _service.GetModelVersionByParams(nil, nil, &externalId)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, *created.Id, *result.Id)
assert.Equal(t, "params-ext-test-version", result.Name)
assert.Equal(t, "params-unique-ext-456", *result.ExternalId)
})
t.Run("invalid parameters", func(t *testing.T) {
result, err := _service.GetModelVersionByParams(nil, nil, nil)
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "invalid parameters call")
})
t.Run("no version found", func(t *testing.T) {
versionName := "nonexistent-version"
registeredModelId := "999"
result, err := _service.GetModelVersionByParams(&versionName, &registeredModelId, nil)
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "no model versions found")
})
t.Run("same version name across different models", func(t *testing.T) {
// This test catches the bug where ParentResourceID was not being used to filter model versions
// Create first registered model
registeredModel1 := &openapi.RegisteredModel{
Name: "model-with-shared-version-1",
}
createdModel1, err := _service.UpsertRegisteredModel(registeredModel1)
require.NoError(t, err)
// Create second registered model
registeredModel2 := &openapi.RegisteredModel{
Name: "model-with-shared-version-2",
}
createdModel2, err := _service.UpsertRegisteredModel(registeredModel2)
require.NoError(t, err)
// Create version "shared-version-name-test" for the first model
version1 := &openapi.ModelVersion{
Name: "shared-version-name-test",
RegisteredModelId: *createdModel1.Id,
Description: apiutils.Of("Version for model 1"),
}
createdVersion1, err := _service.UpsertModelVersion(version1, createdModel1.Id)
require.NoError(t, err)
// Create version "shared-version-name-test" for the second model
version2 := &openapi.ModelVersion{
Name: "shared-version-name-test",
RegisteredModelId: *createdModel2.Id,
Description: apiutils.Of("Version for model 2"),
}
createdVersion2, err := _service.UpsertModelVersion(version2, createdModel2.Id)
require.NoError(t, err)
// Query for version "shared-version-name-test" of the first model
versionName := "shared-version-name-test"
result1, err := _service.GetModelVersionByParams(&versionName, createdModel1.Id, nil)
require.NoError(t, err)
require.NotNil(t, result1)
assert.Equal(t, *createdVersion1.Id, *result1.Id)
assert.Equal(t, *createdModel1.Id, result1.RegisteredModelId)
assert.Equal(t, "Version for model 1", *result1.Description)
// Query for version "shared-version-name-test" of the second model
result2, err := _service.GetModelVersionByParams(&versionName, createdModel2.Id, nil)
require.NoError(t, err)
require.NotNil(t, result2)
assert.Equal(t, *createdVersion2.Id, *result2.Id)
assert.Equal(t, *createdModel2.Id, result2.RegisteredModelId)
assert.Equal(t, "Version for model 2", *result2.Description)
// Ensure we got different versions
assert.NotEqual(t, *result1.Id, *result2.Id)
})
}
func TestGetModelVersions(t *testing.T) {
_service, cleanup := SetupModelRegistryService(t)
defer cleanup()
t.Run("successful list", func(t *testing.T) {
// Create a registered model
registeredModel := &openapi.RegisteredModel{
Name: "list-test-registered-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Create multiple model versions
versions := []*openapi.ModelVersion{
{Name: "list-version-1", ExternalId: apiutils.Of("list-ext-1"), RegisteredModelId: *createdModel.Id},
{Name: "list-version-2", ExternalId: apiutils.Of("list-ext-2"), RegisteredModelId: *createdModel.Id},
{Name: "list-version-3", ExternalId: apiutils.Of("list-ext-3"), RegisteredModelId: *createdModel.Id},
}
var createdIds []string
for _, version := range versions {
created, err := _service.UpsertModelVersion(version, createdModel.Id)
require.NoError(t, err)
createdIds = append(createdIds, *created.Id)
}
// List versions for the registered model
pageSize := int32(10)
listOptions := api.ListOptions{
PageSize: &pageSize,
}
result, err := _service.GetModelVersions(listOptions, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, result)
assert.Len(t, result.Items, 3) // Should have exactly our 3 versions
assert.Equal(t, int32(10), result.PageSize)
// Verify our versions are in the result
foundVersions := 0
for _, item := range result.Items {
for _, createdId := range createdIds {
if *item.Id == createdId {
foundVersions++
break
}
}
}
assert.Equal(t, 3, foundVersions, "All created versions should be found in the list")
})
t.Run("list all versions without registered model filter", func(t *testing.T) {
// Create multiple registered models with versions
for i := 0; i < 2; i++ {
registeredModel := &openapi.RegisteredModel{
Name: "all-versions-model-" + string(rune('A'+i)),
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
version := &openapi.ModelVersion{
Name: "all-versions-version-" + string(rune('A'+i)),
RegisteredModelId: *createdModel.Id,
}
_, err = _service.UpsertModelVersion(version, createdModel.Id)
require.NoError(t, err)
}
// List all versions (no registered model filter)
listOptions := api.ListOptions{}
result, err := _service.GetModelVersions(listOptions, nil)
require.NoError(t, err)
require.NotNil(t, result)
assert.GreaterOrEqual(t, len(result.Items), 2) // Should have at least our 2 versions
})
t.Run("pagination and ordering", func(t *testing.T) {
// Create a registered model
registeredModel := &openapi.RegisteredModel{
Name: "pagination-test-registered-model",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Create several versions for pagination testing
for i := 0; i < 5; i++ {
version := &openapi.ModelVersion{
Name: "pagination-version-" + string(rune('A'+i)),
ExternalId: apiutils.Of("pagination-ext-" + string(rune('A'+i))),
RegisteredModelId: *createdModel.Id,
}
_, err := _service.UpsertModelVersion(version, createdModel.Id)
require.NoError(t, err)
}
// Test with small page size and ordering
pageSize := int32(2)
orderBy := "name"
sortOrder := "asc"
listOptions := api.ListOptions{
PageSize: &pageSize,
OrderBy: &orderBy,
SortOrder: &sortOrder,
}
result, err := _service.GetModelVersions(listOptions, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, result)
assert.GreaterOrEqual(t, len(result.Items), 2) // Should have at least 2 items
assert.Equal(t, int32(2), result.PageSize)
})
t.Run("invalid registered model id", func(t *testing.T) {
invalidId := "invalid"
listOptions := api.ListOptions{}
result, err := _service.GetModelVersions(listOptions, &invalidId)
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "invalid syntax: bad request")
})
}
func TestModelVersionRoundTrip(t *testing.T) {
_service, cleanup := SetupModelRegistryService(t)
defer cleanup()
t.Run("complete roundtrip", func(t *testing.T) {
// Create a registered model
registeredModel := &openapi.RegisteredModel{
Name: "roundtrip-registered-model",
Description: apiutils.Of("Roundtrip test registered model"),
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Create a model version with all fields
originalVersion := &openapi.ModelVersion{
Name: "roundtrip-version",
Description: apiutils.Of("Roundtrip test description"),
Author: apiutils.Of("roundtrip-author"),
ExternalId: apiutils.Of("roundtrip-ext-123"),
State: apiutils.Of(openapi.MODELVERSIONSTATE_LIVE),
RegisteredModelId: *createdModel.Id,
}
// Create
created, err := _service.UpsertModelVersion(originalVersion, createdModel.Id)
require.NoError(t, err)
require.NotNil(t, created.Id)
// Get by ID
retrieved, err := _service.GetModelVersionById(*created.Id)
require.NoError(t, err)
// Verify all fields match
assert.Equal(t, *created.Id, *retrieved.Id)
assert.Equal(t, "roundtrip-version", retrieved.Name)
assert.Equal(t, *originalVersion.Description, *retrieved.Description)
assert.Equal(t, *originalVersion.Author, *retrieved.Author)
assert.Equal(t, *originalVersion.ExternalId, *retrieved.ExternalId)
assert.Equal(t, *originalVersion.State, *retrieved.State)
assert.Equal(t, originalVersion.RegisteredModelId, retrieved.RegisteredModelId)
// Update
retrieved.Description = apiutils.Of("Updated description")
retrieved.State = apiutils.Of(openapi.MODELVERSIONSTATE_ARCHIVED)
updated, err := _service.UpsertModelVersion(retrieved, createdModel.Id)
require.NoError(t, err)
// Verify update
assert.Equal(t, *created.Id, *updated.Id)
assert.Equal(t, "Updated description", *updated.Description)
assert.Equal(t, openapi.MODELVERSIONSTATE_ARCHIVED, *updated.State)
// Get again to verify persistence
final, err := _service.GetModelVersionById(*created.Id)
require.NoError(t, err)
assert.Equal(t, "Updated description", *final.Description)
assert.Equal(t, openapi.MODELVERSIONSTATE_ARCHIVED, *final.State)
})
}
func TestGetModelVersionsWithFilterQuery(t *testing.T) {
_service, cleanup := SetupModelRegistryService(t)
defer cleanup()
// Create a registered model to associate versions with
registeredModel := &openapi.RegisteredModel{
Name: "test-model-for-versions",
}
createdModel, err := _service.UpsertRegisteredModel(registeredModel)
require.NoError(t, err)
// Create test model versions with various properties for filtering
testVersions := []struct {
version *openapi.ModelVersion
}{
{
version: &openapi.ModelVersion{
Name: "v1.0.0",
Description: apiutils.Of("Initial release version"),
ExternalId: apiutils.Of("ext-v1-001"),
Author: apiutils.Of("alice"),
State: (*openapi.ModelVersionState)(apiutils.Of("LIVE")),
RegisteredModelId: *createdModel.Id,
CustomProperties: &map[string]openapi.MetadataValue{
"stage": {
MetadataStringValue: &openapi.MetadataStringValue{
StringValue: "production",
MetadataType: "MetadataStringValue",
},
},
"accuracy": {
MetadataDoubleValue: &openapi.MetadataDoubleValue{
DoubleValue: 0.95,
MetadataType: "MetadataDoubleValue",
},
},
"batch_size": {
MetadataIntValue: &openapi.MetadataIntValue{
IntValue: "32",
MetadataType: "MetadataIntValue",
},
},
},
},
},
{
version: &openapi.ModelVersion{
Name: "v2.0.0",
Description: apiutils.Of("Major update with improvements"),
ExternalId: apiutils.Of("ext-v2-002"),
Author: apiutils.Of("bob"),
State: (*openapi.ModelVersionState)(apiutils.Of("ARCHIVED")),
RegisteredModelId: *createdModel.Id,
CustomProperties: &map[string]openapi.MetadataValue{
"stage": {
MetadataStringValue: &openapi.MetadataStringValue{
StringValue: "staging",
MetadataType: "MetadataStringValue",
},
},
"accuracy": {
MetadataDoubleValue: &openapi.MetadataDoubleValue{
DoubleValue: 0.87,
MetadataType: "MetadataDoubleValue",
},
},
"batch_size": {
MetadataIntValue: &openapi.MetadataIntValue{
IntValue: "64",
MetadataType: "MetadataIntValue",
},
},
},
},
},
{
version: &openapi.ModelVersion{
Name: "v1.1.0",
Description: apiutils.Of("Minor update with bug fixes"),
ExternalId: apiutils.Of("ext-v1-1-003"),
Author: apiutils.Of("alice"),
RegisteredModelId: *createdModel.Id,
CustomProperties: &map[string]openapi.MetadataValue{
"stage": {
MetadataStringValue: &openapi.MetadataStringValue{
StringValue: "production",
MetadataType: "MetadataStringValue",
},
},
"accuracy": {
MetadataDoubleValue: &openapi.MetadataDoubleValue{
DoubleValue: 0.96,
MetadataType: "MetadataDoubleValue",
},
},
"batch_size": {
MetadataIntValue: &openapi.MetadataIntValue{
IntValue: "32",
MetadataType: "MetadataIntValue",
},
},
"experimental": {
MetadataBoolValue: &openapi.MetadataBoolValue{
BoolValue: true,
MetadataType: "MetadataBoolValue",
},
},
},
},
},
{
version: &openapi.ModelVersion{
Name: "v3.0.0-beta",
Description: apiutils.Of("Beta version for testing"),
ExternalId: apiutils.Of("ext-v3-beta-004"),
Author: apiutils.Of("charlie"),
RegisteredModelId: *createdModel.Id,
CustomProperties: &map[string]openapi.MetadataValue{
"stage": {
MetadataStringValue: &openapi.MetadataStringValue{
StringValue: "development",
MetadataType: "MetadataStringValue",
},
},
"accuracy": {
MetadataDoubleValue: &openapi.MetadataDoubleValue{
DoubleValue: 0.92,
MetadataType: "MetadataDoubleValue",
},
},
},
},
},
}
// Create all test versions
for _, tv := range testVersions {
_, err := _service.UpsertModelVersion(tv.version, createdModel.Id)
require.NoError(t, err)
}
testCases := []struct {
name string
filterQuery string
expectedCount int
expectedNames []string
}{
{
name: "Filter by exact name",
filterQuery: "name = 'v1.0.0'",
expectedCount: 1,
expectedNames: []string{"v1.0.0"},
},
{
name: "Filter by name pattern",
filterQuery: "name LIKE 'v1.%'",
expectedCount: 2,
expectedNames: []string{"v1.0.0", "v1.1.0"},
},
{
name: "Filter by description",
filterQuery: "description LIKE '%bug fixes%'",
expectedCount: 1,
expectedNames: []string{"v1.1.0"},
},
{
name: "Filter by author",
filterQuery: "author = 'alice'",
expectedCount: 2,
expectedNames: []string{"v1.0.0", "v1.1.0"},
},
{
name: "Filter by external ID",
filterQuery: "externalId = 'ext-v2-002'",
expectedCount: 1,
expectedNames: []string{"v2.0.0"},
},
{
name: "Filter by state",
filterQuery: "state = 'ARCHIVED'",
expectedCount: 1,
expectedNames: []string{"v2.0.0"},
},
{
name: "Filter by custom property - string",
filterQuery: "stage = 'production'",
expectedCount: 2,
expectedNames: []string{"v1.0.0", "v1.1.0"},
},
{
name: "OpenAPI example: Filter by name and state",
filterQuery: "name='v1.0.0' AND state='LIVE'",
expectedCount: 1,
expectedNames: []string{"v1.0.0"},
},
{
name: "Filter by custom property - numeric comparison",
filterQuery: "accuracy > 0.94",
expectedCount: 2,
expectedNames: []string{"v1.0.0", "v1.1.0"},
},
{
name: "Filter by custom property - integer",
filterQuery: "batch_size = 32",
expectedCount: 2,
expectedNames: []string{"v1.0.0", "v1.1.0"},
},
{
name: "Filter by custom property - boolean",
filterQuery: "experimental = true",
expectedCount: 1,
expectedNames: []string{"v1.1.0"},
},
{
name: "Complex filter with AND",
filterQuery: "stage = 'production' AND accuracy > 0.95",
expectedCount: 1,
expectedNames: []string{"v1.1.0"},
},
{
name: "Complex filter with OR",
filterQuery: "author = 'alice' OR author = 'charlie'",
expectedCount: 3,
expectedNames: []string{"v1.0.0", "v1.1.0", "v3.0.0-beta"},
},
{
name: "Complex filter with parentheses",
filterQuery: "(stage = 'production' OR stage = 'staging') AND accuracy < 0.95",
expectedCount: 1,
expectedNames: []string{"v2.0.0"},
},
{
name: "Case insensitive pattern matching",
filterQuery: "name ILIKE '%BETA%'",
expectedCount: 1,
expectedNames: []string{"v3.0.0-beta"},
},
{
name: "Filter with NOT condition",
filterQuery: "stage != 'development'",
expectedCount: 3,
expectedNames: []string{"v1.0.0", "v2.0.0", "v1.1.0"},
},
{
name: "Filter by name pattern with version suffix",
filterQuery: "name LIKE '%-beta'",
expectedCount: 1,
expectedNames: []string{"v3.0.0-beta"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
pageSize := int32(10)
listOptions := api.ListOptions{
PageSize: &pageSize,
FilterQuery: &tc.filterQuery,
}
result, err := _service.GetModelVersions(listOptions, nil)
require.NoError(t, err)
require.NotNil(t, result)
// Extract names from results
var actualNames []string
for _, item := range result.Items {
for _, expectedName := range tc.expectedNames {
if item.Name == expectedName {
actualNames = append(actualNames, item.Name)
break
}
}
}
assert.Equal(t, tc.expectedCount, len(actualNames),
"Expected %d versions for filter '%s', but got %d",
tc.expectedCount, tc.filterQuery, len(actualNames))
// Verify the expected versions are present
assert.ElementsMatch(t, tc.expectedNames, actualNames,
"Expected versions %v for filter '%s', but got %v",
tc.expectedNames, tc.filterQuery, actualNames)
})
}
// Test error cases
t.Run("Invalid filter syntax", func(t *testing.T) {
pageSize := int32(10)
invalidFilter := "invalid <<< syntax"
listOptions := api.ListOptions{
PageSize: &pageSize,
FilterQuery: &invalidFilter,
}
result, err := _service.GetModelVersions(listOptions, nil)
if assert.Error(t, err) {
assert.Nil(t, result)
assert.Contains(t, err.Error(), "invalid filter query")
}
})
// Test combining filterQuery with registeredModelId parameter
t.Run("Filter with registeredModelId parameter", func(t *testing.T) {
// Create another registered model with versions
anotherModel := &openapi.RegisteredModel{
Name: "another-model",
}
anotherCreatedModel, err := _service.UpsertRegisteredModel(anotherModel)
require.NoError(t, err)
anotherVersion := &openapi.ModelVersion{
Name: "v1.0.0",
RegisteredModelId: *anotherCreatedModel.Id,
CustomProperties: &map[string]openapi.MetadataValue{
"stage": {
MetadataStringValue: &openapi.MetadataStringValue{
StringValue: "production",
MetadataType: "MetadataStringValue",
},
},
},
}
_, err = _service.UpsertModelVersion(anotherVersion, anotherCreatedModel.Id)
require.NoError(t, err)
// Filter by stage=production should return versions from both models
pageSize := int32(10)
filterQuery := "stage = 'production'"
listOptions := api.ListOptions{
PageSize: &pageSize,
FilterQuery: &filterQuery,
}
// Without registeredModelId - should get 3 (2 from first model + 1 from second)
allResult, err := _service.GetModelVersions(listOptions, nil)
require.NoError(t, err)
assert.Equal(t, 3, len(allResult.Items))
// With registeredModelId - should only get 2 from first model
filteredResult, err := _service.GetModelVersions(listOptions, createdModel.Id)
require.NoError(t, err)
assert.Equal(t, 2, len(filteredResult.Items))
for _, item := range filteredResult.Items {
assert.Equal(t, *createdModel.Id, item.RegisteredModelId)
}
})
// Test combining filterQuery with pagination
t.Run("Filter with pagination", func(t *testing.T) {
pageSize := int32(1)
filterQuery := "stage = 'production'"
listOptions := api.ListOptions{
PageSize: &pageSize,
FilterQuery: &filterQuery,
}
// Get first page
firstPage, err := _service.GetModelVersions(listOptions, createdModel.Id)
require.NoError(t, err)
assert.Equal(t, 1, len(firstPage.Items))
assert.NotEmpty(t, firstPage.NextPageToken)
// Get second page
listOptions.NextPageToken = &firstPage.NextPageToken
secondPage, err := _service.GetModelVersions(listOptions, createdModel.Id)
require.NoError(t, err)
assert.Equal(t, 1, len(secondPage.Items))
// Ensure different items on each page
assert.NotEqual(t, firstPage.Items[0].Id, secondPage.Items[0].Id)
})
// Test empty results
t.Run("Filter with no matches", func(t *testing.T) {
pageSize := int32(10)
filterQuery := "stage = 'nonexistent'"
listOptions := api.ListOptions{
PageSize: &pageSize,
FilterQuery: &filterQuery,
}
result, err := _service.GetModelVersions(listOptions, nil)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, 0, len(result.Items))
assert.Equal(t, int32(0), result.Size)
})
}