Merge model-catalog-enhancements branch (#1656)

* feat(catalog): deploy postgres for model catalog (#1584)

* feat(catalog): initialize postgres for model catalog

Initialize PostgreSQL database for model catalog with connection setup.
Refactored embedmd initialization code from model registry for reuse.
Added Kubernetes manifests for PostgreSQL StatefulSet, PVC, and service
configuration.

Signed-off-by: Paul Boyd <paul@pboyd.io>

* fix(catalog): remove chdir before loading catalogs

Instead of changing directories, pass the directory that files should be
relative to when loading a catalog.

Signed-off-by: Paul Boyd <paul@pboyd.io>

---------

Signed-off-by: Paul Boyd <paul@pboyd.io>

* feat(catalog): add postgres for catalog dev (#1623)

Adding to tilt and docker compose.

Signed-off-by: Paul Boyd <paul@pboyd.io>

* feat(datastore): refactor type creation (#1636)

- Refactor embedmd to created types from a spec, instead of database
  migrations.
- Add TypeRepository and TypePropertyRepository services

Signed-off-by: Paul Boyd <paul@pboyd.io>

* Add artifact types to the model catalog (#1649)

* chore: update testutils to take a datastore spec argument

Signed-off-by: Paul Boyd <paul@pboyd.io>

* feat(catalog): add artifact types to openapi spec

And re-generate code.

Signed-off-by: Paul Boyd <paul@pboyd.io>

* feat(catalog): add database models and services

Signed-off-by: Paul Boyd <paul@pboyd.io>

---------

Signed-off-by: Paul Boyd <paul@pboyd.io>

* Update api/openapi/src/catalog.yaml

Co-authored-by: Dhiraj Bokde <dhirajsb@users.noreply.github.com>
Signed-off-by: Paul Boyd <paul@pboyd.io>

* fix(catalog): set source_id type correctly

Signed-off-by: Paul Boyd <paul@pboyd.io>

* fix(catalog): fix k8s resource names

Signed-off-by: Paul Boyd <paul@pboyd.io>

---------

Signed-off-by: Paul Boyd <paul@pboyd.io>
Co-authored-by: Dhiraj Bokde <dhirajsb@users.noreply.github.com>
This commit is contained in:
Paul Boyd 2025-10-01 06:47:33 -04:00 committed by GitHub
parent 122dbfd933
commit 79f837c3c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
98 changed files with 4491 additions and 895 deletions

View File

@ -107,14 +107,14 @@ paths:
required: true
/api/model_catalog/v1alpha1/sources/{source_id}/models/{model_name}/artifacts:
description: >-
The REST endpoint/path used to list `CatalogModelArtifacts`.
The REST endpoint/path used to list `CatalogArtifacts`.
get:
summary: List CatalogModelArtifacts.
summary: List CatalogArtifacts.
tags:
- ModelCatalogService
responses:
"200":
$ref: "#/components/responses/CatalogModelArtifactListResponse"
$ref: "#/components/responses/CatalogArtifactListResponse"
"401":
$ref: "#/components/responses/Unauthorized"
"404":
@ -234,6 +234,51 @@ components:
format: int32
description: Number of items in result list.
type: integer
CatalogArtifact:
description: A single artifact in the catalog API.
oneOf:
- $ref: "#/components/schemas/CatalogModelArtifact"
- $ref: "#/components/schemas/CatalogMetricsArtifact"
discriminator:
propertyName: artifactType
mapping:
model-artifact: "#/components/schemas/CatalogModelArtifact"
metrics-artifact: "#/components/schemas/CatalogMetricsArtifact"
CatalogArtifactList:
description: List of CatalogModel entities.
allOf:
- type: object
properties:
items:
description: Array of `CatalogArtifact` entities.
type: array
items:
$ref: "#/components/schemas/CatalogArtifact"
required:
- items
- $ref: "#/components/schemas/BaseResourceList"
CatalogMetricsArtifact:
description: A metadata Artifact Entity.
allOf:
- type: object
required:
- artifactType
- metricsType
properties:
artifactType:
type: string
default: metrics-artifact
metricsType:
type: string
enum:
- performance-metrics
- accuracy-metrics
customProperties:
description: User provided custom properties which are not defined by its type.
type: object
additionalProperties:
$ref: "#/components/schemas/MetadataValue"
- $ref: "#/components/schemas/BaseResourceDates"
CatalogModel:
description: A model in the model catalog.
allOf:
@ -251,35 +296,26 @@ components:
- $ref: "#/components/schemas/BaseResourceDates"
- $ref: "#/components/schemas/BaseModel"
CatalogModelArtifact:
description: A single artifact for a catalog model.
description: A Catalog Model Artifact Entity.
allOf:
- type: object
required:
- artifactType
- uri
properties:
artifactType:
type: string
default: model-artifact
uri:
type: string
format: uri
description: URI where the artifact can be retrieved.
description: URI where the model can be retrieved.
customProperties:
description: User provided custom properties which are not defined by its type.
type: object
additionalProperties:
$ref: "#/components/schemas/MetadataValue"
- $ref: "#/components/schemas/BaseResourceDates"
CatalogModelArtifactList:
description: List of CatalogModel entities.
allOf:
- type: object
properties:
items:
description: Array of `CatalogModelArtifact` entities.
type: array
items:
$ref: "#/components/schemas/CatalogModelArtifact"
required:
- items
- $ref: "#/components/schemas/BaseResourceList"
CatalogModelList:
description: List of CatalogModel entities.
allOf:
@ -468,12 +504,12 @@ components:
schema:
$ref: "#/components/schemas/Error"
description: Bad Request parameters
CatalogModelArtifactListResponse:
CatalogArtifactListResponse:
content:
application/json:
schema:
$ref: "#/components/schemas/CatalogModelArtifactList"
description: A response containing a list of CatalogModelArtifact entities.
$ref: "#/components/schemas/CatalogArtifactList"
description: A response containing a list of CatalogArtifact entities.
CatalogModelListResponse:
content:
application/json:

View File

@ -107,14 +107,14 @@ paths:
required: true
/api/model_catalog/v1alpha1/sources/{source_id}/models/{model_name}/artifacts:
description: >-
The REST endpoint/path used to list `CatalogModelArtifacts`.
The REST endpoint/path used to list `CatalogArtifacts`.
get:
summary: List CatalogModelArtifacts.
summary: List CatalogArtifacts.
tags:
- ModelCatalogService
responses:
"200":
$ref: "#/components/responses/CatalogModelArtifactListResponse"
$ref: "#/components/responses/CatalogArtifactListResponse"
"401":
$ref: "#/components/responses/Unauthorized"
"404":
@ -137,6 +137,51 @@ paths:
required: true
components:
schemas:
CatalogArtifact:
description: A single artifact in the catalog API.
oneOf:
- $ref: "#/components/schemas/CatalogModelArtifact"
- $ref: "#/components/schemas/CatalogMetricsArtifact"
discriminator:
propertyName: artifactType
mapping:
model-artifact: "#/components/schemas/CatalogModelArtifact"
metrics-artifact: "#/components/schemas/CatalogMetricsArtifact"
CatalogArtifactList:
description: List of CatalogModel entities.
allOf:
- type: object
properties:
items:
description: Array of `CatalogArtifact` entities.
type: array
items:
$ref: "#/components/schemas/CatalogArtifact"
required:
- items
- $ref: "#/components/schemas/BaseResourceList"
CatalogMetricsArtifact:
description: A metadata Artifact Entity.
allOf:
- type: object
required:
- artifactType
- metricsType
properties:
artifactType:
type: string
default: metrics-artifact
metricsType:
type: string
enum:
- performance-metrics
- accuracy-metrics
customProperties:
description: User provided custom properties which are not defined by its type.
type: object
additionalProperties:
$ref: "#/components/schemas/MetadataValue"
- $ref: "#/components/schemas/BaseResourceDates"
CatalogModel:
description: A model in the model catalog.
allOf:
@ -154,35 +199,26 @@ components:
- $ref: "#/components/schemas/BaseResourceDates"
- $ref: "#/components/schemas/BaseModel"
CatalogModelArtifact:
description: A single artifact for a catalog model.
description: A Catalog Model Artifact Entity.
allOf:
- type: object
required:
- artifactType
- uri
properties:
artifactType:
type: string
default: model-artifact
uri:
type: string
format: uri
description: URI where the artifact can be retrieved.
description: URI where the model can be retrieved.
customProperties:
description: User provided custom properties which are not defined by its type.
type: object
additionalProperties:
$ref: "#/components/schemas/MetadataValue"
- $ref: "#/components/schemas/BaseResourceDates"
CatalogModelArtifactList:
description: List of CatalogModel entities.
allOf:
- type: object
properties:
items:
description: Array of `CatalogModelArtifact` entities.
type: array
items:
$ref: "#/components/schemas/CatalogModelArtifact"
required:
- items
- $ref: "#/components/schemas/BaseResourceList"
CatalogModelList:
description: List of CatalogModel entities.
allOf:
@ -240,12 +276,12 @@ components:
type: string
responses:
CatalogModelArtifactListResponse:
CatalogArtifactListResponse:
content:
application/json:
schema:
$ref: "#/components/schemas/CatalogModelArtifactList"
description: A response containing a list of CatalogModelArtifact entities.
$ref: "#/components/schemas/CatalogArtifactList"
description: A response containing a list of CatalogArtifact entities.
CatalogModelListResponse:
content:
application/json:

View File

@ -6,31 +6,51 @@ import (
"github.com/golang/glog"
"github.com/kubeflow/model-registry/catalog/internal/catalog"
"github.com/kubeflow/model-registry/catalog/internal/db/service"
"github.com/kubeflow/model-registry/catalog/internal/server/openapi"
"github.com/kubeflow/model-registry/internal/datastore"
"github.com/kubeflow/model-registry/internal/datastore/embedmd"
"github.com/spf13/cobra"
)
var catalogCfg = struct {
ListenAddress string
ConfigPath string
ConfigPath []string
}{
ListenAddress: "0.0.0.0:8080",
ConfigPath: "sources.yaml",
ConfigPath: []string{"sources.yaml"},
}
var CatalogCmd = &cobra.Command{
Use: "catalog",
Short: "Catalog API server",
Long: `Launch the API server for the model catalog`,
RunE: runCatalogServer,
Long: `Launch the API server for the model catalog. Use PostgreSQL's
environment variables
(https://www.postgresql.org/docs/current/libpq-envars.html) to
configure the database connection.`,
RunE: runCatalogServer,
}
func init() {
CatalogCmd.Flags().StringVarP(&catalogCfg.ListenAddress, "listen", "l", catalogCfg.ListenAddress, "Address to listen on")
CatalogCmd.Flags().StringVar(&catalogCfg.ConfigPath, "catalogs-path", catalogCfg.ConfigPath, "Path to catalog source configuration file")
fs := CatalogCmd.Flags()
fs.StringVarP(&catalogCfg.ListenAddress, "listen", "l", catalogCfg.ListenAddress, "Address to listen on")
fs.StringSliceVar(&catalogCfg.ConfigPath, "catalogs-path", catalogCfg.ConfigPath, "Path to catalog source configuration file")
}
func runCatalogServer(cmd *cobra.Command, args []string) error {
ds, err := datastore.NewConnector("embedmd", &embedmd.EmbedMDConfig{
DatabaseType: "postgres", // We only support postgres right now
DatabaseDSN: "", // Empty DSN, see https://www.postgresql.org/docs/current/libpq-envars.html
})
if err != nil {
return fmt.Errorf("error creating datastore: %w", err)
}
_, err = ds.Connect(service.DatastoreSpec())
if err != nil {
return fmt.Errorf("error initializing datastore: %v", err)
}
sources, err := catalog.LoadCatalogSources(catalogCfg.ConfigPath)
if err != nil {
return fmt.Errorf("error loading catalog sources: %v", err)

View File

@ -33,7 +33,7 @@ type CatalogSourceProvider interface {
// GetArtifacts returns all artifacts for a particular model. If no
// model is found with that name, it returns nil. If the model is
// found, but has no artifacts, an empty list is returned.
GetArtifacts(ctx context.Context, name string) (*model.CatalogModelArtifactList, error)
GetArtifacts(ctx context.Context, name string) (*model.CatalogArtifactList, error)
}
// CatalogSourceConfig is a single entry from the catalog sources YAML file.
@ -52,7 +52,7 @@ type sourceConfig struct {
Catalogs []CatalogSourceConfig `json:"catalogs"`
}
type CatalogTypeRegisterFunc func(source *CatalogSourceConfig) (CatalogSourceProvider, error)
type CatalogTypeRegisterFunc func(source *CatalogSourceConfig, reldir string) (CatalogSourceProvider, error)
var registeredCatalogTypes = make(map[string]CatalogTypeRegisterFunc, 0)
@ -103,24 +103,6 @@ func (sc *SourceCollection) load(path string) error {
// Get the directory of the config file to resolve relative paths
configDir := filepath.Dir(absConfigPath)
// Save current working directory
originalWd, err := os.Getwd()
if err != nil {
return fmt.Errorf("failed to get current working directory: %v", err)
}
// Change to the config directory to make relative paths work
if err := os.Chdir(configDir); err != nil {
return fmt.Errorf("failed to change to config directory %s: %v", configDir, err)
}
// Ensure we restore the original working directory when we're done
defer func() {
if err := os.Chdir(originalWd); err != nil {
glog.Errorf("failed to restore original working directory %s: %v", originalWd, err)
}
}()
config := sourceConfig{}
bytes, err := os.ReadFile(absConfigPath)
if err != nil {
@ -157,12 +139,13 @@ func (sc *SourceCollection) load(path string) error {
if _, exists := sources[id]; exists {
return fmt.Errorf("duplicate catalog id %s", id)
}
labels := make([]string, 0)
if catalogConfig.GetLabels() != nil {
labels = catalogConfig.GetLabels()
}
catalogConfig.CatalogSource.Labels = labels
provider, err := registerFunc(&catalogConfig)
provider, err := registerFunc(&catalogConfig, configDir)
if err != nil {
return fmt.Errorf("error reading catalog type %s with id %s: %v", catalogType, id, err)
}
@ -182,29 +165,32 @@ func (sc *SourceCollection) load(path string) error {
return nil
}
func LoadCatalogSources(path string) (*SourceCollection, error) {
func LoadCatalogSources(paths []string) (*SourceCollection, error) {
sc := &SourceCollection{}
err := sc.load(path)
if err != nil {
return nil, err
}
go func() {
changes, err := getMonitor().Path(path)
for _, path := range paths {
err := sc.load(path)
if err != nil {
glog.Errorf("unable to watch sources file: %v", err)
// Not fatal, we just won't get automatic updates.
return nil, err
}
for range changes {
glog.Infof("Reloading sources %s", path)
err = sc.load(path)
go func(path string) {
changes, err := getMonitor().Path(path)
if err != nil {
glog.Errorf("unable to load sources: %v", err)
glog.Errorf("unable to watch sources file (%s): %v", path, err)
// Not fatal, we just won't get automatic updates.
}
}
}()
for range changes {
glog.Infof("Reloading sources %s", path)
err = sc.load(path)
if err != nil {
glog.Errorf("unable to load sources: %v", err)
}
}
}(path)
}
return sc, nil
}

View File

@ -21,13 +21,13 @@ func TestLoadCatalogSources(t *testing.T) {
{
name: "test-catalog-sources",
args: args{catalogsPath: "testdata/test-catalog-sources.yaml"},
want: []string{"catalog1", "catalog3", "catalog4"},
want: []string{"catalog1"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := LoadCatalogSources(tt.args.catalogsPath)
got, err := LoadCatalogSources([]string{tt.args.catalogsPath})
if (err != nil) != tt.wantErr {
t.Errorf("LoadCatalogSources() error = %v, wantErr %v", err, tt.wantErr)
return
@ -64,23 +64,13 @@ func TestLoadCatalogSourcesEnabledDisabled(t *testing.T) {
Name: "Catalog 1",
Enabled: &trueValue,
},
"catalog3": {
Id: "catalog3",
Name: "Catalog 3",
Enabled: &trueValue,
},
"catalog4": {
Id: "catalog4",
Name: "Catalog 4",
Enabled: &trueValue,
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := LoadCatalogSources(tt.args.catalogsPath)
got, err := LoadCatalogSources([]string{tt.args.catalogsPath})
if (err != nil) != tt.wantErr {
t.Errorf("LoadCatalogSources() error = %v, wantErr %v", err, tt.wantErr)
return

View File

@ -39,11 +39,11 @@ func (h *hfCatalogImpl) ListModels(ctx context.Context, params ListModelsParams)
}, nil
}
func (h *hfCatalogImpl) GetArtifacts(ctx context.Context, name string) (*openapi.CatalogModelArtifactList, error) {
func (h *hfCatalogImpl) GetArtifacts(ctx context.Context, name string) (*openapi.CatalogArtifactList, error) {
// TODO: Implement HuggingFace model artifacts retrieval
// For now, return empty list to satisfy interface
return &openapi.CatalogModelArtifactList{
Items: []openapi.CatalogModelArtifact{},
return &openapi.CatalogArtifactList{
Items: []openapi.CatalogArtifact{},
PageSize: 0,
Size: 0,
}, nil
@ -82,7 +82,7 @@ func (h *hfCatalogImpl) validateCredentials(ctx context.Context) error {
}
// newHfCatalog creates a new HuggingFace catalog source
func newHfCatalog(source *CatalogSourceConfig) (CatalogSourceProvider, error) {
func newHfCatalog(source *CatalogSourceConfig, reldir string) (CatalogSourceProvider, error) {
apiKey, ok := source.Properties["apiKey"].(string)
if !ok || apiKey == "" {
return nil, fmt.Errorf("missing or invalid 'apiKey' property for HuggingFace catalog")

View File

@ -22,7 +22,7 @@ func TestNewHfCatalog_MissingAPIKey(t *testing.T) {
},
}
_, err := newHfCatalog(source)
_, err := newHfCatalog(source, "")
if err == nil {
t.Fatal("Expected error for missing API key, got nil")
}
@ -65,7 +65,7 @@ func TestNewHfCatalog_WithValidCredentials(t *testing.T) {
},
}
catalog, err := newHfCatalog(source)
catalog, err := newHfCatalog(source, "")
if err != nil {
t.Fatalf("Failed to create HF catalog: %v", err)
}
@ -130,7 +130,7 @@ func TestNewHfCatalog_InvalidCredentials(t *testing.T) {
},
}
_, err := newHfCatalog(source)
_, err := newHfCatalog(source, "")
if err == nil {
t.Fatal("Expected error for invalid credentials, got nil")
}
@ -159,7 +159,7 @@ func TestNewHfCatalog_DefaultConfiguration(t *testing.T) {
},
}
catalog, err := newHfCatalog(source)
catalog, err := newHfCatalog(source, "")
if err != nil {
t.Fatalf("Failed to create HF catalog with defaults: %v", err)
}

View File

@ -325,6 +325,14 @@ models:
lastUpdateTimeSinceEpoch: "1734637721000"
artifacts:
- uri: oci://registry.redhat.io/rhelai1/granite-8b-code-base:1.3-1732870892
- artifactType: metrics-artifact
createTimeSinceEpoch: "1733514949000"
lastUpdateTimeSinceEpoch: "1734637721000"
customProperties:
x:
int_value: 1
y:
double_value: 2.1
- name: rhelai1/granite-8b-code-instruct
provider: IBM
description: |-

View File

@ -2,6 +2,7 @@ package catalog
import (
"context"
"encoding/json"
"fmt"
"math"
"os"
@ -19,7 +20,40 @@ import (
type yamlModel struct {
model.CatalogModel `yaml:",inline"`
Artifacts []*model.CatalogModelArtifact `yaml:"artifacts"`
Artifacts []*yamlArtifact `yaml:"artifacts"`
}
type yamlArtifact struct {
model.CatalogArtifact
}
func (a *yamlArtifact) UnmarshalJSON(buf []byte) error {
// This is very similar to generated code to unmarshal a
// CatalogArtifact, but this version properly handles artifacts without
// an artifactType, which is important for backwards compatibility.
var yat struct {
ArtifactType string `json:"artifactType"`
}
err := json.Unmarshal(buf, &yat)
if err != nil {
return err
}
switch yat.ArtifactType {
case "model-artifact", "":
err = json.Unmarshal(buf, &a.CatalogArtifact.CatalogModelArtifact)
if a.CatalogArtifact.CatalogModelArtifact != nil {
// Ensure artifactType is set even if it wasn't initially.
a.CatalogArtifact.CatalogModelArtifact.ArtifactType = "model-artifact"
}
case "metrics-artifact":
err = json.Unmarshal(buf, &a.CatalogArtifact.CatalogMetricsArtifact)
default:
return fmt.Errorf("unknown artifactType: %s", yat.ArtifactType)
}
return err
}
type yamlCatalog struct {
@ -125,7 +159,7 @@ func (y *yamlCatalogImpl) ListModels(ctx context.Context, params ListModelsParam
return list, nil // Return the struct value directly
}
func (y *yamlCatalogImpl) GetArtifacts(ctx context.Context, name string) (*model.CatalogModelArtifactList, error) {
func (y *yamlCatalogImpl) GetArtifacts(ctx context.Context, name string) (*model.CatalogArtifactList, error) {
y.modelsLock.RLock()
defer y.modelsLock.RUnlock()
@ -139,13 +173,13 @@ func (y *yamlCatalogImpl) GetArtifacts(ctx context.Context, name string) (*model
count = math.MaxInt32
}
list := model.CatalogModelArtifactList{
Items: make([]model.CatalogModelArtifact, count),
list := model.CatalogArtifactList{
Items: make([]model.CatalogArtifact, count),
PageSize: int32(count),
Size: int32(count),
}
for i := range list.Items {
list.Items[i] = *ym.Artifacts[i]
list.Items[i] = ym.Artifacts[i].CatalogArtifact
}
return &list, nil
}
@ -192,16 +226,13 @@ func (y *yamlCatalogImpl) load(path string, excludedModelsList []string) error {
const yamlCatalogPath = "yamlCatalogPath"
func newYamlCatalog(source *CatalogSourceConfig) (CatalogSourceProvider, error) {
func newYamlCatalog(source *CatalogSourceConfig, reldir string) (CatalogSourceProvider, error) {
yamlModelFile, exists := source.Properties[yamlCatalogPath].(string)
if !exists || yamlModelFile == "" {
return nil, fmt.Errorf("missing %s string property", yamlCatalogPath)
}
yamlModelFile, err := filepath.Abs(yamlModelFile)
if err != nil {
return nil, fmt.Errorf("abs: %w", err)
}
yamlModelFile = filepath.Join(reldir, yamlModelFile)
// Excluded models is an optional source property.
var excludedModels []string
@ -222,7 +253,7 @@ func newYamlCatalog(source *CatalogSourceConfig) (CatalogSourceProvider, error)
p := &yamlCatalogImpl{
models: make(map[string]*yamlModel),
}
err = p.load(yamlModelFile, excludedModels)
err := p.load(yamlModelFile, excludedModels)
if err != nil {
return nil, err
}

View File

@ -38,10 +38,17 @@ func TestYAMLCatalogGetArtifacts(t *testing.T) {
artifacts, err := provider.GetArtifacts(context.Background(), "rhelai1/granite-8b-code-base")
if assert.NoError(err) {
assert.NotNil(artifacts)
assert.Equal(int32(1), artifacts.Size)
assert.Equal(int32(1), artifacts.PageSize)
assert.Len(artifacts.Items, 1)
assert.Equal("oci://registry.redhat.io/rhelai1/granite-8b-code-base:1.3-1732870892", artifacts.Items[0].Uri)
assert.Equal(int32(2), artifacts.Size)
assert.Equal(int32(2), artifacts.PageSize)
assert.Len(artifacts.Items, 2)
if assert.NotNil(artifacts.Items[0].CatalogModelArtifact) {
assert.Equal("model-artifact", artifacts.Items[0].CatalogModelArtifact.ArtifactType)
assert.Equal("oci://registry.redhat.io/rhelai1/granite-8b-code-base:1.3-1732870892", artifacts.Items[0].CatalogModelArtifact.Uri)
}
if assert.NotNil(artifacts.Items[1].CatalogMetricsArtifact) {
assert.Equal("metrics-artifact", artifacts.Items[1].CatalogMetricsArtifact.ArtifactType)
assert.NotNil(artifacts.Items[1].CatalogMetricsArtifact.CustomProperties)
}
}
// Test case 2: Model with no artifacts
@ -199,7 +206,7 @@ func testYAMLProviderWithExclusions(t *testing.T, path string, excludedModels []
}
provider, err := newYamlCatalog(&CatalogSourceConfig{
Properties: properties,
})
}, "")
if err != nil {
t.Fatalf("newYamlCatalog(%s) with exclusions failed: %v", path, err)
}

View File

@ -0,0 +1,45 @@
package models
import (
"github.com/kubeflow/model-registry/internal/db/filter"
"github.com/kubeflow/model-registry/internal/db/models"
)
type MetricsType string
const (
MetricsTypePerformance MetricsType = "performance-metrics"
MetricsTypeAccuracy MetricsType = "accuracy-metrics"
)
type CatalogMetricsArtifactListOptions struct {
models.Pagination
Name *string
ExternalID *string
ParentResourceID *int32
}
// GetRestEntityType implements the FilterApplier interface
func (c *CatalogMetricsArtifactListOptions) GetRestEntityType() filter.RestEntityType {
return filter.RestEntityModelArtifact // Reusing existing filter type
}
type CatalogMetricsArtifactAttributes struct {
Name *string
MetricsType MetricsType
ExternalID *string
CreateTimeSinceEpoch *int64
LastUpdateTimeSinceEpoch *int64
}
type CatalogMetricsArtifact interface {
models.Entity[CatalogMetricsArtifactAttributes]
}
type CatalogMetricsArtifactImpl = models.BaseEntity[CatalogMetricsArtifactAttributes]
type CatalogMetricsArtifactRepository interface {
GetByID(id int32) (CatalogMetricsArtifact, error)
List(listOptions CatalogMetricsArtifactListOptions) (*models.ListWrapper[CatalogMetricsArtifact], error)
Save(metricsArtifact CatalogMetricsArtifact, parentResourceID *int32) (CatalogMetricsArtifact, error)
}

View File

@ -0,0 +1,36 @@
package models
import (
"github.com/kubeflow/model-registry/internal/db/filter"
"github.com/kubeflow/model-registry/internal/db/models"
)
type CatalogModelListOptions struct {
models.Pagination
Name *string
ExternalID *string
}
// GetRestEntityType implements the FilterApplier interface
func (c *CatalogModelListOptions) GetRestEntityType() filter.RestEntityType {
return "CatalogModel"
}
type CatalogModelAttributes struct {
Name *string
ExternalID *string
CreateTimeSinceEpoch *int64
LastUpdateTimeSinceEpoch *int64
}
type CatalogModel interface {
models.Entity[CatalogModelAttributes]
}
type CatalogModelImpl = models.BaseEntity[CatalogModelAttributes]
type CatalogModelRepository interface {
GetByID(id int32) (CatalogModel, error)
List(listOptions CatalogModelListOptions) (*models.ListWrapper[CatalogModel], error)
Save(model CatalogModel) (CatalogModel, error)
}

View File

@ -0,0 +1,40 @@
package models
import (
"github.com/kubeflow/model-registry/internal/db/filter"
"github.com/kubeflow/model-registry/internal/db/models"
)
const CatalogModelArtifactType = "catalog-model-artifact"
type CatalogModelArtifactListOptions struct {
models.Pagination
Name *string
ExternalID *string
ParentResourceID *int32
}
// GetRestEntityType implements the FilterApplier interface
func (c *CatalogModelArtifactListOptions) GetRestEntityType() filter.RestEntityType {
return filter.RestEntityModelArtifact // Reusing existing filter type
}
type CatalogModelArtifactAttributes struct {
Name *string
URI *string
ExternalID *string
CreateTimeSinceEpoch *int64
LastUpdateTimeSinceEpoch *int64
}
type CatalogModelArtifact interface {
models.Entity[CatalogModelArtifactAttributes]
}
type CatalogModelArtifactImpl = models.BaseEntity[CatalogModelArtifactAttributes]
type CatalogModelArtifactRepository interface {
GetByID(id int32) (CatalogModelArtifact, error)
List(listOptions CatalogModelArtifactListOptions) (*models.ListWrapper[CatalogModelArtifact], error)
Save(modelArtifact CatalogModelArtifact, parentResourceID *int32) (CatalogModelArtifact, error)
}

View File

@ -0,0 +1,162 @@
package service
import (
"errors"
"fmt"
"github.com/kubeflow/model-registry/catalog/internal/db/models"
"github.com/kubeflow/model-registry/internal/apiutils"
dbmodels "github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/schema"
"github.com/kubeflow/model-registry/internal/db/service"
"github.com/kubeflow/model-registry/internal/db/utils"
"gorm.io/gorm"
)
var ErrCatalogMetricsArtifactNotFound = errors.New("catalog metrics artifact by id not found")
type CatalogMetricsArtifactRepositoryImpl struct {
*service.GenericRepository[models.CatalogMetricsArtifact, schema.Artifact, schema.ArtifactProperty, *models.CatalogMetricsArtifactListOptions]
}
func NewCatalogMetricsArtifactRepository(db *gorm.DB, typeID int64) models.CatalogMetricsArtifactRepository {
config := service.GenericRepositoryConfig[models.CatalogMetricsArtifact, schema.Artifact, schema.ArtifactProperty, *models.CatalogMetricsArtifactListOptions]{
DB: db,
TypeID: typeID,
EntityToSchema: mapCatalogMetricsArtifactToArtifact,
SchemaToEntity: mapDataLayerToCatalogMetricsArtifact,
EntityToProperties: mapCatalogMetricsArtifactToArtifactProperties,
NotFoundError: ErrCatalogMetricsArtifactNotFound,
EntityName: "catalog metrics artifact",
PropertyFieldName: "artifact_id",
ApplyListFilters: applyCatalogMetricsArtifactListFilters,
IsNewEntity: func(entity models.CatalogMetricsArtifact) bool { return entity.GetID() == nil },
HasCustomProperties: func(entity models.CatalogMetricsArtifact) bool { return entity.GetCustomProperties() != nil },
}
return &CatalogMetricsArtifactRepositoryImpl{
GenericRepository: service.NewGenericRepository(config),
}
}
func (r *CatalogMetricsArtifactRepositoryImpl) List(listOptions models.CatalogMetricsArtifactListOptions) (*dbmodels.ListWrapper[models.CatalogMetricsArtifact], error) {
return r.GenericRepository.List(&listOptions)
}
func (r *CatalogMetricsArtifactRepositoryImpl) Save(ma models.CatalogMetricsArtifact, parentResourceID *int32) (models.CatalogMetricsArtifact, error) {
attr := ma.GetAttributes()
if attr == nil {
return ma, fmt.Errorf("invalid artifact: nil attributes")
}
switch attr.MetricsType {
case models.MetricsTypeAccuracy, models.MetricsTypePerformance:
// OK
default:
return ma, fmt.Errorf("invalid artifact: unknown metrics type: %s", attr.MetricsType)
}
return r.GenericRepository.Save(ma, parentResourceID)
}
func applyCatalogMetricsArtifactListFilters(query *gorm.DB, listOptions *models.CatalogMetricsArtifactListOptions) *gorm.DB {
if listOptions.Name != nil {
query = query.Where("name LIKE ?", fmt.Sprintf("%%:%s", *listOptions.Name))
} else if listOptions.ExternalID != nil {
query = query.Where("external_id = ?", listOptions.ExternalID)
}
if listOptions.ParentResourceID != nil {
query = query.Joins(utils.BuildAttributionJoin(query)).
Where(utils.GetColumnRef(query, &schema.Attribution{}, "context_id")+" = ?", listOptions.ParentResourceID)
}
return query
}
func mapCatalogMetricsArtifactToArtifact(catalogMetricsArtifact models.CatalogMetricsArtifact) schema.Artifact {
if catalogMetricsArtifact == nil {
return schema.Artifact{}
}
artifact := schema.Artifact{
ID: apiutils.ZeroIfNil(catalogMetricsArtifact.GetID()),
TypeID: apiutils.ZeroIfNil(catalogMetricsArtifact.GetTypeID()),
}
if catalogMetricsArtifact.GetAttributes() != nil {
artifact.Name = catalogMetricsArtifact.GetAttributes().Name
artifact.ExternalID = catalogMetricsArtifact.GetAttributes().ExternalID
artifact.CreateTimeSinceEpoch = apiutils.ZeroIfNil(catalogMetricsArtifact.GetAttributes().CreateTimeSinceEpoch)
artifact.LastUpdateTimeSinceEpoch = apiutils.ZeroIfNil(catalogMetricsArtifact.GetAttributes().LastUpdateTimeSinceEpoch)
}
return artifact
}
func mapCatalogMetricsArtifactToArtifactProperties(catalogMetricsArtifact models.CatalogMetricsArtifact, artifactID int32) []schema.ArtifactProperty {
if catalogMetricsArtifact == nil {
return []schema.ArtifactProperty{}
}
properties := []schema.ArtifactProperty{}
// Add the metricsType as a property
if catalogMetricsArtifact.GetAttributes() != nil {
metricsTypeProp := dbmodels.Properties{
Name: "metricsType",
StringValue: apiutils.Of(string(catalogMetricsArtifact.GetAttributes().MetricsType)),
}
properties = append(properties, service.MapPropertiesToArtifactProperty(metricsTypeProp, artifactID, false))
}
if catalogMetricsArtifact.GetProperties() != nil {
for _, prop := range *catalogMetricsArtifact.GetProperties() {
properties = append(properties, service.MapPropertiesToArtifactProperty(prop, artifactID, false))
}
}
if catalogMetricsArtifact.GetCustomProperties() != nil {
for _, prop := range *catalogMetricsArtifact.GetCustomProperties() {
properties = append(properties, service.MapPropertiesToArtifactProperty(prop, artifactID, true))
}
}
return properties
}
func mapDataLayerToCatalogMetricsArtifact(artifact schema.Artifact, artProperties []schema.ArtifactProperty) models.CatalogMetricsArtifact {
catalogMetricsArtifact := models.CatalogMetricsArtifactImpl{
ID: &artifact.ID,
TypeID: &artifact.TypeID,
Attributes: &models.CatalogMetricsArtifactAttributes{
Name: artifact.Name,
ExternalID: artifact.ExternalID,
CreateTimeSinceEpoch: &artifact.CreateTimeSinceEpoch,
LastUpdateTimeSinceEpoch: &artifact.LastUpdateTimeSinceEpoch,
},
}
customProperties := []dbmodels.Properties{}
properties := []dbmodels.Properties{}
for _, prop := range artProperties {
mappedProperty := service.MapArtifactPropertyToProperties(prop)
// Extract metricsType from properties and set it as an attribute
if mappedProperty.Name == "metricsType" && !prop.IsCustomProperty {
if mappedProperty.StringValue != nil {
catalogMetricsArtifact.Attributes.MetricsType = models.MetricsType(*mappedProperty.StringValue)
}
} else if prop.IsCustomProperty {
customProperties = append(customProperties, mappedProperty)
} else {
properties = append(properties, mappedProperty)
}
}
catalogMetricsArtifact.CustomProperties = &customProperties
catalogMetricsArtifact.Properties = &properties
return &catalogMetricsArtifact
}

View File

@ -0,0 +1,461 @@
package service_test
import (
"fmt"
"testing"
"time"
"github.com/kubeflow/model-registry/catalog/internal/db/models"
"github.com/kubeflow/model-registry/catalog/internal/db/service"
"github.com/kubeflow/model-registry/internal/apiutils"
dbmodels "github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/schema"
"github.com/kubeflow/model-registry/internal/testutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
func TestCatalogMetricsArtifactRepository(t *testing.T) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
// Get the CatalogMetricsArtifact type ID
typeID := getCatalogMetricsArtifactTypeID(t, sharedDB)
repo := service.NewCatalogMetricsArtifactRepository(sharedDB, typeID)
// Also get CatalogModel type ID for creating parent entities
catalogModelTypeID := getCatalogModelTypeID(t, sharedDB)
catalogModelRepo := service.NewCatalogModelRepository(sharedDB, catalogModelTypeID)
t.Run("TestSave", func(t *testing.T) {
// First create a catalog model for attribution
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(catalogModelTypeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("test-catalog-model-for-metrics"),
ExternalID: apiutils.Of("catalog-model-metrics-ext-123"),
},
}
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
require.NoError(t, err)
// Test creating a new catalog metrics artifact
catalogMetricsArtifact := &models.CatalogMetricsArtifactImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogMetricsArtifactAttributes{
Name: apiutils.Of("test-catalog-metrics-artifact"),
ExternalID: apiutils.Of("catalog-metrics-ext-123"),
MetricsType: models.MetricsTypeAccuracy,
},
Properties: &[]dbmodels.Properties{
{
Name: "description",
StringValue: apiutils.Of("Test catalog metrics artifact description"),
},
},
CustomProperties: &[]dbmodels.Properties{
{
Name: "custom-metrics-prop",
StringValue: apiutils.Of("custom-metrics-value"),
},
},
}
saved, err := repo.Save(catalogMetricsArtifact, savedCatalogModel.GetID())
require.NoError(t, err)
require.NotNil(t, saved)
require.NotNil(t, saved.GetID())
assert.Equal(t, "test-catalog-metrics-artifact", *saved.GetAttributes().Name)
assert.Equal(t, "catalog-metrics-ext-123", *saved.GetAttributes().ExternalID)
assert.Equal(t, models.MetricsTypeAccuracy, saved.GetAttributes().MetricsType)
// Test updating the same catalog metrics artifact
catalogMetricsArtifact.ID = saved.GetID()
catalogMetricsArtifact.GetAttributes().Name = apiutils.Of("updated-catalog-metrics-artifact")
catalogMetricsArtifact.GetAttributes().MetricsType = models.MetricsTypePerformance
// Preserve CreateTimeSinceEpoch from the saved entity
catalogMetricsArtifact.GetAttributes().CreateTimeSinceEpoch = saved.GetAttributes().CreateTimeSinceEpoch
updated, err := repo.Save(catalogMetricsArtifact, savedCatalogModel.GetID())
require.NoError(t, err)
require.NotNil(t, updated)
assert.Equal(t, *saved.GetID(), *updated.GetID())
assert.Equal(t, "updated-catalog-metrics-artifact", *updated.GetAttributes().Name)
assert.Equal(t, models.MetricsTypePerformance, updated.GetAttributes().MetricsType)
})
t.Run("TestGetByID", func(t *testing.T) {
// First create a catalog model
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(catalogModelTypeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("test-catalog-model-for-getbyid-metrics"),
ExternalID: apiutils.Of("catalog-model-getbyid-metrics-ext"),
},
}
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
require.NoError(t, err)
// Create a catalog metrics artifact to retrieve
catalogMetricsArtifact := &models.CatalogMetricsArtifactImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogMetricsArtifactAttributes{
Name: apiutils.Of("get-test-catalog-metrics-artifact"),
ExternalID: apiutils.Of("get-catalog-metrics-ext-123"),
MetricsType: models.MetricsTypeAccuracy,
},
}
saved, err := repo.Save(catalogMetricsArtifact, savedCatalogModel.GetID())
require.NoError(t, err)
require.NotNil(t, saved.GetID())
// Test retrieving by ID
retrieved, err := repo.GetByID(*saved.GetID())
require.NoError(t, err)
require.NotNil(t, retrieved)
assert.Equal(t, *saved.GetID(), *retrieved.GetID())
assert.Equal(t, "get-test-catalog-metrics-artifact", *retrieved.GetAttributes().Name)
assert.Equal(t, "get-catalog-metrics-ext-123", *retrieved.GetAttributes().ExternalID)
assert.Equal(t, models.MetricsTypeAccuracy, retrieved.GetAttributes().MetricsType)
// Test retrieving non-existent ID
_, err = repo.GetByID(99999)
assert.ErrorIs(t, err, service.ErrCatalogMetricsArtifactNotFound)
})
t.Run("TestList", func(t *testing.T) {
// Create a catalog model for the artifacts
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(catalogModelTypeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("test-catalog-model-for-list-metrics"),
ExternalID: apiutils.Of("catalog-model-list-metrics-ext"),
},
}
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
require.NoError(t, err)
// Create multiple catalog metrics artifacts for listing
testArtifacts := []*models.CatalogMetricsArtifactImpl{
{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogMetricsArtifactAttributes{
Name: apiutils.Of("list-catalog-metrics-artifact-1"),
ExternalID: apiutils.Of("list-catalog-metrics-ext-1"),
MetricsType: models.MetricsTypeAccuracy,
},
},
{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogMetricsArtifactAttributes{
Name: apiutils.Of("list-catalog-metrics-artifact-2"),
ExternalID: apiutils.Of("list-catalog-metrics-ext-2"),
MetricsType: models.MetricsTypePerformance,
},
},
{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogMetricsArtifactAttributes{
Name: apiutils.Of("list-catalog-metrics-artifact-3"),
ExternalID: apiutils.Of("list-catalog-metrics-ext-3"),
MetricsType: models.MetricsTypePerformance,
},
},
}
// Save all test artifacts
var savedArtifacts []models.CatalogMetricsArtifact
for _, artifact := range testArtifacts {
saved, err := repo.Save(artifact, savedCatalogModel.GetID())
require.NoError(t, err)
savedArtifacts = append(savedArtifacts, saved)
}
// Test listing all artifacts
listOptions := models.CatalogMetricsArtifactListOptions{}
result, err := repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
assert.GreaterOrEqual(t, len(result.Items), 3) // At least our 3 test artifacts
// Test filtering by name
nameFilter := "list-catalog-metrics-artifact-1"
listOptions = models.CatalogMetricsArtifactListOptions{
Name: &nameFilter,
}
result, err = repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
if len(result.Items) > 0 {
assert.Equal(t, 1, len(result.Items))
assert.Equal(t, "list-catalog-metrics-artifact-1", *result.Items[0].GetAttributes().Name)
}
// Test filtering by external ID
externalIDFilter := "list-catalog-metrics-ext-2"
listOptions = models.CatalogMetricsArtifactListOptions{
ExternalID: &externalIDFilter,
}
result, err = repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
if len(result.Items) > 0 {
assert.Equal(t, 1, len(result.Items))
assert.Equal(t, "list-catalog-metrics-ext-2", *result.Items[0].GetAttributes().ExternalID)
}
// Test filtering by parent resource ID (catalog model)
listOptions = models.CatalogMetricsArtifactListOptions{
ParentResourceID: savedCatalogModel.GetID(),
}
result, err = repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
assert.GreaterOrEqual(t, len(result.Items), 3) // Should find our 3 test artifacts
})
t.Run("TestListWithPropertiesAndCustomProperties", func(t *testing.T) {
// Create a catalog model
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(catalogModelTypeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("test-catalog-model-for-props-metrics"),
ExternalID: apiutils.Of("catalog-model-props-metrics-ext"),
},
}
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
require.NoError(t, err)
// Create a catalog metrics artifact with both properties and custom properties
catalogMetricsArtifact := &models.CatalogMetricsArtifactImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogMetricsArtifactAttributes{
Name: apiutils.Of("props-test-catalog-metrics-artifact"),
ExternalID: apiutils.Of("props-catalog-metrics-ext-123"),
MetricsType: models.MetricsTypeAccuracy,
},
Properties: &[]dbmodels.Properties{
{
Name: "version",
StringValue: apiutils.Of("1.0.0"),
},
{
Name: "value",
DoubleValue: apiutils.Of(0.95),
},
},
CustomProperties: &[]dbmodels.Properties{
{
Name: "team",
StringValue: apiutils.Of("catalog-metrics-team"),
},
{
Name: "is_validated",
BoolValue: apiutils.Of(true),
},
},
}
saved, err := repo.Save(catalogMetricsArtifact, savedCatalogModel.GetID())
require.NoError(t, err)
require.NotNil(t, saved)
// Retrieve and verify properties
retrieved, err := repo.GetByID(*saved.GetID())
require.NoError(t, err)
require.NotNil(t, retrieved)
// Check that metricsType is properly set
assert.Equal(t, models.MetricsTypeAccuracy, retrieved.GetAttributes().MetricsType)
// Check regular properties
require.NotNil(t, retrieved.GetProperties())
assert.Len(t, *retrieved.GetProperties(), 2)
// Check custom properties
require.NotNil(t, retrieved.GetCustomProperties())
assert.Len(t, *retrieved.GetCustomProperties(), 2)
// Verify specific properties exist
properties := *retrieved.GetProperties()
var foundVersion, foundValue bool
for _, prop := range properties {
switch prop.Name {
case "version":
foundVersion = true
assert.Equal(t, "1.0.0", *prop.StringValue)
case "value":
foundValue = true
assert.Equal(t, 0.95, *prop.DoubleValue)
}
}
assert.True(t, foundVersion, "Should find version property")
assert.True(t, foundValue, "Should find value property")
// Verify custom properties
customProperties := *retrieved.GetCustomProperties()
var foundTeam, foundIsValidated bool
for _, prop := range customProperties {
switch prop.Name {
case "team":
foundTeam = true
assert.Equal(t, "catalog-metrics-team", *prop.StringValue)
case "is_validated":
foundIsValidated = true
assert.Equal(t, true, *prop.BoolValue)
}
}
assert.True(t, foundTeam, "Should find team custom property")
assert.True(t, foundIsValidated, "Should find is_validated custom property")
})
t.Run("TestSaveWithoutParentResource", func(t *testing.T) {
// Test creating a catalog metrics artifact without parent resource attribution
catalogMetricsArtifact := &models.CatalogMetricsArtifactImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogMetricsArtifactAttributes{
Name: apiutils.Of("standalone-catalog-metrics-artifact"),
ExternalID: apiutils.Of("standalone-catalog-metrics-ext"),
MetricsType: models.MetricsTypeAccuracy,
},
Properties: &[]dbmodels.Properties{
{
Name: "description",
StringValue: apiutils.Of("Standalone catalog metrics artifact without parent"),
},
},
}
saved, err := repo.Save(catalogMetricsArtifact, nil)
require.NoError(t, err)
require.NotNil(t, saved)
require.NotNil(t, saved.GetID())
assert.Equal(t, "standalone-catalog-metrics-artifact", *saved.GetAttributes().Name)
assert.Equal(t, models.MetricsTypeAccuracy, saved.GetAttributes().MetricsType)
// Verify it can be retrieved
retrieved, err := repo.GetByID(*saved.GetID())
require.NoError(t, err)
assert.Equal(t, "standalone-catalog-metrics-artifact", *retrieved.GetAttributes().Name)
assert.Equal(t, models.MetricsTypeAccuracy, retrieved.GetAttributes().MetricsType)
})
t.Run("TestListOrdering", func(t *testing.T) {
// Create a catalog model
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(catalogModelTypeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("test-catalog-model-for-ordering-metrics"),
ExternalID: apiutils.Of("catalog-model-ordering-metrics-ext"),
},
}
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
require.NoError(t, err)
// Create artifacts sequentially with time delays to ensure deterministic ordering
artifact1 := &models.CatalogMetricsArtifactImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogMetricsArtifactAttributes{
Name: apiutils.Of("time-test-catalog-metrics-artifact-1"),
ExternalID: apiutils.Of("time-catalog-metrics-ext-1"),
MetricsType: models.MetricsTypeAccuracy,
},
}
saved1, err := repo.Save(artifact1, savedCatalogModel.GetID())
require.NoError(t, err)
// Small delay to ensure different timestamps
time.Sleep(10 * time.Millisecond)
artifact2 := &models.CatalogMetricsArtifactImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogMetricsArtifactAttributes{
Name: apiutils.Of("time-test-catalog-metrics-artifact-2"),
ExternalID: apiutils.Of("time-catalog-metrics-ext-2"),
MetricsType: models.MetricsTypePerformance,
},
}
saved2, err := repo.Save(artifact2, savedCatalogModel.GetID())
require.NoError(t, err)
// Test ordering by CREATE_TIME
listOptions := models.CatalogMetricsArtifactListOptions{
Pagination: dbmodels.Pagination{
OrderBy: apiutils.Of("CREATE_TIME"),
},
}
result, err := repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
// Find our test artifacts in the results
var foundArtifact1, foundArtifact2 models.CatalogMetricsArtifact
var index1, index2 = -1, -1
for i, item := range result.Items {
if *item.GetID() == *saved1.GetID() {
foundArtifact1 = item
index1 = i
}
if *item.GetID() == *saved2.GetID() {
foundArtifact2 = item
index2 = i
}
}
// Verify both artifacts were found and artifact1 comes before artifact2 (ascending order)
require.NotEqual(t, -1, index1, "Artifact 1 should be found in results")
require.NotEqual(t, -1, index2, "Artifact 2 should be found in results")
assert.Less(t, index1, index2, "Artifact 1 should come before Artifact 2 when ordered by CREATE_TIME")
assert.Less(t, *foundArtifact1.GetAttributes().CreateTimeSinceEpoch, *foundArtifact2.GetAttributes().CreateTimeSinceEpoch, "Artifact 1 should have earlier create time")
})
t.Run("TestMetricsTypeField", func(t *testing.T) {
// Test various metrics types
metricsTypes := []models.MetricsType{models.MetricsTypeAccuracy, models.MetricsTypePerformance}
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(catalogModelTypeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("test-catalog-model-for-metrics-types"),
ExternalID: apiutils.Of("catalog-model-metrics-types-ext"),
},
}
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
require.NoError(t, err)
for i, metricsType := range metricsTypes {
artifact := &models.CatalogMetricsArtifactImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogMetricsArtifactAttributes{
Name: apiutils.Of(fmt.Sprintf("metrics-type-test-%d", i)),
ExternalID: apiutils.Of(fmt.Sprintf("metrics-type-ext-%d", i)),
MetricsType: metricsType,
},
}
saved, err := repo.Save(artifact, savedCatalogModel.GetID())
require.NoError(t, err)
assert.Equal(t, metricsType, saved.GetAttributes().MetricsType)
// Verify retrieval preserves metricsType
retrieved, err := repo.GetByID(*saved.GetID())
require.NoError(t, err)
assert.Equal(t, metricsType, retrieved.GetAttributes().MetricsType)
}
})
}
// Helper function to get or create CatalogMetricsArtifact type ID
func getCatalogMetricsArtifactTypeID(t *testing.T, db *gorm.DB) int64 {
var typeRecord schema.Type
err := db.Where("name = ?", service.CatalogMetricsArtifactTypeName).First(&typeRecord).Error
if err != nil {
require.NoError(t, err, "Failed to query CatalogMetricsArtifact type")
}
return int64(typeRecord.ID)
}

View File

@ -0,0 +1,129 @@
package service
import (
"errors"
"github.com/kubeflow/model-registry/catalog/internal/db/models"
dbmodels "github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/schema"
"github.com/kubeflow/model-registry/internal/db/service"
"gorm.io/gorm"
)
var ErrCatalogModelNotFound = errors.New("catalog model by id not found")
type CatalogModelRepositoryImpl struct {
*service.GenericRepository[models.CatalogModel, schema.Context, schema.ContextProperty, *models.CatalogModelListOptions]
}
func NewCatalogModelRepository(db *gorm.DB, typeID int64) models.CatalogModelRepository {
config := service.GenericRepositoryConfig[models.CatalogModel, schema.Context, schema.ContextProperty, *models.CatalogModelListOptions]{
DB: db,
TypeID: typeID,
EntityToSchema: mapCatalogModelToContext,
SchemaToEntity: mapDataLayerToCatalogModel,
EntityToProperties: mapCatalogModelToContextProperties,
NotFoundError: ErrCatalogModelNotFound,
EntityName: "catalog model",
PropertyFieldName: "context_id",
ApplyListFilters: applyCatalogModelListFilters,
IsNewEntity: func(entity models.CatalogModel) bool { return entity.GetID() == nil },
HasCustomProperties: func(entity models.CatalogModel) bool { return entity.GetCustomProperties() != nil },
}
return &CatalogModelRepositoryImpl{
GenericRepository: service.NewGenericRepository(config),
}
}
func (r *CatalogModelRepositoryImpl) Save(model models.CatalogModel) (models.CatalogModel, error) {
return r.GenericRepository.Save(model, nil)
}
func (r *CatalogModelRepositoryImpl) List(listOptions models.CatalogModelListOptions) (*dbmodels.ListWrapper[models.CatalogModel], error) {
return r.GenericRepository.List(&listOptions)
}
func applyCatalogModelListFilters(query *gorm.DB, listOptions *models.CatalogModelListOptions) *gorm.DB {
if listOptions.Name != nil {
query = query.Where("name LIKE ?", listOptions.Name)
} else if listOptions.ExternalID != nil {
query = query.Where("external_id = ?", listOptions.ExternalID)
}
return query
}
func mapCatalogModelToContext(model models.CatalogModel) schema.Context {
attrs := model.GetAttributes()
context := schema.Context{
TypeID: *model.GetTypeID(),
}
if model.GetID() != nil {
context.ID = *model.GetID()
}
if attrs != nil {
if attrs.Name != nil {
context.Name = *attrs.Name
}
context.ExternalID = attrs.ExternalID
if attrs.CreateTimeSinceEpoch != nil {
context.CreateTimeSinceEpoch = *attrs.CreateTimeSinceEpoch
}
if attrs.LastUpdateTimeSinceEpoch != nil {
context.LastUpdateTimeSinceEpoch = *attrs.LastUpdateTimeSinceEpoch
}
}
return context
}
func mapCatalogModelToContextProperties(model models.CatalogModel, contextID int32) []schema.ContextProperty {
var properties []schema.ContextProperty
if model.GetProperties() != nil {
for _, prop := range *model.GetProperties() {
properties = append(properties, service.MapPropertiesToContextProperty(prop, contextID, false))
}
}
if model.GetCustomProperties() != nil {
for _, prop := range *model.GetCustomProperties() {
properties = append(properties, service.MapPropertiesToContextProperty(prop, contextID, true))
}
}
return properties
}
func mapDataLayerToCatalogModel(modelCtx schema.Context, propertiesCtx []schema.ContextProperty) models.CatalogModel {
catalogModel := &models.CatalogModelImpl{
ID: &modelCtx.ID,
TypeID: &modelCtx.TypeID,
Attributes: &models.CatalogModelAttributes{
Name: &modelCtx.Name,
ExternalID: modelCtx.ExternalID,
CreateTimeSinceEpoch: &modelCtx.CreateTimeSinceEpoch,
LastUpdateTimeSinceEpoch: &modelCtx.LastUpdateTimeSinceEpoch,
},
}
properties := []dbmodels.Properties{}
customProperties := []dbmodels.Properties{}
for _, prop := range propertiesCtx {
mappedProperty := service.MapContextPropertyToProperties(prop)
if prop.IsCustomProperty {
customProperties = append(customProperties, mappedProperty)
} else {
properties = append(properties, mappedProperty)
}
}
catalogModel.Properties = &properties
catalogModel.CustomProperties = &customProperties
return catalogModel
}

View File

@ -0,0 +1,132 @@
package service
import (
"errors"
"fmt"
"github.com/kubeflow/model-registry/catalog/internal/db/models"
"github.com/kubeflow/model-registry/internal/apiutils"
dbmodels "github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/schema"
"github.com/kubeflow/model-registry/internal/db/service"
"github.com/kubeflow/model-registry/internal/db/utils"
"gorm.io/gorm"
)
var ErrCatalogModelArtifactNotFound = errors.New("catalog model artifact by id not found")
type CatalogModelArtifactRepositoryImpl struct {
*service.GenericRepository[models.CatalogModelArtifact, schema.Artifact, schema.ArtifactProperty, *models.CatalogModelArtifactListOptions]
}
func NewCatalogModelArtifactRepository(db *gorm.DB, typeID int64) models.CatalogModelArtifactRepository {
config := service.GenericRepositoryConfig[models.CatalogModelArtifact, schema.Artifact, schema.ArtifactProperty, *models.CatalogModelArtifactListOptions]{
DB: db,
TypeID: typeID,
EntityToSchema: mapCatalogModelArtifactToArtifact,
SchemaToEntity: mapDataLayerToCatalogModelArtifact,
EntityToProperties: mapCatalogModelArtifactToArtifactProperties,
NotFoundError: ErrCatalogModelArtifactNotFound,
EntityName: "catalog model artifact",
PropertyFieldName: "artifact_id",
ApplyListFilters: applyCatalogModelArtifactListFilters,
IsNewEntity: func(entity models.CatalogModelArtifact) bool { return entity.GetID() == nil },
HasCustomProperties: func(entity models.CatalogModelArtifact) bool { return entity.GetCustomProperties() != nil },
}
return &CatalogModelArtifactRepositoryImpl{
GenericRepository: service.NewGenericRepository(config),
}
}
func (r *CatalogModelArtifactRepositoryImpl) List(listOptions models.CatalogModelArtifactListOptions) (*dbmodels.ListWrapper[models.CatalogModelArtifact], error) {
return r.GenericRepository.List(&listOptions)
}
func applyCatalogModelArtifactListFilters(query *gorm.DB, listOptions *models.CatalogModelArtifactListOptions) *gorm.DB {
if listOptions.Name != nil {
query = query.Where("name LIKE ?", fmt.Sprintf("%%:%s", *listOptions.Name))
} else if listOptions.ExternalID != nil {
query = query.Where("external_id = ?", listOptions.ExternalID)
}
if listOptions.ParentResourceID != nil {
query = query.Joins(utils.BuildAttributionJoin(query)).
Where(utils.GetColumnRef(query, &schema.Attribution{}, "context_id")+" = ?", listOptions.ParentResourceID)
}
return query
}
func mapCatalogModelArtifactToArtifact(catalogModelArtifact models.CatalogModelArtifact) schema.Artifact {
if catalogModelArtifact == nil {
return schema.Artifact{}
}
artifact := schema.Artifact{
ID: apiutils.ZeroIfNil(catalogModelArtifact.GetID()),
TypeID: apiutils.ZeroIfNil(catalogModelArtifact.GetTypeID()),
}
if catalogModelArtifact.GetAttributes() != nil {
artifact.Name = catalogModelArtifact.GetAttributes().Name
artifact.URI = catalogModelArtifact.GetAttributes().URI
artifact.ExternalID = catalogModelArtifact.GetAttributes().ExternalID
artifact.CreateTimeSinceEpoch = apiutils.ZeroIfNil(catalogModelArtifact.GetAttributes().CreateTimeSinceEpoch)
artifact.LastUpdateTimeSinceEpoch = apiutils.ZeroIfNil(catalogModelArtifact.GetAttributes().LastUpdateTimeSinceEpoch)
}
return artifact
}
func mapCatalogModelArtifactToArtifactProperties(catalogModelArtifact models.CatalogModelArtifact, artifactID int32) []schema.ArtifactProperty {
if catalogModelArtifact == nil {
return []schema.ArtifactProperty{}
}
properties := []schema.ArtifactProperty{}
if catalogModelArtifact.GetProperties() != nil {
for _, prop := range *catalogModelArtifact.GetProperties() {
properties = append(properties, service.MapPropertiesToArtifactProperty(prop, artifactID, false))
}
}
if catalogModelArtifact.GetCustomProperties() != nil {
for _, prop := range *catalogModelArtifact.GetCustomProperties() {
properties = append(properties, service.MapPropertiesToArtifactProperty(prop, artifactID, true))
}
}
return properties
}
func mapDataLayerToCatalogModelArtifact(artifact schema.Artifact, artProperties []schema.ArtifactProperty) models.CatalogModelArtifact {
catalogModelArtifact := models.CatalogModelArtifactImpl{
ID: &artifact.ID,
TypeID: &artifact.TypeID,
Attributes: &models.CatalogModelArtifactAttributes{
Name: artifact.Name,
URI: artifact.URI,
ExternalID: artifact.ExternalID,
CreateTimeSinceEpoch: &artifact.CreateTimeSinceEpoch,
LastUpdateTimeSinceEpoch: &artifact.LastUpdateTimeSinceEpoch,
},
}
customProperties := []dbmodels.Properties{}
properties := []dbmodels.Properties{}
for _, prop := range artProperties {
if prop.IsCustomProperty {
customProperties = append(customProperties, service.MapArtifactPropertyToProperties(prop))
} else {
properties = append(properties, service.MapArtifactPropertyToProperties(prop))
}
}
catalogModelArtifact.CustomProperties = &customProperties
catalogModelArtifact.Properties = &properties
return &catalogModelArtifact
}

View File

@ -0,0 +1,465 @@
package service_test
import (
"fmt"
"testing"
"time"
"github.com/kubeflow/model-registry/catalog/internal/db/models"
"github.com/kubeflow/model-registry/catalog/internal/db/service"
"github.com/kubeflow/model-registry/internal/apiutils"
dbmodels "github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/schema"
"github.com/kubeflow/model-registry/internal/testutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
func TestCatalogModelArtifactRepository(t *testing.T) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
// Get the CatalogModelArtifact type ID
typeID := getCatalogModelArtifactTypeID(t, sharedDB)
repo := service.NewCatalogModelArtifactRepository(sharedDB, typeID)
// Also get CatalogModel type ID for creating parent entities
catalogModelTypeID := getCatalogModelTypeID(t, sharedDB)
catalogModelRepo := service.NewCatalogModelRepository(sharedDB, catalogModelTypeID)
t.Run("TestSave", func(t *testing.T) {
// First create a catalog model for attribution
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(catalogModelTypeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("test-catalog-model-for-artifact"),
ExternalID: apiutils.Of("catalog-model-ext-123"),
},
}
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
require.NoError(t, err)
// Test creating a new catalog model artifact
catalogModelArtifact := &models.CatalogModelArtifactImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelArtifactAttributes{
Name: apiutils.Of("test-catalog-model-artifact"),
ExternalID: apiutils.Of("catalog-artifact-ext-123"),
URI: apiutils.Of("s3://catalog-bucket/model.pkl"),
},
Properties: &[]dbmodels.Properties{
{
Name: "description",
StringValue: apiutils.Of("Test catalog model artifact description"),
},
},
CustomProperties: &[]dbmodels.Properties{
{
Name: "custom-catalog-prop",
StringValue: apiutils.Of("custom-catalog-value"),
},
},
}
saved, err := repo.Save(catalogModelArtifact, savedCatalogModel.GetID())
require.NoError(t, err)
require.NotNil(t, saved)
require.NotNil(t, saved.GetID())
assert.Equal(t, "test-catalog-model-artifact", *saved.GetAttributes().Name)
assert.Equal(t, "catalog-artifact-ext-123", *saved.GetAttributes().ExternalID)
assert.Equal(t, "s3://catalog-bucket/model.pkl", *saved.GetAttributes().URI)
// Test updating the same catalog model artifact
catalogModelArtifact.ID = saved.GetID()
catalogModelArtifact.GetAttributes().Name = apiutils.Of("updated-catalog-model-artifact")
catalogModelArtifact.GetAttributes().URI = apiutils.Of("s3://catalog-bucket/updated-model.pkl")
// Preserve CreateTimeSinceEpoch from the saved entity
catalogModelArtifact.GetAttributes().CreateTimeSinceEpoch = saved.GetAttributes().CreateTimeSinceEpoch
updated, err := repo.Save(catalogModelArtifact, savedCatalogModel.GetID())
require.NoError(t, err)
require.NotNil(t, updated)
assert.Equal(t, *saved.GetID(), *updated.GetID())
assert.Equal(t, "updated-catalog-model-artifact", *updated.GetAttributes().Name)
assert.Equal(t, "s3://catalog-bucket/updated-model.pkl", *updated.GetAttributes().URI)
})
t.Run("TestGetByID", func(t *testing.T) {
// First create a catalog model
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(catalogModelTypeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("test-catalog-model-for-getbyid"),
ExternalID: apiutils.Of("catalog-model-getbyid-ext"),
},
}
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
require.NoError(t, err)
// Create a catalog model artifact to retrieve
catalogModelArtifact := &models.CatalogModelArtifactImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelArtifactAttributes{
Name: apiutils.Of("get-test-catalog-model-artifact"),
ExternalID: apiutils.Of("get-catalog-artifact-ext-123"),
URI: apiutils.Of("s3://catalog-bucket/get-model.pkl"),
},
}
saved, err := repo.Save(catalogModelArtifact, savedCatalogModel.GetID())
require.NoError(t, err)
require.NotNil(t, saved.GetID())
// Test retrieving by ID
retrieved, err := repo.GetByID(*saved.GetID())
require.NoError(t, err)
require.NotNil(t, retrieved)
assert.Equal(t, *saved.GetID(), *retrieved.GetID())
assert.Equal(t, "get-test-catalog-model-artifact", *retrieved.GetAttributes().Name)
assert.Equal(t, "get-catalog-artifact-ext-123", *retrieved.GetAttributes().ExternalID)
assert.Equal(t, "s3://catalog-bucket/get-model.pkl", *retrieved.GetAttributes().URI)
// Test retrieving non-existent ID
_, err = repo.GetByID(99999)
assert.ErrorIs(t, err, service.ErrCatalogModelArtifactNotFound)
})
t.Run("TestList", func(t *testing.T) {
// Create a catalog model for the artifacts
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(catalogModelTypeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("test-catalog-model-for-list"),
ExternalID: apiutils.Of("catalog-model-list-ext"),
},
}
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
require.NoError(t, err)
// Create multiple catalog model artifacts for listing
testArtifacts := []*models.CatalogModelArtifactImpl{
{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelArtifactAttributes{
Name: apiutils.Of("list-catalog-artifact-1"),
ExternalID: apiutils.Of("list-catalog-artifact-ext-1"),
URI: apiutils.Of("s3://catalog-bucket/list-model-1.pkl"),
},
},
{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelArtifactAttributes{
Name: apiutils.Of("list-catalog-artifact-2"),
ExternalID: apiutils.Of("list-catalog-artifact-ext-2"),
URI: apiutils.Of("s3://catalog-bucket/list-model-2.pkl"),
},
},
{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelArtifactAttributes{
Name: apiutils.Of("list-catalog-artifact-3"),
ExternalID: apiutils.Of("list-catalog-artifact-ext-3"),
URI: apiutils.Of("s3://catalog-bucket/list-model-3.pkl"),
},
},
}
// Save all test artifacts
var savedArtifacts []models.CatalogModelArtifact
for _, artifact := range testArtifacts {
saved, err := repo.Save(artifact, savedCatalogModel.GetID())
require.NoError(t, err)
savedArtifacts = append(savedArtifacts, saved)
}
// Test listing all artifacts
listOptions := models.CatalogModelArtifactListOptions{}
result, err := repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
assert.GreaterOrEqual(t, len(result.Items), 3) // At least our 3 test artifacts
// Test filtering by name
nameFilter := "list-catalog-artifact-1"
listOptions = models.CatalogModelArtifactListOptions{
Name: &nameFilter,
}
result, err = repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
if len(result.Items) > 0 {
assert.Equal(t, 1, len(result.Items))
assert.Equal(t, "list-catalog-artifact-1", *result.Items[0].GetAttributes().Name)
}
// Test filtering by external ID
externalIDFilter := "list-catalog-artifact-ext-2"
listOptions = models.CatalogModelArtifactListOptions{
ExternalID: &externalIDFilter,
}
result, err = repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
if len(result.Items) > 0 {
assert.Equal(t, 1, len(result.Items))
assert.Equal(t, "list-catalog-artifact-ext-2", *result.Items[0].GetAttributes().ExternalID)
}
// Test filtering by parent resource ID (catalog model)
listOptions = models.CatalogModelArtifactListOptions{
ParentResourceID: savedCatalogModel.GetID(),
}
result, err = repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
assert.GreaterOrEqual(t, len(result.Items), 3) // Should find our 3 test artifacts
})
t.Run("TestListWithPropertiesAndCustomProperties", func(t *testing.T) {
// Create a catalog model
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(catalogModelTypeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("test-catalog-model-for-props"),
ExternalID: apiutils.Of("catalog-model-props-ext"),
},
}
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
require.NoError(t, err)
// Create a catalog model artifact with both properties and custom properties
catalogModelArtifact := &models.CatalogModelArtifactImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelArtifactAttributes{
Name: apiutils.Of("props-test-catalog-artifact"),
ExternalID: apiutils.Of("props-catalog-artifact-ext-123"),
URI: apiutils.Of("s3://catalog-bucket/props-model.pkl"),
},
Properties: &[]dbmodels.Properties{
{
Name: "version",
StringValue: apiutils.Of("1.0.0"),
},
{
Name: "size_bytes",
IntValue: apiutils.Of(int32(2048000)),
},
},
CustomProperties: &[]dbmodels.Properties{
{
Name: "team",
StringValue: apiutils.Of("catalog-ml-team"),
},
{
Name: "is_public",
BoolValue: apiutils.Of(true),
},
},
}
saved, err := repo.Save(catalogModelArtifact, savedCatalogModel.GetID())
require.NoError(t, err)
require.NotNil(t, saved)
// Retrieve and verify properties
retrieved, err := repo.GetByID(*saved.GetID())
require.NoError(t, err)
require.NotNil(t, retrieved)
// Check regular properties
require.NotNil(t, retrieved.GetProperties())
assert.Len(t, *retrieved.GetProperties(), 2)
// Check custom properties
require.NotNil(t, retrieved.GetCustomProperties())
assert.Len(t, *retrieved.GetCustomProperties(), 2)
// Verify specific properties exist
properties := *retrieved.GetProperties()
var foundVersion, foundSizeBytes bool
for _, prop := range properties {
switch prop.Name {
case "version":
foundVersion = true
assert.Equal(t, "1.0.0", *prop.StringValue)
case "size_bytes":
foundSizeBytes = true
assert.Equal(t, int32(2048000), *prop.IntValue)
}
}
assert.True(t, foundVersion, "Should find version property")
assert.True(t, foundSizeBytes, "Should find size_bytes property")
// Verify custom properties
customProperties := *retrieved.GetCustomProperties()
var foundTeam, foundIsPublic bool
for _, prop := range customProperties {
switch prop.Name {
case "team":
foundTeam = true
assert.Equal(t, "catalog-ml-team", *prop.StringValue)
case "is_public":
foundIsPublic = true
assert.Equal(t, true, *prop.BoolValue)
}
}
assert.True(t, foundTeam, "Should find team custom property")
assert.True(t, foundIsPublic, "Should find is_public custom property")
})
t.Run("TestSaveWithoutParentResource", func(t *testing.T) {
// Test creating a catalog model artifact without parent resource attribution
catalogModelArtifact := &models.CatalogModelArtifactImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelArtifactAttributes{
Name: apiutils.Of("standalone-catalog-artifact"),
ExternalID: apiutils.Of("standalone-catalog-artifact-ext"),
URI: apiutils.Of("s3://catalog-bucket/standalone-model.pkl"),
},
Properties: &[]dbmodels.Properties{
{
Name: "description",
StringValue: apiutils.Of("Standalone catalog artifact without parent"),
},
},
}
saved, err := repo.Save(catalogModelArtifact, nil)
require.NoError(t, err)
require.NotNil(t, saved)
require.NotNil(t, saved.GetID())
assert.Equal(t, "standalone-catalog-artifact", *saved.GetAttributes().Name)
assert.Equal(t, "s3://catalog-bucket/standalone-model.pkl", *saved.GetAttributes().URI)
// Verify it can be retrieved
retrieved, err := repo.GetByID(*saved.GetID())
require.NoError(t, err)
assert.Equal(t, "standalone-catalog-artifact", *retrieved.GetAttributes().Name)
})
t.Run("TestListOrdering", func(t *testing.T) {
// Create a catalog model
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(catalogModelTypeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("test-catalog-model-for-ordering"),
ExternalID: apiutils.Of("catalog-model-ordering-ext"),
},
}
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
require.NoError(t, err)
// Create artifacts sequentially with time delays to ensure deterministic ordering
artifact1 := &models.CatalogModelArtifactImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelArtifactAttributes{
Name: apiutils.Of("time-test-catalog-artifact-1"),
ExternalID: apiutils.Of("time-catalog-artifact-ext-1"),
URI: apiutils.Of("s3://catalog-bucket/time-model-1.pkl"),
},
}
saved1, err := repo.Save(artifact1, savedCatalogModel.GetID())
require.NoError(t, err)
// Small delay to ensure different timestamps
time.Sleep(10 * time.Millisecond)
artifact2 := &models.CatalogModelArtifactImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelArtifactAttributes{
Name: apiutils.Of("time-test-catalog-artifact-2"),
ExternalID: apiutils.Of("time-catalog-artifact-ext-2"),
URI: apiutils.Of("s3://catalog-bucket/time-model-2.pkl"),
},
}
saved2, err := repo.Save(artifact2, savedCatalogModel.GetID())
require.NoError(t, err)
// Test ordering by CREATE_TIME
listOptions := models.CatalogModelArtifactListOptions{
Pagination: dbmodels.Pagination{
OrderBy: apiutils.Of("CREATE_TIME"),
},
}
result, err := repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
// Find our test artifacts in the results
var foundArtifact1, foundArtifact2 models.CatalogModelArtifact
var index1, index2 = -1, -1
for i, item := range result.Items {
if *item.GetID() == *saved1.GetID() {
foundArtifact1 = item
index1 = i
}
if *item.GetID() == *saved2.GetID() {
foundArtifact2 = item
index2 = i
}
}
// Verify both artifacts were found and artifact1 comes before artifact2 (ascending order)
require.NotEqual(t, -1, index1, "Artifact 1 should be found in results")
require.NotEqual(t, -1, index2, "Artifact 2 should be found in results")
assert.Less(t, index1, index2, "Artifact 1 should come before Artifact 2 when ordered by CREATE_TIME")
assert.Less(t, *foundArtifact1.GetAttributes().CreateTimeSinceEpoch, *foundArtifact2.GetAttributes().CreateTimeSinceEpoch, "Artifact 1 should have earlier create time")
})
t.Run("TestListPagination", func(t *testing.T) {
// Create a catalog model
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(catalogModelTypeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("test-catalog-model-for-pagination"),
ExternalID: apiutils.Of("catalog-model-pagination-ext"),
},
}
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
require.NoError(t, err)
// Create multiple artifacts for pagination testing
for i := 0; i < 5; i++ {
artifact := &models.CatalogModelArtifactImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelArtifactAttributes{
Name: apiutils.Of(fmt.Sprintf("pagination-artifact-%d", i)),
ExternalID: apiutils.Of(fmt.Sprintf("pagination-artifact-ext-%d", i)),
URI: apiutils.Of(fmt.Sprintf("s3://catalog-bucket/pagination-model-%d.pkl", i)),
},
}
_, err := repo.Save(artifact, savedCatalogModel.GetID())
require.NoError(t, err)
}
// Test pagination with page size
pageSize := int32(2)
listOptions := models.CatalogModelArtifactListOptions{
ParentResourceID: savedCatalogModel.GetID(),
Pagination: dbmodels.Pagination{
PageSize: &pageSize,
OrderBy: apiutils.Of("ID"),
},
}
result, err := repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
assert.LessOrEqual(t, len(result.Items), 2, "Should respect page size limit")
assert.GreaterOrEqual(t, len(result.Items), 1, "Should return at least one item")
})
}
// Helper function to get or create CatalogModelArtifact type ID
func getCatalogModelArtifactTypeID(t *testing.T, db *gorm.DB) int64 {
var typeRecord schema.Type
err := db.Where("name = ?", service.CatalogModelArtifactTypeName).First(&typeRecord).Error
if err != nil {
require.NoError(t, err, "Failed to query CatalogModelArtifact type")
}
return int64(typeRecord.ID)
}

View File

@ -0,0 +1,209 @@
package service_test
import (
"testing"
"github.com/kubeflow/model-registry/catalog/internal/db/models"
"github.com/kubeflow/model-registry/catalog/internal/db/service"
"github.com/kubeflow/model-registry/internal/apiutils"
dbmodels "github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/schema"
"github.com/kubeflow/model-registry/internal/testutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
func TestCatalogModelRepository(t *testing.T) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
// Create or get the CatalogModel type ID
typeID := getCatalogModelTypeID(t, sharedDB)
repo := service.NewCatalogModelRepository(sharedDB, typeID)
t.Run("TestSave", func(t *testing.T) {
// Test creating a new catalog model
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("test-catalog-model"),
ExternalID: apiutils.Of("catalog-ext-123"),
},
Properties: &[]dbmodels.Properties{
{
Name: "description",
StringValue: apiutils.Of("Test catalog model description"),
},
},
CustomProperties: &[]dbmodels.Properties{
{
Name: "custom-prop",
StringValue: apiutils.Of("custom-value"),
},
},
}
saved, err := repo.Save(catalogModel)
require.NoError(t, err)
require.NotNil(t, saved)
require.NotNil(t, saved.GetID())
assert.Equal(t, "test-catalog-model", *saved.GetAttributes().Name)
assert.Equal(t, "catalog-ext-123", *saved.GetAttributes().ExternalID)
// Test updating the same model
catalogModel.ID = saved.GetID()
catalogModel.GetAttributes().Name = apiutils.Of("updated-catalog-model")
// Preserve CreateTimeSinceEpoch from the saved entity
catalogModel.GetAttributes().CreateTimeSinceEpoch = saved.GetAttributes().CreateTimeSinceEpoch
updated, err := repo.Save(catalogModel)
require.NoError(t, err)
require.NotNil(t, updated)
assert.Equal(t, *saved.GetID(), *updated.GetID())
assert.Equal(t, "updated-catalog-model", *updated.GetAttributes().Name)
})
t.Run("TestGetByID", func(t *testing.T) {
// First create a model to retrieve
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("get-test-catalog-model"),
ExternalID: apiutils.Of("get-catalog-ext-123"),
},
}
saved, err := repo.Save(catalogModel)
require.NoError(t, err)
require.NotNil(t, saved.GetID())
// Test retrieving by ID
retrieved, err := repo.GetByID(*saved.GetID())
require.NoError(t, err)
require.NotNil(t, retrieved)
assert.Equal(t, *saved.GetID(), *retrieved.GetID())
assert.Equal(t, "get-test-catalog-model", *retrieved.GetAttributes().Name)
assert.Equal(t, "get-catalog-ext-123", *retrieved.GetAttributes().ExternalID)
// Test retrieving non-existent ID
_, err = repo.GetByID(99999)
assert.ErrorIs(t, err, service.ErrCatalogModelNotFound)
})
t.Run("TestList", func(t *testing.T) {
// Create multiple models for listing
testModels := []*models.CatalogModelImpl{
{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("list-catalog-model-1"),
ExternalID: apiutils.Of("list-catalog-ext-1"),
},
},
{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("list-catalog-model-2"),
ExternalID: apiutils.Of("list-catalog-ext-2"),
},
},
}
// Save all test models
var savedModels []models.CatalogModel
for _, model := range testModels {
saved, err := repo.Save(model)
require.NoError(t, err)
savedModels = append(savedModels, saved)
}
// Test listing all models
listOptions := models.CatalogModelListOptions{}
result, err := repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
assert.GreaterOrEqual(t, len(result.Items), 2) // At least our 2 test models
// Test filtering by name
nameFilter := "list-catalog-model-1"
listOptions = models.CatalogModelListOptions{
Name: &nameFilter,
}
result, err = repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, 1, len(result.Items))
assert.Equal(t, "list-catalog-model-1", *result.Items[0].GetAttributes().Name)
// Test filtering by external ID
externalIDFilter := "list-catalog-ext-2"
listOptions = models.CatalogModelListOptions{
ExternalID: &externalIDFilter,
}
result, err = repo.List(listOptions)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, 1, len(result.Items))
assert.Equal(t, "list-catalog-ext-2", *result.Items[0].GetAttributes().ExternalID)
})
t.Run("TestListWithPropertiesAndCustomProperties", func(t *testing.T) {
// Create a model with both properties and custom properties
catalogModel := &models.CatalogModelImpl{
TypeID: apiutils.Of(int32(typeID)),
Attributes: &models.CatalogModelAttributes{
Name: apiutils.Of("props-test-catalog-model"),
ExternalID: apiutils.Of("props-catalog-ext-123"),
},
Properties: &[]dbmodels.Properties{
{
Name: "version",
StringValue: apiutils.Of("1.0.0"),
},
{
Name: "priority",
IntValue: apiutils.Of(int32(5)),
},
},
CustomProperties: &[]dbmodels.Properties{
{
Name: "team",
StringValue: apiutils.Of("ml-team"),
},
{
Name: "active",
BoolValue: apiutils.Of(true),
},
},
}
saved, err := repo.Save(catalogModel)
require.NoError(t, err)
require.NotNil(t, saved)
// Retrieve and verify properties
retrieved, err := repo.GetByID(*saved.GetID())
require.NoError(t, err)
require.NotNil(t, retrieved)
// Check regular properties
require.NotNil(t, retrieved.GetProperties())
assert.Len(t, *retrieved.GetProperties(), 2)
// Check custom properties
require.NotNil(t, retrieved.GetCustomProperties())
assert.Len(t, *retrieved.GetCustomProperties(), 2)
})
}
// Helper function to get or create CatalogModel type ID
func getCatalogModelTypeID(t *testing.T, db *gorm.DB) int64 {
var typeRecord schema.Type
err := db.Where("name = ?", service.CatalogModelTypeName).First(&typeRecord).Error
if err != nil {
require.NoError(t, err, "Failed to query CatalogModel type")
}
return int64(typeRecord.ID)
}

View File

@ -0,0 +1,12 @@
package service
import (
"os"
"testing"
"github.com/kubeflow/model-registry/internal/testutils"
)
func TestMain(m *testing.M) {
os.Exit(testutils.TestMainHelper(m))
}

View File

@ -0,0 +1,36 @@
package service
import (
"github.com/kubeflow/model-registry/internal/datastore"
)
const (
CatalogModelTypeName = "kf.CatalogModel"
CatalogModelArtifactTypeName = "kf.CatalogModelArtifact"
CatalogMetricsArtifactTypeName = "kf.CatalogMetricsArtifact"
)
func DatastoreSpec() *datastore.Spec {
return datastore.NewSpec().
AddContext(CatalogModelTypeName, datastore.NewSpecType(NewCatalogModelRepository).
AddString("source_id").
AddString("description").
AddString("owner").
AddString("state").
AddStruct("language").
AddString("library_name").
AddString("license_link").
AddString("license").
AddString("logo").
AddString("maturity").
AddString("provider").
AddString("readme").
AddStruct("tasks"),
).
AddArtifact(CatalogModelArtifactTypeName, datastore.NewSpecType(NewCatalogModelArtifactRepository).
AddString("uri"),
).
AddArtifact(CatalogMetricsArtifactTypeName, datastore.NewSpecType(NewCatalogMetricsArtifactRepository).
AddString("metricsType"),
)
}

View File

@ -0,0 +1,5 @@
[mysqld]
character-set-server = utf8mb4
collation-server = utf8mb4_general_ci
!includedir /etc/mysql/conf.d/

View File

@ -8,9 +8,11 @@ model_artifact_type_query_param.go
model_base_model.go
model_base_resource_dates.go
model_base_resource_list.go
model_catalog_artifact.go
model_catalog_artifact_list.go
model_catalog_metrics_artifact.go
model_catalog_model.go
model_catalog_model_artifact.go
model_catalog_model_artifact_list.go
model_catalog_model_list.go
model_catalog_source.go
model_catalog_source_list.go

View File

@ -126,7 +126,7 @@ func (c *ModelCatalogServiceAPIController) GetModel(w http.ResponseWriter, r *ht
EncodeJSONResponse(result.Body, &result.Code, w)
}
// GetAllModelArtifacts - List CatalogModelArtifacts.
// GetAllModelArtifacts - List CatalogArtifacts.
func (c *ModelCatalogServiceAPIController) GetAllModelArtifacts(w http.ResponseWriter, r *http.Request) {
sourceIdParam := chi.URLParam(r, "source_id")
modelNameParam := chi.URLParam(r, "model_name")

View File

@ -725,7 +725,7 @@ func TestFindSources(t *testing.T) {
// Define a mock model provider
type mockModelProvider struct {
models map[string]*model.CatalogModel
artifacts map[string][]model.CatalogModelArtifact
artifacts map[string][]model.CatalogArtifact
}
// Implement GetModel method for the mock provider
@ -784,17 +784,17 @@ func (m *mockModelProvider) ListModels(ctx context.Context, params catalog.ListM
}, nil
}
func (m *mockModelProvider) GetArtifacts(ctx context.Context, name string) (*model.CatalogModelArtifactList, error) {
func (m *mockModelProvider) GetArtifacts(ctx context.Context, name string) (*model.CatalogArtifactList, error) {
artifacts, exists := m.artifacts[name]
if !exists {
return &model.CatalogModelArtifactList{
Items: []model.CatalogModelArtifact{},
return &model.CatalogArtifactList{
Items: []model.CatalogArtifact{},
Size: 0,
PageSize: 0, // Or a default page size if applicable
NextPageToken: "",
}, nil
}
return &model.CatalogModelArtifactList{
return &model.CatalogArtifactList{
Items: artifacts,
Size: int32(len(artifacts)),
PageSize: int32(len(artifacts)),
@ -922,7 +922,7 @@ func TestGetAllModelArtifacts(t *testing.T) {
sourceID string
modelName string
expectedStatus int
expectedArtifacts []model.CatalogModelArtifact
expectedArtifacts []model.CatalogArtifact
}{
{
name: "Existing artifacts for model in source",
@ -930,13 +930,17 @@ func TestGetAllModelArtifacts(t *testing.T) {
"source1": {
Metadata: model.CatalogSource{Id: "source1", Name: "Test Source"},
Provider: &mockModelProvider{
artifacts: map[string][]model.CatalogModelArtifact{
artifacts: map[string][]model.CatalogArtifact{
"test-model": {
{
Uri: "s3://bucket/artifact1",
CatalogModelArtifact: &model.CatalogModelArtifact{
Uri: "s3://bucket/artifact1",
},
},
{
Uri: "s3://bucket/artifact2",
CatalogModelArtifact: &model.CatalogModelArtifact{
Uri: "s3://bucket/artifact2",
},
},
},
},
@ -946,12 +950,16 @@ func TestGetAllModelArtifacts(t *testing.T) {
sourceID: "source1",
modelName: "test-model",
expectedStatus: http.StatusOK,
expectedArtifacts: []model.CatalogModelArtifact{
expectedArtifacts: []model.CatalogArtifact{
{
Uri: "s3://bucket/artifact1",
CatalogModelArtifact: &model.CatalogModelArtifact{
Uri: "s3://bucket/artifact1",
},
},
{
Uri: "s3://bucket/artifact2",
CatalogModelArtifact: &model.CatalogModelArtifact{
Uri: "s3://bucket/artifact2",
},
},
},
},
@ -973,14 +981,14 @@ func TestGetAllModelArtifacts(t *testing.T) {
"source1": {
Metadata: model.CatalogSource{Id: "source1", Name: "Test Source"},
Provider: &mockModelProvider{
artifacts: map[string][]model.CatalogModelArtifact{},
artifacts: map[string][]model.CatalogArtifact{},
},
},
},
sourceID: "source1",
modelName: "test-model",
expectedStatus: http.StatusOK,
expectedArtifacts: []model.CatalogModelArtifact{}, // Should be an empty slice, not nil
expectedArtifacts: []model.CatalogArtifact{}, // Should be an empty slice, not nil
},
}
@ -1008,8 +1016,8 @@ func TestGetAllModelArtifacts(t *testing.T) {
require.NotNil(t, resp.Body)
// Type assertion to access the list of artifacts
artifactList, ok := resp.Body.(*model.CatalogModelArtifactList)
require.True(t, ok, "Response body should be a CatalogModelArtifactList")
artifactList, ok := resp.Body.(*model.CatalogArtifactList)
require.True(t, ok, "Response body should be a CatalogArtifactList")
// Check the artifacts
assert.Equal(t, tc.expectedArtifacts, artifactList.Items)

View File

@ -67,18 +67,13 @@ func AssertBaseResourceListRequired(obj model.BaseResourceList) error {
return nil
}
// AssertCatalogModelArtifactConstraints checks if the values respects the defined constraints
func AssertCatalogModelArtifactConstraints(obj model.CatalogModelArtifact) error {
// AssertCatalogArtifactListConstraints checks if the values respects the defined constraints
func AssertCatalogArtifactListConstraints(obj model.CatalogArtifactList) error {
return nil
}
// AssertCatalogModelArtifactListConstraints checks if the values respects the defined constraints
func AssertCatalogModelArtifactListConstraints(obj model.CatalogModelArtifactList) error {
return nil
}
// AssertCatalogModelArtifactListRequired checks if the required fields are not zero-ed
func AssertCatalogModelArtifactListRequired(obj model.CatalogModelArtifactList) error {
// AssertCatalogArtifactListRequired checks if the required fields are not zero-ed
func AssertCatalogArtifactListRequired(obj model.CatalogArtifactList) error {
elements := map[string]interface{}{
"nextPageToken": obj.NextPageToken,
"pageSize": obj.PageSize,
@ -92,17 +87,43 @@ func AssertCatalogModelArtifactListRequired(obj model.CatalogModelArtifactList)
}
for _, el := range obj.Items {
if err := AssertCatalogModelArtifactRequired(el); err != nil {
if err := AssertCatalogArtifactRequired(el); err != nil {
return err
}
}
return nil
}
// AssertCatalogMetricsArtifactConstraints checks if the values respects the defined constraints
func AssertCatalogMetricsArtifactConstraints(obj model.CatalogMetricsArtifact) error {
return nil
}
// AssertCatalogMetricsArtifactRequired checks if the required fields are not zero-ed
func AssertCatalogMetricsArtifactRequired(obj model.CatalogMetricsArtifact) error {
elements := map[string]interface{}{
"artifactType": obj.ArtifactType,
"metricsType": obj.MetricsType,
}
for name, el := range elements {
if isZero := IsZeroValue(el); isZero {
return &RequiredError{Field: name}
}
}
return nil
}
// AssertCatalogModelArtifactConstraints checks if the values respects the defined constraints
func AssertCatalogModelArtifactConstraints(obj model.CatalogModelArtifact) error {
return nil
}
// AssertCatalogModelArtifactRequired checks if the required fields are not zero-ed
func AssertCatalogModelArtifactRequired(obj model.CatalogModelArtifact) error {
elements := map[string]interface{}{
"uri": obj.Uri,
"artifactType": obj.ArtifactType,
"uri": obj.Uri,
}
for name, el := range elements {
if isZero := IsZeroValue(el); isZero {

View File

@ -0,0 +1,12 @@
package openapi
import (
model "github.com/kubeflow/model-registry/catalog/pkg/openapi"
)
// AssertCatalogArtifactRequired checks if the required fields are not zero-ed
func AssertCatalogArtifactRequired(obj model.CatalogArtifact) error {
// CatalogArtifact has no required fields but the openapi code gen
// checks the fields from CatalogModelArtifact, which doesn't compile.
return nil
}

View File

@ -5,9 +5,11 @@ model_artifact_type_query_param.go
model_base_model.go
model_base_resource_dates.go
model_base_resource_list.go
model_catalog_artifact.go
model_catalog_artifact_list.go
model_catalog_metrics_artifact.go
model_catalog_model.go
model_catalog_model_artifact.go
model_catalog_model_artifact_list.go
model_catalog_model_list.go
model_catalog_source.go
model_catalog_source_list.go

View File

@ -424,12 +424,12 @@ type ApiGetAllModelArtifactsRequest struct {
modelName string
}
func (r ApiGetAllModelArtifactsRequest) Execute() (*CatalogModelArtifactList, *http.Response, error) {
func (r ApiGetAllModelArtifactsRequest) Execute() (*CatalogArtifactList, *http.Response, error) {
return r.ApiService.GetAllModelArtifactsExecute(r)
}
/*
GetAllModelArtifacts List CatalogModelArtifacts.
GetAllModelArtifacts List CatalogArtifacts.
@param ctx context.Context - for authentication, logging, cancellation, deadlines, tracing, etc. Passed from http.Request or context.Background().
@param sourceId A unique identifier for a `CatalogSource`.
@ -447,13 +447,13 @@ func (a *ModelCatalogServiceAPIService) GetAllModelArtifacts(ctx context.Context
// Execute executes the request
//
// @return CatalogModelArtifactList
func (a *ModelCatalogServiceAPIService) GetAllModelArtifactsExecute(r ApiGetAllModelArtifactsRequest) (*CatalogModelArtifactList, *http.Response, error) {
// @return CatalogArtifactList
func (a *ModelCatalogServiceAPIService) GetAllModelArtifactsExecute(r ApiGetAllModelArtifactsRequest) (*CatalogArtifactList, *http.Response, error) {
var (
localVarHTTPMethod = http.MethodGet
localVarPostBody interface{}
formFiles []formFile
localVarReturnValue *CatalogModelArtifactList
localVarReturnValue *CatalogArtifactList
)
localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelCatalogServiceAPIService.GetAllModelArtifacts")

View File

@ -0,0 +1,163 @@
/*
Model Catalog REST API
REST API for Model Registry to create and manage ML model metadata
API version: v1alpha1
*/
// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
package openapi
import (
"encoding/json"
"fmt"
)
// CatalogArtifact - A single artifact in the catalog API.
type CatalogArtifact struct {
CatalogMetricsArtifact *CatalogMetricsArtifact
CatalogModelArtifact *CatalogModelArtifact
}
// CatalogMetricsArtifactAsCatalogArtifact is a convenience function that returns CatalogMetricsArtifact wrapped in CatalogArtifact
func CatalogMetricsArtifactAsCatalogArtifact(v *CatalogMetricsArtifact) CatalogArtifact {
return CatalogArtifact{
CatalogMetricsArtifact: v,
}
}
// CatalogModelArtifactAsCatalogArtifact is a convenience function that returns CatalogModelArtifact wrapped in CatalogArtifact
func CatalogModelArtifactAsCatalogArtifact(v *CatalogModelArtifact) CatalogArtifact {
return CatalogArtifact{
CatalogModelArtifact: v,
}
}
// Unmarshal JSON data into one of the pointers in the struct
func (dst *CatalogArtifact) UnmarshalJSON(data []byte) error {
var err error
// use discriminator value to speed up the lookup
var jsonDict map[string]interface{}
err = newStrictDecoder(data).Decode(&jsonDict)
if err != nil {
return fmt.Errorf("failed to unmarshal JSON into map for the discriminator lookup")
}
// check if the discriminator value is 'CatalogMetricsArtifact'
if jsonDict["artifactType"] == "CatalogMetricsArtifact" {
// try to unmarshal JSON data into CatalogMetricsArtifact
err = json.Unmarshal(data, &dst.CatalogMetricsArtifact)
if err == nil {
return nil // data stored in dst.CatalogMetricsArtifact, return on the first match
} else {
dst.CatalogMetricsArtifact = nil
return fmt.Errorf("failed to unmarshal CatalogArtifact as CatalogMetricsArtifact: %s", err.Error())
}
}
// check if the discriminator value is 'CatalogModelArtifact'
if jsonDict["artifactType"] == "CatalogModelArtifact" {
// try to unmarshal JSON data into CatalogModelArtifact
err = json.Unmarshal(data, &dst.CatalogModelArtifact)
if err == nil {
return nil // data stored in dst.CatalogModelArtifact, return on the first match
} else {
dst.CatalogModelArtifact = nil
return fmt.Errorf("failed to unmarshal CatalogArtifact as CatalogModelArtifact: %s", err.Error())
}
}
// check if the discriminator value is 'metrics-artifact'
if jsonDict["artifactType"] == "metrics-artifact" {
// try to unmarshal JSON data into CatalogMetricsArtifact
err = json.Unmarshal(data, &dst.CatalogMetricsArtifact)
if err == nil {
return nil // data stored in dst.CatalogMetricsArtifact, return on the first match
} else {
dst.CatalogMetricsArtifact = nil
return fmt.Errorf("failed to unmarshal CatalogArtifact as CatalogMetricsArtifact: %s", err.Error())
}
}
// check if the discriminator value is 'model-artifact'
if jsonDict["artifactType"] == "model-artifact" {
// try to unmarshal JSON data into CatalogModelArtifact
err = json.Unmarshal(data, &dst.CatalogModelArtifact)
if err == nil {
return nil // data stored in dst.CatalogModelArtifact, return on the first match
} else {
dst.CatalogModelArtifact = nil
return fmt.Errorf("failed to unmarshal CatalogArtifact as CatalogModelArtifact: %s", err.Error())
}
}
return nil
}
// Marshal data from the first non-nil pointers in the struct to JSON
func (src CatalogArtifact) MarshalJSON() ([]byte, error) {
if src.CatalogMetricsArtifact != nil {
return json.Marshal(&src.CatalogMetricsArtifact)
}
if src.CatalogModelArtifact != nil {
return json.Marshal(&src.CatalogModelArtifact)
}
return nil, nil // no data in oneOf schemas
}
// Get the actual instance
func (obj *CatalogArtifact) GetActualInstance() interface{} {
if obj == nil {
return nil
}
if obj.CatalogMetricsArtifact != nil {
return obj.CatalogMetricsArtifact
}
if obj.CatalogModelArtifact != nil {
return obj.CatalogModelArtifact
}
// all schemas are nil
return nil
}
type NullableCatalogArtifact struct {
value *CatalogArtifact
isSet bool
}
func (v NullableCatalogArtifact) Get() *CatalogArtifact {
return v.value
}
func (v *NullableCatalogArtifact) Set(val *CatalogArtifact) {
v.value = val
v.isSet = true
}
func (v NullableCatalogArtifact) IsSet() bool {
return v.isSet
}
func (v *NullableCatalogArtifact) Unset() {
v.value = nil
v.isSet = false
}
func NewNullableCatalogArtifact(val *CatalogArtifact) *NullableCatalogArtifact {
return &NullableCatalogArtifact{value: val, isSet: true}
}
func (v NullableCatalogArtifact) MarshalJSON() ([]byte, error) {
return json.Marshal(v.value)
}
func (v *NullableCatalogArtifact) UnmarshalJSON(src []byte) error {
v.isSet = true
return json.Unmarshal(src, &v.value)
}

View File

@ -14,27 +14,27 @@ import (
"encoding/json"
)
// checks if the CatalogModelArtifactList type satisfies the MappedNullable interface at compile time
var _ MappedNullable = &CatalogModelArtifactList{}
// checks if the CatalogArtifactList type satisfies the MappedNullable interface at compile time
var _ MappedNullable = &CatalogArtifactList{}
// CatalogModelArtifactList List of CatalogModel entities.
type CatalogModelArtifactList struct {
// CatalogArtifactList List of CatalogModel entities.
type CatalogArtifactList struct {
// Token to use to retrieve next page of results.
NextPageToken string `json:"nextPageToken"`
// Maximum number of resources to return in the result.
PageSize int32 `json:"pageSize"`
// Number of items in result list.
Size int32 `json:"size"`
// Array of `CatalogModelArtifact` entities.
Items []CatalogModelArtifact `json:"items"`
// Array of `CatalogArtifact` entities.
Items []CatalogArtifact `json:"items"`
}
// NewCatalogModelArtifactList instantiates a new CatalogModelArtifactList object
// NewCatalogArtifactList instantiates a new CatalogArtifactList object
// This constructor will assign default values to properties that have it defined,
// and makes sure properties required by API are set, but the set of arguments
// will change when the set of required properties is changed
func NewCatalogModelArtifactList(nextPageToken string, pageSize int32, size int32, items []CatalogModelArtifact) *CatalogModelArtifactList {
this := CatalogModelArtifactList{}
func NewCatalogArtifactList(nextPageToken string, pageSize int32, size int32, items []CatalogArtifact) *CatalogArtifactList {
this := CatalogArtifactList{}
this.NextPageToken = nextPageToken
this.PageSize = pageSize
this.Size = size
@ -42,16 +42,16 @@ func NewCatalogModelArtifactList(nextPageToken string, pageSize int32, size int3
return &this
}
// NewCatalogModelArtifactListWithDefaults instantiates a new CatalogModelArtifactList object
// NewCatalogArtifactListWithDefaults instantiates a new CatalogArtifactList object
// This constructor will only assign default values to properties that have it defined,
// but it doesn't guarantee that properties required by API are set
func NewCatalogModelArtifactListWithDefaults() *CatalogModelArtifactList {
this := CatalogModelArtifactList{}
func NewCatalogArtifactListWithDefaults() *CatalogArtifactList {
this := CatalogArtifactList{}
return &this
}
// GetNextPageToken returns the NextPageToken field value
func (o *CatalogModelArtifactList) GetNextPageToken() string {
func (o *CatalogArtifactList) GetNextPageToken() string {
if o == nil {
var ret string
return ret
@ -62,7 +62,7 @@ func (o *CatalogModelArtifactList) GetNextPageToken() string {
// GetNextPageTokenOk returns a tuple with the NextPageToken field value
// and a boolean to check if the value has been set.
func (o *CatalogModelArtifactList) GetNextPageTokenOk() (*string, bool) {
func (o *CatalogArtifactList) GetNextPageTokenOk() (*string, bool) {
if o == nil {
return nil, false
}
@ -70,12 +70,12 @@ func (o *CatalogModelArtifactList) GetNextPageTokenOk() (*string, bool) {
}
// SetNextPageToken sets field value
func (o *CatalogModelArtifactList) SetNextPageToken(v string) {
func (o *CatalogArtifactList) SetNextPageToken(v string) {
o.NextPageToken = v
}
// GetPageSize returns the PageSize field value
func (o *CatalogModelArtifactList) GetPageSize() int32 {
func (o *CatalogArtifactList) GetPageSize() int32 {
if o == nil {
var ret int32
return ret
@ -86,7 +86,7 @@ func (o *CatalogModelArtifactList) GetPageSize() int32 {
// GetPageSizeOk returns a tuple with the PageSize field value
// and a boolean to check if the value has been set.
func (o *CatalogModelArtifactList) GetPageSizeOk() (*int32, bool) {
func (o *CatalogArtifactList) GetPageSizeOk() (*int32, bool) {
if o == nil {
return nil, false
}
@ -94,12 +94,12 @@ func (o *CatalogModelArtifactList) GetPageSizeOk() (*int32, bool) {
}
// SetPageSize sets field value
func (o *CatalogModelArtifactList) SetPageSize(v int32) {
func (o *CatalogArtifactList) SetPageSize(v int32) {
o.PageSize = v
}
// GetSize returns the Size field value
func (o *CatalogModelArtifactList) GetSize() int32 {
func (o *CatalogArtifactList) GetSize() int32 {
if o == nil {
var ret int32
return ret
@ -110,7 +110,7 @@ func (o *CatalogModelArtifactList) GetSize() int32 {
// GetSizeOk returns a tuple with the Size field value
// and a boolean to check if the value has been set.
func (o *CatalogModelArtifactList) GetSizeOk() (*int32, bool) {
func (o *CatalogArtifactList) GetSizeOk() (*int32, bool) {
if o == nil {
return nil, false
}
@ -118,14 +118,14 @@ func (o *CatalogModelArtifactList) GetSizeOk() (*int32, bool) {
}
// SetSize sets field value
func (o *CatalogModelArtifactList) SetSize(v int32) {
func (o *CatalogArtifactList) SetSize(v int32) {
o.Size = v
}
// GetItems returns the Items field value
func (o *CatalogModelArtifactList) GetItems() []CatalogModelArtifact {
func (o *CatalogArtifactList) GetItems() []CatalogArtifact {
if o == nil {
var ret []CatalogModelArtifact
var ret []CatalogArtifact
return ret
}
@ -134,7 +134,7 @@ func (o *CatalogModelArtifactList) GetItems() []CatalogModelArtifact {
// GetItemsOk returns a tuple with the Items field value
// and a boolean to check if the value has been set.
func (o *CatalogModelArtifactList) GetItemsOk() ([]CatalogModelArtifact, bool) {
func (o *CatalogArtifactList) GetItemsOk() ([]CatalogArtifact, bool) {
if o == nil {
return nil, false
}
@ -142,11 +142,11 @@ func (o *CatalogModelArtifactList) GetItemsOk() ([]CatalogModelArtifact, bool) {
}
// SetItems sets field value
func (o *CatalogModelArtifactList) SetItems(v []CatalogModelArtifact) {
func (o *CatalogArtifactList) SetItems(v []CatalogArtifact) {
o.Items = v
}
func (o CatalogModelArtifactList) MarshalJSON() ([]byte, error) {
func (o CatalogArtifactList) MarshalJSON() ([]byte, error) {
toSerialize, err := o.ToMap()
if err != nil {
return []byte{}, err
@ -154,7 +154,7 @@ func (o CatalogModelArtifactList) MarshalJSON() ([]byte, error) {
return json.Marshal(toSerialize)
}
func (o CatalogModelArtifactList) ToMap() (map[string]interface{}, error) {
func (o CatalogArtifactList) ToMap() (map[string]interface{}, error) {
toSerialize := map[string]interface{}{}
toSerialize["nextPageToken"] = o.NextPageToken
toSerialize["pageSize"] = o.PageSize
@ -163,38 +163,38 @@ func (o CatalogModelArtifactList) ToMap() (map[string]interface{}, error) {
return toSerialize, nil
}
type NullableCatalogModelArtifactList struct {
value *CatalogModelArtifactList
type NullableCatalogArtifactList struct {
value *CatalogArtifactList
isSet bool
}
func (v NullableCatalogModelArtifactList) Get() *CatalogModelArtifactList {
func (v NullableCatalogArtifactList) Get() *CatalogArtifactList {
return v.value
}
func (v *NullableCatalogModelArtifactList) Set(val *CatalogModelArtifactList) {
func (v *NullableCatalogArtifactList) Set(val *CatalogArtifactList) {
v.value = val
v.isSet = true
}
func (v NullableCatalogModelArtifactList) IsSet() bool {
func (v NullableCatalogArtifactList) IsSet() bool {
return v.isSet
}
func (v *NullableCatalogModelArtifactList) Unset() {
func (v *NullableCatalogArtifactList) Unset() {
v.value = nil
v.isSet = false
}
func NewNullableCatalogModelArtifactList(val *CatalogModelArtifactList) *NullableCatalogModelArtifactList {
return &NullableCatalogModelArtifactList{value: val, isSet: true}
func NewNullableCatalogArtifactList(val *CatalogArtifactList) *NullableCatalogArtifactList {
return &NullableCatalogArtifactList{value: val, isSet: true}
}
func (v NullableCatalogModelArtifactList) MarshalJSON() ([]byte, error) {
func (v NullableCatalogArtifactList) MarshalJSON() ([]byte, error) {
return json.Marshal(v.value)
}
func (v *NullableCatalogModelArtifactList) UnmarshalJSON(src []byte) error {
func (v *NullableCatalogArtifactList) UnmarshalJSON(src []byte) error {
v.isSet = true
return json.Unmarshal(src, &v.value)
}

View File

@ -0,0 +1,255 @@
/*
Model Catalog REST API
REST API for Model Registry to create and manage ML model metadata
API version: v1alpha1
*/
// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
package openapi
import (
"encoding/json"
)
// checks if the CatalogMetricsArtifact type satisfies the MappedNullable interface at compile time
var _ MappedNullable = &CatalogMetricsArtifact{}
// CatalogMetricsArtifact A metadata Artifact Entity.
type CatalogMetricsArtifact struct {
// Output only. Create time of the resource in millisecond since epoch.
CreateTimeSinceEpoch *string `json:"createTimeSinceEpoch,omitempty"`
// Output only. Last update time of the resource since epoch in millisecond since epoch.
LastUpdateTimeSinceEpoch *string `json:"lastUpdateTimeSinceEpoch,omitempty"`
ArtifactType string `json:"artifactType"`
MetricsType string `json:"metricsType"`
// User provided custom properties which are not defined by its type.
CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"`
}
// NewCatalogMetricsArtifact instantiates a new CatalogMetricsArtifact object
// This constructor will assign default values to properties that have it defined,
// and makes sure properties required by API are set, but the set of arguments
// will change when the set of required properties is changed
func NewCatalogMetricsArtifact(artifactType string, metricsType string) *CatalogMetricsArtifact {
this := CatalogMetricsArtifact{}
this.ArtifactType = artifactType
this.MetricsType = metricsType
return &this
}
// NewCatalogMetricsArtifactWithDefaults instantiates a new CatalogMetricsArtifact object
// This constructor will only assign default values to properties that have it defined,
// but it doesn't guarantee that properties required by API are set
func NewCatalogMetricsArtifactWithDefaults() *CatalogMetricsArtifact {
this := CatalogMetricsArtifact{}
var artifactType string = "metrics-artifact"
this.ArtifactType = artifactType
return &this
}
// GetCreateTimeSinceEpoch returns the CreateTimeSinceEpoch field value if set, zero value otherwise.
func (o *CatalogMetricsArtifact) GetCreateTimeSinceEpoch() string {
if o == nil || IsNil(o.CreateTimeSinceEpoch) {
var ret string
return ret
}
return *o.CreateTimeSinceEpoch
}
// GetCreateTimeSinceEpochOk returns a tuple with the CreateTimeSinceEpoch field value if set, nil otherwise
// and a boolean to check if the value has been set.
func (o *CatalogMetricsArtifact) GetCreateTimeSinceEpochOk() (*string, bool) {
if o == nil || IsNil(o.CreateTimeSinceEpoch) {
return nil, false
}
return o.CreateTimeSinceEpoch, true
}
// HasCreateTimeSinceEpoch returns a boolean if a field has been set.
func (o *CatalogMetricsArtifact) HasCreateTimeSinceEpoch() bool {
if o != nil && !IsNil(o.CreateTimeSinceEpoch) {
return true
}
return false
}
// SetCreateTimeSinceEpoch gets a reference to the given string and assigns it to the CreateTimeSinceEpoch field.
func (o *CatalogMetricsArtifact) SetCreateTimeSinceEpoch(v string) {
o.CreateTimeSinceEpoch = &v
}
// GetLastUpdateTimeSinceEpoch returns the LastUpdateTimeSinceEpoch field value if set, zero value otherwise.
func (o *CatalogMetricsArtifact) GetLastUpdateTimeSinceEpoch() string {
if o == nil || IsNil(o.LastUpdateTimeSinceEpoch) {
var ret string
return ret
}
return *o.LastUpdateTimeSinceEpoch
}
// GetLastUpdateTimeSinceEpochOk returns a tuple with the LastUpdateTimeSinceEpoch field value if set, nil otherwise
// and a boolean to check if the value has been set.
func (o *CatalogMetricsArtifact) GetLastUpdateTimeSinceEpochOk() (*string, bool) {
if o == nil || IsNil(o.LastUpdateTimeSinceEpoch) {
return nil, false
}
return o.LastUpdateTimeSinceEpoch, true
}
// HasLastUpdateTimeSinceEpoch returns a boolean if a field has been set.
func (o *CatalogMetricsArtifact) HasLastUpdateTimeSinceEpoch() bool {
if o != nil && !IsNil(o.LastUpdateTimeSinceEpoch) {
return true
}
return false
}
// SetLastUpdateTimeSinceEpoch gets a reference to the given string and assigns it to the LastUpdateTimeSinceEpoch field.
func (o *CatalogMetricsArtifact) SetLastUpdateTimeSinceEpoch(v string) {
o.LastUpdateTimeSinceEpoch = &v
}
// GetArtifactType returns the ArtifactType field value
func (o *CatalogMetricsArtifact) GetArtifactType() string {
if o == nil {
var ret string
return ret
}
return o.ArtifactType
}
// GetArtifactTypeOk returns a tuple with the ArtifactType field value
// and a boolean to check if the value has been set.
func (o *CatalogMetricsArtifact) GetArtifactTypeOk() (*string, bool) {
if o == nil {
return nil, false
}
return &o.ArtifactType, true
}
// SetArtifactType sets field value
func (o *CatalogMetricsArtifact) SetArtifactType(v string) {
o.ArtifactType = v
}
// GetMetricsType returns the MetricsType field value
func (o *CatalogMetricsArtifact) GetMetricsType() string {
if o == nil {
var ret string
return ret
}
return o.MetricsType
}
// GetMetricsTypeOk returns a tuple with the MetricsType field value
// and a boolean to check if the value has been set.
func (o *CatalogMetricsArtifact) GetMetricsTypeOk() (*string, bool) {
if o == nil {
return nil, false
}
return &o.MetricsType, true
}
// SetMetricsType sets field value
func (o *CatalogMetricsArtifact) SetMetricsType(v string) {
o.MetricsType = v
}
// GetCustomProperties returns the CustomProperties field value if set, zero value otherwise.
func (o *CatalogMetricsArtifact) GetCustomProperties() map[string]MetadataValue {
if o == nil || IsNil(o.CustomProperties) {
var ret map[string]MetadataValue
return ret
}
return *o.CustomProperties
}
// GetCustomPropertiesOk returns a tuple with the CustomProperties field value if set, nil otherwise
// and a boolean to check if the value has been set.
func (o *CatalogMetricsArtifact) GetCustomPropertiesOk() (*map[string]MetadataValue, bool) {
if o == nil || IsNil(o.CustomProperties) {
return nil, false
}
return o.CustomProperties, true
}
// HasCustomProperties returns a boolean if a field has been set.
func (o *CatalogMetricsArtifact) HasCustomProperties() bool {
if o != nil && !IsNil(o.CustomProperties) {
return true
}
return false
}
// SetCustomProperties gets a reference to the given map[string]MetadataValue and assigns it to the CustomProperties field.
func (o *CatalogMetricsArtifact) SetCustomProperties(v map[string]MetadataValue) {
o.CustomProperties = &v
}
func (o CatalogMetricsArtifact) MarshalJSON() ([]byte, error) {
toSerialize, err := o.ToMap()
if err != nil {
return []byte{}, err
}
return json.Marshal(toSerialize)
}
func (o CatalogMetricsArtifact) ToMap() (map[string]interface{}, error) {
toSerialize := map[string]interface{}{}
if !IsNil(o.CreateTimeSinceEpoch) {
toSerialize["createTimeSinceEpoch"] = o.CreateTimeSinceEpoch
}
if !IsNil(o.LastUpdateTimeSinceEpoch) {
toSerialize["lastUpdateTimeSinceEpoch"] = o.LastUpdateTimeSinceEpoch
}
toSerialize["artifactType"] = o.ArtifactType
toSerialize["metricsType"] = o.MetricsType
if !IsNil(o.CustomProperties) {
toSerialize["customProperties"] = o.CustomProperties
}
return toSerialize, nil
}
type NullableCatalogMetricsArtifact struct {
value *CatalogMetricsArtifact
isSet bool
}
func (v NullableCatalogMetricsArtifact) Get() *CatalogMetricsArtifact {
return v.value
}
func (v *NullableCatalogMetricsArtifact) Set(val *CatalogMetricsArtifact) {
v.value = val
v.isSet = true
}
func (v NullableCatalogMetricsArtifact) IsSet() bool {
return v.isSet
}
func (v *NullableCatalogMetricsArtifact) Unset() {
v.value = nil
v.isSet = false
}
func NewNullableCatalogMetricsArtifact(val *CatalogMetricsArtifact) *NullableCatalogMetricsArtifact {
return &NullableCatalogMetricsArtifact{value: val, isSet: true}
}
func (v NullableCatalogMetricsArtifact) MarshalJSON() ([]byte, error) {
return json.Marshal(v.value)
}
func (v *NullableCatalogMetricsArtifact) UnmarshalJSON(src []byte) error {
v.isSet = true
return json.Unmarshal(src, &v.value)
}

View File

@ -17,13 +17,14 @@ import (
// checks if the CatalogModelArtifact type satisfies the MappedNullable interface at compile time
var _ MappedNullable = &CatalogModelArtifact{}
// CatalogModelArtifact A single artifact for a catalog model.
// CatalogModelArtifact A Catalog Model Artifact Entity.
type CatalogModelArtifact struct {
// Output only. Create time of the resource in millisecond since epoch.
CreateTimeSinceEpoch *string `json:"createTimeSinceEpoch,omitempty"`
// Output only. Last update time of the resource since epoch in millisecond since epoch.
LastUpdateTimeSinceEpoch *string `json:"lastUpdateTimeSinceEpoch,omitempty"`
// URI where the artifact can be retrieved.
ArtifactType string `json:"artifactType"`
// URI where the model can be retrieved.
Uri string `json:"uri"`
// User provided custom properties which are not defined by its type.
CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"`
@ -33,8 +34,9 @@ type CatalogModelArtifact struct {
// This constructor will assign default values to properties that have it defined,
// and makes sure properties required by API are set, but the set of arguments
// will change when the set of required properties is changed
func NewCatalogModelArtifact(uri string) *CatalogModelArtifact {
func NewCatalogModelArtifact(artifactType string, uri string) *CatalogModelArtifact {
this := CatalogModelArtifact{}
this.ArtifactType = artifactType
this.Uri = uri
return &this
}
@ -44,6 +46,8 @@ func NewCatalogModelArtifact(uri string) *CatalogModelArtifact {
// but it doesn't guarantee that properties required by API are set
func NewCatalogModelArtifactWithDefaults() *CatalogModelArtifact {
this := CatalogModelArtifact{}
var artifactType string = "model-artifact"
this.ArtifactType = artifactType
return &this
}
@ -111,6 +115,30 @@ func (o *CatalogModelArtifact) SetLastUpdateTimeSinceEpoch(v string) {
o.LastUpdateTimeSinceEpoch = &v
}
// GetArtifactType returns the ArtifactType field value
func (o *CatalogModelArtifact) GetArtifactType() string {
if o == nil {
var ret string
return ret
}
return o.ArtifactType
}
// GetArtifactTypeOk returns a tuple with the ArtifactType field value
// and a boolean to check if the value has been set.
func (o *CatalogModelArtifact) GetArtifactTypeOk() (*string, bool) {
if o == nil {
return nil, false
}
return &o.ArtifactType, true
}
// SetArtifactType sets field value
func (o *CatalogModelArtifact) SetArtifactType(v string) {
o.ArtifactType = v
}
// GetUri returns the Uri field value
func (o *CatalogModelArtifact) GetUri() string {
if o == nil {
@ -183,6 +211,7 @@ func (o CatalogModelArtifact) ToMap() (map[string]interface{}, error) {
if !IsNil(o.LastUpdateTimeSinceEpoch) {
toSerialize["lastUpdateTimeSinceEpoch"] = o.LastUpdateTimeSinceEpoch
}
toSerialize["artifactType"] = o.ArtifactType
toSerialize["uri"] = o.Uri
if !IsNil(o.CustomProperties) {
toSerialize["customProperties"] = o.CustomProperties

View File

@ -5,10 +5,9 @@ set -e
ASSERT_FILE_PATH="$1/type_asserts.go"
PROJECT_ROOT=$(realpath "$(dirname "$0")"/..)
PATCH="${PROJECT_ROOT}/patches/type_asserts.patch"
# AssertMetadataValueRequired from this file generates with the incorrect logic.
rm -f $1/model_metadata_value.go
# These files generate with incorrect logic:
rm -f "$1/model_metadata_value.go" "$1/model_catalog_artifact.go"
python3 "${PROJECT_ROOT}/scripts/gen_type_asserts.py" $1 >"$ASSERT_FILE_PATH"

View File

@ -3,12 +3,16 @@ package cmd
import (
"fmt"
"net/http"
"reflect"
"strings"
"sync"
"github.com/golang/glog"
"github.com/kubeflow/model-registry/internal/core"
"github.com/kubeflow/model-registry/internal/datastore"
"github.com/kubeflow/model-registry/internal/datastore/embedmd"
"github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/service"
"github.com/kubeflow/model-registry/internal/proxy"
"github.com/kubeflow/model-registry/internal/server/middleware"
"github.com/kubeflow/model-registry/internal/server/openapi"
@ -18,7 +22,8 @@ import (
)
type ProxyConfig struct {
Datastore datastore.Datastore
EmbedMD embedmd.EmbedMDConfig
DatastoreType string
}
const (
@ -28,11 +33,9 @@ const (
var (
proxyCfg = ProxyConfig{
Datastore: datastore.Datastore{
Type: "embedmd",
EmbedMD: embedmd.EmbedMDConfig{
TLSConfig: &tls.TLSConfig{},
},
DatastoreType: "embedmd",
EmbedMD: embedmd.EmbedMDConfig{
TLSConfig: &tls.TLSConfig{},
},
}
@ -102,10 +105,19 @@ func runProxyServer(cmd *cobra.Command, args []string) error {
http.Error(w, datastoreUnavailableMessage, http.StatusServiceUnavailable)
}))
mrHealthChecker := &ConditionalModelRegistryHealthChecker{holder: serviceHolder}
dbHealthChecker := proxy.NewDatabaseHealthChecker(proxyCfg.Datastore)
generalReadinessHandler := proxy.GeneralReadinessHandler(proxyCfg.Datastore, dbHealthChecker, mrHealthChecker)
readinessHandler := proxy.GeneralReadinessHandler(proxyCfg.Datastore, dbHealthChecker)
readyChecks := []proxy.HealthChecker{}
generalChecks := []proxy.HealthChecker{
&ConditionalModelRegistryHealthChecker{holder: serviceHolder},
}
if proxyCfg.DatastoreType == "embedmd" {
dbHealthChecker := proxy.NewDatabaseHealthChecker()
readyChecks = append(readyChecks, dbHealthChecker)
generalChecks = append(generalChecks, dbHealthChecker)
}
generalReadinessHandler := proxy.GeneralReadinessHandler(generalChecks...)
readinessHandler := proxy.GeneralReadinessHandler(readyChecks...)
// route health endpoints appropriately
mainHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -142,13 +154,13 @@ func runProxyServer(cmd *cobra.Command, args []string) error {
defer wg.Done()
ds, err = datastore.NewConnector(proxyCfg.Datastore)
ds, err = datastore.NewConnector(proxyCfg.DatastoreType, &proxyCfg.EmbedMD)
if err != nil {
errChan <- fmt.Errorf("error creating datastore: %w", err)
return
}
conn, err := ds.Connect()
conn, err := newModelRegistryService(ds)
if err != nil {
// {{ALERT}} is used to identify this error in pod logs, DO NOT REMOVE
errChan <- fmt.Errorf("{{ALERT}} error connecting to datastore: %w", err)
@ -177,32 +189,63 @@ func runProxyServer(cmd *cobra.Command, args []string) error {
}
}()
defer func() {
if ds != nil {
//nolint:errcheck
ds.Teardown()
}
}()
// Wait for either the Datastore server connection or the proxy server to return an error
// or for both to finish successfully.
return <-errChan
}
func newModelRegistryService(ds datastore.Connector) (api.ModelRegistryApi, error) {
repoSet, err := ds.Connect(service.DatastoreSpec())
if err != nil {
return nil, err
}
modelRegistryService := core.NewModelRegistryService(
getRepo[models.ArtifactRepository](repoSet),
getRepo[models.ModelArtifactRepository](repoSet),
getRepo[models.DocArtifactRepository](repoSet),
getRepo[models.RegisteredModelRepository](repoSet),
getRepo[models.ModelVersionRepository](repoSet),
getRepo[models.ServingEnvironmentRepository](repoSet),
getRepo[models.InferenceServiceRepository](repoSet),
getRepo[models.ServeModelRepository](repoSet),
getRepo[models.ExperimentRepository](repoSet),
getRepo[models.ExperimentRunRepository](repoSet),
getRepo[models.DataSetRepository](repoSet),
getRepo[models.MetricRepository](repoSet),
getRepo[models.ParameterRepository](repoSet),
getRepo[models.MetricHistoryRepository](repoSet),
repoSet.TypeMap(),
)
glog.Infof("EmbedMD service connected")
return modelRegistryService, nil
}
func getRepo[T any](repoSet datastore.RepoSet) T {
repo, err := repoSet.Repository(reflect.TypeFor[T]())
if err != nil {
panic(fmt.Sprintf("unable to get repository: %v", err))
}
return repo.(T)
}
func init() {
rootCmd.AddCommand(proxyCmd)
proxyCmd.Flags().StringVarP(&cfg.Hostname, "hostname", "n", cfg.Hostname, "Proxy server listen hostname")
proxyCmd.Flags().IntVarP(&cfg.Port, "port", "p", cfg.Port, "Proxy server listen port")
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.EmbedMD.DatabaseType, "embedmd-database-type", "mysql", "EmbedMD database type")
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.EmbedMD.DatabaseDSN, "embedmd-database-dsn", "", "EmbedMD database DSN")
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.EmbedMD.TLSConfig.CertPath, "embedmd-database-ssl-cert", "", "EmbedMD SSL cert path")
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.EmbedMD.TLSConfig.KeyPath, "embedmd-database-ssl-key", "", "EmbedMD SSL key path")
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.EmbedMD.TLSConfig.RootCertPath, "embedmd-database-ssl-root-cert", "", "EmbedMD SSL root cert path")
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.EmbedMD.TLSConfig.CAPath, "embedmd-database-ssl-ca", "", "EmbedMD SSL CA path")
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.EmbedMD.TLSConfig.Cipher, "embedmd-database-ssl-cipher", "", "Colon-separated list of allowed TLS ciphers for the EmbedMD database connection. Values are from the list at https://pkg.go.dev/crypto/tls#pkg-constants e.g. 'TLS_AES_128_GCM_SHA256:TLS_CHACHA20_POLY1305_SHA256'")
proxyCmd.Flags().BoolVar(&proxyCfg.Datastore.EmbedMD.TLSConfig.VerifyServerCert, "embedmd-database-ssl-verify-server-cert", false, "EmbedMD SSL verify server cert")
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.DatabaseType, "embedmd-database-type", "mysql", "EmbedMD database type")
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.DatabaseDSN, "embedmd-database-dsn", "", "EmbedMD database DSN")
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.CertPath, "embedmd-database-ssl-cert", "", "EmbedMD SSL cert path")
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.KeyPath, "embedmd-database-ssl-key", "", "EmbedMD SSL key path")
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.RootCertPath, "embedmd-database-ssl-root-cert", "", "EmbedMD SSL root cert path")
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.CAPath, "embedmd-database-ssl-ca", "", "EmbedMD SSL CA path")
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.Cipher, "embedmd-database-ssl-cipher", "", "Colon-separated list of allowed TLS ciphers for the EmbedMD database connection. Values are from the list at https://pkg.go.dev/crypto/tls#pkg-constants e.g. 'TLS_AES_128_GCM_SHA256:TLS_CHACHA20_POLY1305_SHA256'")
proxyCmd.Flags().BoolVar(&proxyCfg.EmbedMD.TLSConfig.VerifyServerCert, "embedmd-database-ssl-verify-server-cert", false, "EmbedMD SSL verify server cert")
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.Type, "datastore-type", proxyCfg.Datastore.Type, "Datastore type")
proxyCmd.Flags().StringVar(&proxyCfg.DatastoreType, "datastore-type", proxyCfg.DatastoreType, "Datastore type")
}

View File

@ -39,6 +39,7 @@ k8s_resource(
new_name="db",
labels="backend",
resource_deps=["kubeflow-namespace"],
port_forwards="3306:3306",
)
k8s_resource(

View File

@ -1,4 +1,4 @@
manifests = kustomize("../../../manifests/kustomize/options/catalog")
manifests = kustomize("../../../manifests/kustomize/options/catalog/base")
objects = decode_yaml_stream(manifests)
@ -20,3 +20,12 @@ k8s_resource(
port_forwards="8082:8080",
trigger_mode=TRIGGER_MODE_AUTO
)
k8s_resource(
workload="model-catalog-postgres",
new_name="catalog-db",
labels="backend",
resource_deps=["kubeflow-namespace"],
port_forwards="5432:5432",
trigger_mode=TRIGGER_MODE_AUTO
)

View File

@ -19,6 +19,15 @@ services:
- "8081:8081"
volumes:
- ./catalog/internal/catalog/testdata:/testdata
depends_on:
- postgres
profiles:
- postgres
environment:
- PGHOST=postgres
- PGDATABASE=model_catalog
- PGUSER=postgres
- PGPASSWORD=demo
model-registry:
build:
@ -82,7 +91,16 @@ services:
start_period: 20s
profiles:
- postgres
configs:
- source: pginit
target: /docker-entrypoint-initdb.d/pginit.sql
volumes:
mysql_data:
postgres_data:
configs:
pginit:
content: |
CREATE DATABASE model_catalog;
GRANT ALL PRIVILEGES ON model_catalog TO postgres;

View File

@ -18,6 +18,15 @@ services:
- "8081:8081"
volumes:
- ./catalog/internal/catalog/testdata:/testdata
depends_on:
- postgres
profiles:
- postgres
environment:
- PGHOST=postgres
- PGDATABASE=model_catalog
- PGUSER=postgres
- PGPASSWORD=demo
model-registry:
image: ghcr.io/kubeflow/model-registry/server:latest
@ -80,7 +89,16 @@ services:
start_period: 20s
profiles:
- postgres
configs:
- source: pginit
target: /docker-entrypoint-initdb.d/pginit.sql
volumes:
mysql_data:
postgres_data:
configs:
pginit:
content: |
CREATE DATABASE model_catalog;
GRANT ALL PRIVILEGES ON model_catalog TO postgres;

View File

@ -18,7 +18,7 @@ func TestMain(m *testing.M) {
}
func setupTestDB(t *testing.T) (*gorm.DB, func()) {
db, dbCleanup := testutils.SetupMySQLWithMigrations(t)
db, dbCleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
// Clean up test data before each test
testutils.CleanupTestData(t, db)
@ -69,7 +69,14 @@ func createModelRegistryService(t *testing.T, db *gorm.DB) *core.ModelRegistrySe
typesMap := getTypeIDs(t, db)
// Create all repositories
artifactRepo := service.NewArtifactRepository(db, typesMap[defaults.ModelArtifactTypeName], typesMap[defaults.DocArtifactTypeName], typesMap[defaults.DataSetTypeName], typesMap[defaults.MetricTypeName], typesMap[defaults.ParameterTypeName], typesMap[defaults.MetricHistoryTypeName])
artifactRepo := service.NewArtifactRepository(db, map[string]int64{
defaults.ModelArtifactTypeName: typesMap[defaults.ModelArtifactTypeName],
defaults.DocArtifactTypeName: typesMap[defaults.DocArtifactTypeName],
defaults.DataSetTypeName: typesMap[defaults.DataSetTypeName],
defaults.MetricTypeName: typesMap[defaults.MetricTypeName],
defaults.ParameterTypeName: typesMap[defaults.ParameterTypeName],
defaults.MetricHistoryTypeName: typesMap[defaults.MetricHistoryTypeName],
})
modelArtifactRepo := service.NewModelArtifactRepository(db, typesMap[defaults.ModelArtifactTypeName])
docArtifactRepo := service.NewDocArtifactRepository(db, typesMap[defaults.DocArtifactTypeName])
registeredModelRepo := service.NewRegisteredModelRepository(db, typesMap[defaults.RegisteredModelTypeName])

View File

@ -3,9 +3,8 @@ package datastore
import (
"errors"
"fmt"
"github.com/kubeflow/model-registry/internal/datastore/embedmd"
"github.com/kubeflow/model-registry/pkg/api"
"maps"
"slices"
)
var (
@ -13,32 +12,29 @@ var (
ErrUnsupportedDatastore = errors.New("unsupported datastore type")
)
type TeardownFunc func() error
type Datastore struct {
EmbedMD embedmd.EmbedMDConfig
Type string
}
type Connector interface {
Connect() (api.ModelRegistryApi, error)
Teardown() error
Type() string
Connect(spec *Spec) (RepoSet, error)
}
func NewConnector(ds Datastore) (Connector, error) {
switch ds.Type {
case "embedmd":
if err := ds.EmbedMD.Validate(); err != nil {
return nil, fmt.Errorf("invalid EmbedMD config: %w", err)
}
var connectorTypes map[string]func(any) (Connector, error)
embedmd, err := embedmd.NewEmbedMDService(&ds.EmbedMD)
if err != nil {
return nil, fmt.Errorf("error creating EmbedMD service: %w", err)
}
return embedmd, nil
default:
return nil, fmt.Errorf("%w: %s. Supported types: embedmd", ErrUnsupportedDatastore, ds.Type)
func Register(t string, fn func(config any) (Connector, error)) {
if connectorTypes == nil {
connectorTypes = make(map[string]func(any) (Connector, error), 1)
}
if _, exists := connectorTypes[t]; exists {
panic(fmt.Sprintf("duplicate connector type: %s", t))
}
connectorTypes[t] = fn
}
func NewConnector(t string, config any) (Connector, error) {
if fn, ok := connectorTypes[t]; ok {
return fn(config)
}
return nil, fmt.Errorf("%w: %s. Supported types: %v", ErrUnsupportedDatastore, t, slices.Sorted(maps.Keys(connectorTypes)))
}

View File

@ -23,19 +23,6 @@ func (Type) TableName() string {
return "Type"
}
// TypeProperty represents the TypeProperty table structure
type TypeProperty struct {
ID int64 `gorm:"primaryKey"`
TypeID int64 `gorm:"column:type_id"`
Name string `gorm:"column:name"`
DataType string `gorm:"column:data_type"`
Description string `gorm:"column:description"`
}
func (TypeProperty) TableName() string {
return "TypeProperty"
}
func TestMain(m *testing.M) {
os.Exit(testutils.TestMainHelper(m))
}
@ -63,11 +50,6 @@ func TestMigrations(t *testing.T) {
err = sharedDB.Model(&Type{}).Count(&count).Error
require.NoError(t, err)
assert.Greater(t, count, int64(0))
// Verify TypeProperty table has expected entries
err = sharedDB.Model(&TypeProperty{}).Count(&count).Error
require.NoError(t, err)
assert.Greater(t, count, int64(0))
}
func TestDownMigrations(t *testing.T) {

View File

@ -7,12 +7,5 @@ DELETE FROM `Type` WHERE `name` IN (
'mlmd.Transform',
'mlmd.Process',
'mlmd.Evaluate',
'mlmd.Deploy',
'kf.RegisteredModel',
'kf.ModelVersion',
'kf.DocArtifact',
'kf.ModelArtifact',
'kf.ServingEnvironment',
'kf.InferenceService',
'kf.ServeModel'
'mlmd.Deploy'
);

View File

@ -9,13 +9,6 @@ SELECT t.* FROM (
UNION ALL SELECT 'mlmd.Process', NULL, 0, NULL, NULL, NULL, NULL
UNION ALL SELECT 'mlmd.Evaluate', NULL, 0, NULL, NULL, NULL, NULL
UNION ALL SELECT 'mlmd.Deploy', NULL, 0, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.RegisteredModel', NULL, 2, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.ModelVersion', NULL, 2, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.DocArtifact', NULL, 1, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.ModelArtifact', NULL, 1, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.ServingEnvironment', NULL, 2, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.InferenceService', NULL, 2, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.ServeModel', NULL, 0, NULL, NULL, NULL, NULL
) t
WHERE NOT EXISTS (
SELECT 1 FROM `Type`

View File

@ -1,27 +1 @@
DELETE FROM `TypeProperty`
WHERE (`type_id`, `name`, `data_type`) IN (
((SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'description', 3),
((SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'owner', 3),
((SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'state', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'author', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'description', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'model_name', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'state', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'version', 3),
((SELECT id FROM `Type` WHERE name = 'kf.DocArtifact'), 'description', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'description', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'model_format_name', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'model_format_version', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'service_account_name', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'storage_key', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'storage_path', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ServingEnvironment'), 'description', 3),
((SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'description', 3),
((SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'desired_state', 3),
((SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'model_version_id', 1),
((SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'registered_model_id', 1),
((SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'runtime', 3),
((SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'serving_environment_id', 1),
((SELECT id FROM `Type` WHERE name = 'kf.ServeModel'), 'description', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ServeModel'), 'model_version_id', 1)
);
-- Migration removed

View File

@ -1,25 +1 @@
INSERT IGNORE INTO `TypeProperty` (`type_id`, `name`, `data_type`)
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'description', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'owner', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'state', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'author', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'description', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'model_name', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'state', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'version', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.DocArtifact'), 'description', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'description', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'model_format_name', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'model_format_version', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'service_account_name', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'storage_key', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'storage_path', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ServingEnvironment'), 'description', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'description', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'desired_state', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'model_version_id', 1 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'registered_model_id', 1 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'runtime', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'serving_environment_id', 1 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ServeModel'), 'description', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ServeModel'), 'model_version_id', 1;
-- Migration removed

View File

@ -1,13 +1 @@
DELETE FROM `TypeProperty` WHERE type_id=(
SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'
) AND `name` IN (
'language',
'library_name',
'license_link',
'license',
'logo',
'maturity',
'provider',
'readme',
'tasks'
);
-- Migration removed

View File

@ -1,10 +1 @@
INSERT IGNORE INTO `TypeProperty` (`type_id`, `name`, `data_type`)
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'language', 4 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'library_name', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'license_link', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'license', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'logo', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'maturity', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'provider', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'readme', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'tasks', 4;
-- Migration removed

View File

@ -1,8 +1 @@
DELETE FROM `Type` WHERE `name` IN (
'kf.MetricHistory',
'kf.Experiment',
'kf.ExperimentRun',
'kf.DataSet',
'kf.Metric',
'kf.Parameter'
);
-- Migration removed

View File

@ -1,13 +1 @@
INSERT INTO `Type` (`name`, `version`, `type_kind`, `description`, `input_type`, `output_type`, `external_id`)
SELECT t.* FROM (
SELECT 'kf.MetricHistory' as name, NULL as version, 1 as type_kind, NULL as description, NULL as input_type, NULL as output_type, NULL as external_id
UNION ALL SELECT 'kf.Experiment', NULL, 2, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.ExperimentRun', NULL, 2, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.DataSet', NULL, 1, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.Metric', NULL, 1, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.Parameter', NULL, 1, NULL, NULL, NULL, NULL
) t
WHERE NOT EXISTS (
SELECT 1 FROM `Type`
WHERE `name` = t.name AND `version` IS NULL
);
-- Migration removed

View File

@ -1,30 +1 @@
DELETE FROM `TypeProperty`
WHERE (`type_id`, `name`, `data_type`) IN (
((SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'description', 3),
((SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'value', 5),
((SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'timestamp', 3),
((SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'step', 1),
((SELECT id FROM `Type` WHERE name = 'kf.Parameter'), 'description', 3),
((SELECT id FROM `Type` WHERE name = 'kf.Parameter'), 'value', 3),
((SELECT id FROM `Type` WHERE name = 'kf.Parameter'), 'parameter_type', 3),
((SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'description', 3),
((SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'value', 5),
((SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'timestamp', 3),
((SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'step', 1),
((SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'description', 3),
((SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'digest', 3),
((SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'source_type', 3),
((SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'source', 3),
((SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'schema', 3),
((SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'profile', 3),
((SELECT id FROM `Type` WHERE name = 'kf.Experiment'), 'description', 3),
((SELECT id FROM `Type` WHERE name = 'kf.Experiment'), 'owner', 3),
((SELECT id FROM `Type` WHERE name = 'kf.Experiment'), 'state', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'description', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'owner', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'state', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'status', 3),
((SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'start_time_since_epoch', 1),
((SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'end_time_since_epoch', 1),
((SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'experiment_id', 1)
);
-- Migration removed

View File

@ -1,28 +1 @@
INSERT IGNORE INTO `TypeProperty` (`type_id`, `name`, `data_type`)
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'description', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'value', 5 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'timestamp', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'step', 1 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Parameter'), 'description', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Parameter'), 'value', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Parameter'), 'parameter_type', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'description', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'value', 5 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'timestamp', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'step', 1 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'description', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'digest', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'source_type', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'source', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'schema', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'profile', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Experiment'), 'description', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Experiment'), 'owner', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Experiment'), 'state', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'description', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'owner', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'state', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'status', 3 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'start_time_since_epoch', 1 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'end_time_since_epoch', 1 UNION ALL
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'experiment_id', 1;
-- Migration removed

View File

@ -28,19 +28,6 @@ func (Type) TableName() string {
return "Type"
}
// TypeProperty represents the TypeProperty table structure
type TypeProperty struct {
ID int64 `gorm:"primaryKey"`
TypeID int64 `gorm:"column:type_id"`
Name string `gorm:"column:name"`
DataType string `gorm:"column:data_type"`
Description string `gorm:"column:description"`
}
func (TypeProperty) TableName() string {
return "TypeProperty"
}
// Package-level shared database instance
var (
sharedDB *gorm.DB
@ -190,11 +177,6 @@ func TestMigrations(t *testing.T) {
err = sharedDB.Model(&Type{}).Count(&count).Error
require.NoError(t, err)
assert.Greater(t, count, int64(0))
// Verify TypeProperty table has expected entries
err = sharedDB.Model(&TypeProperty{}).Count(&count).Error
require.NoError(t, err)
assert.Greater(t, count, int64(0))
}
func TestDownMigrations(t *testing.T) {

View File

@ -7,12 +7,5 @@ DELETE FROM "Type" WHERE name IN (
'mlmd.Transform',
'mlmd.Process',
'mlmd.Evaluate',
'mlmd.Deploy',
'kf.RegisteredModel',
'kf.ModelVersion',
'kf.DocArtifact',
'kf.ModelArtifact',
'kf.ServingEnvironment',
'kf.InferenceService',
'kf.ServeModel'
);
'mlmd.Deploy'
);

View File

@ -10,15 +10,8 @@ SELECT t.* FROM (
UNION ALL SELECT 'mlmd.Process', NULL, 0, NULL, NULL, NULL, NULL
UNION ALL SELECT 'mlmd.Evaluate', NULL, 0, NULL, NULL, NULL, NULL
UNION ALL SELECT 'mlmd.Deploy', NULL, 0, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.RegisteredModel', NULL, 2, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.ModelVersion', NULL, 2, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.DocArtifact', NULL, 1, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.ModelArtifact', NULL, 1, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.ServingEnvironment', NULL, 2, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.InferenceService', NULL, 2, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.ServeModel', NULL, 0, NULL, NULL, NULL, NULL
) t
WHERE NOT EXISTS (
SELECT 1 FROM "Type"
WHERE name = t.name AND version IS NULL
);
);

View File

@ -1,28 +1 @@
-- Clear seeded TypeProperty data
DELETE FROM "TypeProperty"
WHERE (type_id, name, data_type) IN (
((SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'description', 3),
((SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'owner', 3),
((SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'state', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'author', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'description', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'model_name', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'state', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'version', 3),
((SELECT id FROM "Type" WHERE name = 'kf.DocArtifact'), 'description', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'description', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'model_format_name', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'model_format_version', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'service_account_name', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'storage_key', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'storage_path', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ServingEnvironment'), 'description', 3),
((SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'description', 3),
((SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'desired_state', 3),
((SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'model_version_id', 1),
((SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'registered_model_id', 1),
((SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'runtime', 3),
((SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'serving_environment_id', 1),
((SELECT id FROM "Type" WHERE name = 'kf.ServeModel'), 'description', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ServeModel'), 'model_version_id', 1)
);
-- Migration removed

View File

@ -1,26 +1 @@
-- Seed TypeProperty table
INSERT INTO "TypeProperty" (type_id, name, data_type)
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'description', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'owner', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'state', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'author', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'description', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'model_name', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'state', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'version', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.DocArtifact'), 'description', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'description', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'model_format_name', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'model_format_version', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'service_account_name', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'storage_key', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'storage_path', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ServingEnvironment'), 'description', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'description', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'desired_state', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'model_version_id', 1 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'registered_model_id', 1 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'runtime', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'serving_environment_id', 1 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ServeModel'), 'description', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ServeModel'), 'model_version_id', 1;
-- Migration removed

View File

@ -1,13 +1 @@
DELETE FROM "TypeProperty" WHERE type_id=(
SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'
) AND name IN (
'language',
'library_name',
'license_link',
'license',
'logo',
'maturity',
'provider',
'readme',
'tasks'
);
-- Migration removed

View File

@ -1,10 +1 @@
INSERT INTO "TypeProperty" (type_id, name, data_type)
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'language', 4 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'library_name', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'license_link', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'license', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'logo', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'maturity', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'provider', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'readme', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'tasks', 4;
-- Migration removed

View File

@ -1,8 +1 @@
DELETE FROM "Type" WHERE name IN (
'kf.MetricHistory',
'kf.Experiment',
'kf.ExperimentRun',
'kf.DataSet',
'kf.Metric',
'kf.Parameter'
);
-- Migration removed

View File

@ -1,14 +1 @@
-- Add missing types for experiment tracking and metric history
INSERT INTO "Type" (name, version, type_kind, description, input_type, output_type, external_id)
SELECT t.* FROM (
SELECT 'kf.MetricHistory' as name, NULL as version, 1 as type_kind, NULL as description, NULL as input_type, NULL as output_type, NULL as external_id
UNION ALL SELECT 'kf.Experiment', NULL, 2, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.ExperimentRun', NULL, 2, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.DataSet', NULL, 1, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.Metric', NULL, 1, NULL, NULL, NULL, NULL
UNION ALL SELECT 'kf.Parameter', NULL, 1, NULL, NULL, NULL, NULL
) t
WHERE NOT EXISTS (
SELECT 1 FROM "Type"
WHERE name = t.name AND version IS NULL
);
-- Migration removed

View File

@ -1,31 +1 @@
-- Remove TypeProperty entries for experiment-related types
DELETE FROM "TypeProperty"
WHERE (type_id, name, data_type) IN (
((SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'description', 3),
((SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'value', 5),
((SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'timestamp', 3),
((SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'step', 1),
((SELECT id FROM "Type" WHERE name = 'kf.Parameter'), 'description', 3),
((SELECT id FROM "Type" WHERE name = 'kf.Parameter'), 'value', 3),
((SELECT id FROM "Type" WHERE name = 'kf.Parameter'), 'parameter_type', 3),
((SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'description', 3),
((SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'value', 5),
((SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'timestamp', 3),
((SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'step', 1),
((SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'description', 3),
((SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'digest', 3),
((SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'source_type', 3),
((SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'source', 3),
((SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'schema', 3),
((SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'profile', 3),
((SELECT id FROM "Type" WHERE name = 'kf.Experiment'), 'description', 3),
((SELECT id FROM "Type" WHERE name = 'kf.Experiment'), 'owner', 3),
((SELECT id FROM "Type" WHERE name = 'kf.Experiment'), 'state', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'description', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'owner', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'state', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'status', 3),
((SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'start_time_since_epoch', 1),
((SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'end_time_since_epoch', 1),
((SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'experiment_id', 1)
);
-- Migration removed

View File

@ -1,29 +1 @@
-- Add TypeProperty entries for experiment-related types
INSERT INTO "TypeProperty" (type_id, name, data_type)
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'description', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'value', 5 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'timestamp', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'step', 1 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Parameter'), 'description', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Parameter'), 'value', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Parameter'), 'parameter_type', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'description', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'value', 5 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'timestamp', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'step', 1 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'description', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'digest', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'source_type', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'source', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'schema', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'profile', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Experiment'), 'description', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Experiment'), 'owner', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Experiment'), 'state', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'description', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'owner', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'state', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'status', 3 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'start_time_since_epoch', 1 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'end_time_since_epoch', 1 UNION ALL
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'experiment_id', 1;
-- Migration removed

View File

@ -0,0 +1,196 @@
package embedmd
import (
"fmt"
"reflect"
"github.com/golang/glog"
"github.com/kubeflow/model-registry/internal/datastore"
"github.com/kubeflow/model-registry/internal/db/service"
"gorm.io/gorm"
)
var _ datastore.RepoSet = (*repoSetImpl)(nil)
type repoSetImpl struct {
db *gorm.DB
spec *datastore.Spec
nameIDMap map[string]int64
repos map[reflect.Type]any
}
func newRepoSet(db *gorm.DB, spec *datastore.Spec) (datastore.RepoSet, error) {
typeRepository := service.NewTypeRepository(db)
glog.Infof("Getting types...")
types, err := typeRepository.GetAll()
if err != nil {
return nil, err
}
nameIDMap := make(map[string]int64, len(types))
for _, t := range types {
nameIDMap[*t.GetAttributes().Name] = int64(*t.GetID())
}
glog.Infof("Types retrieved")
// Add debug logging to see what types are actually available
glog.V(2).Infof("DEBUG: Available types:")
for typeName, typeID := range nameIDMap {
glog.V(2).Infof(" %s = %d", typeName, typeID)
}
// Validate that all required types are registered
requiredTypes := spec.AllNames()
for _, requiredType := range requiredTypes {
if _, exists := nameIDMap[requiredType]; !exists {
return nil, fmt.Errorf("required type '%s' not found in database. Please ensure all migrations have been applied", requiredType)
}
}
glog.Infof("All required types validated successfully")
rs := &repoSetImpl{
db: db,
spec: spec,
nameIDMap: nameIDMap,
repos: make(map[reflect.Type]any, len(requiredTypes)+1),
}
artifactTypes := makeTypeMap[datastore.ArtifactTypeMap](spec.ArtifactTypes, nameIDMap)
contextTypes := makeTypeMap[datastore.ContextTypeMap](spec.ContextTypes, nameIDMap)
executionTypes := makeTypeMap[datastore.ExecutionTypeMap](spec.ExecutionTypes, nameIDMap)
args := map[reflect.Type]any{
reflect.TypeOf(db): db,
reflect.TypeOf(artifactTypes): artifactTypes,
reflect.TypeOf(contextTypes): contextTypes,
reflect.TypeOf(executionTypes): executionTypes,
}
for i, fn := range spec.Others {
repo, err := rs.call(fn, args)
if err != nil {
return nil, fmt.Errorf("embedmd: other %d: %w", i, err)
}
rs.put(repo)
}
for name, specType := range spec.ArtifactTypes {
args[reflect.TypeOf(nameIDMap[name])] = nameIDMap[name]
repo, err := rs.call(specType.InitFn, args)
if err != nil {
return nil, fmt.Errorf("embedmd: %s: %w", name, err)
}
rs.put(repo)
}
for name, specType := range spec.ContextTypes {
args[reflect.TypeOf(nameIDMap[name])] = nameIDMap[name]
repo, err := rs.call(specType.InitFn, args)
if err != nil {
return nil, fmt.Errorf("embedmd: %s: %w", name, err)
}
rs.put(repo)
}
for name, specType := range spec.ExecutionTypes {
args[reflect.TypeOf(nameIDMap[name])] = nameIDMap[name]
repo, err := rs.call(specType.InitFn, args)
if err != nil {
return nil, fmt.Errorf("embedmd: %s: %w", name, err)
}
rs.put(repo)
}
return rs, nil
}
// call invokes the function pointed to by fn. It matches fn's arguments to the
// types in args. fn must return at least one argument, and may optionally
// return an error.
func (rs *repoSetImpl) call(fn any, args map[reflect.Type]any) (any, error) {
t := reflect.TypeOf(fn)
if t.Kind() != reflect.Func {
return nil, fmt.Errorf("initializer is not a function (got type %T)", fn)
}
switch t.NumOut() {
case 0:
return nil, fmt.Errorf("initializer has no return value")
case 1, 2:
// OK
default:
return nil, fmt.Errorf("unknown initializer type, more than 2 return values")
}
fnArgs := make([]reflect.Value, t.NumIn())
for i := range t.NumIn() {
v, ok := args[t.In(i)]
if !ok {
return nil, fmt.Errorf("no initializer argument for type %v", t.In(i))
}
fnArgs[i] = reflect.ValueOf(v)
}
out := reflect.ValueOf(fn).Call(fnArgs)
var err error
if len(out) > 1 {
ierr := out[1].Interface()
if ierr != nil {
var ok bool
err, ok = ierr.(error)
if !ok {
return nil, fmt.Errorf("unknown return value, expected error, got %T", err)
}
}
}
return out[0].Interface(), err
}
// put adds one repository to the set.
func (rs *repoSetImpl) put(repo any) {
rs.repos[reflect.TypeOf(repo)] = repo
}
func (rs *repoSetImpl) Repository(t reflect.Type) (any, error) {
// First try an exact match for the requested type.
repo, ok := rs.repos[t]
if ok {
return repo, nil
}
// If the attempt above failed and the requested type is an interface,
// use the first repo that implements it.
if t.Kind() == reflect.Interface {
for repoType, repo := range rs.repos {
if repoType.Implements(t) {
return repo, nil
}
}
}
return nil, fmt.Errorf("unknown repository type: %s", t.Name())
}
func (rs *repoSetImpl) TypeMap() map[string]int64 {
clone := make(map[string]int64, len(rs.nameIDMap))
for k, v := range rs.nameIDMap {
clone[k] = v
}
return clone
}
func makeTypeMap[T ~map[string]int64](specMap map[string]*datastore.SpecType, nameIDMap map[string]int64) T {
returnMap := make(T, len(specMap))
for k := range specMap {
returnMap[k] = nameIDMap[k]
}
return returnMap
}

View File

@ -0,0 +1,403 @@
package embedmd
import (
"context"
"errors"
"reflect"
"testing"
"github.com/kubeflow/model-registry/internal/datastore"
"github.com/kubeflow/model-registry/internal/datastore/embedmd/mysql"
"github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/schema"
"github.com/kubeflow/model-registry/internal/db/service"
"github.com/kubeflow/model-registry/internal/tls"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cont_mysql "github.com/testcontainers/testcontainers-go/modules/mysql"
"gorm.io/gorm"
)
// Mock repository types for testing
type mockArtifactRepo struct {
typeID int64
}
type mockContextRepo struct {
typeID int64
}
type mockExecutionRepo struct {
typeID int64
}
type mockOtherRepo struct {
db *gorm.DB
}
// Mock initializer functions
func newMockArtifactRepo(db *gorm.DB, typeID int64) *mockArtifactRepo {
return &mockArtifactRepo{typeID: typeID}
}
func newMockContextRepo(db *gorm.DB, typeID int64) *mockContextRepo {
return &mockContextRepo{typeID: typeID}
}
func newMockExecutionRepo(db *gorm.DB, typeID int64) *mockExecutionRepo {
return &mockExecutionRepo{typeID: typeID}
}
func newMockOtherRepo(db *gorm.DB) *mockOtherRepo {
return &mockOtherRepo{db: db}
}
func newMockOtherRepoWithError(db *gorm.DB) (*mockOtherRepo, error) {
return nil, errors.New("mock initialization error")
}
// Test helper to create a test database with types using a simplified MySQL setup
func setupTestDB(t *testing.T) (*gorm.DB, func()) {
// Create a simple MySQL container without config file
ctx := context.Background()
mysqlContainer, err := cont_mysql.Run(ctx, "mysql:8.3",
cont_mysql.WithDatabase("test"),
cont_mysql.WithUsername("root"),
cont_mysql.WithPassword("root"),
)
require.NoError(t, err)
// Get connection string
dsn, err := mysqlContainer.ConnectionString(ctx)
require.NoError(t, err)
// Connect to database
dbConnector := mysql.NewMySQLDBConnector(dsn, &tls.TLSConfig{})
db, err := dbConnector.Connect()
require.NoError(t, err)
// Create the required tables
err = db.AutoMigrate(&schema.Type{})
require.NoError(t, err)
// Insert test types
testTypes := []schema.Type{
{ID: 1, Name: "TestArtifact", TypeKind: 1},
{ID: 2, Name: "TestDoc", TypeKind: 1},
{ID: 3, Name: "TestContext", TypeKind: 2},
{ID: 4, Name: "TestModel", TypeKind: 2},
{ID: 5, Name: "TestExecution", TypeKind: 3},
}
for _, typ := range testTypes {
require.NoError(t, db.Create(&typ).Error)
}
cleanup := func() {
sqlDB, err := db.DB()
if err == nil {
//nolint:errcheck
sqlDB.Close()
}
//nolint:errcheck
mysqlContainer.Terminate(ctx)
}
return db, cleanup
}
func TestNewRepoSet_Success(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
spec := datastore.NewSpec().
AddArtifact("TestArtifact", datastore.NewSpecType(newMockArtifactRepo)).
AddArtifact("TestDoc", datastore.NewSpecType(newMockArtifactRepo)).
AddContext("TestContext", datastore.NewSpecType(newMockContextRepo)).
AddContext("TestModel", datastore.NewSpecType(newMockContextRepo)).
AddExecution("TestExecution", datastore.NewSpecType(newMockExecutionRepo)).
AddOther(newMockOtherRepo)
repoSet, err := newRepoSet(db, spec)
require.NoError(t, err)
assert.NotNil(t, repoSet)
// Verify we can get repositories by type
mockArtifact, err := repoSet.Repository(reflect.TypeOf(&mockArtifactRepo{}))
require.NoError(t, err)
assert.NotNil(t, mockArtifact)
mockContext, err := repoSet.Repository(reflect.TypeOf(&mockContextRepo{}))
require.NoError(t, err)
assert.NotNil(t, mockContext)
mockExecution, err := repoSet.Repository(reflect.TypeOf(&mockExecutionRepo{}))
require.NoError(t, err)
assert.NotNil(t, mockExecution)
mockOther, err := repoSet.Repository(reflect.TypeOf(&mockOtherRepo{}))
require.NoError(t, err)
assert.NotNil(t, mockOther)
// Verify TypeMap returns correct mappings
typeMap := repoSet.TypeMap()
assert.Equal(t, int64(1), typeMap["TestArtifact"])
assert.Equal(t, int64(2), typeMap["TestDoc"])
assert.Equal(t, int64(3), typeMap["TestContext"])
assert.Equal(t, int64(4), typeMap["TestModel"])
assert.Equal(t, int64(5), typeMap["TestExecution"])
}
func TestNewRepoSet_MissingType(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
spec := datastore.NewSpec().
AddArtifact("NonExistentType", datastore.NewSpecType(newMockArtifactRepo))
repoSet, err := newRepoSet(db, spec)
assert.Error(t, err)
assert.Nil(t, repoSet)
assert.Contains(t, err.Error(), "required type 'NonExistentType' not found in database")
}
func TestNewRepoSet_InitializerError(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
spec := datastore.NewSpec().AddOther(newMockOtherRepoWithError)
repoSet, err := newRepoSet(db, spec)
assert.Error(t, err)
assert.Nil(t, repoSet)
assert.Contains(t, err.Error(), "mock initialization error")
}
// Define a simple interface and implementation for interface testing
type TestInterface interface {
TestMethod() string
}
type testImpl struct{}
func (ti *testImpl) TestMethod() string {
return "test"
}
func TestRepoSetImpl_Repository_InterfaceMatch(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
// Create initializer that returns the implementation
newTestImpl := func(db *gorm.DB) *testImpl {
return &testImpl{}
}
spec := datastore.NewSpec().AddOther(newTestImpl)
repoSet, err := newRepoSet(db, spec)
require.NoError(t, err)
// Should be able to get repository by interface type
repo, err := repoSet.Repository(reflect.TypeOf((*TestInterface)(nil)).Elem())
require.NoError(t, err)
assert.NotNil(t, repo)
// Verify it's the correct implementation
impl, ok := repo.(*testImpl)
assert.True(t, ok)
assert.Equal(t, "test", impl.TestMethod())
}
func TestRepoSetImpl_Repository_UnknownType(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
spec := datastore.NewSpec().
AddArtifact("TestArtifact", datastore.NewSpecType(newMockArtifactRepo))
repoSet, err := newRepoSet(db, spec)
require.NoError(t, err)
// Try to get a repository type that doesn't exist
type unknownType struct{}
repo, err := repoSet.Repository(reflect.TypeOf(&unknownType{}))
assert.Error(t, err)
assert.Nil(t, repo)
assert.Contains(t, err.Error(), "unknown repository type")
}
func TestRepoSetImpl_Call_InvalidInitializer(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
rs := &repoSetImpl{
db: db,
}
args := map[reflect.Type]any{
reflect.TypeOf(db): db,
}
// Test with non-function
_, err := rs.call("not a function", args)
assert.Error(t, err)
assert.Contains(t, err.Error(), "initializer is not a function")
// Test with function that has no return values
noReturnFunc := func() {}
_, err = rs.call(noReturnFunc, args)
assert.Error(t, err)
assert.Contains(t, err.Error(), "initializer has no return value")
// Test with function that has too many return values
tooManyReturnsFunc := func() (int, string, error) {
return 0, "", nil
}
_, err = rs.call(tooManyReturnsFunc, args)
assert.Error(t, err)
assert.Contains(t, err.Error(), "more than 2 return values")
// Test with missing argument type
needsIntFunc := func(i int) string {
return "test"
}
_, err = rs.call(needsIntFunc, args)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no initializer argument for type")
}
func TestRepoSetImpl_Call_ValidInitializers(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
rs := &repoSetImpl{
db: db,
}
args := map[reflect.Type]any{
reflect.TypeOf(db): db,
reflect.TypeOf(int64(0)): int64(42),
}
// Test function with one return value
oneReturnFunc := func(db *gorm.DB) *mockOtherRepo {
return &mockOtherRepo{db: db}
}
result, err := rs.call(oneReturnFunc, args)
require.NoError(t, err)
assert.NotNil(t, result)
mockRepo, ok := result.(*mockOtherRepo)
assert.True(t, ok)
assert.Equal(t, db, mockRepo.db)
// Test function with two return values (success case)
twoReturnSuccessFunc := func(db *gorm.DB, id int64) (*mockArtifactRepo, error) {
return &mockArtifactRepo{typeID: id}, nil
}
result, err = rs.call(twoReturnSuccessFunc, args)
require.NoError(t, err)
assert.NotNil(t, result)
mockArtifact, ok := result.(*mockArtifactRepo)
assert.True(t, ok)
assert.Equal(t, int64(42), mockArtifact.typeID)
// Test function with two return values (error case)
twoReturnErrorFunc := func(db *gorm.DB) (*mockOtherRepo, error) {
return nil, errors.New("initialization failed")
}
result, err = rs.call(twoReturnErrorFunc, args)
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "initialization failed")
}
func TestMakeTypeMap(t *testing.T) {
specMap := map[string]*datastore.SpecType{
"type1": datastore.NewSpecType("func1"),
"type2": datastore.NewSpecType("func2"),
"type3": datastore.NewSpecType("func3"),
}
nameIDMap := map[string]int64{
"type1": 10,
"type2": 20,
"type3": 30,
}
// Test ArtifactTypeMap
artifactMap := makeTypeMap[datastore.ArtifactTypeMap](specMap, nameIDMap)
assert.Equal(t, datastore.ArtifactTypeMap{"type1": 10, "type2": 20, "type3": 30}, artifactMap)
// Test ContextTypeMap
contextMap := makeTypeMap[datastore.ContextTypeMap](specMap, nameIDMap)
assert.Equal(t, datastore.ContextTypeMap{"type1": 10, "type2": 20, "type3": 30}, contextMap)
// Test ExecutionTypeMap
executionMap := makeTypeMap[datastore.ExecutionTypeMap](specMap, nameIDMap)
assert.Equal(t, datastore.ExecutionTypeMap{"type1": 10, "type2": 20, "type3": 30}, executionMap)
}
func TestRepoSetImpl_TypeMapCloning(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
spec := datastore.NewSpec().
AddArtifact("TestArtifact", datastore.NewSpecType(newMockArtifactRepo))
repoSet, err := newRepoSet(db, spec)
require.NoError(t, err)
rs := repoSet.(*repoSetImpl)
// Get the type map
typeMap1 := rs.TypeMap()
typeMap2 := rs.TypeMap()
// Verify they have the same values
assert.Equal(t, typeMap1, typeMap2)
// Verify they are different objects (cloned)
typeMap1["TestModification"] = 999
assert.NotEqual(t, typeMap1, typeMap2)
assert.NotContains(t, typeMap2, "TestModification")
}
// Integration test with real service repositories
func TestRepoSetImpl_WithRealRepositories(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
// Create artifact types that match what the real repositories expect
realArtifactTypes := []schema.Type{
{ID: 10, Name: "model-artifact", TypeKind: 1},
{ID: 11, Name: "doc-artifact", TypeKind: 1},
}
for _, at := range realArtifactTypes {
require.NoError(t, db.Create(&at).Error)
}
spec := datastore.NewSpec().
AddArtifact("model-artifact", datastore.NewSpecType(service.NewModelArtifactRepository)).
AddArtifact("doc-artifact", datastore.NewSpecType(service.NewDocArtifactRepository)).
AddOther(service.NewArtifactRepository)
repoSet, err := newRepoSet(db, spec)
require.NoError(t, err)
assert.NotNil(t, repoSet)
// Verify we can get the real repositories
modelRepo, err := repoSet.Repository(reflect.TypeOf((*models.ModelArtifactRepository)(nil)).Elem())
require.NoError(t, err)
assert.NotNil(t, modelRepo)
docRepo, err := repoSet.Repository(reflect.TypeOf((*models.DocArtifactRepository)(nil)).Elem())
require.NoError(t, err)
assert.NotNil(t, docRepo)
artifactRepo, err := repoSet.Repository(reflect.TypeOf((*models.ArtifactRepository)(nil)).Elem())
require.NoError(t, err)
assert.NotNil(t, artifactRepo)
}

View File

@ -1,41 +1,69 @@
package embedmd
import (
"errors"
"fmt"
"github.com/golang/glog"
"github.com/kubeflow/model-registry/internal/core"
"github.com/kubeflow/model-registry/internal/apiutils"
"github.com/kubeflow/model-registry/internal/datastore"
"github.com/kubeflow/model-registry/internal/db"
"github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/service"
"github.com/kubeflow/model-registry/internal/db/types"
"github.com/kubeflow/model-registry/internal/defaults"
"github.com/kubeflow/model-registry/internal/tls"
"github.com/kubeflow/model-registry/pkg/api"
"gorm.io/gorm"
)
const connectorType = "embedmd"
func init() {
datastore.Register(connectorType, func(cfg any) (datastore.Connector, error) {
emdbCfg, ok := cfg.(*EmbedMDConfig)
if !ok {
return nil, fmt.Errorf("invalid EmbedMD config type (%T)", cfg)
}
if err := emdbCfg.Validate(); err != nil {
return nil, fmt.Errorf("invalid EmbedMD config: %w", err)
}
return NewEmbedMDService(cfg.(*EmbedMDConfig))
})
}
type EmbedMDConfig struct {
DatabaseType string
DatabaseDSN string
TLSConfig *tls.TLSConfig
// DB is an already connected database instance that, if provided, will
// be used instead of making a new connection.
DB *gorm.DB
}
func (c *EmbedMDConfig) Validate() error {
if c.DatabaseType != types.DatabaseTypeMySQL && c.DatabaseType != types.DatabaseTypePostgres {
return fmt.Errorf("unsupported database type: %s. Supported types: %s, %s", c.DatabaseType, types.DatabaseTypeMySQL, types.DatabaseTypePostgres)
if c.DB == nil {
if c.DatabaseType != types.DatabaseTypeMySQL && c.DatabaseType != types.DatabaseTypePostgres {
return fmt.Errorf("unsupported database type: %s. Supported types: %s, %s", c.DatabaseType, types.DatabaseTypeMySQL, types.DatabaseTypePostgres)
}
}
return nil
}
type EmbedMDService struct {
*EmbedMDConfig
dbConnector db.Connector
}
func NewEmbedMDService(cfg *EmbedMDConfig) (*EmbedMDService, error) {
err := db.Init(cfg.DatabaseType, cfg.DatabaseDSN, cfg.TLSConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize database connector: %w", err)
if cfg.DB != nil {
db.SetDB(cfg.DB)
} else {
err := db.Init(cfg.DatabaseType, cfg.DatabaseDSN, cfg.TLSConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize database connector: %w", err)
}
}
dbConnector, ok := db.GetConnector()
@ -44,12 +72,11 @@ func NewEmbedMDService(cfg *EmbedMDConfig) (*EmbedMDService, error) {
}
return &EmbedMDService{
EmbedMDConfig: cfg,
dbConnector: dbConnector,
dbConnector: dbConnector,
}, nil
}
func (s *EmbedMDService) Connect() (api.ModelRegistryApi, error) {
func (s *EmbedMDService) Connect(spec *datastore.Spec) (datastore.RepoSet, error) {
glog.Infof("Connecting to EmbedMD service...")
connectedDB, err := s.dbConnector.Connect()
@ -59,7 +86,7 @@ func (s *EmbedMDService) Connect() (api.ModelRegistryApi, error) {
glog.Infof("Connected to EmbedMD service")
migrator, err := db.NewDBMigrator(s.DatabaseType, connectedDB)
migrator, err := db.NewDBMigrator(connectedDB)
if err != nil {
return nil, err
}
@ -73,93 +100,88 @@ func (s *EmbedMDService) Connect() (api.ModelRegistryApi, error) {
glog.Infof("Migrations completed")
typeRepository := service.NewTypeRepository(connectedDB)
glog.Infof("Getting types...")
types, err := typeRepository.GetAll()
glog.Infof("Syncing types...")
err = s.syncTypes(connectedDB, spec)
if err != nil {
return nil, err
}
glog.Infof("Syncing types completed")
typesMap := make(map[string]int64)
return newRepoSet(connectedDB, spec)
}
for _, t := range types {
typesMap[*t.GetAttributes().Name] = int64(*t.GetID())
func (s EmbedMDService) Type() string {
return connectorType
}
const (
executionTypeKind int32 = iota
artifactTypeKind
contextTypeKind
)
func (s *EmbedMDService) syncTypes(conn *gorm.DB, spec *datastore.Spec) error {
idMap := make(map[string]int32, len(spec.ExecutionTypes)+len(spec.ArtifactTypes)+len(spec.ContextTypes))
var errs []error
typeRepository := service.NewTypeRepository(conn)
errs = append(errs, s.createTypes(typeRepository, spec.ExecutionTypes, executionTypeKind, idMap))
errs = append(errs, s.createTypes(typeRepository, spec.ArtifactTypes, artifactTypeKind, idMap))
errs = append(errs, s.createTypes(typeRepository, spec.ContextTypes, contextTypeKind, idMap))
typePropertyRepository := service.NewTypePropertyRepository(conn)
errs = append(errs, s.createTypeProperties(typePropertyRepository, spec.ExecutionTypes, idMap))
errs = append(errs, s.createTypeProperties(typePropertyRepository, spec.ArtifactTypes, idMap))
errs = append(errs, s.createTypeProperties(typePropertyRepository, spec.ContextTypes, idMap))
return errors.Join(errs...)
}
func (s *EmbedMDService) createTypes(repo models.TypeRepository, types map[string]*datastore.SpecType, kind int32, idMap map[string]int32) error {
var errs []error
for name := range types {
t, err := repo.Save(&models.TypeImpl{
Attributes: &models.TypeAttributes{
Name: &name,
TypeKind: &kind,
},
})
if err != nil {
errs = append(errs, fmt.Errorf("%s: unable to create type: %w", name, err))
continue
}
id := t.GetID()
if id == nil {
errs = append(errs, fmt.Errorf("%s: unable to determine type ID", name))
continue
}
idMap[name] = *id
}
glog.Infof("Types retrieved")
return errors.Join(errs...)
}
// Add debug logging to see what types are actually available
glog.V(2).Infof("DEBUG: Available types in typesMap:")
for typeName, typeID := range typesMap {
glog.V(2).Infof(" %s = %d", typeName, typeID)
}
func (s *EmbedMDService) createTypeProperties(repo models.TypePropertyRepository, types map[string]*datastore.SpecType, idMap map[string]int32) error {
var errs []error
for typeName, typeSpec := range types {
typeID := idMap[typeName]
if typeID == 0 {
errs = append(errs, fmt.Errorf("%s: unknown type", typeName))
continue
}
// Validate that all required types are registered
requiredTypes := []string{
defaults.ModelArtifactTypeName,
defaults.DocArtifactTypeName,
defaults.DataSetTypeName,
defaults.MetricTypeName,
defaults.ParameterTypeName,
defaults.MetricHistoryTypeName,
defaults.RegisteredModelTypeName,
defaults.ModelVersionTypeName,
defaults.ServingEnvironmentTypeName,
defaults.InferenceServiceTypeName,
defaults.ServeModelTypeName,
defaults.ExperimentTypeName,
defaults.ExperimentRunTypeName,
}
for _, requiredType := range requiredTypes {
if _, exists := typesMap[requiredType]; !exists {
return nil, fmt.Errorf("required type '%s' not found in database. Please ensure all migrations have been applied", requiredType)
for name, dataType := range typeSpec.Properties {
_, err := repo.Save(&models.TypePropertyImpl{
TypeID: typeID,
Name: name,
DataType: apiutils.Of(int32(dataType)),
})
if err != nil {
errs = append(errs, fmt.Errorf("%s-%s: %w", typeName, name, err))
}
}
}
glog.Infof("All required types validated successfully")
artifactRepository := service.NewArtifactRepository(connectedDB, typesMap[defaults.ModelArtifactTypeName], typesMap[defaults.DocArtifactTypeName], typesMap[defaults.DataSetTypeName], typesMap[defaults.MetricTypeName], typesMap[defaults.ParameterTypeName], typesMap[defaults.MetricHistoryTypeName])
modelArtifactRepository := service.NewModelArtifactRepository(connectedDB, typesMap[defaults.ModelArtifactTypeName])
docArtifactRepository := service.NewDocArtifactRepository(connectedDB, typesMap[defaults.DocArtifactTypeName])
registeredModelRepository := service.NewRegisteredModelRepository(connectedDB, typesMap[defaults.RegisteredModelTypeName])
modelVersionRepository := service.NewModelVersionRepository(connectedDB, typesMap[defaults.ModelVersionTypeName])
servingEnvironmentRepository := service.NewServingEnvironmentRepository(connectedDB, typesMap[defaults.ServingEnvironmentTypeName])
inferenceServiceRepository := service.NewInferenceServiceRepository(connectedDB, typesMap[defaults.InferenceServiceTypeName])
serveModelRepository := service.NewServeModelRepository(connectedDB, typesMap[defaults.ServeModelTypeName])
experimentRepository := service.NewExperimentRepository(connectedDB, typesMap[defaults.ExperimentTypeName])
experimentRunRepository := service.NewExperimentRunRepository(connectedDB, typesMap[defaults.ExperimentRunTypeName])
dataSetRepository := service.NewDataSetRepository(connectedDB, typesMap[defaults.DataSetTypeName])
metricRepository := service.NewMetricRepository(connectedDB, typesMap[defaults.MetricTypeName])
parameterRepository := service.NewParameterRepository(connectedDB, typesMap[defaults.ParameterTypeName])
metricHistoryRepository := service.NewMetricHistoryRepository(connectedDB, typesMap[defaults.MetricHistoryTypeName])
modelRegistryService := core.NewModelRegistryService(
artifactRepository,
modelArtifactRepository,
docArtifactRepository,
registeredModelRepository,
modelVersionRepository,
servingEnvironmentRepository,
inferenceServiceRepository,
serveModelRepository,
experimentRepository,
experimentRunRepository,
dataSetRepository,
metricRepository,
parameterRepository,
metricHistoryRepository,
typesMap,
)
glog.Infof("EmbedMD service connected")
return modelRegistryService, nil
}
func (s *EmbedMDService) Teardown() error {
return nil
return errors.Join(errs...)
}

164
internal/datastore/repos.go Normal file
View File

@ -0,0 +1,164 @@
package datastore
import "reflect"
// Spec is the specification for the datastore.
// Each entry
type Spec struct {
// Maps artifact type names to an initializer.
ArtifactTypes map[string]*SpecType
// Maps context type names to an initializer.
ContextTypes map[string]*SpecType
// Maps execution type names to an initializer.
ExecutionTypes map[string]*SpecType
// Any repo initialization functions that don't map to a type can be
// added here.
Others []any
}
// NewSpec returns an empty Spec instance.
func NewSpec() *Spec {
return &Spec{
ArtifactTypes: map[string]*SpecType{},
ContextTypes: map[string]*SpecType{},
ExecutionTypes: map[string]*SpecType{},
Others: []any{},
}
}
// AddArtifact adds an artifact type to the spec.
func (s *Spec) AddArtifact(name string, t *SpecType) *Spec {
s.ArtifactTypes[name] = t
return s
}
// AddContext adds a context type to the spec.
func (s *Spec) AddContext(name string, t *SpecType) *Spec {
s.ContextTypes[name] = t
return s
}
// AddExecution adds an execution type to the spec.
func (s *Spec) AddExecution(name string, t *SpecType) *Spec {
s.ExecutionTypes[name] = t
return s
}
// AddOther adds a repo initializer to the spec.
func (s *Spec) AddOther(initFn any) *Spec {
s.Others = append(s.Others, initFn)
return s
}
// AllNames returns all the type names in the spec.
func (s *Spec) AllNames() []string {
names := make([]string, 0, len(s.ArtifactTypes)+len(s.ContextTypes)+len(s.ExecutionTypes))
for n := range s.ArtifactTypes {
names = append(names, n)
}
for n := range s.ContextTypes {
names = append(names, n)
}
for n := range s.ExecutionTypes {
names = append(names, n)
}
return names
}
// PropertyType is the data type of a property's value.
type PropertyType int32
const (
PropertyTypeUnknown = iota
PropertyTypeInt
PropertyTypeDouble
PropertyTypeString
PropertyTypeStruct
PropertyTypeProto
PropertyTypeBoolean
)
// SpecType is a single type in the spec.
type SpecType struct {
// InitFn is a pointer to an initialization function that returns the
// repository instance and an optional error.
//
// Data Store implementations must pass arguments to the initialization
// functions that are required by the repository (database handles,
// HTTP clients, etc.) and should also provide values for the following
// types:
//
// - int64: type id
// - ArtifactTypeMap: Map of all artifact names to type IDs
// - ContextTypeMap: Map of all context names to type IDs
// - ExecutionTypeMap: Map of all execution names to type IDs
InitFn any
// Defined (non-custom) properties of the type.
Properties map[string]PropertyType
}
// NewSpecType creates a SpecType instance.
func NewSpecType(initFn any) *SpecType {
return &SpecType{
InitFn: initFn,
Properties: map[string]PropertyType{},
}
}
// AddInt adds an int property to the type spec.
func (st *SpecType) AddInt(name string) *SpecType {
st.Properties[name] = PropertyTypeInt
return st
}
// AddDouble adds a double property to the type spec.
func (st *SpecType) AddDouble(name string) *SpecType {
st.Properties[name] = PropertyTypeDouble
return st
}
// AddString adds a string property to the type spec.
func (st *SpecType) AddString(name string) *SpecType {
st.Properties[name] = PropertyTypeString
return st
}
// AddStruct adds a struct property to the type spec.
func (st *SpecType) AddStruct(name string) *SpecType {
st.Properties[name] = PropertyTypeStruct
return st
}
// AddProto adds a proto property to the type spec.
func (st *SpecType) AddProto(name string) *SpecType {
st.Properties[name] = PropertyTypeProto
return st
}
// AddBoolean adds a boolean property to the type spec.
func (st *SpecType) AddBoolean(name string) *SpecType {
st.Properties[name] = PropertyTypeBoolean
return st
}
// RepoSet holds repository implementions.
type RepoSet interface {
// TypeMap returns a map of type names to IDs
TypeMap() map[string]int64
// Repository returns a repository instance of the specified type.
Repository(t reflect.Type) (any, error)
}
// ArtifactTypeMap maps artifact type names to IDs
type ArtifactTypeMap map[string]int64
// ContextTypeMap maps context type names to IDs
type ContextTypeMap map[string]int64
// ExecutionTypeMap maps execution type names to IDs
type ExecutionTypeMap map[string]int64

View File

@ -30,9 +30,9 @@ func Init(dbType string, dsn string, tlsConfig *tls.TLSConfig) error {
}
switch dbType {
case "mysql":
case types.DatabaseTypeMySQL:
_connectorInstance = mysql.NewMySQLDBConnector(dsn, tlsConfig)
case "postgres":
case types.DatabaseTypePostgres:
_connectorInstance = postgres.NewPostgresDBConnector(dsn, tlsConfig)
default:
return fmt.Errorf("unsupported database type: %s. Supported types: %s, %s", dbType, types.DatabaseTypeMySQL, types.DatabaseTypePostgres)
@ -41,6 +41,12 @@ func Init(dbType string, dsn string, tlsConfig *tls.TLSConfig) error {
return nil
}
func SetDB(connectedDB *gorm.DB) {
connectorMutex.Lock()
defer connectorMutex.Unlock()
_connectorInstance = ConnectedConnector{ConnectedDB: connectedDB}
}
func GetConnector() (Connector, bool) {
connectorMutex.RLock()
defer connectorMutex.RUnlock()
@ -54,3 +60,17 @@ func ClearConnector() {
_connectorInstance = nil
}
// ConnectedConnector satifies the connector interface for an already connected
// gorm.DB instance.
type ConnectedConnector struct {
ConnectedDB *gorm.DB
}
func (c ConnectedConnector) Connect() (*gorm.DB, error) {
return c.ConnectedDB, nil
}
func (c ConnectedConnector) DB() *gorm.DB {
return c.ConnectedDB
}

View File

@ -15,13 +15,13 @@ type DBMigrator interface {
Down(steps *int) error
}
func NewDBMigrator(dbType string, db *gorm.DB) (DBMigrator, error) {
switch dbType {
func NewDBMigrator(db *gorm.DB) (DBMigrator, error) {
switch db.Name() {
case types.DatabaseTypeMySQL:
return mysql.NewMySQLMigrator(db)
case types.DatabaseTypePostgres:
return postgres.NewPostgresMigrator(db)
}
return nil, fmt.Errorf("unsupported database type: %s. Supported types: %s, %s", dbType, types.DatabaseTypeMySQL, types.DatabaseTypePostgres)
return nil, fmt.Errorf("unsupported database type: %s. Supported types: %s, %s", db.Name(), types.DatabaseTypeMySQL, types.DatabaseTypePostgres)
}

View File

@ -29,5 +29,9 @@ func (t *TypeImpl) GetAttributes() *TypeAttributes {
}
type TypeRepository interface {
// GetAll returns every registered type.
GetAll() ([]Type, error)
// Save updates a type, if the definition differs from what's stored.
Save(t Type) (Type, error)
}

View File

@ -0,0 +1,32 @@
package models
type TypeProperty interface {
GetTypeID() int32
GetName() string
GetDataType() *int32
}
var _ TypeProperty = (*TypePropertyImpl)(nil)
type TypePropertyImpl struct {
TypeID int32
Name string
DataType *int32
}
func (tp *TypePropertyImpl) GetTypeID() int32 {
return tp.TypeID
}
func (tp *TypePropertyImpl) GetName() string {
return tp.Name
}
func (tp *TypePropertyImpl) GetDataType() *int32 {
return tp.DataType
}
type TypePropertyRepository interface {
// Save stores a type property if it doesn't exist.
Save(tp TypeProperty) (TypeProperty, error)
}

View File

@ -4,10 +4,12 @@ import (
"errors"
"fmt"
"github.com/kubeflow/model-registry/internal/datastore"
"github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/schema"
"github.com/kubeflow/model-registry/internal/db/scopes"
"github.com/kubeflow/model-registry/internal/db/utils"
"github.com/kubeflow/model-registry/internal/defaults"
"github.com/kubeflow/model-registry/pkg/api"
"github.com/kubeflow/model-registry/pkg/openapi"
"gorm.io/gorm"
@ -16,24 +18,21 @@ import (
var ErrArtifactNotFound = errors.New("artifact by id not found")
type ArtifactRepositoryImpl struct {
db *gorm.DB
modelArtifactTypeID int64
docArtifactTypeID int64
dataSetTypeID int64
metricTypeID int64
parameterTypeID int64
metricHistoryTypeID int64
db *gorm.DB
idToName map[int64]string
nameToID datastore.ArtifactTypeMap
}
func NewArtifactRepository(db *gorm.DB, modelArtifactTypeID int64, docArtifactTypeID int64, dataSetTypeID int64, metricTypeID int64, parameterTypeID int64, metricHistoryTypeID int64) models.ArtifactRepository {
func NewArtifactRepository(db *gorm.DB, artifactTypes datastore.ArtifactTypeMap) models.ArtifactRepository {
idToName := make(map[int64]string, len(artifactTypes))
for name, id := range artifactTypes {
idToName[id] = name
}
return &ArtifactRepositoryImpl{
db: db,
modelArtifactTypeID: modelArtifactTypeID,
docArtifactTypeID: docArtifactTypeID,
dataSetTypeID: dataSetTypeID,
metricTypeID: metricTypeID,
parameterTypeID: parameterTypeID,
metricHistoryTypeID: metricHistoryTypeID,
db: db,
nameToID: artifactTypes,
idToName: idToName,
}
}
@ -73,7 +72,9 @@ func (r *ArtifactRepositoryImpl) List(listOptions models.ArtifactListOptions) (*
query := r.db.Model(&schema.Artifact{})
// Exclude metric history records - they should only be returned via metric history endpoints
query = query.Where("type_id != ?", r.metricHistoryTypeID)
if metricHistoryTypeID, ok := r.nameToID[defaults.MetricHistoryTypeName]; ok {
query = query.Where("type_id != ?", metricHistoryTypeID)
}
if listOptions.Name != nil {
// Name is not prefixed with the parent resource id to allow for filtering by name only
@ -170,17 +171,17 @@ func (r *ArtifactRepositoryImpl) List(listOptions models.ArtifactListOptions) (*
// getTypeIDFromArtifactType maps artifact type strings to their corresponding type IDs
func (r *ArtifactRepositoryImpl) getTypeIDFromArtifactType(artifactType string) (int64, error) {
switch artifactType {
case string(openapi.ARTIFACTTYPEQUERYPARAM_MODEL_ARTIFACT):
return r.modelArtifactTypeID, nil
case string(openapi.ARTIFACTTYPEQUERYPARAM_DOC_ARTIFACT):
return r.docArtifactTypeID, nil
case string(openapi.ARTIFACTTYPEQUERYPARAM_DATASET_ARTIFACT):
return r.dataSetTypeID, nil
case string(openapi.ARTIFACTTYPEQUERYPARAM_METRIC):
return r.metricTypeID, nil
case string(openapi.ARTIFACTTYPEQUERYPARAM_PARAMETER):
return r.parameterTypeID, nil
switch openapi.ArtifactTypeQueryParam(artifactType) {
case openapi.ARTIFACTTYPEQUERYPARAM_MODEL_ARTIFACT:
return r.nameToID[defaults.ModelArtifactTypeName], nil
case openapi.ARTIFACTTYPEQUERYPARAM_DOC_ARTIFACT:
return r.nameToID[defaults.DocArtifactTypeName], nil
case openapi.ARTIFACTTYPEQUERYPARAM_DATASET_ARTIFACT:
return r.nameToID[defaults.DataSetTypeName], nil
case openapi.ARTIFACTTYPEQUERYPARAM_METRIC:
return r.nameToID[defaults.MetricTypeName], nil
case openapi.ARTIFACTTYPEQUERYPARAM_PARAMETER:
return r.nameToID[defaults.ParameterTypeName], nil
default:
return 0, fmt.Errorf("unsupported artifact type: %s: %w", artifactType, api.ErrBadRequest)
}
@ -189,25 +190,26 @@ func (r *ArtifactRepositoryImpl) getTypeIDFromArtifactType(artifactType string)
func (r *ArtifactRepositoryImpl) mapDataLayerToArtifact(artifact schema.Artifact, properties []schema.ArtifactProperty) (models.Artifact, error) {
artToReturn := models.Artifact{}
switch artifact.TypeID {
case int32(r.modelArtifactTypeID):
typeName := r.idToName[int64(artifact.TypeID)]
switch typeName {
case defaults.ModelArtifactTypeName:
modelArtifact := mapDataLayerToModelArtifact(artifact, properties)
artToReturn.ModelArtifact = &modelArtifact
case int32(r.docArtifactTypeID):
case defaults.DocArtifactTypeName:
docArtifact := mapDataLayerToDocArtifact(artifact, properties)
artToReturn.DocArtifact = &docArtifact
case int32(r.dataSetTypeID):
case defaults.DataSetTypeName:
dataSet := mapDataLayerToDataSet(artifact, properties)
artToReturn.DataSet = &dataSet
case int32(r.metricTypeID):
case defaults.MetricTypeName:
metric := mapDataLayerToMetric(artifact, properties)
artToReturn.Metric = &metric
case int32(r.parameterTypeID):
case defaults.ParameterTypeName:
parameter := mapDataLayerToParameter(artifact, properties)
artToReturn.Parameter = &parameter
default:
return models.Artifact{}, fmt.Errorf("invalid artifact type: %d (expected: modelArtifact=%d, docArtifact=%d, dataSet=%d, metric=%d, parameter=%d, metricHistory=%d [filtered])",
artifact.TypeID, r.modelArtifactTypeID, r.docArtifactTypeID, r.dataSetTypeID, r.metricTypeID, r.parameterTypeID, r.metricHistoryTypeID)
return models.Artifact{}, fmt.Errorf("invalid artifact type: %s=%d (expected: %v)", typeName, artifact.TypeID, r.idToName)
}
return artToReturn, nil

View File

@ -8,13 +8,14 @@ import (
"github.com/kubeflow/model-registry/internal/apiutils"
"github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/service"
"github.com/kubeflow/model-registry/internal/defaults"
"github.com/kubeflow/model-registry/internal/testutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestArtifactRepository(t *testing.T) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
// Get the actual type IDs from the database
@ -24,7 +25,14 @@ func TestArtifactRepository(t *testing.T) {
metricTypeID := getMetricTypeID(t, sharedDB)
parameterTypeID := getParameterTypeID(t, sharedDB)
metricHistoryTypeID := getMetricHistoryTypeID(t, sharedDB)
repo := service.NewArtifactRepository(sharedDB, modelArtifactTypeID, docArtifactTypeID, dataSetTypeID, metricTypeID, parameterTypeID, metricHistoryTypeID)
repo := service.NewArtifactRepository(sharedDB, map[string]int64{
defaults.ModelArtifactTypeName: modelArtifactTypeID,
defaults.DocArtifactTypeName: docArtifactTypeID,
defaults.DataSetTypeName: dataSetTypeID,
defaults.MetricTypeName: metricTypeID,
defaults.ParameterTypeName: parameterTypeID,
defaults.MetricHistoryTypeName: metricHistoryTypeID,
})
// Also get other type IDs for creating related entities
registeredModelTypeID := getRegisteredModelTypeID(t, sharedDB)

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/kubeflow/model-registry/internal/db/schema"
"github.com/kubeflow/model-registry/internal/db/service"
"github.com/kubeflow/model-registry/internal/defaults"
"github.com/kubeflow/model-registry/internal/testutils"
"github.com/stretchr/testify/require"
@ -16,7 +17,7 @@ func TestMain(m *testing.M) {
}
func setupTestDB(t *testing.T) (*gorm.DB, func()) {
db, dbCleanup := testutils.SetupMySQLWithMigrations(t)
db, dbCleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
// Clean up test data before each test
testutils.CleanupTestData(t, db)

View File

@ -14,7 +14,7 @@ import (
)
func TestDocArtifactRepository(t *testing.T) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
// Get the actual DocArtifact type ID from the database

View File

@ -13,7 +13,7 @@ import (
)
func TestInferenceServiceRepository(t *testing.T) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
// Get the actual InferenceService type ID from the database

View File

@ -14,7 +14,7 @@ import (
)
func TestModelArtifactRepository(t *testing.T) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
// Get the actual ModelArtifact type ID from the database

View File

@ -14,7 +14,7 @@ import (
)
func TestModelVersionRepository(t *testing.T) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
// Get the actual ModelVersion type ID from the database

View File

@ -13,7 +13,7 @@ import (
)
func TestRegisteredModelRepository(t *testing.T) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
// Get the actual RegisteredModel type ID from the database

View File

@ -14,7 +14,7 @@ import (
)
func TestServeModelRepository(t *testing.T) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
// Get the actual ServeModel type ID from the database

View File

@ -13,7 +13,7 @@ import (
)
func TestServingEnvironmentRepository(t *testing.T) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
// Get the actual ServingEnvironment type ID from the database

View File

@ -0,0 +1,97 @@
package service
import (
"github.com/kubeflow/model-registry/internal/datastore"
"github.com/kubeflow/model-registry/internal/defaults"
)
func DatastoreSpec() *datastore.Spec {
return datastore.NewSpec().
AddArtifact(defaults.ModelArtifactTypeName, datastore.NewSpecType(NewModelArtifactRepository).
AddString("description").
AddString("model_format_name").
AddString("model_format_version").
AddString("service_account_name").
AddString("storage_key").
AddString("storage_path"),
).
AddArtifact(defaults.DocArtifactTypeName, datastore.NewSpecType(NewDocArtifactRepository).
AddString("description"),
).
AddArtifact(defaults.DataSetTypeName, datastore.NewSpecType(NewDataSetRepository).
AddString("description").
AddString("digest").
AddString("source_type").
AddString("source").
AddString("schema").
AddString("profile"),
).
AddArtifact(defaults.MetricTypeName, datastore.NewSpecType(NewMetricRepository).
AddString("description").
AddProto("value").
AddString("timestamp").
AddInt("step"),
).
AddArtifact(defaults.ParameterTypeName, datastore.NewSpecType(NewParameterRepository).
AddString("description").
AddString("value").
AddString("parameter_type"),
).
AddArtifact(defaults.MetricHistoryTypeName, datastore.NewSpecType(NewMetricHistoryRepository).
AddString("description").
AddProto("value").
AddString("timestamp").
AddInt("step"),
).
AddContext(defaults.RegisteredModelTypeName, datastore.NewSpecType(NewRegisteredModelRepository).
AddString("description").
AddString("owner").
AddString("state").
AddStruct("language").
AddString("library_name").
AddString("license_link").
AddString("license").
AddString("logo").
AddString("maturity").
AddString("provider").
AddString("readme").
AddStruct("tasks"),
).
AddContext(defaults.ModelVersionTypeName, datastore.NewSpecType(NewModelVersionRepository).
AddString("author").
AddString("description").
AddString("model_name").
AddString("state").
AddString("version"),
).
AddContext(defaults.ServingEnvironmentTypeName, datastore.NewSpecType(NewServingEnvironmentRepository).
AddString("description"),
).
AddContext(defaults.InferenceServiceTypeName, datastore.NewSpecType(NewInferenceServiceRepository).
AddString("description").
AddString("desired_state").
AddInt("model_version_id").
AddInt("registered_model_id").
AddString("runtime").
AddInt("serving_environment_id"),
).
AddContext(defaults.ExperimentTypeName, datastore.NewSpecType(NewExperimentRepository).
AddString("description").
AddString("owner").
AddString("state"),
).
AddContext(defaults.ExperimentRunTypeName, datastore.NewSpecType(NewExperimentRunRepository).
AddString("description").
AddString("owner").
AddString("state").
AddString("status").
AddInt("start_time_since_epoch").
AddInt("end_time_since_epoch").
AddInt("experiment_id"),
).
AddExecution(defaults.ServeModelTypeName, datastore.NewSpecType(NewServeModelRepository).
AddString("description").
AddInt("model_version_id"),
).
AddOther(NewArtifactRepository)
}

View File

@ -1,6 +1,9 @@
package service
import (
"errors"
"fmt"
"github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/schema"
"gorm.io/gorm"
@ -40,3 +43,58 @@ func (r *TypeRepositoryImpl) GetAll() ([]models.Type, error) {
return typesModels, nil
}
func (r *TypeRepositoryImpl) Save(t models.Type) (models.Type, error) {
attr := t.GetAttributes()
if attr == nil {
return t, errors.New("invalid type: missing attributes")
}
if attr.Name == nil {
return t, errors.New("invalid type: missing name")
}
if attr.TypeKind == nil {
return t, errors.New("invalid type: missing kind")
}
var st schema.Type
err := r.db.Where("name = ?", *attr.Name).First(&st).Error
if err == nil {
// Record already exists. We don't support updates, but we can return the full details.
// Catch this case in particular.
if st.TypeKind != *attr.TypeKind {
return t, fmt.Errorf("invalid type: kind is %d, cannot change to kind %d", st.TypeKind, *attr.TypeKind)
}
} else if errors.Is(err, gorm.ErrRecordNotFound) {
// Record doesn't exist, so we'll create it.
st = schema.Type{
Name: *attr.Name,
Version: attr.Version,
TypeKind: *attr.TypeKind,
Description: attr.Description,
InputType: attr.InputType,
OutputType: attr.OutputType,
ExternalID: attr.ExternalID,
}
if err := r.db.Create(&st).Error; err != nil {
return t, err
}
} else {
return t, err
}
return &models.TypeImpl{
ID: &st.ID,
Attributes: &models.TypeAttributes{
Name: &st.Name,
Version: st.Version,
TypeKind: &st.TypeKind,
Description: st.Description,
InputType: st.InputType,
OutputType: st.OutputType,
ExternalID: st.ExternalID,
},
}, nil
}

View File

@ -0,0 +1,53 @@
package service
import (
"errors"
"fmt"
"strconv"
"github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/schema"
"golang.org/x/exp/constraints"
"gorm.io/gorm"
)
type typePropertyRepositoryImpl struct {
db *gorm.DB
}
func NewTypePropertyRepository(db *gorm.DB) models.TypePropertyRepository {
return &typePropertyRepositoryImpl{db: db}
}
func (r *typePropertyRepositoryImpl) Save(tp models.TypeProperty) (models.TypeProperty, error) {
var stp schema.TypeProperty
err := r.db.Where("type_id=? AND name=?", tp.GetTypeID(), tp.GetName()).First(&stp).Error
if err == nil {
oldType := intPointerString(stp.DataType)
newType := intPointerString(tp.GetDataType())
if oldType != newType {
return tp, fmt.Errorf("invalid property type: data type is %s, cannot change to %s", oldType, newType)
}
} else if errors.Is(err, gorm.ErrRecordNotFound) {
stp.TypeID = tp.GetTypeID()
stp.Name = tp.GetName()
stp.DataType = tp.GetDataType()
if err := r.db.Create(&stp).Error; err != nil {
return tp, err
}
}
return &models.TypePropertyImpl{
TypeID: stp.TypeID,
Name: stp.Name,
DataType: stp.DataType,
}, nil
}
func intPointerString[T constraints.Integer](v *T) string {
if v == nil {
return "<nil>"
}
return strconv.Itoa(int(*v))
}

View File

@ -0,0 +1,238 @@
package service_test
import (
"fmt"
"testing"
"github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/service"
"github.com/kubeflow/model-registry/internal/testutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTypePropertyRepository(t *testing.T) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
repo := service.NewTypePropertyRepository(sharedDB)
typeRepo := service.NewTypeRepository(sharedDB)
// Create a test type to use for properties
testTypeName := "test-type-for-properties"
testTypeKind := int32(1)
testType := &models.TypeImpl{
Attributes: &models.TypeAttributes{
Name: &testTypeName,
TypeKind: &testTypeKind,
},
}
savedType, err := typeRepo.Save(testType)
require.NoError(t, err)
require.NotNil(t, savedType)
typeID := *savedType.GetID()
t.Run("TestSave", func(t *testing.T) {
// Test saving a new type property
propertyName := "test-property"
dataType := int32(1)
property := &models.TypePropertyImpl{
TypeID: typeID,
Name: propertyName,
DataType: &dataType,
}
savedProperty, err := repo.Save(property)
require.NoError(t, err)
require.NotNil(t, savedProperty)
// Verify the saved property
assert.Equal(t, typeID, savedProperty.GetTypeID())
assert.Equal(t, propertyName, savedProperty.GetName())
assert.Equal(t, dataType, *savedProperty.GetDataType())
})
t.Run("TestSaveExisting", func(t *testing.T) {
// Test saving a property that already exists with same data type
propertyName := "test-existing-property"
dataType := int32(2)
// First, save the property
firstProperty := &models.TypePropertyImpl{
TypeID: typeID,
Name: propertyName,
DataType: &dataType,
}
savedProperty1, err := repo.Save(firstProperty)
require.NoError(t, err)
require.NotNil(t, savedProperty1)
// Now try to save the same property again
secondProperty := &models.TypePropertyImpl{
TypeID: typeID,
Name: propertyName,
DataType: &dataType,
}
savedProperty2, err := repo.Save(secondProperty)
require.NoError(t, err)
require.NotNil(t, savedProperty2)
// Should return the existing property
assert.Equal(t, savedProperty1.GetTypeID(), savedProperty2.GetTypeID())
assert.Equal(t, savedProperty1.GetName(), savedProperty2.GetName())
assert.Equal(t, *savedProperty1.GetDataType(), *savedProperty2.GetDataType())
})
t.Run("TestSaveInvalidDataTypeChange", func(t *testing.T) {
// Test trying to save a property with different data type than existing
propertyName := "test-datatype-change-property"
originalDataType := int32(1)
newDataType := int32(2)
// First, save the property with original data type
originalProperty := &models.TypePropertyImpl{
TypeID: typeID,
Name: propertyName,
DataType: &originalDataType,
}
_, err := repo.Save(originalProperty)
require.NoError(t, err)
// Now try to save with different data type - should fail
changedProperty := &models.TypePropertyImpl{
TypeID: typeID,
Name: propertyName,
DataType: &newDataType,
}
_, err = repo.Save(changedProperty)
assert.Error(t, err)
assert.Contains(t, err.Error(), "cannot change to")
})
t.Run("TestSaveMultiplePropertiesForSameType", func(t *testing.T) {
// Test saving multiple properties for the same type
properties := []struct {
name string
dataType int32
}{
{"property1", 1},
{"property2", 2},
{"property3", 3},
}
for _, prop := range properties {
property := &models.TypePropertyImpl{
TypeID: typeID,
Name: prop.name,
DataType: &prop.dataType,
}
savedProperty, err := repo.Save(property)
require.NoError(t, err, "Failed to save property %s", prop.name)
require.NotNil(t, savedProperty)
assert.Equal(t, typeID, savedProperty.GetTypeID())
assert.Equal(t, prop.name, savedProperty.GetName())
assert.Equal(t, prop.dataType, *savedProperty.GetDataType())
}
})
t.Run("TestSaveWithDifferentTypes", func(t *testing.T) {
// Create another test type
anotherTypeName := "another-test-type"
anotherTypeKind := int32(2)
anotherType := &models.TypeImpl{
Attributes: &models.TypeAttributes{
Name: &anotherTypeName,
TypeKind: &anotherTypeKind,
},
}
savedAnotherType, err := typeRepo.Save(anotherType)
require.NoError(t, err)
require.NotNil(t, savedAnotherType)
anotherTypeID := *savedAnotherType.GetID()
// Test saving properties with same name for different types
propertyName := "shared-property-name"
dataType := int32(1)
// Save property for first type
property1 := &models.TypePropertyImpl{
TypeID: typeID,
Name: propertyName,
DataType: &dataType,
}
savedProperty1, err := repo.Save(property1)
require.NoError(t, err)
require.NotNil(t, savedProperty1)
// Save property with same name for second type - should work
property2 := &models.TypePropertyImpl{
TypeID: anotherTypeID,
Name: propertyName,
DataType: &dataType,
}
savedProperty2, err := repo.Save(property2)
require.NoError(t, err)
require.NotNil(t, savedProperty2)
// Verify both properties exist with different type IDs
assert.Equal(t, typeID, savedProperty1.GetTypeID())
assert.Equal(t, anotherTypeID, savedProperty2.GetTypeID())
assert.Equal(t, propertyName, savedProperty1.GetName())
assert.Equal(t, propertyName, savedProperty2.GetName())
assert.Equal(t, dataType, *savedProperty1.GetDataType())
assert.Equal(t, dataType, *savedProperty2.GetDataType())
})
t.Run("TestSaveNilDataType", func(t *testing.T) {
// Test saving a property with nil data type for a new property
propertyName := "test-nil-datatype-property"
property := &models.TypePropertyImpl{
TypeID: typeID,
Name: propertyName,
DataType: nil,
}
savedProperty, err := repo.Save(property)
require.NoError(t, err)
require.NotNil(t, savedProperty)
// Verify the saved property
assert.Equal(t, typeID, savedProperty.GetTypeID())
assert.Equal(t, propertyName, savedProperty.GetName())
assert.Nil(t, savedProperty.GetDataType())
})
t.Run("TestSaveValidDataTypes", func(t *testing.T) {
// Test saving properties with various valid data types
validDataTypes := []int32{0, 1, 2, 3, 4, 5, 10, 100}
for i, dataType := range validDataTypes {
propertyName := fmt.Sprintf("test-datatype-%d-property", i)
property := &models.TypePropertyImpl{
TypeID: typeID,
Name: propertyName,
DataType: &dataType,
}
savedProperty, err := repo.Save(property)
require.NoError(t, err, "Failed to save property with data type %d", dataType)
require.NotNil(t, savedProperty)
assert.Equal(t, typeID, savedProperty.GetTypeID())
assert.Equal(t, propertyName, savedProperty.GetName())
assert.Equal(t, dataType, *savedProperty.GetDataType())
}
})
}

View File

@ -3,6 +3,7 @@ package service_test
import (
"testing"
"github.com/kubeflow/model-registry/internal/apiutils"
"github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/service"
"github.com/kubeflow/model-registry/internal/defaults"
@ -12,7 +13,7 @@ import (
)
func TestTypeRepository(t *testing.T) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
repo := service.NewTypeRepository(sharedDB)
@ -188,4 +189,168 @@ func TestTypeRepository(t *testing.T) {
assert.NotEmpty(t, types, "Migrated database should have default types")
assert.GreaterOrEqual(t, len(types), 1, "Should have at least one type after migration")
})
t.Run("TestSave", func(t *testing.T) {
// Test saving a new type
newTypeName := "test-custom-type"
newTypeKind := int32(1)
newVersion := "1.0.0"
newDescription := "Test custom type description"
newInputType := "application/json"
newOutputType := "application/json"
newExternalID := "external-123"
newType := &models.TypeImpl{
Attributes: &models.TypeAttributes{
Name: &newTypeName,
TypeKind: &newTypeKind,
Version: &newVersion,
Description: &newDescription,
InputType: &newInputType,
OutputType: &newOutputType,
ExternalID: &newExternalID,
},
}
savedType, err := repo.Save(newType)
require.NoError(t, err)
require.NotNil(t, savedType)
// Verify the saved type has an ID
assert.NotNil(t, savedType.GetID())
assert.Greater(t, *savedType.GetID(), int32(0))
// Verify all attributes are preserved
attrs := savedType.GetAttributes()
assert.Equal(t, newTypeName, *attrs.Name)
assert.Equal(t, newTypeKind, *attrs.TypeKind)
assert.Equal(t, newVersion, *attrs.Version)
assert.Equal(t, newDescription, *attrs.Description)
assert.Equal(t, newInputType, *attrs.InputType)
assert.Equal(t, newOutputType, *attrs.OutputType)
assert.Equal(t, newExternalID, *attrs.ExternalID)
})
t.Run("TestSaveExisting", func(t *testing.T) {
// Test saving a type that already exists
existingTypeName := "test-existing-type"
existingTypeKind := int32(2)
// First, save the type
firstType := &models.TypeImpl{
Attributes: &models.TypeAttributes{
Name: &existingTypeName,
TypeKind: &existingTypeKind,
},
}
savedType1, err := repo.Save(firstType)
require.NoError(t, err)
require.NotNil(t, savedType1)
// Now try to save the same type again
secondType := &models.TypeImpl{
Attributes: &models.TypeAttributes{
Name: &existingTypeName,
TypeKind: &existingTypeKind,
},
}
savedType2, err := repo.Save(secondType)
require.NoError(t, err)
require.NotNil(t, savedType2)
// Should return the existing type with same ID
assert.Equal(t, *savedType1.GetID(), *savedType2.GetID())
assert.Equal(t, *savedType1.GetAttributes().Name, *savedType2.GetAttributes().Name)
assert.Equal(t, *savedType1.GetAttributes().TypeKind, *savedType2.GetAttributes().TypeKind)
})
t.Run("TestSaveInvalidTypeKindChange", func(t *testing.T) {
// Test trying to save a type with different kind than existing
typeName := "test-kind-change-type"
originalKind := int32(1)
newKind := int32(2)
// First, save the type with original kind
originalType := &models.TypeImpl{
Attributes: &models.TypeAttributes{
Name: &typeName,
TypeKind: &originalKind,
},
}
_, err := repo.Save(originalType)
require.NoError(t, err)
// Now try to save with different kind - should fail
changedType := &models.TypeImpl{
Attributes: &models.TypeAttributes{
Name: &typeName,
TypeKind: &newKind,
},
}
_, err = repo.Save(changedType)
assert.Error(t, err)
assert.Contains(t, err.Error(), "cannot change to kind")
})
t.Run("TestSaveValidationErrors", func(t *testing.T) {
// Test saving type without attributes
emptyType := &models.TypeImpl{}
_, err := repo.Save(emptyType)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing attributes")
// Test saving type without name
noNameType := &models.TypeImpl{
Attributes: &models.TypeAttributes{
TypeKind: apiutils.Of(int32(1)),
},
}
_, err = repo.Save(noNameType)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing name")
// Test saving type without kind
noKindType := &models.TypeImpl{
Attributes: &models.TypeAttributes{
Name: apiutils.Of("test-no-kind"),
},
}
_, err = repo.Save(noKindType)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing kind")
})
t.Run("TestSaveMinimalType", func(t *testing.T) {
// Test saving type with only required fields
typeName := "test-minimal-type"
typeKind := int32(3)
minimalType := &models.TypeImpl{
Attributes: &models.TypeAttributes{
Name: &typeName,
TypeKind: &typeKind,
},
}
savedType, err := repo.Save(minimalType)
require.NoError(t, err)
require.NotNil(t, savedType)
// Verify required fields
assert.NotNil(t, savedType.GetID())
assert.Equal(t, typeName, *savedType.GetAttributes().Name)
assert.Equal(t, typeKind, *savedType.GetAttributes().TypeKind)
// Optional fields should be nil
attrs := savedType.GetAttributes()
assert.Nil(t, attrs.Version)
assert.Nil(t, attrs.Description)
assert.Nil(t, attrs.InputType)
assert.Nil(t, attrs.OutputType)
assert.Nil(t, attrs.ExternalID)
})
}

View File

@ -6,7 +6,6 @@ import (
"net/http"
"time"
"github.com/kubeflow/model-registry/internal/datastore"
"github.com/kubeflow/model-registry/internal/db"
"github.com/kubeflow/model-registry/pkg/api"
)
@ -72,13 +71,10 @@ type HealthChecker interface {
// DatabaseHealthChecker checks database connectivity and schema state
type DatabaseHealthChecker struct {
datastore datastore.Datastore
}
func NewDatabaseHealthChecker(datastore datastore.Datastore) *DatabaseHealthChecker {
return &DatabaseHealthChecker{
datastore: datastore,
}
func NewDatabaseHealthChecker() *DatabaseHealthChecker {
return &DatabaseHealthChecker{}
}
func (d *DatabaseHealthChecker) Check() HealthCheck {
@ -87,22 +83,6 @@ func (d *DatabaseHealthChecker) Check() HealthCheck {
Details: make(map[string]interface{}),
}
// Skip embedmd check for mlmd datastore
if d.datastore.Type != "embedmd" {
check.Status = StatusPass
check.Message = "MLMD datastore - skipping database check"
check.Details[detailDatastoreType] = d.datastore.Type
return check
}
// Check DSN configuration
dsn := d.datastore.EmbedMD.DatabaseDSN
if dsn == "" {
check.Status = StatusFail
check.Message = "database DSN not configured"
return check
}
// Check database connector
dbConnector, ok := db.GetConnector()
if !ok {
@ -259,7 +239,7 @@ func (m *ModelRegistryHealthChecker) Check() HealthCheck {
}
// GeneralReadinessHandler creates a general readiness handler with configurable health checks
func GeneralReadinessHandler(datastore datastore.Datastore, additionalCheckers ...HealthChecker) http.Handler {
func GeneralReadinessHandler(additionalCheckers ...HealthChecker) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)

View File

@ -10,8 +10,6 @@ import (
"time"
"github.com/kubeflow/model-registry/internal/core"
"github.com/kubeflow/model-registry/internal/datastore"
"github.com/kubeflow/model-registry/internal/datastore/embedmd"
"github.com/kubeflow/model-registry/internal/db"
"github.com/kubeflow/model-registry/internal/db/schema"
"github.com/kubeflow/model-registry/internal/db/service"
@ -29,7 +27,7 @@ func TestMain(m *testing.M) {
}
func setupTestDB(t *testing.T) (*gorm.DB, string, api.ModelRegistryApi, func()) {
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
dsn := testutils.GetSharedMySQLDSN(t)
svc := setupModelRegistryService(sharedDB)
@ -78,7 +76,14 @@ func setupModelRegistryService(sharedDB *gorm.DB) api.ModelRegistryApi {
typesMap := getTypeIDs(sharedDB)
// Create all repositories
artifactRepo := service.NewArtifactRepository(sharedDB, typesMap[defaults.ModelArtifactTypeName], typesMap[defaults.DocArtifactTypeName], typesMap[defaults.DataSetTypeName], typesMap[defaults.MetricTypeName], typesMap[defaults.ParameterTypeName], typesMap[defaults.MetricHistoryTypeName])
artifactRepo := service.NewArtifactRepository(sharedDB, map[string]int64{
defaults.ModelArtifactTypeName: typesMap[defaults.ModelArtifactTypeName],
defaults.DocArtifactTypeName: typesMap[defaults.DocArtifactTypeName],
defaults.DataSetTypeName: typesMap[defaults.DataSetTypeName],
defaults.MetricTypeName: typesMap[defaults.MetricTypeName],
defaults.ParameterTypeName: typesMap[defaults.ParameterTypeName],
defaults.MetricHistoryTypeName: typesMap[defaults.MetricHistoryTypeName],
})
modelArtifactRepo := service.NewModelArtifactRepository(sharedDB, typesMap[defaults.ModelArtifactTypeName])
docArtifactRepo := service.NewDocArtifactRepository(sharedDB, typesMap[defaults.DocArtifactTypeName])
registeredModelRepo := service.NewRegisteredModelRepository(sharedDB, typesMap[defaults.RegisteredModelTypeName])
@ -129,28 +134,15 @@ func setDirtySchemaState(t *testing.T, sharedDB *gorm.DB) {
require.NoError(t, err)
}
// createTestDatastore creates a datastore config for testing
func createTestDatastore(sharedDSN string) datastore.Datastore {
return datastore.Datastore{
Type: "embedmd",
EmbedMD: embedmd.EmbedMDConfig{
DatabaseType: "mysql",
DatabaseDSN: sharedDSN,
},
}
}
func TestReadinessHandler_EmbedMD_Success(t *testing.T) {
// Ensure clean state before test
sharedDB, sharedDSN, _, cleanup := setupTestDB(t)
sharedDB, _, _, cleanup := setupTestDB(t)
defer cleanup()
cleanupSchemaState(t, sharedDB)
ds := createTestDatastore(sharedDSN)
dbHealthChecker := NewDatabaseHealthChecker(ds)
handler := GeneralReadinessHandler(ds, dbHealthChecker)
dbHealthChecker := NewDatabaseHealthChecker()
handler := GeneralReadinessHandler(dbHealthChecker)
req, err := http.NewRequest("GET", "/readyz/isDirty", nil)
require.NoError(t, err)
@ -163,16 +155,14 @@ func TestReadinessHandler_EmbedMD_Success(t *testing.T) {
func TestReadinessHandler_EmbedMD_Dirty(t *testing.T) {
// Set dirty state for this test
sharedDB, sharedDSN, _, cleanup := setupTestDB(t)
sharedDB, _, _, cleanup := setupTestDB(t)
defer cleanup()
setDirtySchemaState(t, sharedDB)
defer cleanupSchemaState(t, sharedDB)
ds := createTestDatastore(sharedDSN)
dbHealthChecker := NewDatabaseHealthChecker(ds)
handler := GeneralReadinessHandler(ds, dbHealthChecker)
dbHealthChecker := NewDatabaseHealthChecker()
handler := GeneralReadinessHandler(dbHealthChecker)
req, err := http.NewRequest("GET", "/readyz/isDirty", nil)
require.NoError(t, err)
@ -208,17 +198,15 @@ func TestReadinessHandler_EmbedMD_Dirty(t *testing.T) {
func TestGeneralReadinessHandler_WithModelRegistry_Success(t *testing.T) {
// Ensure clean state before test
sharedDB, sharedDSN, sharedModelRegistryService, cleanup := setupTestDB(t)
sharedDB, _, sharedModelRegistryService, cleanup := setupTestDB(t)
defer cleanup()
cleanupSchemaState(t, sharedDB)
ds := createTestDatastore(sharedDSN)
// Create both health checkers
dbHealthChecker := NewDatabaseHealthChecker(ds)
dbHealthChecker := NewDatabaseHealthChecker()
mrHealthChecker := NewModelRegistryHealthChecker(sharedModelRegistryService)
handler := GeneralReadinessHandler(ds, dbHealthChecker, mrHealthChecker)
handler := GeneralReadinessHandler(dbHealthChecker, mrHealthChecker)
req, err := http.NewRequest("GET", "/readyz/health", nil)
require.NoError(t, err)
@ -232,17 +220,15 @@ func TestGeneralReadinessHandler_WithModelRegistry_Success(t *testing.T) {
func TestGeneralReadinessHandler_WithModelRegistry_JSONFormat(t *testing.T) {
// Ensure clean state before test
sharedDB, sharedDSN, sharedModelRegistryService, cleanup := setupTestDB(t)
sharedDB, _, sharedModelRegistryService, cleanup := setupTestDB(t)
defer cleanup()
cleanupSchemaState(t, sharedDB)
ds := createTestDatastore(sharedDSN)
// Create both health checkers
dbHealthChecker := NewDatabaseHealthChecker(ds)
dbHealthChecker := NewDatabaseHealthChecker()
mrHealthChecker := NewModelRegistryHealthChecker(sharedModelRegistryService)
handler := GeneralReadinessHandler(ds, dbHealthChecker, mrHealthChecker)
handler := GeneralReadinessHandler(dbHealthChecker, mrHealthChecker)
req, err := http.NewRequest("GET", "/readyz/health?format=json", nil)
require.NoError(t, err)
@ -279,18 +265,16 @@ func TestGeneralReadinessHandler_WithModelRegistry_JSONFormat(t *testing.T) {
func TestGeneralReadinessHandler_WithModelRegistry_DatabaseFail(t *testing.T) {
// Set dirty state to make database check fail
sharedDB, sharedDSN, sharedModelRegistryService, cleanup := setupTestDB(t)
sharedDB, _, sharedModelRegistryService, cleanup := setupTestDB(t)
defer cleanup()
setDirtySchemaState(t, sharedDB)
defer cleanupSchemaState(t, sharedDB)
ds := createTestDatastore(sharedDSN)
// Create both health checkers
dbHealthChecker := NewDatabaseHealthChecker(ds)
dbHealthChecker := NewDatabaseHealthChecker()
mrHealthChecker := NewModelRegistryHealthChecker(sharedModelRegistryService)
handler := GeneralReadinessHandler(ds, dbHealthChecker, mrHealthChecker)
handler := GeneralReadinessHandler(dbHealthChecker, mrHealthChecker)
req, err := http.NewRequest("GET", "/readyz/health?format=json", nil)
require.NoError(t, err)
@ -319,17 +303,15 @@ func TestGeneralReadinessHandler_WithModelRegistry_DatabaseFail(t *testing.T) {
func TestGeneralReadinessHandler_WithModelRegistry_ModelRegistryNil(t *testing.T) {
// Ensure clean state before test
sharedDB, sharedDSN, _, cleanup := setupTestDB(t)
sharedDB, _, _, cleanup := setupTestDB(t)
defer cleanup()
cleanupSchemaState(t, sharedDB)
ds := createTestDatastore(sharedDSN)
// Create health checkers - with nil model registry service
dbHealthChecker := NewDatabaseHealthChecker(ds)
dbHealthChecker := NewDatabaseHealthChecker()
mrHealthChecker := NewModelRegistryHealthChecker(nil)
handler := GeneralReadinessHandler(ds, dbHealthChecker, mrHealthChecker)
handler := GeneralReadinessHandler(dbHealthChecker, mrHealthChecker)
req, err := http.NewRequest("GET", "/readyz/health?format=json", nil)
require.NoError(t, err)
@ -358,18 +340,16 @@ func TestGeneralReadinessHandler_WithModelRegistry_ModelRegistryNil(t *testing.T
func TestGeneralReadinessHandler_SimpleTextResponse_Failure(t *testing.T) {
// Set dirty state to make database check fail
sharedDB, sharedDSN, sharedModelRegistryService, cleanup := setupTestDB(t)
sharedDB, _, sharedModelRegistryService, cleanup := setupTestDB(t)
defer cleanup()
setDirtySchemaState(t, sharedDB)
defer cleanupSchemaState(t, sharedDB)
ds := createTestDatastore(sharedDSN)
// Create both health checkers
dbHealthChecker := NewDatabaseHealthChecker(ds)
dbHealthChecker := NewDatabaseHealthChecker()
mrHealthChecker := NewModelRegistryHealthChecker(sharedModelRegistryService)
handler := GeneralReadinessHandler(ds, dbHealthChecker, mrHealthChecker)
handler := GeneralReadinessHandler(dbHealthChecker, mrHealthChecker)
req, err := http.NewRequest("GET", "/readyz/health", nil)
require.NoError(t, err)
@ -382,36 +362,17 @@ func TestGeneralReadinessHandler_SimpleTextResponse_Failure(t *testing.T) {
assert.Contains(t, rr.Body.String(), "database schema is in dirty state")
}
func TestDatabaseHealthChecker_EmptyDSN(t *testing.T) {
ds := datastore.Datastore{
Type: "embedmd",
EmbedMD: embedmd.EmbedMDConfig{
DatabaseType: "mysql",
DatabaseDSN: "", // Empty DSN
},
}
checker := NewDatabaseHealthChecker(ds)
result := checker.Check()
assert.Equal(t, HealthCheckDatabase, result.Name)
assert.Equal(t, StatusFail, result.Status)
assert.Equal(t, "database DSN not configured", result.Message)
}
func TestGeneralReadinessHandler_MultipleFailures(t *testing.T) {
// Test with both database and model registry failing
sharedDB, sharedDSN, _, cleanup := setupTestDB(t)
sharedDB, _, _, cleanup := setupTestDB(t)
defer cleanup()
setDirtySchemaState(t, sharedDB)
defer cleanupSchemaState(t, sharedDB)
ds := createTestDatastore(sharedDSN)
dbHealthChecker := NewDatabaseHealthChecker(ds)
dbHealthChecker := NewDatabaseHealthChecker()
mrHealthChecker := NewModelRegistryHealthChecker(nil) // Nil service to make it fail
handler := GeneralReadinessHandler(ds, dbHealthChecker, mrHealthChecker)
handler := GeneralReadinessHandler(dbHealthChecker, mrHealthChecker)
req, err := http.NewRequest("GET", "/readyz/health?format=json", nil)
require.NoError(t, err)

View File

@ -7,6 +7,8 @@ import (
"sync"
"testing"
"github.com/kubeflow/model-registry/internal/datastore"
"github.com/kubeflow/model-registry/internal/datastore/embedmd"
"github.com/kubeflow/model-registry/internal/datastore/embedmd/mysql"
"github.com/kubeflow/model-registry/internal/tls"
"github.com/stretchr/testify/require"
@ -102,14 +104,18 @@ func CleanupSharedMySQL() {
}
// SetupMySQLWithMigrations returns a migrated MySQL database connection
func SetupMySQLWithMigrations(t *testing.T) (*gorm.DB, func()) {
func SetupMySQLWithMigrations(t *testing.T, spec *datastore.Spec) (*gorm.DB, func()) {
db, cleanup := GetSharedMySQLDB(t)
// Run migrations
migrator, err := mysql.NewMySQLMigrator(db)
require.NoError(t, err)
err = migrator.Migrate()
require.NoError(t, err)
ds, err := datastore.NewConnector("embedmd", &embedmd.EmbedMDConfig{DB: db})
if err != nil {
t.Fatalf("unable get datastore connector: %v", err)
}
_, err = ds.Connect(spec)
if err != nil {
t.Fatalf("unable to connect to datastore: %v", err)
}
return db, cleanup
}

View File

@ -4,6 +4,7 @@ import (
"os"
"testing"
"github.com/kubeflow/model-registry/internal/db/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -32,7 +33,7 @@ func TestSharedMySQLUtility(t *testing.T) {
})
t.Run("SetupMySQLWithMigrations", func(t *testing.T) {
db, cleanup := SetupMySQLWithMigrations(t)
db, cleanup := SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
// Test that migrations were applied
@ -43,7 +44,7 @@ func TestSharedMySQLUtility(t *testing.T) {
})
t.Run("CleanupTestData", func(t *testing.T) {
db, cleanup := SetupMySQLWithMigrations(t)
db, cleanup := SetupMySQLWithMigrations(t, service.DatastoreSpec())
defer cleanup()
// Insert some test data

View File

@ -4,10 +4,16 @@ This directory contains manifests for deploying the Model Catalog using Kustomiz
## Deployment
To deploy the Model Catalog to your Kubernetes cluster, run the following command in the `base` directory:
The model catalog manifests deploy a PostgreSQL database, set `POSTGRES_PASSWORD` in `postgres.env` before deploying the manifests. On Linux, you can generate a random password with:
```sh
kubectl apply -k . -n <your-namespace>
(echo POSTGRES_USER=postgres ; echo -n POSTGRES_PASSWORD=; dd if=/dev/random of=/dev/stdout bs=15 count=1 status=none | base64) >base/postgres.env
```
To deploy the Model Catalog to your Kubernetes cluster (without Kubeflow--see below for Istio support), run the following command from this directory:
```sh
kubectl apply -k base -n <your-namespace>
```
Replace `<your-namespace>` with the Kubernetes namespace where you want to deploy the catalog.
@ -16,11 +22,13 @@ This command will create:
* A `Deployment` to run the Model Catalog server.
* A `Service` to expose the Model Catalog server.
* A `ConfigMap` named `model-catalog-sources` containing the configuration for the catalog sources.
* A `StatefulSet` with a PostgreSQL database
* A `PersistentVolumeClaim` for PostgreSQL
For deployment in a Kubeflow environment with Istio support, use the overlay, running the following command in the `overlay` directory::
For deployment in a Kubeflow environment with Istio support, use the `overlay` directory instead:
```sh
kubectl apply -k . -n <your-namespace>
kubectl apply -k overlay -n <your-namespace>
```
## Configuring Catalog Sources
@ -118,4 +126,4 @@ catalogs:
### Sample Catalog File
You can refer to `sample-catalog.yaml` in this directory for an example of how to structure your model definitions file.
You can refer to `sample-catalog.yaml` in this directory for an example of how to structure your model definitions file.

View File

@ -3,17 +3,23 @@ kind: Deployment
metadata:
name: model-catalog-server
labels:
component: model-catalog-server
app.kubernetes.io/name: model-catalog
app.kubernetes.io/part-of: model-catalog
app.kubernetes.io/component: server
spec:
replicas: 1
selector:
matchLabels:
component: model-catalog-server
app.kubernetes.io/name: model-catalog
app.kubernetes.io/part-of: model-catalog
app.kubernetes.io/component: server
template:
metadata:
labels:
sidecar.istio.io/inject: "true"
component: model-catalog-server
app.kubernetes.io/name: model-catalog
app.kubernetes.io/part-of: model-catalog
app.kubernetes.io/component: server
spec:
securityContext:
seccompProfile:
@ -25,6 +31,19 @@ spec:
name: model-catalog-sources
containers:
- name: catalog
env:
- name: PGHOST
value: model-catalog-postgres
- name: PGUSER
valueFrom:
secretKeyRef:
name: model-catalog-postgres
key: POSTGRES_USER
- name: PGPASSWORD
valueFrom:
secretKeyRef:
name: model-catalog-postgres
key: POSTGRES_PASSWORD
command:
- /model-registry
- catalog

View File

@ -4,6 +4,9 @@ kind: Kustomization
resources:
- deployment.yaml
- service.yaml
- postgres-statefulset.yaml
- postgres-pvc.yaml
- postgres-service.yaml
configMapGenerator:
- behavior: create
@ -13,3 +16,10 @@ configMapGenerator:
name: model-catalog-sources
options:
disableNameSuffixHash: true
secretGenerator:
- name: postgres
envs:
- postgres.env
options:
disableNameSuffixHash: true

View File

@ -0,0 +1,10 @@
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: model-catalog-postgres
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 5Gi

View File

@ -0,0 +1,18 @@
kind: Service
apiVersion: v1
metadata:
labels:
app.kubernetes.io/name: model-catalog
app.kubernetes.io/part-of: model-catalog
app.kubernetes.io/component: database
name: model-catalog-postgres
spec:
selector:
app.kubernetes.io/part-of: model-catalog
app.kubernetes.io/name: postgres
app.kubernetes.io/component: database
type: ClusterIP
ports:
- port: 5432
protocol: TCP
name: postgres

View File

@ -0,0 +1,52 @@
apiVersion: apps/v1
kind: StatefulSet
metadata:
name: model-catalog-postgres
labels:
component: model-catalog-server
spec:
selector:
matchLabels:
app.kubernetes.io/name: postgres
app.kubernetes.io/part-of: model-catalog
replicas: 1
template:
metadata:
name: postgres
labels:
app.kubernetes.io/name: postgres
app.kubernetes.io/part-of: model-catalog
app.kubernetes.io/component: database
sidecar.istio.io/inject: "false"
spec:
securityContext:
seccompProfile:
type: RuntimeDefault
runAsNonRoot: true
fsGroup: 70
containers:
- name: postgres
image: postgres:17.6
env:
- name: PGDATA
value: /var/lib/postgresql/data/pgdata
envFrom:
- secretRef:
name: model-catalog-postgres
ports:
- name: postgres
containerPort: 5432
volumeMounts:
- name: model-catalog-postgres
mountPath: /var/lib/postgresql/data
securityContext:
runAsUser: 70
runAsGroup: 70
allowPrivilegeEscalation: false
capabilities:
drop:
- ALL
volumes:
- name: model-catalog-postgres
persistentVolumeClaim:
claimName: model-catalog-postgres

View File

@ -0,0 +1,2 @@
POSTGRES_USER=postgres
POSTGRES_PASSWORD=postgres

View File

@ -2,16 +2,14 @@ kind: Service
apiVersion: v1
metadata:
labels:
app: model-catalog-service
app.kubernetes.io/component: model-catalog
app.kubernetes.io/instance: model-catalog-service
app.kubernetes.io/name: model-catalog
app.kubernetes.io/component: server
app.kubernetes.io/part-of: model-catalog
component: model-catalog
name: model-catalog
spec:
selector:
component: model-catalog-server
app.kubernetes.io/part-of: model-catalog
app.kubernetes.io/component: server
type: ClusterIP
ports:
- port: 8080