Merge model-catalog-enhancements branch (#1656)
* feat(catalog): deploy postgres for model catalog (#1584) * feat(catalog): initialize postgres for model catalog Initialize PostgreSQL database for model catalog with connection setup. Refactored embedmd initialization code from model registry for reuse. Added Kubernetes manifests for PostgreSQL StatefulSet, PVC, and service configuration. Signed-off-by: Paul Boyd <paul@pboyd.io> * fix(catalog): remove chdir before loading catalogs Instead of changing directories, pass the directory that files should be relative to when loading a catalog. Signed-off-by: Paul Boyd <paul@pboyd.io> --------- Signed-off-by: Paul Boyd <paul@pboyd.io> * feat(catalog): add postgres for catalog dev (#1623) Adding to tilt and docker compose. Signed-off-by: Paul Boyd <paul@pboyd.io> * feat(datastore): refactor type creation (#1636) - Refactor embedmd to created types from a spec, instead of database migrations. - Add TypeRepository and TypePropertyRepository services Signed-off-by: Paul Boyd <paul@pboyd.io> * Add artifact types to the model catalog (#1649) * chore: update testutils to take a datastore spec argument Signed-off-by: Paul Boyd <paul@pboyd.io> * feat(catalog): add artifact types to openapi spec And re-generate code. Signed-off-by: Paul Boyd <paul@pboyd.io> * feat(catalog): add database models and services Signed-off-by: Paul Boyd <paul@pboyd.io> --------- Signed-off-by: Paul Boyd <paul@pboyd.io> * Update api/openapi/src/catalog.yaml Co-authored-by: Dhiraj Bokde <dhirajsb@users.noreply.github.com> Signed-off-by: Paul Boyd <paul@pboyd.io> * fix(catalog): set source_id type correctly Signed-off-by: Paul Boyd <paul@pboyd.io> * fix(catalog): fix k8s resource names Signed-off-by: Paul Boyd <paul@pboyd.io> --------- Signed-off-by: Paul Boyd <paul@pboyd.io> Co-authored-by: Dhiraj Bokde <dhirajsb@users.noreply.github.com>
This commit is contained in:
parent
122dbfd933
commit
79f837c3c1
|
|
@ -107,14 +107,14 @@ paths:
|
|||
required: true
|
||||
/api/model_catalog/v1alpha1/sources/{source_id}/models/{model_name}/artifacts:
|
||||
description: >-
|
||||
The REST endpoint/path used to list `CatalogModelArtifacts`.
|
||||
The REST endpoint/path used to list `CatalogArtifacts`.
|
||||
get:
|
||||
summary: List CatalogModelArtifacts.
|
||||
summary: List CatalogArtifacts.
|
||||
tags:
|
||||
- ModelCatalogService
|
||||
responses:
|
||||
"200":
|
||||
$ref: "#/components/responses/CatalogModelArtifactListResponse"
|
||||
$ref: "#/components/responses/CatalogArtifactListResponse"
|
||||
"401":
|
||||
$ref: "#/components/responses/Unauthorized"
|
||||
"404":
|
||||
|
|
@ -234,6 +234,51 @@ components:
|
|||
format: int32
|
||||
description: Number of items in result list.
|
||||
type: integer
|
||||
CatalogArtifact:
|
||||
description: A single artifact in the catalog API.
|
||||
oneOf:
|
||||
- $ref: "#/components/schemas/CatalogModelArtifact"
|
||||
- $ref: "#/components/schemas/CatalogMetricsArtifact"
|
||||
discriminator:
|
||||
propertyName: artifactType
|
||||
mapping:
|
||||
model-artifact: "#/components/schemas/CatalogModelArtifact"
|
||||
metrics-artifact: "#/components/schemas/CatalogMetricsArtifact"
|
||||
CatalogArtifactList:
|
||||
description: List of CatalogModel entities.
|
||||
allOf:
|
||||
- type: object
|
||||
properties:
|
||||
items:
|
||||
description: Array of `CatalogArtifact` entities.
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/CatalogArtifact"
|
||||
required:
|
||||
- items
|
||||
- $ref: "#/components/schemas/BaseResourceList"
|
||||
CatalogMetricsArtifact:
|
||||
description: A metadata Artifact Entity.
|
||||
allOf:
|
||||
- type: object
|
||||
required:
|
||||
- artifactType
|
||||
- metricsType
|
||||
properties:
|
||||
artifactType:
|
||||
type: string
|
||||
default: metrics-artifact
|
||||
metricsType:
|
||||
type: string
|
||||
enum:
|
||||
- performance-metrics
|
||||
- accuracy-metrics
|
||||
customProperties:
|
||||
description: User provided custom properties which are not defined by its type.
|
||||
type: object
|
||||
additionalProperties:
|
||||
$ref: "#/components/schemas/MetadataValue"
|
||||
- $ref: "#/components/schemas/BaseResourceDates"
|
||||
CatalogModel:
|
||||
description: A model in the model catalog.
|
||||
allOf:
|
||||
|
|
@ -251,35 +296,26 @@ components:
|
|||
- $ref: "#/components/schemas/BaseResourceDates"
|
||||
- $ref: "#/components/schemas/BaseModel"
|
||||
CatalogModelArtifact:
|
||||
description: A single artifact for a catalog model.
|
||||
description: A Catalog Model Artifact Entity.
|
||||
allOf:
|
||||
- type: object
|
||||
required:
|
||||
- artifactType
|
||||
- uri
|
||||
properties:
|
||||
artifactType:
|
||||
type: string
|
||||
default: model-artifact
|
||||
uri:
|
||||
type: string
|
||||
format: uri
|
||||
description: URI where the artifact can be retrieved.
|
||||
description: URI where the model can be retrieved.
|
||||
customProperties:
|
||||
description: User provided custom properties which are not defined by its type.
|
||||
type: object
|
||||
additionalProperties:
|
||||
$ref: "#/components/schemas/MetadataValue"
|
||||
- $ref: "#/components/schemas/BaseResourceDates"
|
||||
CatalogModelArtifactList:
|
||||
description: List of CatalogModel entities.
|
||||
allOf:
|
||||
- type: object
|
||||
properties:
|
||||
items:
|
||||
description: Array of `CatalogModelArtifact` entities.
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/CatalogModelArtifact"
|
||||
required:
|
||||
- items
|
||||
- $ref: "#/components/schemas/BaseResourceList"
|
||||
CatalogModelList:
|
||||
description: List of CatalogModel entities.
|
||||
allOf:
|
||||
|
|
@ -468,12 +504,12 @@ components:
|
|||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
description: Bad Request parameters
|
||||
CatalogModelArtifactListResponse:
|
||||
CatalogArtifactListResponse:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CatalogModelArtifactList"
|
||||
description: A response containing a list of CatalogModelArtifact entities.
|
||||
$ref: "#/components/schemas/CatalogArtifactList"
|
||||
description: A response containing a list of CatalogArtifact entities.
|
||||
CatalogModelListResponse:
|
||||
content:
|
||||
application/json:
|
||||
|
|
|
|||
|
|
@ -107,14 +107,14 @@ paths:
|
|||
required: true
|
||||
/api/model_catalog/v1alpha1/sources/{source_id}/models/{model_name}/artifacts:
|
||||
description: >-
|
||||
The REST endpoint/path used to list `CatalogModelArtifacts`.
|
||||
The REST endpoint/path used to list `CatalogArtifacts`.
|
||||
get:
|
||||
summary: List CatalogModelArtifacts.
|
||||
summary: List CatalogArtifacts.
|
||||
tags:
|
||||
- ModelCatalogService
|
||||
responses:
|
||||
"200":
|
||||
$ref: "#/components/responses/CatalogModelArtifactListResponse"
|
||||
$ref: "#/components/responses/CatalogArtifactListResponse"
|
||||
"401":
|
||||
$ref: "#/components/responses/Unauthorized"
|
||||
"404":
|
||||
|
|
@ -137,6 +137,51 @@ paths:
|
|||
required: true
|
||||
components:
|
||||
schemas:
|
||||
CatalogArtifact:
|
||||
description: A single artifact in the catalog API.
|
||||
oneOf:
|
||||
- $ref: "#/components/schemas/CatalogModelArtifact"
|
||||
- $ref: "#/components/schemas/CatalogMetricsArtifact"
|
||||
discriminator:
|
||||
propertyName: artifactType
|
||||
mapping:
|
||||
model-artifact: "#/components/schemas/CatalogModelArtifact"
|
||||
metrics-artifact: "#/components/schemas/CatalogMetricsArtifact"
|
||||
CatalogArtifactList:
|
||||
description: List of CatalogModel entities.
|
||||
allOf:
|
||||
- type: object
|
||||
properties:
|
||||
items:
|
||||
description: Array of `CatalogArtifact` entities.
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/CatalogArtifact"
|
||||
required:
|
||||
- items
|
||||
- $ref: "#/components/schemas/BaseResourceList"
|
||||
CatalogMetricsArtifact:
|
||||
description: A metadata Artifact Entity.
|
||||
allOf:
|
||||
- type: object
|
||||
required:
|
||||
- artifactType
|
||||
- metricsType
|
||||
properties:
|
||||
artifactType:
|
||||
type: string
|
||||
default: metrics-artifact
|
||||
metricsType:
|
||||
type: string
|
||||
enum:
|
||||
- performance-metrics
|
||||
- accuracy-metrics
|
||||
customProperties:
|
||||
description: User provided custom properties which are not defined by its type.
|
||||
type: object
|
||||
additionalProperties:
|
||||
$ref: "#/components/schemas/MetadataValue"
|
||||
- $ref: "#/components/schemas/BaseResourceDates"
|
||||
CatalogModel:
|
||||
description: A model in the model catalog.
|
||||
allOf:
|
||||
|
|
@ -154,35 +199,26 @@ components:
|
|||
- $ref: "#/components/schemas/BaseResourceDates"
|
||||
- $ref: "#/components/schemas/BaseModel"
|
||||
CatalogModelArtifact:
|
||||
description: A single artifact for a catalog model.
|
||||
description: A Catalog Model Artifact Entity.
|
||||
allOf:
|
||||
- type: object
|
||||
required:
|
||||
- artifactType
|
||||
- uri
|
||||
properties:
|
||||
artifactType:
|
||||
type: string
|
||||
default: model-artifact
|
||||
uri:
|
||||
type: string
|
||||
format: uri
|
||||
description: URI where the artifact can be retrieved.
|
||||
description: URI where the model can be retrieved.
|
||||
customProperties:
|
||||
description: User provided custom properties which are not defined by its type.
|
||||
type: object
|
||||
additionalProperties:
|
||||
$ref: "#/components/schemas/MetadataValue"
|
||||
- $ref: "#/components/schemas/BaseResourceDates"
|
||||
CatalogModelArtifactList:
|
||||
description: List of CatalogModel entities.
|
||||
allOf:
|
||||
- type: object
|
||||
properties:
|
||||
items:
|
||||
description: Array of `CatalogModelArtifact` entities.
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/CatalogModelArtifact"
|
||||
required:
|
||||
- items
|
||||
- $ref: "#/components/schemas/BaseResourceList"
|
||||
CatalogModelList:
|
||||
description: List of CatalogModel entities.
|
||||
allOf:
|
||||
|
|
@ -240,12 +276,12 @@ components:
|
|||
type: string
|
||||
|
||||
responses:
|
||||
CatalogModelArtifactListResponse:
|
||||
CatalogArtifactListResponse:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CatalogModelArtifactList"
|
||||
description: A response containing a list of CatalogModelArtifact entities.
|
||||
$ref: "#/components/schemas/CatalogArtifactList"
|
||||
description: A response containing a list of CatalogArtifact entities.
|
||||
CatalogModelListResponse:
|
||||
content:
|
||||
application/json:
|
||||
|
|
|
|||
|
|
@ -6,31 +6,51 @@ import (
|
|||
|
||||
"github.com/golang/glog"
|
||||
"github.com/kubeflow/model-registry/catalog/internal/catalog"
|
||||
"github.com/kubeflow/model-registry/catalog/internal/db/service"
|
||||
"github.com/kubeflow/model-registry/catalog/internal/server/openapi"
|
||||
"github.com/kubeflow/model-registry/internal/datastore"
|
||||
"github.com/kubeflow/model-registry/internal/datastore/embedmd"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var catalogCfg = struct {
|
||||
ListenAddress string
|
||||
ConfigPath string
|
||||
ConfigPath []string
|
||||
}{
|
||||
ListenAddress: "0.0.0.0:8080",
|
||||
ConfigPath: "sources.yaml",
|
||||
ConfigPath: []string{"sources.yaml"},
|
||||
}
|
||||
|
||||
var CatalogCmd = &cobra.Command{
|
||||
Use: "catalog",
|
||||
Short: "Catalog API server",
|
||||
Long: `Launch the API server for the model catalog`,
|
||||
RunE: runCatalogServer,
|
||||
Long: `Launch the API server for the model catalog. Use PostgreSQL's
|
||||
environment variables
|
||||
(https://www.postgresql.org/docs/current/libpq-envars.html) to
|
||||
configure the database connection.`,
|
||||
RunE: runCatalogServer,
|
||||
}
|
||||
|
||||
func init() {
|
||||
CatalogCmd.Flags().StringVarP(&catalogCfg.ListenAddress, "listen", "l", catalogCfg.ListenAddress, "Address to listen on")
|
||||
CatalogCmd.Flags().StringVar(&catalogCfg.ConfigPath, "catalogs-path", catalogCfg.ConfigPath, "Path to catalog source configuration file")
|
||||
fs := CatalogCmd.Flags()
|
||||
fs.StringVarP(&catalogCfg.ListenAddress, "listen", "l", catalogCfg.ListenAddress, "Address to listen on")
|
||||
fs.StringSliceVar(&catalogCfg.ConfigPath, "catalogs-path", catalogCfg.ConfigPath, "Path to catalog source configuration file")
|
||||
}
|
||||
|
||||
func runCatalogServer(cmd *cobra.Command, args []string) error {
|
||||
ds, err := datastore.NewConnector("embedmd", &embedmd.EmbedMDConfig{
|
||||
DatabaseType: "postgres", // We only support postgres right now
|
||||
DatabaseDSN: "", // Empty DSN, see https://www.postgresql.org/docs/current/libpq-envars.html
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating datastore: %w", err)
|
||||
}
|
||||
|
||||
_, err = ds.Connect(service.DatastoreSpec())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error initializing datastore: %v", err)
|
||||
}
|
||||
|
||||
sources, err := catalog.LoadCatalogSources(catalogCfg.ConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading catalog sources: %v", err)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ type CatalogSourceProvider interface {
|
|||
// GetArtifacts returns all artifacts for a particular model. If no
|
||||
// model is found with that name, it returns nil. If the model is
|
||||
// found, but has no artifacts, an empty list is returned.
|
||||
GetArtifacts(ctx context.Context, name string) (*model.CatalogModelArtifactList, error)
|
||||
GetArtifacts(ctx context.Context, name string) (*model.CatalogArtifactList, error)
|
||||
}
|
||||
|
||||
// CatalogSourceConfig is a single entry from the catalog sources YAML file.
|
||||
|
|
@ -52,7 +52,7 @@ type sourceConfig struct {
|
|||
Catalogs []CatalogSourceConfig `json:"catalogs"`
|
||||
}
|
||||
|
||||
type CatalogTypeRegisterFunc func(source *CatalogSourceConfig) (CatalogSourceProvider, error)
|
||||
type CatalogTypeRegisterFunc func(source *CatalogSourceConfig, reldir string) (CatalogSourceProvider, error)
|
||||
|
||||
var registeredCatalogTypes = make(map[string]CatalogTypeRegisterFunc, 0)
|
||||
|
||||
|
|
@ -103,24 +103,6 @@ func (sc *SourceCollection) load(path string) error {
|
|||
// Get the directory of the config file to resolve relative paths
|
||||
configDir := filepath.Dir(absConfigPath)
|
||||
|
||||
// Save current working directory
|
||||
originalWd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current working directory: %v", err)
|
||||
}
|
||||
|
||||
// Change to the config directory to make relative paths work
|
||||
if err := os.Chdir(configDir); err != nil {
|
||||
return fmt.Errorf("failed to change to config directory %s: %v", configDir, err)
|
||||
}
|
||||
|
||||
// Ensure we restore the original working directory when we're done
|
||||
defer func() {
|
||||
if err := os.Chdir(originalWd); err != nil {
|
||||
glog.Errorf("failed to restore original working directory %s: %v", originalWd, err)
|
||||
}
|
||||
}()
|
||||
|
||||
config := sourceConfig{}
|
||||
bytes, err := os.ReadFile(absConfigPath)
|
||||
if err != nil {
|
||||
|
|
@ -157,12 +139,13 @@ func (sc *SourceCollection) load(path string) error {
|
|||
if _, exists := sources[id]; exists {
|
||||
return fmt.Errorf("duplicate catalog id %s", id)
|
||||
}
|
||||
|
||||
labels := make([]string, 0)
|
||||
if catalogConfig.GetLabels() != nil {
|
||||
labels = catalogConfig.GetLabels()
|
||||
}
|
||||
catalogConfig.CatalogSource.Labels = labels
|
||||
provider, err := registerFunc(&catalogConfig)
|
||||
provider, err := registerFunc(&catalogConfig, configDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading catalog type %s with id %s: %v", catalogType, id, err)
|
||||
}
|
||||
|
|
@ -182,29 +165,32 @@ func (sc *SourceCollection) load(path string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func LoadCatalogSources(path string) (*SourceCollection, error) {
|
||||
func LoadCatalogSources(paths []string) (*SourceCollection, error) {
|
||||
sc := &SourceCollection{}
|
||||
err := sc.load(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
changes, err := getMonitor().Path(path)
|
||||
for _, path := range paths {
|
||||
err := sc.load(path)
|
||||
if err != nil {
|
||||
glog.Errorf("unable to watch sources file: %v", err)
|
||||
// Not fatal, we just won't get automatic updates.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for range changes {
|
||||
glog.Infof("Reloading sources %s", path)
|
||||
|
||||
err = sc.load(path)
|
||||
go func(path string) {
|
||||
changes, err := getMonitor().Path(path)
|
||||
if err != nil {
|
||||
glog.Errorf("unable to load sources: %v", err)
|
||||
glog.Errorf("unable to watch sources file (%s): %v", path, err)
|
||||
// Not fatal, we just won't get automatic updates.
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for range changes {
|
||||
glog.Infof("Reloading sources %s", path)
|
||||
|
||||
err = sc.load(path)
|
||||
if err != nil {
|
||||
glog.Errorf("unable to load sources: %v", err)
|
||||
}
|
||||
}
|
||||
}(path)
|
||||
}
|
||||
|
||||
return sc, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,13 +21,13 @@ func TestLoadCatalogSources(t *testing.T) {
|
|||
{
|
||||
name: "test-catalog-sources",
|
||||
args: args{catalogsPath: "testdata/test-catalog-sources.yaml"},
|
||||
want: []string{"catalog1", "catalog3", "catalog4"},
|
||||
want: []string{"catalog1"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := LoadCatalogSources(tt.args.catalogsPath)
|
||||
got, err := LoadCatalogSources([]string{tt.args.catalogsPath})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("LoadCatalogSources() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
|
@ -64,23 +64,13 @@ func TestLoadCatalogSourcesEnabledDisabled(t *testing.T) {
|
|||
Name: "Catalog 1",
|
||||
Enabled: &trueValue,
|
||||
},
|
||||
"catalog3": {
|
||||
Id: "catalog3",
|
||||
Name: "Catalog 3",
|
||||
Enabled: &trueValue,
|
||||
},
|
||||
"catalog4": {
|
||||
Id: "catalog4",
|
||||
Name: "Catalog 4",
|
||||
Enabled: &trueValue,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := LoadCatalogSources(tt.args.catalogsPath)
|
||||
got, err := LoadCatalogSources([]string{tt.args.catalogsPath})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("LoadCatalogSources() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -39,11 +39,11 @@ func (h *hfCatalogImpl) ListModels(ctx context.Context, params ListModelsParams)
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (h *hfCatalogImpl) GetArtifacts(ctx context.Context, name string) (*openapi.CatalogModelArtifactList, error) {
|
||||
func (h *hfCatalogImpl) GetArtifacts(ctx context.Context, name string) (*openapi.CatalogArtifactList, error) {
|
||||
// TODO: Implement HuggingFace model artifacts retrieval
|
||||
// For now, return empty list to satisfy interface
|
||||
return &openapi.CatalogModelArtifactList{
|
||||
Items: []openapi.CatalogModelArtifact{},
|
||||
return &openapi.CatalogArtifactList{
|
||||
Items: []openapi.CatalogArtifact{},
|
||||
PageSize: 0,
|
||||
Size: 0,
|
||||
}, nil
|
||||
|
|
@ -82,7 +82,7 @@ func (h *hfCatalogImpl) validateCredentials(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// newHfCatalog creates a new HuggingFace catalog source
|
||||
func newHfCatalog(source *CatalogSourceConfig) (CatalogSourceProvider, error) {
|
||||
func newHfCatalog(source *CatalogSourceConfig, reldir string) (CatalogSourceProvider, error) {
|
||||
apiKey, ok := source.Properties["apiKey"].(string)
|
||||
if !ok || apiKey == "" {
|
||||
return nil, fmt.Errorf("missing or invalid 'apiKey' property for HuggingFace catalog")
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ func TestNewHfCatalog_MissingAPIKey(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := newHfCatalog(source)
|
||||
_, err := newHfCatalog(source, "")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing API key, got nil")
|
||||
}
|
||||
|
|
@ -65,7 +65,7 @@ func TestNewHfCatalog_WithValidCredentials(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
catalog, err := newHfCatalog(source)
|
||||
catalog, err := newHfCatalog(source, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create HF catalog: %v", err)
|
||||
}
|
||||
|
|
@ -130,7 +130,7 @@ func TestNewHfCatalog_InvalidCredentials(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := newHfCatalog(source)
|
||||
_, err := newHfCatalog(source, "")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid credentials, got nil")
|
||||
}
|
||||
|
|
@ -159,7 +159,7 @@ func TestNewHfCatalog_DefaultConfiguration(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
catalog, err := newHfCatalog(source)
|
||||
catalog, err := newHfCatalog(source, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create HF catalog with defaults: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -325,6 +325,14 @@ models:
|
|||
lastUpdateTimeSinceEpoch: "1734637721000"
|
||||
artifacts:
|
||||
- uri: oci://registry.redhat.io/rhelai1/granite-8b-code-base:1.3-1732870892
|
||||
- artifactType: metrics-artifact
|
||||
createTimeSinceEpoch: "1733514949000"
|
||||
lastUpdateTimeSinceEpoch: "1734637721000"
|
||||
customProperties:
|
||||
x:
|
||||
int_value: 1
|
||||
y:
|
||||
double_value: 2.1
|
||||
- name: rhelai1/granite-8b-code-instruct
|
||||
provider: IBM
|
||||
description: |-
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package catalog
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
|
|
@ -19,7 +20,40 @@ import (
|
|||
|
||||
type yamlModel struct {
|
||||
model.CatalogModel `yaml:",inline"`
|
||||
Artifacts []*model.CatalogModelArtifact `yaml:"artifacts"`
|
||||
Artifacts []*yamlArtifact `yaml:"artifacts"`
|
||||
}
|
||||
|
||||
type yamlArtifact struct {
|
||||
model.CatalogArtifact
|
||||
}
|
||||
|
||||
func (a *yamlArtifact) UnmarshalJSON(buf []byte) error {
|
||||
// This is very similar to generated code to unmarshal a
|
||||
// CatalogArtifact, but this version properly handles artifacts without
|
||||
// an artifactType, which is important for backwards compatibility.
|
||||
var yat struct {
|
||||
ArtifactType string `json:"artifactType"`
|
||||
}
|
||||
|
||||
err := json.Unmarshal(buf, &yat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch yat.ArtifactType {
|
||||
case "model-artifact", "":
|
||||
err = json.Unmarshal(buf, &a.CatalogArtifact.CatalogModelArtifact)
|
||||
if a.CatalogArtifact.CatalogModelArtifact != nil {
|
||||
// Ensure artifactType is set even if it wasn't initially.
|
||||
a.CatalogArtifact.CatalogModelArtifact.ArtifactType = "model-artifact"
|
||||
}
|
||||
case "metrics-artifact":
|
||||
err = json.Unmarshal(buf, &a.CatalogArtifact.CatalogMetricsArtifact)
|
||||
default:
|
||||
return fmt.Errorf("unknown artifactType: %s", yat.ArtifactType)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
type yamlCatalog struct {
|
||||
|
|
@ -125,7 +159,7 @@ func (y *yamlCatalogImpl) ListModels(ctx context.Context, params ListModelsParam
|
|||
return list, nil // Return the struct value directly
|
||||
}
|
||||
|
||||
func (y *yamlCatalogImpl) GetArtifacts(ctx context.Context, name string) (*model.CatalogModelArtifactList, error) {
|
||||
func (y *yamlCatalogImpl) GetArtifacts(ctx context.Context, name string) (*model.CatalogArtifactList, error) {
|
||||
y.modelsLock.RLock()
|
||||
defer y.modelsLock.RUnlock()
|
||||
|
||||
|
|
@ -139,13 +173,13 @@ func (y *yamlCatalogImpl) GetArtifacts(ctx context.Context, name string) (*model
|
|||
count = math.MaxInt32
|
||||
}
|
||||
|
||||
list := model.CatalogModelArtifactList{
|
||||
Items: make([]model.CatalogModelArtifact, count),
|
||||
list := model.CatalogArtifactList{
|
||||
Items: make([]model.CatalogArtifact, count),
|
||||
PageSize: int32(count),
|
||||
Size: int32(count),
|
||||
}
|
||||
for i := range list.Items {
|
||||
list.Items[i] = *ym.Artifacts[i]
|
||||
list.Items[i] = ym.Artifacts[i].CatalogArtifact
|
||||
}
|
||||
return &list, nil
|
||||
}
|
||||
|
|
@ -192,16 +226,13 @@ func (y *yamlCatalogImpl) load(path string, excludedModelsList []string) error {
|
|||
|
||||
const yamlCatalogPath = "yamlCatalogPath"
|
||||
|
||||
func newYamlCatalog(source *CatalogSourceConfig) (CatalogSourceProvider, error) {
|
||||
func newYamlCatalog(source *CatalogSourceConfig, reldir string) (CatalogSourceProvider, error) {
|
||||
yamlModelFile, exists := source.Properties[yamlCatalogPath].(string)
|
||||
if !exists || yamlModelFile == "" {
|
||||
return nil, fmt.Errorf("missing %s string property", yamlCatalogPath)
|
||||
}
|
||||
|
||||
yamlModelFile, err := filepath.Abs(yamlModelFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("abs: %w", err)
|
||||
}
|
||||
yamlModelFile = filepath.Join(reldir, yamlModelFile)
|
||||
|
||||
// Excluded models is an optional source property.
|
||||
var excludedModels []string
|
||||
|
|
@ -222,7 +253,7 @@ func newYamlCatalog(source *CatalogSourceConfig) (CatalogSourceProvider, error)
|
|||
p := &yamlCatalogImpl{
|
||||
models: make(map[string]*yamlModel),
|
||||
}
|
||||
err = p.load(yamlModelFile, excludedModels)
|
||||
err := p.load(yamlModelFile, excludedModels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,10 +38,17 @@ func TestYAMLCatalogGetArtifacts(t *testing.T) {
|
|||
artifacts, err := provider.GetArtifacts(context.Background(), "rhelai1/granite-8b-code-base")
|
||||
if assert.NoError(err) {
|
||||
assert.NotNil(artifacts)
|
||||
assert.Equal(int32(1), artifacts.Size)
|
||||
assert.Equal(int32(1), artifacts.PageSize)
|
||||
assert.Len(artifacts.Items, 1)
|
||||
assert.Equal("oci://registry.redhat.io/rhelai1/granite-8b-code-base:1.3-1732870892", artifacts.Items[0].Uri)
|
||||
assert.Equal(int32(2), artifacts.Size)
|
||||
assert.Equal(int32(2), artifacts.PageSize)
|
||||
assert.Len(artifacts.Items, 2)
|
||||
if assert.NotNil(artifacts.Items[0].CatalogModelArtifact) {
|
||||
assert.Equal("model-artifact", artifacts.Items[0].CatalogModelArtifact.ArtifactType)
|
||||
assert.Equal("oci://registry.redhat.io/rhelai1/granite-8b-code-base:1.3-1732870892", artifacts.Items[0].CatalogModelArtifact.Uri)
|
||||
}
|
||||
if assert.NotNil(artifacts.Items[1].CatalogMetricsArtifact) {
|
||||
assert.Equal("metrics-artifact", artifacts.Items[1].CatalogMetricsArtifact.ArtifactType)
|
||||
assert.NotNil(artifacts.Items[1].CatalogMetricsArtifact.CustomProperties)
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 2: Model with no artifacts
|
||||
|
|
@ -199,7 +206,7 @@ func testYAMLProviderWithExclusions(t *testing.T, path string, excludedModels []
|
|||
}
|
||||
provider, err := newYamlCatalog(&CatalogSourceConfig{
|
||||
Properties: properties,
|
||||
})
|
||||
}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("newYamlCatalog(%s) with exclusions failed: %v", path, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"github.com/kubeflow/model-registry/internal/db/filter"
|
||||
"github.com/kubeflow/model-registry/internal/db/models"
|
||||
)
|
||||
|
||||
type MetricsType string
|
||||
|
||||
const (
|
||||
MetricsTypePerformance MetricsType = "performance-metrics"
|
||||
MetricsTypeAccuracy MetricsType = "accuracy-metrics"
|
||||
)
|
||||
|
||||
type CatalogMetricsArtifactListOptions struct {
|
||||
models.Pagination
|
||||
Name *string
|
||||
ExternalID *string
|
||||
ParentResourceID *int32
|
||||
}
|
||||
|
||||
// GetRestEntityType implements the FilterApplier interface
|
||||
func (c *CatalogMetricsArtifactListOptions) GetRestEntityType() filter.RestEntityType {
|
||||
return filter.RestEntityModelArtifact // Reusing existing filter type
|
||||
}
|
||||
|
||||
type CatalogMetricsArtifactAttributes struct {
|
||||
Name *string
|
||||
MetricsType MetricsType
|
||||
ExternalID *string
|
||||
CreateTimeSinceEpoch *int64
|
||||
LastUpdateTimeSinceEpoch *int64
|
||||
}
|
||||
|
||||
type CatalogMetricsArtifact interface {
|
||||
models.Entity[CatalogMetricsArtifactAttributes]
|
||||
}
|
||||
|
||||
type CatalogMetricsArtifactImpl = models.BaseEntity[CatalogMetricsArtifactAttributes]
|
||||
|
||||
type CatalogMetricsArtifactRepository interface {
|
||||
GetByID(id int32) (CatalogMetricsArtifact, error)
|
||||
List(listOptions CatalogMetricsArtifactListOptions) (*models.ListWrapper[CatalogMetricsArtifact], error)
|
||||
Save(metricsArtifact CatalogMetricsArtifact, parentResourceID *int32) (CatalogMetricsArtifact, error)
|
||||
}
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"github.com/kubeflow/model-registry/internal/db/filter"
|
||||
"github.com/kubeflow/model-registry/internal/db/models"
|
||||
)
|
||||
|
||||
type CatalogModelListOptions struct {
|
||||
models.Pagination
|
||||
Name *string
|
||||
ExternalID *string
|
||||
}
|
||||
|
||||
// GetRestEntityType implements the FilterApplier interface
|
||||
func (c *CatalogModelListOptions) GetRestEntityType() filter.RestEntityType {
|
||||
return "CatalogModel"
|
||||
}
|
||||
|
||||
type CatalogModelAttributes struct {
|
||||
Name *string
|
||||
ExternalID *string
|
||||
CreateTimeSinceEpoch *int64
|
||||
LastUpdateTimeSinceEpoch *int64
|
||||
}
|
||||
|
||||
type CatalogModel interface {
|
||||
models.Entity[CatalogModelAttributes]
|
||||
}
|
||||
|
||||
type CatalogModelImpl = models.BaseEntity[CatalogModelAttributes]
|
||||
|
||||
type CatalogModelRepository interface {
|
||||
GetByID(id int32) (CatalogModel, error)
|
||||
List(listOptions CatalogModelListOptions) (*models.ListWrapper[CatalogModel], error)
|
||||
Save(model CatalogModel) (CatalogModel, error)
|
||||
}
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"github.com/kubeflow/model-registry/internal/db/filter"
|
||||
"github.com/kubeflow/model-registry/internal/db/models"
|
||||
)
|
||||
|
||||
const CatalogModelArtifactType = "catalog-model-artifact"
|
||||
|
||||
type CatalogModelArtifactListOptions struct {
|
||||
models.Pagination
|
||||
Name *string
|
||||
ExternalID *string
|
||||
ParentResourceID *int32
|
||||
}
|
||||
|
||||
// GetRestEntityType implements the FilterApplier interface
|
||||
func (c *CatalogModelArtifactListOptions) GetRestEntityType() filter.RestEntityType {
|
||||
return filter.RestEntityModelArtifact // Reusing existing filter type
|
||||
}
|
||||
|
||||
type CatalogModelArtifactAttributes struct {
|
||||
Name *string
|
||||
URI *string
|
||||
ExternalID *string
|
||||
CreateTimeSinceEpoch *int64
|
||||
LastUpdateTimeSinceEpoch *int64
|
||||
}
|
||||
|
||||
type CatalogModelArtifact interface {
|
||||
models.Entity[CatalogModelArtifactAttributes]
|
||||
}
|
||||
|
||||
type CatalogModelArtifactImpl = models.BaseEntity[CatalogModelArtifactAttributes]
|
||||
|
||||
type CatalogModelArtifactRepository interface {
|
||||
GetByID(id int32) (CatalogModelArtifact, error)
|
||||
List(listOptions CatalogModelArtifactListOptions) (*models.ListWrapper[CatalogModelArtifact], error)
|
||||
Save(modelArtifact CatalogModelArtifact, parentResourceID *int32) (CatalogModelArtifact, error)
|
||||
}
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/kubeflow/model-registry/catalog/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/apiutils"
|
||||
dbmodels "github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/schema"
|
||||
"github.com/kubeflow/model-registry/internal/db/service"
|
||||
"github.com/kubeflow/model-registry/internal/db/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var ErrCatalogMetricsArtifactNotFound = errors.New("catalog metrics artifact by id not found")
|
||||
|
||||
type CatalogMetricsArtifactRepositoryImpl struct {
|
||||
*service.GenericRepository[models.CatalogMetricsArtifact, schema.Artifact, schema.ArtifactProperty, *models.CatalogMetricsArtifactListOptions]
|
||||
}
|
||||
|
||||
func NewCatalogMetricsArtifactRepository(db *gorm.DB, typeID int64) models.CatalogMetricsArtifactRepository {
|
||||
config := service.GenericRepositoryConfig[models.CatalogMetricsArtifact, schema.Artifact, schema.ArtifactProperty, *models.CatalogMetricsArtifactListOptions]{
|
||||
DB: db,
|
||||
TypeID: typeID,
|
||||
EntityToSchema: mapCatalogMetricsArtifactToArtifact,
|
||||
SchemaToEntity: mapDataLayerToCatalogMetricsArtifact,
|
||||
EntityToProperties: mapCatalogMetricsArtifactToArtifactProperties,
|
||||
NotFoundError: ErrCatalogMetricsArtifactNotFound,
|
||||
EntityName: "catalog metrics artifact",
|
||||
PropertyFieldName: "artifact_id",
|
||||
ApplyListFilters: applyCatalogMetricsArtifactListFilters,
|
||||
IsNewEntity: func(entity models.CatalogMetricsArtifact) bool { return entity.GetID() == nil },
|
||||
HasCustomProperties: func(entity models.CatalogMetricsArtifact) bool { return entity.GetCustomProperties() != nil },
|
||||
}
|
||||
|
||||
return &CatalogMetricsArtifactRepositoryImpl{
|
||||
GenericRepository: service.NewGenericRepository(config),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *CatalogMetricsArtifactRepositoryImpl) List(listOptions models.CatalogMetricsArtifactListOptions) (*dbmodels.ListWrapper[models.CatalogMetricsArtifact], error) {
|
||||
return r.GenericRepository.List(&listOptions)
|
||||
}
|
||||
|
||||
func (r *CatalogMetricsArtifactRepositoryImpl) Save(ma models.CatalogMetricsArtifact, parentResourceID *int32) (models.CatalogMetricsArtifact, error) {
|
||||
attr := ma.GetAttributes()
|
||||
if attr == nil {
|
||||
return ma, fmt.Errorf("invalid artifact: nil attributes")
|
||||
}
|
||||
|
||||
switch attr.MetricsType {
|
||||
case models.MetricsTypeAccuracy, models.MetricsTypePerformance:
|
||||
// OK
|
||||
default:
|
||||
return ma, fmt.Errorf("invalid artifact: unknown metrics type: %s", attr.MetricsType)
|
||||
}
|
||||
|
||||
return r.GenericRepository.Save(ma, parentResourceID)
|
||||
}
|
||||
|
||||
func applyCatalogMetricsArtifactListFilters(query *gorm.DB, listOptions *models.CatalogMetricsArtifactListOptions) *gorm.DB {
|
||||
if listOptions.Name != nil {
|
||||
query = query.Where("name LIKE ?", fmt.Sprintf("%%:%s", *listOptions.Name))
|
||||
} else if listOptions.ExternalID != nil {
|
||||
query = query.Where("external_id = ?", listOptions.ExternalID)
|
||||
}
|
||||
|
||||
if listOptions.ParentResourceID != nil {
|
||||
query = query.Joins(utils.BuildAttributionJoin(query)).
|
||||
Where(utils.GetColumnRef(query, &schema.Attribution{}, "context_id")+" = ?", listOptions.ParentResourceID)
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func mapCatalogMetricsArtifactToArtifact(catalogMetricsArtifact models.CatalogMetricsArtifact) schema.Artifact {
|
||||
if catalogMetricsArtifact == nil {
|
||||
return schema.Artifact{}
|
||||
}
|
||||
|
||||
artifact := schema.Artifact{
|
||||
ID: apiutils.ZeroIfNil(catalogMetricsArtifact.GetID()),
|
||||
TypeID: apiutils.ZeroIfNil(catalogMetricsArtifact.GetTypeID()),
|
||||
}
|
||||
|
||||
if catalogMetricsArtifact.GetAttributes() != nil {
|
||||
artifact.Name = catalogMetricsArtifact.GetAttributes().Name
|
||||
artifact.ExternalID = catalogMetricsArtifact.GetAttributes().ExternalID
|
||||
artifact.CreateTimeSinceEpoch = apiutils.ZeroIfNil(catalogMetricsArtifact.GetAttributes().CreateTimeSinceEpoch)
|
||||
artifact.LastUpdateTimeSinceEpoch = apiutils.ZeroIfNil(catalogMetricsArtifact.GetAttributes().LastUpdateTimeSinceEpoch)
|
||||
}
|
||||
|
||||
return artifact
|
||||
}
|
||||
|
||||
func mapCatalogMetricsArtifactToArtifactProperties(catalogMetricsArtifact models.CatalogMetricsArtifact, artifactID int32) []schema.ArtifactProperty {
|
||||
if catalogMetricsArtifact == nil {
|
||||
return []schema.ArtifactProperty{}
|
||||
}
|
||||
|
||||
properties := []schema.ArtifactProperty{}
|
||||
|
||||
// Add the metricsType as a property
|
||||
if catalogMetricsArtifact.GetAttributes() != nil {
|
||||
metricsTypeProp := dbmodels.Properties{
|
||||
Name: "metricsType",
|
||||
StringValue: apiutils.Of(string(catalogMetricsArtifact.GetAttributes().MetricsType)),
|
||||
}
|
||||
properties = append(properties, service.MapPropertiesToArtifactProperty(metricsTypeProp, artifactID, false))
|
||||
}
|
||||
|
||||
if catalogMetricsArtifact.GetProperties() != nil {
|
||||
for _, prop := range *catalogMetricsArtifact.GetProperties() {
|
||||
properties = append(properties, service.MapPropertiesToArtifactProperty(prop, artifactID, false))
|
||||
}
|
||||
}
|
||||
|
||||
if catalogMetricsArtifact.GetCustomProperties() != nil {
|
||||
for _, prop := range *catalogMetricsArtifact.GetCustomProperties() {
|
||||
properties = append(properties, service.MapPropertiesToArtifactProperty(prop, artifactID, true))
|
||||
}
|
||||
}
|
||||
|
||||
return properties
|
||||
}
|
||||
|
||||
func mapDataLayerToCatalogMetricsArtifact(artifact schema.Artifact, artProperties []schema.ArtifactProperty) models.CatalogMetricsArtifact {
|
||||
catalogMetricsArtifact := models.CatalogMetricsArtifactImpl{
|
||||
ID: &artifact.ID,
|
||||
TypeID: &artifact.TypeID,
|
||||
Attributes: &models.CatalogMetricsArtifactAttributes{
|
||||
Name: artifact.Name,
|
||||
ExternalID: artifact.ExternalID,
|
||||
CreateTimeSinceEpoch: &artifact.CreateTimeSinceEpoch,
|
||||
LastUpdateTimeSinceEpoch: &artifact.LastUpdateTimeSinceEpoch,
|
||||
},
|
||||
}
|
||||
|
||||
customProperties := []dbmodels.Properties{}
|
||||
properties := []dbmodels.Properties{}
|
||||
|
||||
for _, prop := range artProperties {
|
||||
mappedProperty := service.MapArtifactPropertyToProperties(prop)
|
||||
|
||||
// Extract metricsType from properties and set it as an attribute
|
||||
if mappedProperty.Name == "metricsType" && !prop.IsCustomProperty {
|
||||
if mappedProperty.StringValue != nil {
|
||||
catalogMetricsArtifact.Attributes.MetricsType = models.MetricsType(*mappedProperty.StringValue)
|
||||
}
|
||||
} else if prop.IsCustomProperty {
|
||||
customProperties = append(customProperties, mappedProperty)
|
||||
} else {
|
||||
properties = append(properties, mappedProperty)
|
||||
}
|
||||
}
|
||||
|
||||
catalogMetricsArtifact.CustomProperties = &customProperties
|
||||
catalogMetricsArtifact.Properties = &properties
|
||||
|
||||
return &catalogMetricsArtifact
|
||||
}
|
||||
|
|
@ -0,0 +1,461 @@
|
|||
package service_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kubeflow/model-registry/catalog/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/catalog/internal/db/service"
|
||||
"github.com/kubeflow/model-registry/internal/apiutils"
|
||||
dbmodels "github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/schema"
|
||||
"github.com/kubeflow/model-registry/internal/testutils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestCatalogMetricsArtifactRepository(t *testing.T) {
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
// Get the CatalogMetricsArtifact type ID
|
||||
typeID := getCatalogMetricsArtifactTypeID(t, sharedDB)
|
||||
repo := service.NewCatalogMetricsArtifactRepository(sharedDB, typeID)
|
||||
|
||||
// Also get CatalogModel type ID for creating parent entities
|
||||
catalogModelTypeID := getCatalogModelTypeID(t, sharedDB)
|
||||
catalogModelRepo := service.NewCatalogModelRepository(sharedDB, catalogModelTypeID)
|
||||
|
||||
t.Run("TestSave", func(t *testing.T) {
|
||||
// First create a catalog model for attribution
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(catalogModelTypeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("test-catalog-model-for-metrics"),
|
||||
ExternalID: apiutils.Of("catalog-model-metrics-ext-123"),
|
||||
},
|
||||
}
|
||||
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test creating a new catalog metrics artifact
|
||||
catalogMetricsArtifact := &models.CatalogMetricsArtifactImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogMetricsArtifactAttributes{
|
||||
Name: apiutils.Of("test-catalog-metrics-artifact"),
|
||||
ExternalID: apiutils.Of("catalog-metrics-ext-123"),
|
||||
MetricsType: models.MetricsTypeAccuracy,
|
||||
},
|
||||
Properties: &[]dbmodels.Properties{
|
||||
{
|
||||
Name: "description",
|
||||
StringValue: apiutils.Of("Test catalog metrics artifact description"),
|
||||
},
|
||||
},
|
||||
CustomProperties: &[]dbmodels.Properties{
|
||||
{
|
||||
Name: "custom-metrics-prop",
|
||||
StringValue: apiutils.Of("custom-metrics-value"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
saved, err := repo.Save(catalogMetricsArtifact, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, saved)
|
||||
require.NotNil(t, saved.GetID())
|
||||
assert.Equal(t, "test-catalog-metrics-artifact", *saved.GetAttributes().Name)
|
||||
assert.Equal(t, "catalog-metrics-ext-123", *saved.GetAttributes().ExternalID)
|
||||
assert.Equal(t, models.MetricsTypeAccuracy, saved.GetAttributes().MetricsType)
|
||||
|
||||
// Test updating the same catalog metrics artifact
|
||||
catalogMetricsArtifact.ID = saved.GetID()
|
||||
catalogMetricsArtifact.GetAttributes().Name = apiutils.Of("updated-catalog-metrics-artifact")
|
||||
catalogMetricsArtifact.GetAttributes().MetricsType = models.MetricsTypePerformance
|
||||
// Preserve CreateTimeSinceEpoch from the saved entity
|
||||
catalogMetricsArtifact.GetAttributes().CreateTimeSinceEpoch = saved.GetAttributes().CreateTimeSinceEpoch
|
||||
|
||||
updated, err := repo.Save(catalogMetricsArtifact, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
assert.Equal(t, *saved.GetID(), *updated.GetID())
|
||||
assert.Equal(t, "updated-catalog-metrics-artifact", *updated.GetAttributes().Name)
|
||||
assert.Equal(t, models.MetricsTypePerformance, updated.GetAttributes().MetricsType)
|
||||
})
|
||||
|
||||
t.Run("TestGetByID", func(t *testing.T) {
|
||||
// First create a catalog model
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(catalogModelTypeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("test-catalog-model-for-getbyid-metrics"),
|
||||
ExternalID: apiutils.Of("catalog-model-getbyid-metrics-ext"),
|
||||
},
|
||||
}
|
||||
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a catalog metrics artifact to retrieve
|
||||
catalogMetricsArtifact := &models.CatalogMetricsArtifactImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogMetricsArtifactAttributes{
|
||||
Name: apiutils.Of("get-test-catalog-metrics-artifact"),
|
||||
ExternalID: apiutils.Of("get-catalog-metrics-ext-123"),
|
||||
MetricsType: models.MetricsTypeAccuracy,
|
||||
},
|
||||
}
|
||||
|
||||
saved, err := repo.Save(catalogMetricsArtifact, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, saved.GetID())
|
||||
|
||||
// Test retrieving by ID
|
||||
retrieved, err := repo.GetByID(*saved.GetID())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrieved)
|
||||
assert.Equal(t, *saved.GetID(), *retrieved.GetID())
|
||||
assert.Equal(t, "get-test-catalog-metrics-artifact", *retrieved.GetAttributes().Name)
|
||||
assert.Equal(t, "get-catalog-metrics-ext-123", *retrieved.GetAttributes().ExternalID)
|
||||
assert.Equal(t, models.MetricsTypeAccuracy, retrieved.GetAttributes().MetricsType)
|
||||
|
||||
// Test retrieving non-existent ID
|
||||
_, err = repo.GetByID(99999)
|
||||
assert.ErrorIs(t, err, service.ErrCatalogMetricsArtifactNotFound)
|
||||
})
|
||||
|
||||
t.Run("TestList", func(t *testing.T) {
|
||||
// Create a catalog model for the artifacts
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(catalogModelTypeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("test-catalog-model-for-list-metrics"),
|
||||
ExternalID: apiutils.Of("catalog-model-list-metrics-ext"),
|
||||
},
|
||||
}
|
||||
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create multiple catalog metrics artifacts for listing
|
||||
testArtifacts := []*models.CatalogMetricsArtifactImpl{
|
||||
{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogMetricsArtifactAttributes{
|
||||
Name: apiutils.Of("list-catalog-metrics-artifact-1"),
|
||||
ExternalID: apiutils.Of("list-catalog-metrics-ext-1"),
|
||||
MetricsType: models.MetricsTypeAccuracy,
|
||||
},
|
||||
},
|
||||
{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogMetricsArtifactAttributes{
|
||||
Name: apiutils.Of("list-catalog-metrics-artifact-2"),
|
||||
ExternalID: apiutils.Of("list-catalog-metrics-ext-2"),
|
||||
MetricsType: models.MetricsTypePerformance,
|
||||
},
|
||||
},
|
||||
{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogMetricsArtifactAttributes{
|
||||
Name: apiutils.Of("list-catalog-metrics-artifact-3"),
|
||||
ExternalID: apiutils.Of("list-catalog-metrics-ext-3"),
|
||||
MetricsType: models.MetricsTypePerformance,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Save all test artifacts
|
||||
var savedArtifacts []models.CatalogMetricsArtifact
|
||||
for _, artifact := range testArtifacts {
|
||||
saved, err := repo.Save(artifact, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
savedArtifacts = append(savedArtifacts, saved)
|
||||
}
|
||||
|
||||
// Test listing all artifacts
|
||||
listOptions := models.CatalogMetricsArtifactListOptions{}
|
||||
result, err := repo.List(listOptions)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.GreaterOrEqual(t, len(result.Items), 3) // At least our 3 test artifacts
|
||||
|
||||
// Test filtering by name
|
||||
nameFilter := "list-catalog-metrics-artifact-1"
|
||||
listOptions = models.CatalogMetricsArtifactListOptions{
|
||||
Name: &nameFilter,
|
||||
}
|
||||
result, err = repo.List(listOptions)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
if len(result.Items) > 0 {
|
||||
assert.Equal(t, 1, len(result.Items))
|
||||
assert.Equal(t, "list-catalog-metrics-artifact-1", *result.Items[0].GetAttributes().Name)
|
||||
}
|
||||
|
||||
// Test filtering by external ID
|
||||
externalIDFilter := "list-catalog-metrics-ext-2"
|
||||
listOptions = models.CatalogMetricsArtifactListOptions{
|
||||
ExternalID: &externalIDFilter,
|
||||
}
|
||||
result, err = repo.List(listOptions)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
if len(result.Items) > 0 {
|
||||
assert.Equal(t, 1, len(result.Items))
|
||||
assert.Equal(t, "list-catalog-metrics-ext-2", *result.Items[0].GetAttributes().ExternalID)
|
||||
}
|
||||
|
||||
// Test filtering by parent resource ID (catalog model)
|
||||
listOptions = models.CatalogMetricsArtifactListOptions{
|
||||
ParentResourceID: savedCatalogModel.GetID(),
|
||||
}
|
||||
result, err = repo.List(listOptions)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.GreaterOrEqual(t, len(result.Items), 3) // Should find our 3 test artifacts
|
||||
})
|
||||
|
||||
t.Run("TestListWithPropertiesAndCustomProperties", func(t *testing.T) {
|
||||
// Create a catalog model
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(catalogModelTypeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("test-catalog-model-for-props-metrics"),
|
||||
ExternalID: apiutils.Of("catalog-model-props-metrics-ext"),
|
||||
},
|
||||
}
|
||||
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a catalog metrics artifact with both properties and custom properties
|
||||
catalogMetricsArtifact := &models.CatalogMetricsArtifactImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogMetricsArtifactAttributes{
|
||||
Name: apiutils.Of("props-test-catalog-metrics-artifact"),
|
||||
ExternalID: apiutils.Of("props-catalog-metrics-ext-123"),
|
||||
MetricsType: models.MetricsTypeAccuracy,
|
||||
},
|
||||
Properties: &[]dbmodels.Properties{
|
||||
{
|
||||
Name: "version",
|
||||
StringValue: apiutils.Of("1.0.0"),
|
||||
},
|
||||
{
|
||||
Name: "value",
|
||||
DoubleValue: apiutils.Of(0.95),
|
||||
},
|
||||
},
|
||||
CustomProperties: &[]dbmodels.Properties{
|
||||
{
|
||||
Name: "team",
|
||||
StringValue: apiutils.Of("catalog-metrics-team"),
|
||||
},
|
||||
{
|
||||
Name: "is_validated",
|
||||
BoolValue: apiutils.Of(true),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
saved, err := repo.Save(catalogMetricsArtifact, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, saved)
|
||||
|
||||
// Retrieve and verify properties
|
||||
retrieved, err := repo.GetByID(*saved.GetID())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrieved)
|
||||
|
||||
// Check that metricsType is properly set
|
||||
assert.Equal(t, models.MetricsTypeAccuracy, retrieved.GetAttributes().MetricsType)
|
||||
|
||||
// Check regular properties
|
||||
require.NotNil(t, retrieved.GetProperties())
|
||||
assert.Len(t, *retrieved.GetProperties(), 2)
|
||||
|
||||
// Check custom properties
|
||||
require.NotNil(t, retrieved.GetCustomProperties())
|
||||
assert.Len(t, *retrieved.GetCustomProperties(), 2)
|
||||
|
||||
// Verify specific properties exist
|
||||
properties := *retrieved.GetProperties()
|
||||
var foundVersion, foundValue bool
|
||||
for _, prop := range properties {
|
||||
switch prop.Name {
|
||||
case "version":
|
||||
foundVersion = true
|
||||
assert.Equal(t, "1.0.0", *prop.StringValue)
|
||||
case "value":
|
||||
foundValue = true
|
||||
assert.Equal(t, 0.95, *prop.DoubleValue)
|
||||
}
|
||||
}
|
||||
assert.True(t, foundVersion, "Should find version property")
|
||||
assert.True(t, foundValue, "Should find value property")
|
||||
|
||||
// Verify custom properties
|
||||
customProperties := *retrieved.GetCustomProperties()
|
||||
var foundTeam, foundIsValidated bool
|
||||
for _, prop := range customProperties {
|
||||
switch prop.Name {
|
||||
case "team":
|
||||
foundTeam = true
|
||||
assert.Equal(t, "catalog-metrics-team", *prop.StringValue)
|
||||
case "is_validated":
|
||||
foundIsValidated = true
|
||||
assert.Equal(t, true, *prop.BoolValue)
|
||||
}
|
||||
}
|
||||
assert.True(t, foundTeam, "Should find team custom property")
|
||||
assert.True(t, foundIsValidated, "Should find is_validated custom property")
|
||||
})
|
||||
|
||||
t.Run("TestSaveWithoutParentResource", func(t *testing.T) {
|
||||
// Test creating a catalog metrics artifact without parent resource attribution
|
||||
catalogMetricsArtifact := &models.CatalogMetricsArtifactImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogMetricsArtifactAttributes{
|
||||
Name: apiutils.Of("standalone-catalog-metrics-artifact"),
|
||||
ExternalID: apiutils.Of("standalone-catalog-metrics-ext"),
|
||||
MetricsType: models.MetricsTypeAccuracy,
|
||||
},
|
||||
Properties: &[]dbmodels.Properties{
|
||||
{
|
||||
Name: "description",
|
||||
StringValue: apiutils.Of("Standalone catalog metrics artifact without parent"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
saved, err := repo.Save(catalogMetricsArtifact, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, saved)
|
||||
require.NotNil(t, saved.GetID())
|
||||
assert.Equal(t, "standalone-catalog-metrics-artifact", *saved.GetAttributes().Name)
|
||||
assert.Equal(t, models.MetricsTypeAccuracy, saved.GetAttributes().MetricsType)
|
||||
|
||||
// Verify it can be retrieved
|
||||
retrieved, err := repo.GetByID(*saved.GetID())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "standalone-catalog-metrics-artifact", *retrieved.GetAttributes().Name)
|
||||
assert.Equal(t, models.MetricsTypeAccuracy, retrieved.GetAttributes().MetricsType)
|
||||
})
|
||||
|
||||
t.Run("TestListOrdering", func(t *testing.T) {
|
||||
// Create a catalog model
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(catalogModelTypeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("test-catalog-model-for-ordering-metrics"),
|
||||
ExternalID: apiutils.Of("catalog-model-ordering-metrics-ext"),
|
||||
},
|
||||
}
|
||||
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create artifacts sequentially with time delays to ensure deterministic ordering
|
||||
artifact1 := &models.CatalogMetricsArtifactImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogMetricsArtifactAttributes{
|
||||
Name: apiutils.Of("time-test-catalog-metrics-artifact-1"),
|
||||
ExternalID: apiutils.Of("time-catalog-metrics-ext-1"),
|
||||
MetricsType: models.MetricsTypeAccuracy,
|
||||
},
|
||||
}
|
||||
saved1, err := repo.Save(artifact1, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Small delay to ensure different timestamps
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
artifact2 := &models.CatalogMetricsArtifactImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogMetricsArtifactAttributes{
|
||||
Name: apiutils.Of("time-test-catalog-metrics-artifact-2"),
|
||||
ExternalID: apiutils.Of("time-catalog-metrics-ext-2"),
|
||||
MetricsType: models.MetricsTypePerformance,
|
||||
},
|
||||
}
|
||||
saved2, err := repo.Save(artifact2, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test ordering by CREATE_TIME
|
||||
listOptions := models.CatalogMetricsArtifactListOptions{
|
||||
Pagination: dbmodels.Pagination{
|
||||
OrderBy: apiutils.Of("CREATE_TIME"),
|
||||
},
|
||||
}
|
||||
|
||||
result, err := repo.List(listOptions)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Find our test artifacts in the results
|
||||
var foundArtifact1, foundArtifact2 models.CatalogMetricsArtifact
|
||||
var index1, index2 = -1, -1
|
||||
|
||||
for i, item := range result.Items {
|
||||
if *item.GetID() == *saved1.GetID() {
|
||||
foundArtifact1 = item
|
||||
index1 = i
|
||||
}
|
||||
if *item.GetID() == *saved2.GetID() {
|
||||
foundArtifact2 = item
|
||||
index2 = i
|
||||
}
|
||||
}
|
||||
|
||||
// Verify both artifacts were found and artifact1 comes before artifact2 (ascending order)
|
||||
require.NotEqual(t, -1, index1, "Artifact 1 should be found in results")
|
||||
require.NotEqual(t, -1, index2, "Artifact 2 should be found in results")
|
||||
assert.Less(t, index1, index2, "Artifact 1 should come before Artifact 2 when ordered by CREATE_TIME")
|
||||
assert.Less(t, *foundArtifact1.GetAttributes().CreateTimeSinceEpoch, *foundArtifact2.GetAttributes().CreateTimeSinceEpoch, "Artifact 1 should have earlier create time")
|
||||
})
|
||||
|
||||
t.Run("TestMetricsTypeField", func(t *testing.T) {
|
||||
// Test various metrics types
|
||||
metricsTypes := []models.MetricsType{models.MetricsTypeAccuracy, models.MetricsTypePerformance}
|
||||
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(catalogModelTypeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("test-catalog-model-for-metrics-types"),
|
||||
ExternalID: apiutils.Of("catalog-model-metrics-types-ext"),
|
||||
},
|
||||
}
|
||||
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i, metricsType := range metricsTypes {
|
||||
artifact := &models.CatalogMetricsArtifactImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogMetricsArtifactAttributes{
|
||||
Name: apiutils.Of(fmt.Sprintf("metrics-type-test-%d", i)),
|
||||
ExternalID: apiutils.Of(fmt.Sprintf("metrics-type-ext-%d", i)),
|
||||
MetricsType: metricsType,
|
||||
},
|
||||
}
|
||||
|
||||
saved, err := repo.Save(artifact, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, metricsType, saved.GetAttributes().MetricsType)
|
||||
|
||||
// Verify retrieval preserves metricsType
|
||||
retrieved, err := repo.GetByID(*saved.GetID())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, metricsType, retrieved.GetAttributes().MetricsType)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to get or create CatalogMetricsArtifact type ID
|
||||
func getCatalogMetricsArtifactTypeID(t *testing.T, db *gorm.DB) int64 {
|
||||
var typeRecord schema.Type
|
||||
err := db.Where("name = ?", service.CatalogMetricsArtifactTypeName).First(&typeRecord).Error
|
||||
if err != nil {
|
||||
require.NoError(t, err, "Failed to query CatalogMetricsArtifact type")
|
||||
}
|
||||
|
||||
return int64(typeRecord.ID)
|
||||
}
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/kubeflow/model-registry/catalog/internal/db/models"
|
||||
dbmodels "github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/schema"
|
||||
"github.com/kubeflow/model-registry/internal/db/service"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var ErrCatalogModelNotFound = errors.New("catalog model by id not found")
|
||||
|
||||
type CatalogModelRepositoryImpl struct {
|
||||
*service.GenericRepository[models.CatalogModel, schema.Context, schema.ContextProperty, *models.CatalogModelListOptions]
|
||||
}
|
||||
|
||||
func NewCatalogModelRepository(db *gorm.DB, typeID int64) models.CatalogModelRepository {
|
||||
config := service.GenericRepositoryConfig[models.CatalogModel, schema.Context, schema.ContextProperty, *models.CatalogModelListOptions]{
|
||||
DB: db,
|
||||
TypeID: typeID,
|
||||
EntityToSchema: mapCatalogModelToContext,
|
||||
SchemaToEntity: mapDataLayerToCatalogModel,
|
||||
EntityToProperties: mapCatalogModelToContextProperties,
|
||||
NotFoundError: ErrCatalogModelNotFound,
|
||||
EntityName: "catalog model",
|
||||
PropertyFieldName: "context_id",
|
||||
ApplyListFilters: applyCatalogModelListFilters,
|
||||
IsNewEntity: func(entity models.CatalogModel) bool { return entity.GetID() == nil },
|
||||
HasCustomProperties: func(entity models.CatalogModel) bool { return entity.GetCustomProperties() != nil },
|
||||
}
|
||||
|
||||
return &CatalogModelRepositoryImpl{
|
||||
GenericRepository: service.NewGenericRepository(config),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *CatalogModelRepositoryImpl) Save(model models.CatalogModel) (models.CatalogModel, error) {
|
||||
return r.GenericRepository.Save(model, nil)
|
||||
}
|
||||
|
||||
func (r *CatalogModelRepositoryImpl) List(listOptions models.CatalogModelListOptions) (*dbmodels.ListWrapper[models.CatalogModel], error) {
|
||||
return r.GenericRepository.List(&listOptions)
|
||||
}
|
||||
|
||||
func applyCatalogModelListFilters(query *gorm.DB, listOptions *models.CatalogModelListOptions) *gorm.DB {
|
||||
if listOptions.Name != nil {
|
||||
query = query.Where("name LIKE ?", listOptions.Name)
|
||||
} else if listOptions.ExternalID != nil {
|
||||
query = query.Where("external_id = ?", listOptions.ExternalID)
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func mapCatalogModelToContext(model models.CatalogModel) schema.Context {
|
||||
attrs := model.GetAttributes()
|
||||
context := schema.Context{
|
||||
TypeID: *model.GetTypeID(),
|
||||
}
|
||||
|
||||
if model.GetID() != nil {
|
||||
context.ID = *model.GetID()
|
||||
}
|
||||
|
||||
if attrs != nil {
|
||||
if attrs.Name != nil {
|
||||
context.Name = *attrs.Name
|
||||
}
|
||||
context.ExternalID = attrs.ExternalID
|
||||
if attrs.CreateTimeSinceEpoch != nil {
|
||||
context.CreateTimeSinceEpoch = *attrs.CreateTimeSinceEpoch
|
||||
}
|
||||
if attrs.LastUpdateTimeSinceEpoch != nil {
|
||||
context.LastUpdateTimeSinceEpoch = *attrs.LastUpdateTimeSinceEpoch
|
||||
}
|
||||
}
|
||||
|
||||
return context
|
||||
}
|
||||
|
||||
func mapCatalogModelToContextProperties(model models.CatalogModel, contextID int32) []schema.ContextProperty {
|
||||
var properties []schema.ContextProperty
|
||||
|
||||
if model.GetProperties() != nil {
|
||||
for _, prop := range *model.GetProperties() {
|
||||
properties = append(properties, service.MapPropertiesToContextProperty(prop, contextID, false))
|
||||
}
|
||||
}
|
||||
|
||||
if model.GetCustomProperties() != nil {
|
||||
for _, prop := range *model.GetCustomProperties() {
|
||||
properties = append(properties, service.MapPropertiesToContextProperty(prop, contextID, true))
|
||||
}
|
||||
}
|
||||
|
||||
return properties
|
||||
}
|
||||
|
||||
func mapDataLayerToCatalogModel(modelCtx schema.Context, propertiesCtx []schema.ContextProperty) models.CatalogModel {
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
ID: &modelCtx.ID,
|
||||
TypeID: &modelCtx.TypeID,
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: &modelCtx.Name,
|
||||
ExternalID: modelCtx.ExternalID,
|
||||
CreateTimeSinceEpoch: &modelCtx.CreateTimeSinceEpoch,
|
||||
LastUpdateTimeSinceEpoch: &modelCtx.LastUpdateTimeSinceEpoch,
|
||||
},
|
||||
}
|
||||
|
||||
properties := []dbmodels.Properties{}
|
||||
customProperties := []dbmodels.Properties{}
|
||||
|
||||
for _, prop := range propertiesCtx {
|
||||
mappedProperty := service.MapContextPropertyToProperties(prop)
|
||||
|
||||
if prop.IsCustomProperty {
|
||||
customProperties = append(customProperties, mappedProperty)
|
||||
} else {
|
||||
properties = append(properties, mappedProperty)
|
||||
}
|
||||
}
|
||||
|
||||
catalogModel.Properties = &properties
|
||||
catalogModel.CustomProperties = &customProperties
|
||||
|
||||
return catalogModel
|
||||
}
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/kubeflow/model-registry/catalog/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/apiutils"
|
||||
dbmodels "github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/schema"
|
||||
"github.com/kubeflow/model-registry/internal/db/service"
|
||||
"github.com/kubeflow/model-registry/internal/db/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var ErrCatalogModelArtifactNotFound = errors.New("catalog model artifact by id not found")
|
||||
|
||||
type CatalogModelArtifactRepositoryImpl struct {
|
||||
*service.GenericRepository[models.CatalogModelArtifact, schema.Artifact, schema.ArtifactProperty, *models.CatalogModelArtifactListOptions]
|
||||
}
|
||||
|
||||
func NewCatalogModelArtifactRepository(db *gorm.DB, typeID int64) models.CatalogModelArtifactRepository {
|
||||
config := service.GenericRepositoryConfig[models.CatalogModelArtifact, schema.Artifact, schema.ArtifactProperty, *models.CatalogModelArtifactListOptions]{
|
||||
DB: db,
|
||||
TypeID: typeID,
|
||||
EntityToSchema: mapCatalogModelArtifactToArtifact,
|
||||
SchemaToEntity: mapDataLayerToCatalogModelArtifact,
|
||||
EntityToProperties: mapCatalogModelArtifactToArtifactProperties,
|
||||
NotFoundError: ErrCatalogModelArtifactNotFound,
|
||||
EntityName: "catalog model artifact",
|
||||
PropertyFieldName: "artifact_id",
|
||||
ApplyListFilters: applyCatalogModelArtifactListFilters,
|
||||
IsNewEntity: func(entity models.CatalogModelArtifact) bool { return entity.GetID() == nil },
|
||||
HasCustomProperties: func(entity models.CatalogModelArtifact) bool { return entity.GetCustomProperties() != nil },
|
||||
}
|
||||
|
||||
return &CatalogModelArtifactRepositoryImpl{
|
||||
GenericRepository: service.NewGenericRepository(config),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *CatalogModelArtifactRepositoryImpl) List(listOptions models.CatalogModelArtifactListOptions) (*dbmodels.ListWrapper[models.CatalogModelArtifact], error) {
|
||||
return r.GenericRepository.List(&listOptions)
|
||||
}
|
||||
|
||||
func applyCatalogModelArtifactListFilters(query *gorm.DB, listOptions *models.CatalogModelArtifactListOptions) *gorm.DB {
|
||||
if listOptions.Name != nil {
|
||||
query = query.Where("name LIKE ?", fmt.Sprintf("%%:%s", *listOptions.Name))
|
||||
} else if listOptions.ExternalID != nil {
|
||||
query = query.Where("external_id = ?", listOptions.ExternalID)
|
||||
}
|
||||
|
||||
if listOptions.ParentResourceID != nil {
|
||||
query = query.Joins(utils.BuildAttributionJoin(query)).
|
||||
Where(utils.GetColumnRef(query, &schema.Attribution{}, "context_id")+" = ?", listOptions.ParentResourceID)
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func mapCatalogModelArtifactToArtifact(catalogModelArtifact models.CatalogModelArtifact) schema.Artifact {
|
||||
if catalogModelArtifact == nil {
|
||||
return schema.Artifact{}
|
||||
}
|
||||
|
||||
artifact := schema.Artifact{
|
||||
ID: apiutils.ZeroIfNil(catalogModelArtifact.GetID()),
|
||||
TypeID: apiutils.ZeroIfNil(catalogModelArtifact.GetTypeID()),
|
||||
}
|
||||
|
||||
if catalogModelArtifact.GetAttributes() != nil {
|
||||
artifact.Name = catalogModelArtifact.GetAttributes().Name
|
||||
artifact.URI = catalogModelArtifact.GetAttributes().URI
|
||||
artifact.ExternalID = catalogModelArtifact.GetAttributes().ExternalID
|
||||
artifact.CreateTimeSinceEpoch = apiutils.ZeroIfNil(catalogModelArtifact.GetAttributes().CreateTimeSinceEpoch)
|
||||
artifact.LastUpdateTimeSinceEpoch = apiutils.ZeroIfNil(catalogModelArtifact.GetAttributes().LastUpdateTimeSinceEpoch)
|
||||
}
|
||||
|
||||
return artifact
|
||||
}
|
||||
|
||||
func mapCatalogModelArtifactToArtifactProperties(catalogModelArtifact models.CatalogModelArtifact, artifactID int32) []schema.ArtifactProperty {
|
||||
if catalogModelArtifact == nil {
|
||||
return []schema.ArtifactProperty{}
|
||||
}
|
||||
|
||||
properties := []schema.ArtifactProperty{}
|
||||
|
||||
if catalogModelArtifact.GetProperties() != nil {
|
||||
for _, prop := range *catalogModelArtifact.GetProperties() {
|
||||
properties = append(properties, service.MapPropertiesToArtifactProperty(prop, artifactID, false))
|
||||
}
|
||||
}
|
||||
|
||||
if catalogModelArtifact.GetCustomProperties() != nil {
|
||||
for _, prop := range *catalogModelArtifact.GetCustomProperties() {
|
||||
properties = append(properties, service.MapPropertiesToArtifactProperty(prop, artifactID, true))
|
||||
}
|
||||
}
|
||||
|
||||
return properties
|
||||
}
|
||||
|
||||
func mapDataLayerToCatalogModelArtifact(artifact schema.Artifact, artProperties []schema.ArtifactProperty) models.CatalogModelArtifact {
|
||||
catalogModelArtifact := models.CatalogModelArtifactImpl{
|
||||
ID: &artifact.ID,
|
||||
TypeID: &artifact.TypeID,
|
||||
Attributes: &models.CatalogModelArtifactAttributes{
|
||||
Name: artifact.Name,
|
||||
URI: artifact.URI,
|
||||
ExternalID: artifact.ExternalID,
|
||||
CreateTimeSinceEpoch: &artifact.CreateTimeSinceEpoch,
|
||||
LastUpdateTimeSinceEpoch: &artifact.LastUpdateTimeSinceEpoch,
|
||||
},
|
||||
}
|
||||
|
||||
customProperties := []dbmodels.Properties{}
|
||||
properties := []dbmodels.Properties{}
|
||||
|
||||
for _, prop := range artProperties {
|
||||
if prop.IsCustomProperty {
|
||||
customProperties = append(customProperties, service.MapArtifactPropertyToProperties(prop))
|
||||
} else {
|
||||
properties = append(properties, service.MapArtifactPropertyToProperties(prop))
|
||||
}
|
||||
}
|
||||
|
||||
catalogModelArtifact.CustomProperties = &customProperties
|
||||
catalogModelArtifact.Properties = &properties
|
||||
|
||||
return &catalogModelArtifact
|
||||
}
|
||||
|
|
@ -0,0 +1,465 @@
|
|||
package service_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kubeflow/model-registry/catalog/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/catalog/internal/db/service"
|
||||
"github.com/kubeflow/model-registry/internal/apiutils"
|
||||
dbmodels "github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/schema"
|
||||
"github.com/kubeflow/model-registry/internal/testutils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestCatalogModelArtifactRepository(t *testing.T) {
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
// Get the CatalogModelArtifact type ID
|
||||
typeID := getCatalogModelArtifactTypeID(t, sharedDB)
|
||||
repo := service.NewCatalogModelArtifactRepository(sharedDB, typeID)
|
||||
|
||||
// Also get CatalogModel type ID for creating parent entities
|
||||
catalogModelTypeID := getCatalogModelTypeID(t, sharedDB)
|
||||
catalogModelRepo := service.NewCatalogModelRepository(sharedDB, catalogModelTypeID)
|
||||
|
||||
t.Run("TestSave", func(t *testing.T) {
|
||||
// First create a catalog model for attribution
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(catalogModelTypeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("test-catalog-model-for-artifact"),
|
||||
ExternalID: apiutils.Of("catalog-model-ext-123"),
|
||||
},
|
||||
}
|
||||
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test creating a new catalog model artifact
|
||||
catalogModelArtifact := &models.CatalogModelArtifactImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelArtifactAttributes{
|
||||
Name: apiutils.Of("test-catalog-model-artifact"),
|
||||
ExternalID: apiutils.Of("catalog-artifact-ext-123"),
|
||||
URI: apiutils.Of("s3://catalog-bucket/model.pkl"),
|
||||
},
|
||||
Properties: &[]dbmodels.Properties{
|
||||
{
|
||||
Name: "description",
|
||||
StringValue: apiutils.Of("Test catalog model artifact description"),
|
||||
},
|
||||
},
|
||||
CustomProperties: &[]dbmodels.Properties{
|
||||
{
|
||||
Name: "custom-catalog-prop",
|
||||
StringValue: apiutils.Of("custom-catalog-value"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
saved, err := repo.Save(catalogModelArtifact, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, saved)
|
||||
require.NotNil(t, saved.GetID())
|
||||
assert.Equal(t, "test-catalog-model-artifact", *saved.GetAttributes().Name)
|
||||
assert.Equal(t, "catalog-artifact-ext-123", *saved.GetAttributes().ExternalID)
|
||||
assert.Equal(t, "s3://catalog-bucket/model.pkl", *saved.GetAttributes().URI)
|
||||
|
||||
// Test updating the same catalog model artifact
|
||||
catalogModelArtifact.ID = saved.GetID()
|
||||
catalogModelArtifact.GetAttributes().Name = apiutils.Of("updated-catalog-model-artifact")
|
||||
catalogModelArtifact.GetAttributes().URI = apiutils.Of("s3://catalog-bucket/updated-model.pkl")
|
||||
// Preserve CreateTimeSinceEpoch from the saved entity
|
||||
catalogModelArtifact.GetAttributes().CreateTimeSinceEpoch = saved.GetAttributes().CreateTimeSinceEpoch
|
||||
|
||||
updated, err := repo.Save(catalogModelArtifact, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
assert.Equal(t, *saved.GetID(), *updated.GetID())
|
||||
assert.Equal(t, "updated-catalog-model-artifact", *updated.GetAttributes().Name)
|
||||
assert.Equal(t, "s3://catalog-bucket/updated-model.pkl", *updated.GetAttributes().URI)
|
||||
})
|
||||
|
||||
t.Run("TestGetByID", func(t *testing.T) {
|
||||
// First create a catalog model
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(catalogModelTypeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("test-catalog-model-for-getbyid"),
|
||||
ExternalID: apiutils.Of("catalog-model-getbyid-ext"),
|
||||
},
|
||||
}
|
||||
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a catalog model artifact to retrieve
|
||||
catalogModelArtifact := &models.CatalogModelArtifactImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelArtifactAttributes{
|
||||
Name: apiutils.Of("get-test-catalog-model-artifact"),
|
||||
ExternalID: apiutils.Of("get-catalog-artifact-ext-123"),
|
||||
URI: apiutils.Of("s3://catalog-bucket/get-model.pkl"),
|
||||
},
|
||||
}
|
||||
|
||||
saved, err := repo.Save(catalogModelArtifact, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, saved.GetID())
|
||||
|
||||
// Test retrieving by ID
|
||||
retrieved, err := repo.GetByID(*saved.GetID())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrieved)
|
||||
assert.Equal(t, *saved.GetID(), *retrieved.GetID())
|
||||
assert.Equal(t, "get-test-catalog-model-artifact", *retrieved.GetAttributes().Name)
|
||||
assert.Equal(t, "get-catalog-artifact-ext-123", *retrieved.GetAttributes().ExternalID)
|
||||
assert.Equal(t, "s3://catalog-bucket/get-model.pkl", *retrieved.GetAttributes().URI)
|
||||
|
||||
// Test retrieving non-existent ID
|
||||
_, err = repo.GetByID(99999)
|
||||
assert.ErrorIs(t, err, service.ErrCatalogModelArtifactNotFound)
|
||||
})
|
||||
|
||||
t.Run("TestList", func(t *testing.T) {
|
||||
// Create a catalog model for the artifacts
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(catalogModelTypeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("test-catalog-model-for-list"),
|
||||
ExternalID: apiutils.Of("catalog-model-list-ext"),
|
||||
},
|
||||
}
|
||||
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create multiple catalog model artifacts for listing
|
||||
testArtifacts := []*models.CatalogModelArtifactImpl{
|
||||
{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelArtifactAttributes{
|
||||
Name: apiutils.Of("list-catalog-artifact-1"),
|
||||
ExternalID: apiutils.Of("list-catalog-artifact-ext-1"),
|
||||
URI: apiutils.Of("s3://catalog-bucket/list-model-1.pkl"),
|
||||
},
|
||||
},
|
||||
{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelArtifactAttributes{
|
||||
Name: apiutils.Of("list-catalog-artifact-2"),
|
||||
ExternalID: apiutils.Of("list-catalog-artifact-ext-2"),
|
||||
URI: apiutils.Of("s3://catalog-bucket/list-model-2.pkl"),
|
||||
},
|
||||
},
|
||||
{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelArtifactAttributes{
|
||||
Name: apiutils.Of("list-catalog-artifact-3"),
|
||||
ExternalID: apiutils.Of("list-catalog-artifact-ext-3"),
|
||||
URI: apiutils.Of("s3://catalog-bucket/list-model-3.pkl"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Save all test artifacts
|
||||
var savedArtifacts []models.CatalogModelArtifact
|
||||
for _, artifact := range testArtifacts {
|
||||
saved, err := repo.Save(artifact, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
savedArtifacts = append(savedArtifacts, saved)
|
||||
}
|
||||
|
||||
// Test listing all artifacts
|
||||
listOptions := models.CatalogModelArtifactListOptions{}
|
||||
result, err := repo.List(listOptions)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.GreaterOrEqual(t, len(result.Items), 3) // At least our 3 test artifacts
|
||||
|
||||
// Test filtering by name
|
||||
nameFilter := "list-catalog-artifact-1"
|
||||
listOptions = models.CatalogModelArtifactListOptions{
|
||||
Name: &nameFilter,
|
||||
}
|
||||
result, err = repo.List(listOptions)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
if len(result.Items) > 0 {
|
||||
assert.Equal(t, 1, len(result.Items))
|
||||
assert.Equal(t, "list-catalog-artifact-1", *result.Items[0].GetAttributes().Name)
|
||||
}
|
||||
|
||||
// Test filtering by external ID
|
||||
externalIDFilter := "list-catalog-artifact-ext-2"
|
||||
listOptions = models.CatalogModelArtifactListOptions{
|
||||
ExternalID: &externalIDFilter,
|
||||
}
|
||||
result, err = repo.List(listOptions)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
if len(result.Items) > 0 {
|
||||
assert.Equal(t, 1, len(result.Items))
|
||||
assert.Equal(t, "list-catalog-artifact-ext-2", *result.Items[0].GetAttributes().ExternalID)
|
||||
}
|
||||
|
||||
// Test filtering by parent resource ID (catalog model)
|
||||
listOptions = models.CatalogModelArtifactListOptions{
|
||||
ParentResourceID: savedCatalogModel.GetID(),
|
||||
}
|
||||
result, err = repo.List(listOptions)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.GreaterOrEqual(t, len(result.Items), 3) // Should find our 3 test artifacts
|
||||
})
|
||||
|
||||
t.Run("TestListWithPropertiesAndCustomProperties", func(t *testing.T) {
|
||||
// Create a catalog model
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(catalogModelTypeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("test-catalog-model-for-props"),
|
||||
ExternalID: apiutils.Of("catalog-model-props-ext"),
|
||||
},
|
||||
}
|
||||
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a catalog model artifact with both properties and custom properties
|
||||
catalogModelArtifact := &models.CatalogModelArtifactImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelArtifactAttributes{
|
||||
Name: apiutils.Of("props-test-catalog-artifact"),
|
||||
ExternalID: apiutils.Of("props-catalog-artifact-ext-123"),
|
||||
URI: apiutils.Of("s3://catalog-bucket/props-model.pkl"),
|
||||
},
|
||||
Properties: &[]dbmodels.Properties{
|
||||
{
|
||||
Name: "version",
|
||||
StringValue: apiutils.Of("1.0.0"),
|
||||
},
|
||||
{
|
||||
Name: "size_bytes",
|
||||
IntValue: apiutils.Of(int32(2048000)),
|
||||
},
|
||||
},
|
||||
CustomProperties: &[]dbmodels.Properties{
|
||||
{
|
||||
Name: "team",
|
||||
StringValue: apiutils.Of("catalog-ml-team"),
|
||||
},
|
||||
{
|
||||
Name: "is_public",
|
||||
BoolValue: apiutils.Of(true),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
saved, err := repo.Save(catalogModelArtifact, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, saved)
|
||||
|
||||
// Retrieve and verify properties
|
||||
retrieved, err := repo.GetByID(*saved.GetID())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrieved)
|
||||
|
||||
// Check regular properties
|
||||
require.NotNil(t, retrieved.GetProperties())
|
||||
assert.Len(t, *retrieved.GetProperties(), 2)
|
||||
|
||||
// Check custom properties
|
||||
require.NotNil(t, retrieved.GetCustomProperties())
|
||||
assert.Len(t, *retrieved.GetCustomProperties(), 2)
|
||||
|
||||
// Verify specific properties exist
|
||||
properties := *retrieved.GetProperties()
|
||||
var foundVersion, foundSizeBytes bool
|
||||
for _, prop := range properties {
|
||||
switch prop.Name {
|
||||
case "version":
|
||||
foundVersion = true
|
||||
assert.Equal(t, "1.0.0", *prop.StringValue)
|
||||
case "size_bytes":
|
||||
foundSizeBytes = true
|
||||
assert.Equal(t, int32(2048000), *prop.IntValue)
|
||||
}
|
||||
}
|
||||
assert.True(t, foundVersion, "Should find version property")
|
||||
assert.True(t, foundSizeBytes, "Should find size_bytes property")
|
||||
|
||||
// Verify custom properties
|
||||
customProperties := *retrieved.GetCustomProperties()
|
||||
var foundTeam, foundIsPublic bool
|
||||
for _, prop := range customProperties {
|
||||
switch prop.Name {
|
||||
case "team":
|
||||
foundTeam = true
|
||||
assert.Equal(t, "catalog-ml-team", *prop.StringValue)
|
||||
case "is_public":
|
||||
foundIsPublic = true
|
||||
assert.Equal(t, true, *prop.BoolValue)
|
||||
}
|
||||
}
|
||||
assert.True(t, foundTeam, "Should find team custom property")
|
||||
assert.True(t, foundIsPublic, "Should find is_public custom property")
|
||||
})
|
||||
|
||||
t.Run("TestSaveWithoutParentResource", func(t *testing.T) {
|
||||
// Test creating a catalog model artifact without parent resource attribution
|
||||
catalogModelArtifact := &models.CatalogModelArtifactImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelArtifactAttributes{
|
||||
Name: apiutils.Of("standalone-catalog-artifact"),
|
||||
ExternalID: apiutils.Of("standalone-catalog-artifact-ext"),
|
||||
URI: apiutils.Of("s3://catalog-bucket/standalone-model.pkl"),
|
||||
},
|
||||
Properties: &[]dbmodels.Properties{
|
||||
{
|
||||
Name: "description",
|
||||
StringValue: apiutils.Of("Standalone catalog artifact without parent"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
saved, err := repo.Save(catalogModelArtifact, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, saved)
|
||||
require.NotNil(t, saved.GetID())
|
||||
assert.Equal(t, "standalone-catalog-artifact", *saved.GetAttributes().Name)
|
||||
assert.Equal(t, "s3://catalog-bucket/standalone-model.pkl", *saved.GetAttributes().URI)
|
||||
|
||||
// Verify it can be retrieved
|
||||
retrieved, err := repo.GetByID(*saved.GetID())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "standalone-catalog-artifact", *retrieved.GetAttributes().Name)
|
||||
})
|
||||
|
||||
t.Run("TestListOrdering", func(t *testing.T) {
|
||||
// Create a catalog model
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(catalogModelTypeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("test-catalog-model-for-ordering"),
|
||||
ExternalID: apiutils.Of("catalog-model-ordering-ext"),
|
||||
},
|
||||
}
|
||||
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create artifacts sequentially with time delays to ensure deterministic ordering
|
||||
artifact1 := &models.CatalogModelArtifactImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelArtifactAttributes{
|
||||
Name: apiutils.Of("time-test-catalog-artifact-1"),
|
||||
ExternalID: apiutils.Of("time-catalog-artifact-ext-1"),
|
||||
URI: apiutils.Of("s3://catalog-bucket/time-model-1.pkl"),
|
||||
},
|
||||
}
|
||||
saved1, err := repo.Save(artifact1, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Small delay to ensure different timestamps
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
artifact2 := &models.CatalogModelArtifactImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelArtifactAttributes{
|
||||
Name: apiutils.Of("time-test-catalog-artifact-2"),
|
||||
ExternalID: apiutils.Of("time-catalog-artifact-ext-2"),
|
||||
URI: apiutils.Of("s3://catalog-bucket/time-model-2.pkl"),
|
||||
},
|
||||
}
|
||||
saved2, err := repo.Save(artifact2, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test ordering by CREATE_TIME
|
||||
listOptions := models.CatalogModelArtifactListOptions{
|
||||
Pagination: dbmodels.Pagination{
|
||||
OrderBy: apiutils.Of("CREATE_TIME"),
|
||||
},
|
||||
}
|
||||
|
||||
result, err := repo.List(listOptions)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Find our test artifacts in the results
|
||||
var foundArtifact1, foundArtifact2 models.CatalogModelArtifact
|
||||
var index1, index2 = -1, -1
|
||||
|
||||
for i, item := range result.Items {
|
||||
if *item.GetID() == *saved1.GetID() {
|
||||
foundArtifact1 = item
|
||||
index1 = i
|
||||
}
|
||||
if *item.GetID() == *saved2.GetID() {
|
||||
foundArtifact2 = item
|
||||
index2 = i
|
||||
}
|
||||
}
|
||||
|
||||
// Verify both artifacts were found and artifact1 comes before artifact2 (ascending order)
|
||||
require.NotEqual(t, -1, index1, "Artifact 1 should be found in results")
|
||||
require.NotEqual(t, -1, index2, "Artifact 2 should be found in results")
|
||||
assert.Less(t, index1, index2, "Artifact 1 should come before Artifact 2 when ordered by CREATE_TIME")
|
||||
assert.Less(t, *foundArtifact1.GetAttributes().CreateTimeSinceEpoch, *foundArtifact2.GetAttributes().CreateTimeSinceEpoch, "Artifact 1 should have earlier create time")
|
||||
})
|
||||
|
||||
t.Run("TestListPagination", func(t *testing.T) {
|
||||
// Create a catalog model
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(catalogModelTypeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("test-catalog-model-for-pagination"),
|
||||
ExternalID: apiutils.Of("catalog-model-pagination-ext"),
|
||||
},
|
||||
}
|
||||
savedCatalogModel, err := catalogModelRepo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create multiple artifacts for pagination testing
|
||||
for i := 0; i < 5; i++ {
|
||||
artifact := &models.CatalogModelArtifactImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelArtifactAttributes{
|
||||
Name: apiutils.Of(fmt.Sprintf("pagination-artifact-%d", i)),
|
||||
ExternalID: apiutils.Of(fmt.Sprintf("pagination-artifact-ext-%d", i)),
|
||||
URI: apiutils.Of(fmt.Sprintf("s3://catalog-bucket/pagination-model-%d.pkl", i)),
|
||||
},
|
||||
}
|
||||
_, err := repo.Save(artifact, savedCatalogModel.GetID())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Test pagination with page size
|
||||
pageSize := int32(2)
|
||||
listOptions := models.CatalogModelArtifactListOptions{
|
||||
ParentResourceID: savedCatalogModel.GetID(),
|
||||
Pagination: dbmodels.Pagination{
|
||||
PageSize: &pageSize,
|
||||
OrderBy: apiutils.Of("ID"),
|
||||
},
|
||||
}
|
||||
|
||||
result, err := repo.List(listOptions)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.LessOrEqual(t, len(result.Items), 2, "Should respect page size limit")
|
||||
assert.GreaterOrEqual(t, len(result.Items), 1, "Should return at least one item")
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to get or create CatalogModelArtifact type ID
|
||||
func getCatalogModelArtifactTypeID(t *testing.T, db *gorm.DB) int64 {
|
||||
var typeRecord schema.Type
|
||||
err := db.Where("name = ?", service.CatalogModelArtifactTypeName).First(&typeRecord).Error
|
||||
if err != nil {
|
||||
require.NoError(t, err, "Failed to query CatalogModelArtifact type")
|
||||
}
|
||||
|
||||
return int64(typeRecord.ID)
|
||||
}
|
||||
|
|
@ -0,0 +1,209 @@
|
|||
package service_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/kubeflow/model-registry/catalog/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/catalog/internal/db/service"
|
||||
"github.com/kubeflow/model-registry/internal/apiutils"
|
||||
dbmodels "github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/schema"
|
||||
"github.com/kubeflow/model-registry/internal/testutils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestCatalogModelRepository(t *testing.T) {
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
// Create or get the CatalogModel type ID
|
||||
typeID := getCatalogModelTypeID(t, sharedDB)
|
||||
repo := service.NewCatalogModelRepository(sharedDB, typeID)
|
||||
|
||||
t.Run("TestSave", func(t *testing.T) {
|
||||
// Test creating a new catalog model
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("test-catalog-model"),
|
||||
ExternalID: apiutils.Of("catalog-ext-123"),
|
||||
},
|
||||
Properties: &[]dbmodels.Properties{
|
||||
{
|
||||
Name: "description",
|
||||
StringValue: apiutils.Of("Test catalog model description"),
|
||||
},
|
||||
},
|
||||
CustomProperties: &[]dbmodels.Properties{
|
||||
{
|
||||
Name: "custom-prop",
|
||||
StringValue: apiutils.Of("custom-value"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
saved, err := repo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, saved)
|
||||
require.NotNil(t, saved.GetID())
|
||||
assert.Equal(t, "test-catalog-model", *saved.GetAttributes().Name)
|
||||
assert.Equal(t, "catalog-ext-123", *saved.GetAttributes().ExternalID)
|
||||
|
||||
// Test updating the same model
|
||||
catalogModel.ID = saved.GetID()
|
||||
catalogModel.GetAttributes().Name = apiutils.Of("updated-catalog-model")
|
||||
// Preserve CreateTimeSinceEpoch from the saved entity
|
||||
catalogModel.GetAttributes().CreateTimeSinceEpoch = saved.GetAttributes().CreateTimeSinceEpoch
|
||||
|
||||
updated, err := repo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
assert.Equal(t, *saved.GetID(), *updated.GetID())
|
||||
assert.Equal(t, "updated-catalog-model", *updated.GetAttributes().Name)
|
||||
})
|
||||
|
||||
t.Run("TestGetByID", func(t *testing.T) {
|
||||
// First create a model to retrieve
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("get-test-catalog-model"),
|
||||
ExternalID: apiutils.Of("get-catalog-ext-123"),
|
||||
},
|
||||
}
|
||||
|
||||
saved, err := repo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, saved.GetID())
|
||||
|
||||
// Test retrieving by ID
|
||||
retrieved, err := repo.GetByID(*saved.GetID())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrieved)
|
||||
assert.Equal(t, *saved.GetID(), *retrieved.GetID())
|
||||
assert.Equal(t, "get-test-catalog-model", *retrieved.GetAttributes().Name)
|
||||
assert.Equal(t, "get-catalog-ext-123", *retrieved.GetAttributes().ExternalID)
|
||||
|
||||
// Test retrieving non-existent ID
|
||||
_, err = repo.GetByID(99999)
|
||||
assert.ErrorIs(t, err, service.ErrCatalogModelNotFound)
|
||||
})
|
||||
|
||||
t.Run("TestList", func(t *testing.T) {
|
||||
// Create multiple models for listing
|
||||
testModels := []*models.CatalogModelImpl{
|
||||
{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("list-catalog-model-1"),
|
||||
ExternalID: apiutils.Of("list-catalog-ext-1"),
|
||||
},
|
||||
},
|
||||
{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("list-catalog-model-2"),
|
||||
ExternalID: apiutils.Of("list-catalog-ext-2"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Save all test models
|
||||
var savedModels []models.CatalogModel
|
||||
for _, model := range testModels {
|
||||
saved, err := repo.Save(model)
|
||||
require.NoError(t, err)
|
||||
savedModels = append(savedModels, saved)
|
||||
}
|
||||
|
||||
// Test listing all models
|
||||
listOptions := models.CatalogModelListOptions{}
|
||||
result, err := repo.List(listOptions)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.GreaterOrEqual(t, len(result.Items), 2) // At least our 2 test models
|
||||
|
||||
// Test filtering by name
|
||||
nameFilter := "list-catalog-model-1"
|
||||
listOptions = models.CatalogModelListOptions{
|
||||
Name: &nameFilter,
|
||||
}
|
||||
result, err = repo.List(listOptions)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, 1, len(result.Items))
|
||||
assert.Equal(t, "list-catalog-model-1", *result.Items[0].GetAttributes().Name)
|
||||
|
||||
// Test filtering by external ID
|
||||
externalIDFilter := "list-catalog-ext-2"
|
||||
listOptions = models.CatalogModelListOptions{
|
||||
ExternalID: &externalIDFilter,
|
||||
}
|
||||
result, err = repo.List(listOptions)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, 1, len(result.Items))
|
||||
assert.Equal(t, "list-catalog-ext-2", *result.Items[0].GetAttributes().ExternalID)
|
||||
})
|
||||
|
||||
t.Run("TestListWithPropertiesAndCustomProperties", func(t *testing.T) {
|
||||
// Create a model with both properties and custom properties
|
||||
catalogModel := &models.CatalogModelImpl{
|
||||
TypeID: apiutils.Of(int32(typeID)),
|
||||
Attributes: &models.CatalogModelAttributes{
|
||||
Name: apiutils.Of("props-test-catalog-model"),
|
||||
ExternalID: apiutils.Of("props-catalog-ext-123"),
|
||||
},
|
||||
Properties: &[]dbmodels.Properties{
|
||||
{
|
||||
Name: "version",
|
||||
StringValue: apiutils.Of("1.0.0"),
|
||||
},
|
||||
{
|
||||
Name: "priority",
|
||||
IntValue: apiutils.Of(int32(5)),
|
||||
},
|
||||
},
|
||||
CustomProperties: &[]dbmodels.Properties{
|
||||
{
|
||||
Name: "team",
|
||||
StringValue: apiutils.Of("ml-team"),
|
||||
},
|
||||
{
|
||||
Name: "active",
|
||||
BoolValue: apiutils.Of(true),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
saved, err := repo.Save(catalogModel)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, saved)
|
||||
|
||||
// Retrieve and verify properties
|
||||
retrieved, err := repo.GetByID(*saved.GetID())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrieved)
|
||||
|
||||
// Check regular properties
|
||||
require.NotNil(t, retrieved.GetProperties())
|
||||
assert.Len(t, *retrieved.GetProperties(), 2)
|
||||
|
||||
// Check custom properties
|
||||
require.NotNil(t, retrieved.GetCustomProperties())
|
||||
assert.Len(t, *retrieved.GetCustomProperties(), 2)
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to get or create CatalogModel type ID
|
||||
func getCatalogModelTypeID(t *testing.T, db *gorm.DB) int64 {
|
||||
var typeRecord schema.Type
|
||||
err := db.Where("name = ?", service.CatalogModelTypeName).First(&typeRecord).Error
|
||||
if err != nil {
|
||||
require.NoError(t, err, "Failed to query CatalogModel type")
|
||||
}
|
||||
|
||||
return int64(typeRecord.ID)
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/kubeflow/model-registry/internal/testutils"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(testutils.TestMainHelper(m))
|
||||
}
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"github.com/kubeflow/model-registry/internal/datastore"
|
||||
)
|
||||
|
||||
const (
|
||||
CatalogModelTypeName = "kf.CatalogModel"
|
||||
CatalogModelArtifactTypeName = "kf.CatalogModelArtifact"
|
||||
CatalogMetricsArtifactTypeName = "kf.CatalogMetricsArtifact"
|
||||
)
|
||||
|
||||
func DatastoreSpec() *datastore.Spec {
|
||||
return datastore.NewSpec().
|
||||
AddContext(CatalogModelTypeName, datastore.NewSpecType(NewCatalogModelRepository).
|
||||
AddString("source_id").
|
||||
AddString("description").
|
||||
AddString("owner").
|
||||
AddString("state").
|
||||
AddStruct("language").
|
||||
AddString("library_name").
|
||||
AddString("license_link").
|
||||
AddString("license").
|
||||
AddString("logo").
|
||||
AddString("maturity").
|
||||
AddString("provider").
|
||||
AddString("readme").
|
||||
AddStruct("tasks"),
|
||||
).
|
||||
AddArtifact(CatalogModelArtifactTypeName, datastore.NewSpecType(NewCatalogModelArtifactRepository).
|
||||
AddString("uri"),
|
||||
).
|
||||
AddArtifact(CatalogMetricsArtifactTypeName, datastore.NewSpecType(NewCatalogMetricsArtifactRepository).
|
||||
AddString("metricsType"),
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
[mysqld]
|
||||
character-set-server = utf8mb4
|
||||
collation-server = utf8mb4_general_ci
|
||||
|
||||
!includedir /etc/mysql/conf.d/
|
||||
|
|
@ -8,9 +8,11 @@ model_artifact_type_query_param.go
|
|||
model_base_model.go
|
||||
model_base_resource_dates.go
|
||||
model_base_resource_list.go
|
||||
model_catalog_artifact.go
|
||||
model_catalog_artifact_list.go
|
||||
model_catalog_metrics_artifact.go
|
||||
model_catalog_model.go
|
||||
model_catalog_model_artifact.go
|
||||
model_catalog_model_artifact_list.go
|
||||
model_catalog_model_list.go
|
||||
model_catalog_source.go
|
||||
model_catalog_source_list.go
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ func (c *ModelCatalogServiceAPIController) GetModel(w http.ResponseWriter, r *ht
|
|||
EncodeJSONResponse(result.Body, &result.Code, w)
|
||||
}
|
||||
|
||||
// GetAllModelArtifacts - List CatalogModelArtifacts.
|
||||
// GetAllModelArtifacts - List CatalogArtifacts.
|
||||
func (c *ModelCatalogServiceAPIController) GetAllModelArtifacts(w http.ResponseWriter, r *http.Request) {
|
||||
sourceIdParam := chi.URLParam(r, "source_id")
|
||||
modelNameParam := chi.URLParam(r, "model_name")
|
||||
|
|
|
|||
|
|
@ -725,7 +725,7 @@ func TestFindSources(t *testing.T) {
|
|||
// Define a mock model provider
|
||||
type mockModelProvider struct {
|
||||
models map[string]*model.CatalogModel
|
||||
artifacts map[string][]model.CatalogModelArtifact
|
||||
artifacts map[string][]model.CatalogArtifact
|
||||
}
|
||||
|
||||
// Implement GetModel method for the mock provider
|
||||
|
|
@ -784,17 +784,17 @@ func (m *mockModelProvider) ListModels(ctx context.Context, params catalog.ListM
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockModelProvider) GetArtifacts(ctx context.Context, name string) (*model.CatalogModelArtifactList, error) {
|
||||
func (m *mockModelProvider) GetArtifacts(ctx context.Context, name string) (*model.CatalogArtifactList, error) {
|
||||
artifacts, exists := m.artifacts[name]
|
||||
if !exists {
|
||||
return &model.CatalogModelArtifactList{
|
||||
Items: []model.CatalogModelArtifact{},
|
||||
return &model.CatalogArtifactList{
|
||||
Items: []model.CatalogArtifact{},
|
||||
Size: 0,
|
||||
PageSize: 0, // Or a default page size if applicable
|
||||
NextPageToken: "",
|
||||
}, nil
|
||||
}
|
||||
return &model.CatalogModelArtifactList{
|
||||
return &model.CatalogArtifactList{
|
||||
Items: artifacts,
|
||||
Size: int32(len(artifacts)),
|
||||
PageSize: int32(len(artifacts)),
|
||||
|
|
@ -922,7 +922,7 @@ func TestGetAllModelArtifacts(t *testing.T) {
|
|||
sourceID string
|
||||
modelName string
|
||||
expectedStatus int
|
||||
expectedArtifacts []model.CatalogModelArtifact
|
||||
expectedArtifacts []model.CatalogArtifact
|
||||
}{
|
||||
{
|
||||
name: "Existing artifacts for model in source",
|
||||
|
|
@ -930,13 +930,17 @@ func TestGetAllModelArtifacts(t *testing.T) {
|
|||
"source1": {
|
||||
Metadata: model.CatalogSource{Id: "source1", Name: "Test Source"},
|
||||
Provider: &mockModelProvider{
|
||||
artifacts: map[string][]model.CatalogModelArtifact{
|
||||
artifacts: map[string][]model.CatalogArtifact{
|
||||
"test-model": {
|
||||
{
|
||||
Uri: "s3://bucket/artifact1",
|
||||
CatalogModelArtifact: &model.CatalogModelArtifact{
|
||||
Uri: "s3://bucket/artifact1",
|
||||
},
|
||||
},
|
||||
{
|
||||
Uri: "s3://bucket/artifact2",
|
||||
CatalogModelArtifact: &model.CatalogModelArtifact{
|
||||
Uri: "s3://bucket/artifact2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -946,12 +950,16 @@ func TestGetAllModelArtifacts(t *testing.T) {
|
|||
sourceID: "source1",
|
||||
modelName: "test-model",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedArtifacts: []model.CatalogModelArtifact{
|
||||
expectedArtifacts: []model.CatalogArtifact{
|
||||
{
|
||||
Uri: "s3://bucket/artifact1",
|
||||
CatalogModelArtifact: &model.CatalogModelArtifact{
|
||||
Uri: "s3://bucket/artifact1",
|
||||
},
|
||||
},
|
||||
{
|
||||
Uri: "s3://bucket/artifact2",
|
||||
CatalogModelArtifact: &model.CatalogModelArtifact{
|
||||
Uri: "s3://bucket/artifact2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -973,14 +981,14 @@ func TestGetAllModelArtifacts(t *testing.T) {
|
|||
"source1": {
|
||||
Metadata: model.CatalogSource{Id: "source1", Name: "Test Source"},
|
||||
Provider: &mockModelProvider{
|
||||
artifacts: map[string][]model.CatalogModelArtifact{},
|
||||
artifacts: map[string][]model.CatalogArtifact{},
|
||||
},
|
||||
},
|
||||
},
|
||||
sourceID: "source1",
|
||||
modelName: "test-model",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedArtifacts: []model.CatalogModelArtifact{}, // Should be an empty slice, not nil
|
||||
expectedArtifacts: []model.CatalogArtifact{}, // Should be an empty slice, not nil
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -1008,8 +1016,8 @@ func TestGetAllModelArtifacts(t *testing.T) {
|
|||
require.NotNil(t, resp.Body)
|
||||
|
||||
// Type assertion to access the list of artifacts
|
||||
artifactList, ok := resp.Body.(*model.CatalogModelArtifactList)
|
||||
require.True(t, ok, "Response body should be a CatalogModelArtifactList")
|
||||
artifactList, ok := resp.Body.(*model.CatalogArtifactList)
|
||||
require.True(t, ok, "Response body should be a CatalogArtifactList")
|
||||
|
||||
// Check the artifacts
|
||||
assert.Equal(t, tc.expectedArtifacts, artifactList.Items)
|
||||
|
|
|
|||
|
|
@ -67,18 +67,13 @@ func AssertBaseResourceListRequired(obj model.BaseResourceList) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// AssertCatalogModelArtifactConstraints checks if the values respects the defined constraints
|
||||
func AssertCatalogModelArtifactConstraints(obj model.CatalogModelArtifact) error {
|
||||
// AssertCatalogArtifactListConstraints checks if the values respects the defined constraints
|
||||
func AssertCatalogArtifactListConstraints(obj model.CatalogArtifactList) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssertCatalogModelArtifactListConstraints checks if the values respects the defined constraints
|
||||
func AssertCatalogModelArtifactListConstraints(obj model.CatalogModelArtifactList) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssertCatalogModelArtifactListRequired checks if the required fields are not zero-ed
|
||||
func AssertCatalogModelArtifactListRequired(obj model.CatalogModelArtifactList) error {
|
||||
// AssertCatalogArtifactListRequired checks if the required fields are not zero-ed
|
||||
func AssertCatalogArtifactListRequired(obj model.CatalogArtifactList) error {
|
||||
elements := map[string]interface{}{
|
||||
"nextPageToken": obj.NextPageToken,
|
||||
"pageSize": obj.PageSize,
|
||||
|
|
@ -92,17 +87,43 @@ func AssertCatalogModelArtifactListRequired(obj model.CatalogModelArtifactList)
|
|||
}
|
||||
|
||||
for _, el := range obj.Items {
|
||||
if err := AssertCatalogModelArtifactRequired(el); err != nil {
|
||||
if err := AssertCatalogArtifactRequired(el); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssertCatalogMetricsArtifactConstraints checks if the values respects the defined constraints
|
||||
func AssertCatalogMetricsArtifactConstraints(obj model.CatalogMetricsArtifact) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssertCatalogMetricsArtifactRequired checks if the required fields are not zero-ed
|
||||
func AssertCatalogMetricsArtifactRequired(obj model.CatalogMetricsArtifact) error {
|
||||
elements := map[string]interface{}{
|
||||
"artifactType": obj.ArtifactType,
|
||||
"metricsType": obj.MetricsType,
|
||||
}
|
||||
for name, el := range elements {
|
||||
if isZero := IsZeroValue(el); isZero {
|
||||
return &RequiredError{Field: name}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssertCatalogModelArtifactConstraints checks if the values respects the defined constraints
|
||||
func AssertCatalogModelArtifactConstraints(obj model.CatalogModelArtifact) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssertCatalogModelArtifactRequired checks if the required fields are not zero-ed
|
||||
func AssertCatalogModelArtifactRequired(obj model.CatalogModelArtifact) error {
|
||||
elements := map[string]interface{}{
|
||||
"uri": obj.Uri,
|
||||
"artifactType": obj.ArtifactType,
|
||||
"uri": obj.Uri,
|
||||
}
|
||||
for name, el := range elements {
|
||||
if isZero := IsZeroValue(el); isZero {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
package openapi
|
||||
|
||||
import (
|
||||
model "github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
)
|
||||
|
||||
// AssertCatalogArtifactRequired checks if the required fields are not zero-ed
|
||||
func AssertCatalogArtifactRequired(obj model.CatalogArtifact) error {
|
||||
// CatalogArtifact has no required fields but the openapi code gen
|
||||
// checks the fields from CatalogModelArtifact, which doesn't compile.
|
||||
return nil
|
||||
}
|
||||
|
|
@ -5,9 +5,11 @@ model_artifact_type_query_param.go
|
|||
model_base_model.go
|
||||
model_base_resource_dates.go
|
||||
model_base_resource_list.go
|
||||
model_catalog_artifact.go
|
||||
model_catalog_artifact_list.go
|
||||
model_catalog_metrics_artifact.go
|
||||
model_catalog_model.go
|
||||
model_catalog_model_artifact.go
|
||||
model_catalog_model_artifact_list.go
|
||||
model_catalog_model_list.go
|
||||
model_catalog_source.go
|
||||
model_catalog_source_list.go
|
||||
|
|
|
|||
|
|
@ -424,12 +424,12 @@ type ApiGetAllModelArtifactsRequest struct {
|
|||
modelName string
|
||||
}
|
||||
|
||||
func (r ApiGetAllModelArtifactsRequest) Execute() (*CatalogModelArtifactList, *http.Response, error) {
|
||||
func (r ApiGetAllModelArtifactsRequest) Execute() (*CatalogArtifactList, *http.Response, error) {
|
||||
return r.ApiService.GetAllModelArtifactsExecute(r)
|
||||
}
|
||||
|
||||
/*
|
||||
GetAllModelArtifacts List CatalogModelArtifacts.
|
||||
GetAllModelArtifacts List CatalogArtifacts.
|
||||
|
||||
@param ctx context.Context - for authentication, logging, cancellation, deadlines, tracing, etc. Passed from http.Request or context.Background().
|
||||
@param sourceId A unique identifier for a `CatalogSource`.
|
||||
|
|
@ -447,13 +447,13 @@ func (a *ModelCatalogServiceAPIService) GetAllModelArtifacts(ctx context.Context
|
|||
|
||||
// Execute executes the request
|
||||
//
|
||||
// @return CatalogModelArtifactList
|
||||
func (a *ModelCatalogServiceAPIService) GetAllModelArtifactsExecute(r ApiGetAllModelArtifactsRequest) (*CatalogModelArtifactList, *http.Response, error) {
|
||||
// @return CatalogArtifactList
|
||||
func (a *ModelCatalogServiceAPIService) GetAllModelArtifactsExecute(r ApiGetAllModelArtifactsRequest) (*CatalogArtifactList, *http.Response, error) {
|
||||
var (
|
||||
localVarHTTPMethod = http.MethodGet
|
||||
localVarPostBody interface{}
|
||||
formFiles []formFile
|
||||
localVarReturnValue *CatalogModelArtifactList
|
||||
localVarReturnValue *CatalogArtifactList
|
||||
)
|
||||
|
||||
localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelCatalogServiceAPIService.GetAllModelArtifacts")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,163 @@
|
|||
/*
|
||||
Model Catalog REST API
|
||||
|
||||
REST API for Model Registry to create and manage ML model metadata
|
||||
|
||||
API version: v1alpha1
|
||||
*/
|
||||
|
||||
// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
|
||||
|
||||
package openapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// CatalogArtifact - A single artifact in the catalog API.
|
||||
type CatalogArtifact struct {
|
||||
CatalogMetricsArtifact *CatalogMetricsArtifact
|
||||
CatalogModelArtifact *CatalogModelArtifact
|
||||
}
|
||||
|
||||
// CatalogMetricsArtifactAsCatalogArtifact is a convenience function that returns CatalogMetricsArtifact wrapped in CatalogArtifact
|
||||
func CatalogMetricsArtifactAsCatalogArtifact(v *CatalogMetricsArtifact) CatalogArtifact {
|
||||
return CatalogArtifact{
|
||||
CatalogMetricsArtifact: v,
|
||||
}
|
||||
}
|
||||
|
||||
// CatalogModelArtifactAsCatalogArtifact is a convenience function that returns CatalogModelArtifact wrapped in CatalogArtifact
|
||||
func CatalogModelArtifactAsCatalogArtifact(v *CatalogModelArtifact) CatalogArtifact {
|
||||
return CatalogArtifact{
|
||||
CatalogModelArtifact: v,
|
||||
}
|
||||
}
|
||||
|
||||
// Unmarshal JSON data into one of the pointers in the struct
|
||||
func (dst *CatalogArtifact) UnmarshalJSON(data []byte) error {
|
||||
var err error
|
||||
// use discriminator value to speed up the lookup
|
||||
var jsonDict map[string]interface{}
|
||||
err = newStrictDecoder(data).Decode(&jsonDict)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal JSON into map for the discriminator lookup")
|
||||
}
|
||||
|
||||
// check if the discriminator value is 'CatalogMetricsArtifact'
|
||||
if jsonDict["artifactType"] == "CatalogMetricsArtifact" {
|
||||
// try to unmarshal JSON data into CatalogMetricsArtifact
|
||||
err = json.Unmarshal(data, &dst.CatalogMetricsArtifact)
|
||||
if err == nil {
|
||||
return nil // data stored in dst.CatalogMetricsArtifact, return on the first match
|
||||
} else {
|
||||
dst.CatalogMetricsArtifact = nil
|
||||
return fmt.Errorf("failed to unmarshal CatalogArtifact as CatalogMetricsArtifact: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// check if the discriminator value is 'CatalogModelArtifact'
|
||||
if jsonDict["artifactType"] == "CatalogModelArtifact" {
|
||||
// try to unmarshal JSON data into CatalogModelArtifact
|
||||
err = json.Unmarshal(data, &dst.CatalogModelArtifact)
|
||||
if err == nil {
|
||||
return nil // data stored in dst.CatalogModelArtifact, return on the first match
|
||||
} else {
|
||||
dst.CatalogModelArtifact = nil
|
||||
return fmt.Errorf("failed to unmarshal CatalogArtifact as CatalogModelArtifact: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// check if the discriminator value is 'metrics-artifact'
|
||||
if jsonDict["artifactType"] == "metrics-artifact" {
|
||||
// try to unmarshal JSON data into CatalogMetricsArtifact
|
||||
err = json.Unmarshal(data, &dst.CatalogMetricsArtifact)
|
||||
if err == nil {
|
||||
return nil // data stored in dst.CatalogMetricsArtifact, return on the first match
|
||||
} else {
|
||||
dst.CatalogMetricsArtifact = nil
|
||||
return fmt.Errorf("failed to unmarshal CatalogArtifact as CatalogMetricsArtifact: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// check if the discriminator value is 'model-artifact'
|
||||
if jsonDict["artifactType"] == "model-artifact" {
|
||||
// try to unmarshal JSON data into CatalogModelArtifact
|
||||
err = json.Unmarshal(data, &dst.CatalogModelArtifact)
|
||||
if err == nil {
|
||||
return nil // data stored in dst.CatalogModelArtifact, return on the first match
|
||||
} else {
|
||||
dst.CatalogModelArtifact = nil
|
||||
return fmt.Errorf("failed to unmarshal CatalogArtifact as CatalogModelArtifact: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Marshal data from the first non-nil pointers in the struct to JSON
|
||||
func (src CatalogArtifact) MarshalJSON() ([]byte, error) {
|
||||
if src.CatalogMetricsArtifact != nil {
|
||||
return json.Marshal(&src.CatalogMetricsArtifact)
|
||||
}
|
||||
|
||||
if src.CatalogModelArtifact != nil {
|
||||
return json.Marshal(&src.CatalogModelArtifact)
|
||||
}
|
||||
|
||||
return nil, nil // no data in oneOf schemas
|
||||
}
|
||||
|
||||
// Get the actual instance
|
||||
func (obj *CatalogArtifact) GetActualInstance() interface{} {
|
||||
if obj == nil {
|
||||
return nil
|
||||
}
|
||||
if obj.CatalogMetricsArtifact != nil {
|
||||
return obj.CatalogMetricsArtifact
|
||||
}
|
||||
|
||||
if obj.CatalogModelArtifact != nil {
|
||||
return obj.CatalogModelArtifact
|
||||
}
|
||||
|
||||
// all schemas are nil
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullableCatalogArtifact struct {
|
||||
value *CatalogArtifact
|
||||
isSet bool
|
||||
}
|
||||
|
||||
func (v NullableCatalogArtifact) Get() *CatalogArtifact {
|
||||
return v.value
|
||||
}
|
||||
|
||||
func (v *NullableCatalogArtifact) Set(val *CatalogArtifact) {
|
||||
v.value = val
|
||||
v.isSet = true
|
||||
}
|
||||
|
||||
func (v NullableCatalogArtifact) IsSet() bool {
|
||||
return v.isSet
|
||||
}
|
||||
|
||||
func (v *NullableCatalogArtifact) Unset() {
|
||||
v.value = nil
|
||||
v.isSet = false
|
||||
}
|
||||
|
||||
func NewNullableCatalogArtifact(val *CatalogArtifact) *NullableCatalogArtifact {
|
||||
return &NullableCatalogArtifact{value: val, isSet: true}
|
||||
}
|
||||
|
||||
func (v NullableCatalogArtifact) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(v.value)
|
||||
}
|
||||
|
||||
func (v *NullableCatalogArtifact) UnmarshalJSON(src []byte) error {
|
||||
v.isSet = true
|
||||
return json.Unmarshal(src, &v.value)
|
||||
}
|
||||
|
|
@ -14,27 +14,27 @@ import (
|
|||
"encoding/json"
|
||||
)
|
||||
|
||||
// checks if the CatalogModelArtifactList type satisfies the MappedNullable interface at compile time
|
||||
var _ MappedNullable = &CatalogModelArtifactList{}
|
||||
// checks if the CatalogArtifactList type satisfies the MappedNullable interface at compile time
|
||||
var _ MappedNullable = &CatalogArtifactList{}
|
||||
|
||||
// CatalogModelArtifactList List of CatalogModel entities.
|
||||
type CatalogModelArtifactList struct {
|
||||
// CatalogArtifactList List of CatalogModel entities.
|
||||
type CatalogArtifactList struct {
|
||||
// Token to use to retrieve next page of results.
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
// Maximum number of resources to return in the result.
|
||||
PageSize int32 `json:"pageSize"`
|
||||
// Number of items in result list.
|
||||
Size int32 `json:"size"`
|
||||
// Array of `CatalogModelArtifact` entities.
|
||||
Items []CatalogModelArtifact `json:"items"`
|
||||
// Array of `CatalogArtifact` entities.
|
||||
Items []CatalogArtifact `json:"items"`
|
||||
}
|
||||
|
||||
// NewCatalogModelArtifactList instantiates a new CatalogModelArtifactList object
|
||||
// NewCatalogArtifactList instantiates a new CatalogArtifactList object
|
||||
// This constructor will assign default values to properties that have it defined,
|
||||
// and makes sure properties required by API are set, but the set of arguments
|
||||
// will change when the set of required properties is changed
|
||||
func NewCatalogModelArtifactList(nextPageToken string, pageSize int32, size int32, items []CatalogModelArtifact) *CatalogModelArtifactList {
|
||||
this := CatalogModelArtifactList{}
|
||||
func NewCatalogArtifactList(nextPageToken string, pageSize int32, size int32, items []CatalogArtifact) *CatalogArtifactList {
|
||||
this := CatalogArtifactList{}
|
||||
this.NextPageToken = nextPageToken
|
||||
this.PageSize = pageSize
|
||||
this.Size = size
|
||||
|
|
@ -42,16 +42,16 @@ func NewCatalogModelArtifactList(nextPageToken string, pageSize int32, size int3
|
|||
return &this
|
||||
}
|
||||
|
||||
// NewCatalogModelArtifactListWithDefaults instantiates a new CatalogModelArtifactList object
|
||||
// NewCatalogArtifactListWithDefaults instantiates a new CatalogArtifactList object
|
||||
// This constructor will only assign default values to properties that have it defined,
|
||||
// but it doesn't guarantee that properties required by API are set
|
||||
func NewCatalogModelArtifactListWithDefaults() *CatalogModelArtifactList {
|
||||
this := CatalogModelArtifactList{}
|
||||
func NewCatalogArtifactListWithDefaults() *CatalogArtifactList {
|
||||
this := CatalogArtifactList{}
|
||||
return &this
|
||||
}
|
||||
|
||||
// GetNextPageToken returns the NextPageToken field value
|
||||
func (o *CatalogModelArtifactList) GetNextPageToken() string {
|
||||
func (o *CatalogArtifactList) GetNextPageToken() string {
|
||||
if o == nil {
|
||||
var ret string
|
||||
return ret
|
||||
|
|
@ -62,7 +62,7 @@ func (o *CatalogModelArtifactList) GetNextPageToken() string {
|
|||
|
||||
// GetNextPageTokenOk returns a tuple with the NextPageToken field value
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *CatalogModelArtifactList) GetNextPageTokenOk() (*string, bool) {
|
||||
func (o *CatalogArtifactList) GetNextPageTokenOk() (*string, bool) {
|
||||
if o == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
|
@ -70,12 +70,12 @@ func (o *CatalogModelArtifactList) GetNextPageTokenOk() (*string, bool) {
|
|||
}
|
||||
|
||||
// SetNextPageToken sets field value
|
||||
func (o *CatalogModelArtifactList) SetNextPageToken(v string) {
|
||||
func (o *CatalogArtifactList) SetNextPageToken(v string) {
|
||||
o.NextPageToken = v
|
||||
}
|
||||
|
||||
// GetPageSize returns the PageSize field value
|
||||
func (o *CatalogModelArtifactList) GetPageSize() int32 {
|
||||
func (o *CatalogArtifactList) GetPageSize() int32 {
|
||||
if o == nil {
|
||||
var ret int32
|
||||
return ret
|
||||
|
|
@ -86,7 +86,7 @@ func (o *CatalogModelArtifactList) GetPageSize() int32 {
|
|||
|
||||
// GetPageSizeOk returns a tuple with the PageSize field value
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *CatalogModelArtifactList) GetPageSizeOk() (*int32, bool) {
|
||||
func (o *CatalogArtifactList) GetPageSizeOk() (*int32, bool) {
|
||||
if o == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
|
@ -94,12 +94,12 @@ func (o *CatalogModelArtifactList) GetPageSizeOk() (*int32, bool) {
|
|||
}
|
||||
|
||||
// SetPageSize sets field value
|
||||
func (o *CatalogModelArtifactList) SetPageSize(v int32) {
|
||||
func (o *CatalogArtifactList) SetPageSize(v int32) {
|
||||
o.PageSize = v
|
||||
}
|
||||
|
||||
// GetSize returns the Size field value
|
||||
func (o *CatalogModelArtifactList) GetSize() int32 {
|
||||
func (o *CatalogArtifactList) GetSize() int32 {
|
||||
if o == nil {
|
||||
var ret int32
|
||||
return ret
|
||||
|
|
@ -110,7 +110,7 @@ func (o *CatalogModelArtifactList) GetSize() int32 {
|
|||
|
||||
// GetSizeOk returns a tuple with the Size field value
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *CatalogModelArtifactList) GetSizeOk() (*int32, bool) {
|
||||
func (o *CatalogArtifactList) GetSizeOk() (*int32, bool) {
|
||||
if o == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
|
@ -118,14 +118,14 @@ func (o *CatalogModelArtifactList) GetSizeOk() (*int32, bool) {
|
|||
}
|
||||
|
||||
// SetSize sets field value
|
||||
func (o *CatalogModelArtifactList) SetSize(v int32) {
|
||||
func (o *CatalogArtifactList) SetSize(v int32) {
|
||||
o.Size = v
|
||||
}
|
||||
|
||||
// GetItems returns the Items field value
|
||||
func (o *CatalogModelArtifactList) GetItems() []CatalogModelArtifact {
|
||||
func (o *CatalogArtifactList) GetItems() []CatalogArtifact {
|
||||
if o == nil {
|
||||
var ret []CatalogModelArtifact
|
||||
var ret []CatalogArtifact
|
||||
return ret
|
||||
}
|
||||
|
||||
|
|
@ -134,7 +134,7 @@ func (o *CatalogModelArtifactList) GetItems() []CatalogModelArtifact {
|
|||
|
||||
// GetItemsOk returns a tuple with the Items field value
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *CatalogModelArtifactList) GetItemsOk() ([]CatalogModelArtifact, bool) {
|
||||
func (o *CatalogArtifactList) GetItemsOk() ([]CatalogArtifact, bool) {
|
||||
if o == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
|
@ -142,11 +142,11 @@ func (o *CatalogModelArtifactList) GetItemsOk() ([]CatalogModelArtifact, bool) {
|
|||
}
|
||||
|
||||
// SetItems sets field value
|
||||
func (o *CatalogModelArtifactList) SetItems(v []CatalogModelArtifact) {
|
||||
func (o *CatalogArtifactList) SetItems(v []CatalogArtifact) {
|
||||
o.Items = v
|
||||
}
|
||||
|
||||
func (o CatalogModelArtifactList) MarshalJSON() ([]byte, error) {
|
||||
func (o CatalogArtifactList) MarshalJSON() ([]byte, error) {
|
||||
toSerialize, err := o.ToMap()
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
|
|
@ -154,7 +154,7 @@ func (o CatalogModelArtifactList) MarshalJSON() ([]byte, error) {
|
|||
return json.Marshal(toSerialize)
|
||||
}
|
||||
|
||||
func (o CatalogModelArtifactList) ToMap() (map[string]interface{}, error) {
|
||||
func (o CatalogArtifactList) ToMap() (map[string]interface{}, error) {
|
||||
toSerialize := map[string]interface{}{}
|
||||
toSerialize["nextPageToken"] = o.NextPageToken
|
||||
toSerialize["pageSize"] = o.PageSize
|
||||
|
|
@ -163,38 +163,38 @@ func (o CatalogModelArtifactList) ToMap() (map[string]interface{}, error) {
|
|||
return toSerialize, nil
|
||||
}
|
||||
|
||||
type NullableCatalogModelArtifactList struct {
|
||||
value *CatalogModelArtifactList
|
||||
type NullableCatalogArtifactList struct {
|
||||
value *CatalogArtifactList
|
||||
isSet bool
|
||||
}
|
||||
|
||||
func (v NullableCatalogModelArtifactList) Get() *CatalogModelArtifactList {
|
||||
func (v NullableCatalogArtifactList) Get() *CatalogArtifactList {
|
||||
return v.value
|
||||
}
|
||||
|
||||
func (v *NullableCatalogModelArtifactList) Set(val *CatalogModelArtifactList) {
|
||||
func (v *NullableCatalogArtifactList) Set(val *CatalogArtifactList) {
|
||||
v.value = val
|
||||
v.isSet = true
|
||||
}
|
||||
|
||||
func (v NullableCatalogModelArtifactList) IsSet() bool {
|
||||
func (v NullableCatalogArtifactList) IsSet() bool {
|
||||
return v.isSet
|
||||
}
|
||||
|
||||
func (v *NullableCatalogModelArtifactList) Unset() {
|
||||
func (v *NullableCatalogArtifactList) Unset() {
|
||||
v.value = nil
|
||||
v.isSet = false
|
||||
}
|
||||
|
||||
func NewNullableCatalogModelArtifactList(val *CatalogModelArtifactList) *NullableCatalogModelArtifactList {
|
||||
return &NullableCatalogModelArtifactList{value: val, isSet: true}
|
||||
func NewNullableCatalogArtifactList(val *CatalogArtifactList) *NullableCatalogArtifactList {
|
||||
return &NullableCatalogArtifactList{value: val, isSet: true}
|
||||
}
|
||||
|
||||
func (v NullableCatalogModelArtifactList) MarshalJSON() ([]byte, error) {
|
||||
func (v NullableCatalogArtifactList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(v.value)
|
||||
}
|
||||
|
||||
func (v *NullableCatalogModelArtifactList) UnmarshalJSON(src []byte) error {
|
||||
func (v *NullableCatalogArtifactList) UnmarshalJSON(src []byte) error {
|
||||
v.isSet = true
|
||||
return json.Unmarshal(src, &v.value)
|
||||
}
|
||||
|
|
@ -0,0 +1,255 @@
|
|||
/*
|
||||
Model Catalog REST API
|
||||
|
||||
REST API for Model Registry to create and manage ML model metadata
|
||||
|
||||
API version: v1alpha1
|
||||
*/
|
||||
|
||||
// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
|
||||
|
||||
package openapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// checks if the CatalogMetricsArtifact type satisfies the MappedNullable interface at compile time
|
||||
var _ MappedNullable = &CatalogMetricsArtifact{}
|
||||
|
||||
// CatalogMetricsArtifact A metadata Artifact Entity.
|
||||
type CatalogMetricsArtifact struct {
|
||||
// Output only. Create time of the resource in millisecond since epoch.
|
||||
CreateTimeSinceEpoch *string `json:"createTimeSinceEpoch,omitempty"`
|
||||
// Output only. Last update time of the resource since epoch in millisecond since epoch.
|
||||
LastUpdateTimeSinceEpoch *string `json:"lastUpdateTimeSinceEpoch,omitempty"`
|
||||
ArtifactType string `json:"artifactType"`
|
||||
MetricsType string `json:"metricsType"`
|
||||
// User provided custom properties which are not defined by its type.
|
||||
CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"`
|
||||
}
|
||||
|
||||
// NewCatalogMetricsArtifact instantiates a new CatalogMetricsArtifact object
|
||||
// This constructor will assign default values to properties that have it defined,
|
||||
// and makes sure properties required by API are set, but the set of arguments
|
||||
// will change when the set of required properties is changed
|
||||
func NewCatalogMetricsArtifact(artifactType string, metricsType string) *CatalogMetricsArtifact {
|
||||
this := CatalogMetricsArtifact{}
|
||||
this.ArtifactType = artifactType
|
||||
this.MetricsType = metricsType
|
||||
return &this
|
||||
}
|
||||
|
||||
// NewCatalogMetricsArtifactWithDefaults instantiates a new CatalogMetricsArtifact object
|
||||
// This constructor will only assign default values to properties that have it defined,
|
||||
// but it doesn't guarantee that properties required by API are set
|
||||
func NewCatalogMetricsArtifactWithDefaults() *CatalogMetricsArtifact {
|
||||
this := CatalogMetricsArtifact{}
|
||||
var artifactType string = "metrics-artifact"
|
||||
this.ArtifactType = artifactType
|
||||
return &this
|
||||
}
|
||||
|
||||
// GetCreateTimeSinceEpoch returns the CreateTimeSinceEpoch field value if set, zero value otherwise.
|
||||
func (o *CatalogMetricsArtifact) GetCreateTimeSinceEpoch() string {
|
||||
if o == nil || IsNil(o.CreateTimeSinceEpoch) {
|
||||
var ret string
|
||||
return ret
|
||||
}
|
||||
return *o.CreateTimeSinceEpoch
|
||||
}
|
||||
|
||||
// GetCreateTimeSinceEpochOk returns a tuple with the CreateTimeSinceEpoch field value if set, nil otherwise
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *CatalogMetricsArtifact) GetCreateTimeSinceEpochOk() (*string, bool) {
|
||||
if o == nil || IsNil(o.CreateTimeSinceEpoch) {
|
||||
return nil, false
|
||||
}
|
||||
return o.CreateTimeSinceEpoch, true
|
||||
}
|
||||
|
||||
// HasCreateTimeSinceEpoch returns a boolean if a field has been set.
|
||||
func (o *CatalogMetricsArtifact) HasCreateTimeSinceEpoch() bool {
|
||||
if o != nil && !IsNil(o.CreateTimeSinceEpoch) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetCreateTimeSinceEpoch gets a reference to the given string and assigns it to the CreateTimeSinceEpoch field.
|
||||
func (o *CatalogMetricsArtifact) SetCreateTimeSinceEpoch(v string) {
|
||||
o.CreateTimeSinceEpoch = &v
|
||||
}
|
||||
|
||||
// GetLastUpdateTimeSinceEpoch returns the LastUpdateTimeSinceEpoch field value if set, zero value otherwise.
|
||||
func (o *CatalogMetricsArtifact) GetLastUpdateTimeSinceEpoch() string {
|
||||
if o == nil || IsNil(o.LastUpdateTimeSinceEpoch) {
|
||||
var ret string
|
||||
return ret
|
||||
}
|
||||
return *o.LastUpdateTimeSinceEpoch
|
||||
}
|
||||
|
||||
// GetLastUpdateTimeSinceEpochOk returns a tuple with the LastUpdateTimeSinceEpoch field value if set, nil otherwise
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *CatalogMetricsArtifact) GetLastUpdateTimeSinceEpochOk() (*string, bool) {
|
||||
if o == nil || IsNil(o.LastUpdateTimeSinceEpoch) {
|
||||
return nil, false
|
||||
}
|
||||
return o.LastUpdateTimeSinceEpoch, true
|
||||
}
|
||||
|
||||
// HasLastUpdateTimeSinceEpoch returns a boolean if a field has been set.
|
||||
func (o *CatalogMetricsArtifact) HasLastUpdateTimeSinceEpoch() bool {
|
||||
if o != nil && !IsNil(o.LastUpdateTimeSinceEpoch) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetLastUpdateTimeSinceEpoch gets a reference to the given string and assigns it to the LastUpdateTimeSinceEpoch field.
|
||||
func (o *CatalogMetricsArtifact) SetLastUpdateTimeSinceEpoch(v string) {
|
||||
o.LastUpdateTimeSinceEpoch = &v
|
||||
}
|
||||
|
||||
// GetArtifactType returns the ArtifactType field value
|
||||
func (o *CatalogMetricsArtifact) GetArtifactType() string {
|
||||
if o == nil {
|
||||
var ret string
|
||||
return ret
|
||||
}
|
||||
|
||||
return o.ArtifactType
|
||||
}
|
||||
|
||||
// GetArtifactTypeOk returns a tuple with the ArtifactType field value
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *CatalogMetricsArtifact) GetArtifactTypeOk() (*string, bool) {
|
||||
if o == nil {
|
||||
return nil, false
|
||||
}
|
||||
return &o.ArtifactType, true
|
||||
}
|
||||
|
||||
// SetArtifactType sets field value
|
||||
func (o *CatalogMetricsArtifact) SetArtifactType(v string) {
|
||||
o.ArtifactType = v
|
||||
}
|
||||
|
||||
// GetMetricsType returns the MetricsType field value
|
||||
func (o *CatalogMetricsArtifact) GetMetricsType() string {
|
||||
if o == nil {
|
||||
var ret string
|
||||
return ret
|
||||
}
|
||||
|
||||
return o.MetricsType
|
||||
}
|
||||
|
||||
// GetMetricsTypeOk returns a tuple with the MetricsType field value
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *CatalogMetricsArtifact) GetMetricsTypeOk() (*string, bool) {
|
||||
if o == nil {
|
||||
return nil, false
|
||||
}
|
||||
return &o.MetricsType, true
|
||||
}
|
||||
|
||||
// SetMetricsType sets field value
|
||||
func (o *CatalogMetricsArtifact) SetMetricsType(v string) {
|
||||
o.MetricsType = v
|
||||
}
|
||||
|
||||
// GetCustomProperties returns the CustomProperties field value if set, zero value otherwise.
|
||||
func (o *CatalogMetricsArtifact) GetCustomProperties() map[string]MetadataValue {
|
||||
if o == nil || IsNil(o.CustomProperties) {
|
||||
var ret map[string]MetadataValue
|
||||
return ret
|
||||
}
|
||||
return *o.CustomProperties
|
||||
}
|
||||
|
||||
// GetCustomPropertiesOk returns a tuple with the CustomProperties field value if set, nil otherwise
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *CatalogMetricsArtifact) GetCustomPropertiesOk() (*map[string]MetadataValue, bool) {
|
||||
if o == nil || IsNil(o.CustomProperties) {
|
||||
return nil, false
|
||||
}
|
||||
return o.CustomProperties, true
|
||||
}
|
||||
|
||||
// HasCustomProperties returns a boolean if a field has been set.
|
||||
func (o *CatalogMetricsArtifact) HasCustomProperties() bool {
|
||||
if o != nil && !IsNil(o.CustomProperties) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetCustomProperties gets a reference to the given map[string]MetadataValue and assigns it to the CustomProperties field.
|
||||
func (o *CatalogMetricsArtifact) SetCustomProperties(v map[string]MetadataValue) {
|
||||
o.CustomProperties = &v
|
||||
}
|
||||
|
||||
func (o CatalogMetricsArtifact) MarshalJSON() ([]byte, error) {
|
||||
toSerialize, err := o.ToMap()
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
return json.Marshal(toSerialize)
|
||||
}
|
||||
|
||||
func (o CatalogMetricsArtifact) ToMap() (map[string]interface{}, error) {
|
||||
toSerialize := map[string]interface{}{}
|
||||
if !IsNil(o.CreateTimeSinceEpoch) {
|
||||
toSerialize["createTimeSinceEpoch"] = o.CreateTimeSinceEpoch
|
||||
}
|
||||
if !IsNil(o.LastUpdateTimeSinceEpoch) {
|
||||
toSerialize["lastUpdateTimeSinceEpoch"] = o.LastUpdateTimeSinceEpoch
|
||||
}
|
||||
toSerialize["artifactType"] = o.ArtifactType
|
||||
toSerialize["metricsType"] = o.MetricsType
|
||||
if !IsNil(o.CustomProperties) {
|
||||
toSerialize["customProperties"] = o.CustomProperties
|
||||
}
|
||||
return toSerialize, nil
|
||||
}
|
||||
|
||||
type NullableCatalogMetricsArtifact struct {
|
||||
value *CatalogMetricsArtifact
|
||||
isSet bool
|
||||
}
|
||||
|
||||
func (v NullableCatalogMetricsArtifact) Get() *CatalogMetricsArtifact {
|
||||
return v.value
|
||||
}
|
||||
|
||||
func (v *NullableCatalogMetricsArtifact) Set(val *CatalogMetricsArtifact) {
|
||||
v.value = val
|
||||
v.isSet = true
|
||||
}
|
||||
|
||||
func (v NullableCatalogMetricsArtifact) IsSet() bool {
|
||||
return v.isSet
|
||||
}
|
||||
|
||||
func (v *NullableCatalogMetricsArtifact) Unset() {
|
||||
v.value = nil
|
||||
v.isSet = false
|
||||
}
|
||||
|
||||
func NewNullableCatalogMetricsArtifact(val *CatalogMetricsArtifact) *NullableCatalogMetricsArtifact {
|
||||
return &NullableCatalogMetricsArtifact{value: val, isSet: true}
|
||||
}
|
||||
|
||||
func (v NullableCatalogMetricsArtifact) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(v.value)
|
||||
}
|
||||
|
||||
func (v *NullableCatalogMetricsArtifact) UnmarshalJSON(src []byte) error {
|
||||
v.isSet = true
|
||||
return json.Unmarshal(src, &v.value)
|
||||
}
|
||||
|
|
@ -17,13 +17,14 @@ import (
|
|||
// checks if the CatalogModelArtifact type satisfies the MappedNullable interface at compile time
|
||||
var _ MappedNullable = &CatalogModelArtifact{}
|
||||
|
||||
// CatalogModelArtifact A single artifact for a catalog model.
|
||||
// CatalogModelArtifact A Catalog Model Artifact Entity.
|
||||
type CatalogModelArtifact struct {
|
||||
// Output only. Create time of the resource in millisecond since epoch.
|
||||
CreateTimeSinceEpoch *string `json:"createTimeSinceEpoch,omitempty"`
|
||||
// Output only. Last update time of the resource since epoch in millisecond since epoch.
|
||||
LastUpdateTimeSinceEpoch *string `json:"lastUpdateTimeSinceEpoch,omitempty"`
|
||||
// URI where the artifact can be retrieved.
|
||||
ArtifactType string `json:"artifactType"`
|
||||
// URI where the model can be retrieved.
|
||||
Uri string `json:"uri"`
|
||||
// User provided custom properties which are not defined by its type.
|
||||
CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"`
|
||||
|
|
@ -33,8 +34,9 @@ type CatalogModelArtifact struct {
|
|||
// This constructor will assign default values to properties that have it defined,
|
||||
// and makes sure properties required by API are set, but the set of arguments
|
||||
// will change when the set of required properties is changed
|
||||
func NewCatalogModelArtifact(uri string) *CatalogModelArtifact {
|
||||
func NewCatalogModelArtifact(artifactType string, uri string) *CatalogModelArtifact {
|
||||
this := CatalogModelArtifact{}
|
||||
this.ArtifactType = artifactType
|
||||
this.Uri = uri
|
||||
return &this
|
||||
}
|
||||
|
|
@ -44,6 +46,8 @@ func NewCatalogModelArtifact(uri string) *CatalogModelArtifact {
|
|||
// but it doesn't guarantee that properties required by API are set
|
||||
func NewCatalogModelArtifactWithDefaults() *CatalogModelArtifact {
|
||||
this := CatalogModelArtifact{}
|
||||
var artifactType string = "model-artifact"
|
||||
this.ArtifactType = artifactType
|
||||
return &this
|
||||
}
|
||||
|
||||
|
|
@ -111,6 +115,30 @@ func (o *CatalogModelArtifact) SetLastUpdateTimeSinceEpoch(v string) {
|
|||
o.LastUpdateTimeSinceEpoch = &v
|
||||
}
|
||||
|
||||
// GetArtifactType returns the ArtifactType field value
|
||||
func (o *CatalogModelArtifact) GetArtifactType() string {
|
||||
if o == nil {
|
||||
var ret string
|
||||
return ret
|
||||
}
|
||||
|
||||
return o.ArtifactType
|
||||
}
|
||||
|
||||
// GetArtifactTypeOk returns a tuple with the ArtifactType field value
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *CatalogModelArtifact) GetArtifactTypeOk() (*string, bool) {
|
||||
if o == nil {
|
||||
return nil, false
|
||||
}
|
||||
return &o.ArtifactType, true
|
||||
}
|
||||
|
||||
// SetArtifactType sets field value
|
||||
func (o *CatalogModelArtifact) SetArtifactType(v string) {
|
||||
o.ArtifactType = v
|
||||
}
|
||||
|
||||
// GetUri returns the Uri field value
|
||||
func (o *CatalogModelArtifact) GetUri() string {
|
||||
if o == nil {
|
||||
|
|
@ -183,6 +211,7 @@ func (o CatalogModelArtifact) ToMap() (map[string]interface{}, error) {
|
|||
if !IsNil(o.LastUpdateTimeSinceEpoch) {
|
||||
toSerialize["lastUpdateTimeSinceEpoch"] = o.LastUpdateTimeSinceEpoch
|
||||
}
|
||||
toSerialize["artifactType"] = o.ArtifactType
|
||||
toSerialize["uri"] = o.Uri
|
||||
if !IsNil(o.CustomProperties) {
|
||||
toSerialize["customProperties"] = o.CustomProperties
|
||||
|
|
|
|||
|
|
@ -5,10 +5,9 @@ set -e
|
|||
ASSERT_FILE_PATH="$1/type_asserts.go"
|
||||
|
||||
PROJECT_ROOT=$(realpath "$(dirname "$0")"/..)
|
||||
PATCH="${PROJECT_ROOT}/patches/type_asserts.patch"
|
||||
|
||||
# AssertMetadataValueRequired from this file generates with the incorrect logic.
|
||||
rm -f $1/model_metadata_value.go
|
||||
# These files generate with incorrect logic:
|
||||
rm -f "$1/model_metadata_value.go" "$1/model_catalog_artifact.go"
|
||||
|
||||
python3 "${PROJECT_ROOT}/scripts/gen_type_asserts.py" $1 >"$ASSERT_FILE_PATH"
|
||||
|
||||
|
|
|
|||
99
cmd/proxy.go
99
cmd/proxy.go
|
|
@ -3,12 +3,16 @@ package cmd
|
|||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/golang/glog"
|
||||
"github.com/kubeflow/model-registry/internal/core"
|
||||
"github.com/kubeflow/model-registry/internal/datastore"
|
||||
"github.com/kubeflow/model-registry/internal/datastore/embedmd"
|
||||
"github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/service"
|
||||
"github.com/kubeflow/model-registry/internal/proxy"
|
||||
"github.com/kubeflow/model-registry/internal/server/middleware"
|
||||
"github.com/kubeflow/model-registry/internal/server/openapi"
|
||||
|
|
@ -18,7 +22,8 @@ import (
|
|||
)
|
||||
|
||||
type ProxyConfig struct {
|
||||
Datastore datastore.Datastore
|
||||
EmbedMD embedmd.EmbedMDConfig
|
||||
DatastoreType string
|
||||
}
|
||||
|
||||
const (
|
||||
|
|
@ -28,11 +33,9 @@ const (
|
|||
|
||||
var (
|
||||
proxyCfg = ProxyConfig{
|
||||
Datastore: datastore.Datastore{
|
||||
Type: "embedmd",
|
||||
EmbedMD: embedmd.EmbedMDConfig{
|
||||
TLSConfig: &tls.TLSConfig{},
|
||||
},
|
||||
DatastoreType: "embedmd",
|
||||
EmbedMD: embedmd.EmbedMDConfig{
|
||||
TLSConfig: &tls.TLSConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -102,10 +105,19 @@ func runProxyServer(cmd *cobra.Command, args []string) error {
|
|||
http.Error(w, datastoreUnavailableMessage, http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
mrHealthChecker := &ConditionalModelRegistryHealthChecker{holder: serviceHolder}
|
||||
dbHealthChecker := proxy.NewDatabaseHealthChecker(proxyCfg.Datastore)
|
||||
generalReadinessHandler := proxy.GeneralReadinessHandler(proxyCfg.Datastore, dbHealthChecker, mrHealthChecker)
|
||||
readinessHandler := proxy.GeneralReadinessHandler(proxyCfg.Datastore, dbHealthChecker)
|
||||
readyChecks := []proxy.HealthChecker{}
|
||||
generalChecks := []proxy.HealthChecker{
|
||||
&ConditionalModelRegistryHealthChecker{holder: serviceHolder},
|
||||
}
|
||||
|
||||
if proxyCfg.DatastoreType == "embedmd" {
|
||||
dbHealthChecker := proxy.NewDatabaseHealthChecker()
|
||||
readyChecks = append(readyChecks, dbHealthChecker)
|
||||
generalChecks = append(generalChecks, dbHealthChecker)
|
||||
}
|
||||
|
||||
generalReadinessHandler := proxy.GeneralReadinessHandler(generalChecks...)
|
||||
readinessHandler := proxy.GeneralReadinessHandler(readyChecks...)
|
||||
|
||||
// route health endpoints appropriately
|
||||
mainHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -142,13 +154,13 @@ func runProxyServer(cmd *cobra.Command, args []string) error {
|
|||
|
||||
defer wg.Done()
|
||||
|
||||
ds, err = datastore.NewConnector(proxyCfg.Datastore)
|
||||
ds, err = datastore.NewConnector(proxyCfg.DatastoreType, &proxyCfg.EmbedMD)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error creating datastore: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := ds.Connect()
|
||||
conn, err := newModelRegistryService(ds)
|
||||
if err != nil {
|
||||
// {{ALERT}} is used to identify this error in pod logs, DO NOT REMOVE
|
||||
errChan <- fmt.Errorf("{{ALERT}} error connecting to datastore: %w", err)
|
||||
|
|
@ -177,32 +189,63 @@ func runProxyServer(cmd *cobra.Command, args []string) error {
|
|||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
if ds != nil {
|
||||
//nolint:errcheck
|
||||
ds.Teardown()
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for either the Datastore server connection or the proxy server to return an error
|
||||
// or for both to finish successfully.
|
||||
return <-errChan
|
||||
}
|
||||
|
||||
func newModelRegistryService(ds datastore.Connector) (api.ModelRegistryApi, error) {
|
||||
repoSet, err := ds.Connect(service.DatastoreSpec())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelRegistryService := core.NewModelRegistryService(
|
||||
getRepo[models.ArtifactRepository](repoSet),
|
||||
getRepo[models.ModelArtifactRepository](repoSet),
|
||||
getRepo[models.DocArtifactRepository](repoSet),
|
||||
getRepo[models.RegisteredModelRepository](repoSet),
|
||||
getRepo[models.ModelVersionRepository](repoSet),
|
||||
getRepo[models.ServingEnvironmentRepository](repoSet),
|
||||
getRepo[models.InferenceServiceRepository](repoSet),
|
||||
getRepo[models.ServeModelRepository](repoSet),
|
||||
getRepo[models.ExperimentRepository](repoSet),
|
||||
getRepo[models.ExperimentRunRepository](repoSet),
|
||||
getRepo[models.DataSetRepository](repoSet),
|
||||
getRepo[models.MetricRepository](repoSet),
|
||||
getRepo[models.ParameterRepository](repoSet),
|
||||
getRepo[models.MetricHistoryRepository](repoSet),
|
||||
repoSet.TypeMap(),
|
||||
)
|
||||
|
||||
glog.Infof("EmbedMD service connected")
|
||||
|
||||
return modelRegistryService, nil
|
||||
}
|
||||
|
||||
func getRepo[T any](repoSet datastore.RepoSet) T {
|
||||
repo, err := repoSet.Repository(reflect.TypeFor[T]())
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("unable to get repository: %v", err))
|
||||
}
|
||||
|
||||
return repo.(T)
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(proxyCmd)
|
||||
|
||||
proxyCmd.Flags().StringVarP(&cfg.Hostname, "hostname", "n", cfg.Hostname, "Proxy server listen hostname")
|
||||
proxyCmd.Flags().IntVarP(&cfg.Port, "port", "p", cfg.Port, "Proxy server listen port")
|
||||
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.EmbedMD.DatabaseType, "embedmd-database-type", "mysql", "EmbedMD database type")
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.EmbedMD.DatabaseDSN, "embedmd-database-dsn", "", "EmbedMD database DSN")
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.EmbedMD.TLSConfig.CertPath, "embedmd-database-ssl-cert", "", "EmbedMD SSL cert path")
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.EmbedMD.TLSConfig.KeyPath, "embedmd-database-ssl-key", "", "EmbedMD SSL key path")
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.EmbedMD.TLSConfig.RootCertPath, "embedmd-database-ssl-root-cert", "", "EmbedMD SSL root cert path")
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.EmbedMD.TLSConfig.CAPath, "embedmd-database-ssl-ca", "", "EmbedMD SSL CA path")
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.EmbedMD.TLSConfig.Cipher, "embedmd-database-ssl-cipher", "", "Colon-separated list of allowed TLS ciphers for the EmbedMD database connection. Values are from the list at https://pkg.go.dev/crypto/tls#pkg-constants e.g. 'TLS_AES_128_GCM_SHA256:TLS_CHACHA20_POLY1305_SHA256'")
|
||||
proxyCmd.Flags().BoolVar(&proxyCfg.Datastore.EmbedMD.TLSConfig.VerifyServerCert, "embedmd-database-ssl-verify-server-cert", false, "EmbedMD SSL verify server cert")
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.DatabaseType, "embedmd-database-type", "mysql", "EmbedMD database type")
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.DatabaseDSN, "embedmd-database-dsn", "", "EmbedMD database DSN")
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.CertPath, "embedmd-database-ssl-cert", "", "EmbedMD SSL cert path")
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.KeyPath, "embedmd-database-ssl-key", "", "EmbedMD SSL key path")
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.RootCertPath, "embedmd-database-ssl-root-cert", "", "EmbedMD SSL root cert path")
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.CAPath, "embedmd-database-ssl-ca", "", "EmbedMD SSL CA path")
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.Cipher, "embedmd-database-ssl-cipher", "", "Colon-separated list of allowed TLS ciphers for the EmbedMD database connection. Values are from the list at https://pkg.go.dev/crypto/tls#pkg-constants e.g. 'TLS_AES_128_GCM_SHA256:TLS_CHACHA20_POLY1305_SHA256'")
|
||||
proxyCmd.Flags().BoolVar(&proxyCfg.EmbedMD.TLSConfig.VerifyServerCert, "embedmd-database-ssl-verify-server-cert", false, "EmbedMD SSL verify server cert")
|
||||
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.Datastore.Type, "datastore-type", proxyCfg.Datastore.Type, "Datastore type")
|
||||
proxyCmd.Flags().StringVar(&proxyCfg.DatastoreType, "datastore-type", proxyCfg.DatastoreType, "Datastore type")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ k8s_resource(
|
|||
new_name="db",
|
||||
labels="backend",
|
||||
resource_deps=["kubeflow-namespace"],
|
||||
port_forwards="3306:3306",
|
||||
)
|
||||
|
||||
k8s_resource(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
manifests = kustomize("../../../manifests/kustomize/options/catalog")
|
||||
manifests = kustomize("../../../manifests/kustomize/options/catalog/base")
|
||||
|
||||
objects = decode_yaml_stream(manifests)
|
||||
|
||||
|
|
@ -20,3 +20,12 @@ k8s_resource(
|
|||
port_forwards="8082:8080",
|
||||
trigger_mode=TRIGGER_MODE_AUTO
|
||||
)
|
||||
|
||||
k8s_resource(
|
||||
workload="model-catalog-postgres",
|
||||
new_name="catalog-db",
|
||||
labels="backend",
|
||||
resource_deps=["kubeflow-namespace"],
|
||||
port_forwards="5432:5432",
|
||||
trigger_mode=TRIGGER_MODE_AUTO
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,15 @@ services:
|
|||
- "8081:8081"
|
||||
volumes:
|
||||
- ./catalog/internal/catalog/testdata:/testdata
|
||||
depends_on:
|
||||
- postgres
|
||||
profiles:
|
||||
- postgres
|
||||
environment:
|
||||
- PGHOST=postgres
|
||||
- PGDATABASE=model_catalog
|
||||
- PGUSER=postgres
|
||||
- PGPASSWORD=demo
|
||||
|
||||
model-registry:
|
||||
build:
|
||||
|
|
@ -82,7 +91,16 @@ services:
|
|||
start_period: 20s
|
||||
profiles:
|
||||
- postgres
|
||||
configs:
|
||||
- source: pginit
|
||||
target: /docker-entrypoint-initdb.d/pginit.sql
|
||||
|
||||
volumes:
|
||||
mysql_data:
|
||||
postgres_data:
|
||||
|
||||
configs:
|
||||
pginit:
|
||||
content: |
|
||||
CREATE DATABASE model_catalog;
|
||||
GRANT ALL PRIVILEGES ON model_catalog TO postgres;
|
||||
|
|
|
|||
|
|
@ -18,6 +18,15 @@ services:
|
|||
- "8081:8081"
|
||||
volumes:
|
||||
- ./catalog/internal/catalog/testdata:/testdata
|
||||
depends_on:
|
||||
- postgres
|
||||
profiles:
|
||||
- postgres
|
||||
environment:
|
||||
- PGHOST=postgres
|
||||
- PGDATABASE=model_catalog
|
||||
- PGUSER=postgres
|
||||
- PGPASSWORD=demo
|
||||
|
||||
model-registry:
|
||||
image: ghcr.io/kubeflow/model-registry/server:latest
|
||||
|
|
@ -80,7 +89,16 @@ services:
|
|||
start_period: 20s
|
||||
profiles:
|
||||
- postgres
|
||||
configs:
|
||||
- source: pginit
|
||||
target: /docker-entrypoint-initdb.d/pginit.sql
|
||||
|
||||
volumes:
|
||||
mysql_data:
|
||||
postgres_data:
|
||||
|
||||
configs:
|
||||
pginit:
|
||||
content: |
|
||||
CREATE DATABASE model_catalog;
|
||||
GRANT ALL PRIVILEGES ON model_catalog TO postgres;
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
|
||||
func setupTestDB(t *testing.T) (*gorm.DB, func()) {
|
||||
db, dbCleanup := testutils.SetupMySQLWithMigrations(t)
|
||||
db, dbCleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
|
||||
// Clean up test data before each test
|
||||
testutils.CleanupTestData(t, db)
|
||||
|
|
@ -69,7 +69,14 @@ func createModelRegistryService(t *testing.T, db *gorm.DB) *core.ModelRegistrySe
|
|||
typesMap := getTypeIDs(t, db)
|
||||
|
||||
// Create all repositories
|
||||
artifactRepo := service.NewArtifactRepository(db, typesMap[defaults.ModelArtifactTypeName], typesMap[defaults.DocArtifactTypeName], typesMap[defaults.DataSetTypeName], typesMap[defaults.MetricTypeName], typesMap[defaults.ParameterTypeName], typesMap[defaults.MetricHistoryTypeName])
|
||||
artifactRepo := service.NewArtifactRepository(db, map[string]int64{
|
||||
defaults.ModelArtifactTypeName: typesMap[defaults.ModelArtifactTypeName],
|
||||
defaults.DocArtifactTypeName: typesMap[defaults.DocArtifactTypeName],
|
||||
defaults.DataSetTypeName: typesMap[defaults.DataSetTypeName],
|
||||
defaults.MetricTypeName: typesMap[defaults.MetricTypeName],
|
||||
defaults.ParameterTypeName: typesMap[defaults.ParameterTypeName],
|
||||
defaults.MetricHistoryTypeName: typesMap[defaults.MetricHistoryTypeName],
|
||||
})
|
||||
modelArtifactRepo := service.NewModelArtifactRepository(db, typesMap[defaults.ModelArtifactTypeName])
|
||||
docArtifactRepo := service.NewDocArtifactRepository(db, typesMap[defaults.DocArtifactTypeName])
|
||||
registeredModelRepo := service.NewRegisteredModelRepository(db, typesMap[defaults.RegisteredModelTypeName])
|
||||
|
|
|
|||
|
|
@ -3,9 +3,8 @@ package datastore
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/kubeflow/model-registry/internal/datastore/embedmd"
|
||||
"github.com/kubeflow/model-registry/pkg/api"
|
||||
"maps"
|
||||
"slices"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -13,32 +12,29 @@ var (
|
|||
ErrUnsupportedDatastore = errors.New("unsupported datastore type")
|
||||
)
|
||||
|
||||
type TeardownFunc func() error
|
||||
|
||||
type Datastore struct {
|
||||
EmbedMD embedmd.EmbedMDConfig
|
||||
Type string
|
||||
}
|
||||
|
||||
type Connector interface {
|
||||
Connect() (api.ModelRegistryApi, error)
|
||||
Teardown() error
|
||||
Type() string
|
||||
Connect(spec *Spec) (RepoSet, error)
|
||||
}
|
||||
|
||||
func NewConnector(ds Datastore) (Connector, error) {
|
||||
switch ds.Type {
|
||||
case "embedmd":
|
||||
if err := ds.EmbedMD.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid EmbedMD config: %w", err)
|
||||
}
|
||||
var connectorTypes map[string]func(any) (Connector, error)
|
||||
|
||||
embedmd, err := embedmd.NewEmbedMDService(&ds.EmbedMD)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating EmbedMD service: %w", err)
|
||||
}
|
||||
|
||||
return embedmd, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: %s. Supported types: embedmd", ErrUnsupportedDatastore, ds.Type)
|
||||
func Register(t string, fn func(config any) (Connector, error)) {
|
||||
if connectorTypes == nil {
|
||||
connectorTypes = make(map[string]func(any) (Connector, error), 1)
|
||||
}
|
||||
|
||||
if _, exists := connectorTypes[t]; exists {
|
||||
panic(fmt.Sprintf("duplicate connector type: %s", t))
|
||||
}
|
||||
|
||||
connectorTypes[t] = fn
|
||||
}
|
||||
|
||||
func NewConnector(t string, config any) (Connector, error) {
|
||||
if fn, ok := connectorTypes[t]; ok {
|
||||
return fn(config)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("%w: %s. Supported types: %v", ErrUnsupportedDatastore, t, slices.Sorted(maps.Keys(connectorTypes)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,19 +23,6 @@ func (Type) TableName() string {
|
|||
return "Type"
|
||||
}
|
||||
|
||||
// TypeProperty represents the TypeProperty table structure
|
||||
type TypeProperty struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
TypeID int64 `gorm:"column:type_id"`
|
||||
Name string `gorm:"column:name"`
|
||||
DataType string `gorm:"column:data_type"`
|
||||
Description string `gorm:"column:description"`
|
||||
}
|
||||
|
||||
func (TypeProperty) TableName() string {
|
||||
return "TypeProperty"
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(testutils.TestMainHelper(m))
|
||||
}
|
||||
|
|
@ -63,11 +50,6 @@ func TestMigrations(t *testing.T) {
|
|||
err = sharedDB.Model(&Type{}).Count(&count).Error
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, count, int64(0))
|
||||
|
||||
// Verify TypeProperty table has expected entries
|
||||
err = sharedDB.Model(&TypeProperty{}).Count(&count).Error
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, count, int64(0))
|
||||
}
|
||||
|
||||
func TestDownMigrations(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -7,12 +7,5 @@ DELETE FROM `Type` WHERE `name` IN (
|
|||
'mlmd.Transform',
|
||||
'mlmd.Process',
|
||||
'mlmd.Evaluate',
|
||||
'mlmd.Deploy',
|
||||
'kf.RegisteredModel',
|
||||
'kf.ModelVersion',
|
||||
'kf.DocArtifact',
|
||||
'kf.ModelArtifact',
|
||||
'kf.ServingEnvironment',
|
||||
'kf.InferenceService',
|
||||
'kf.ServeModel'
|
||||
'mlmd.Deploy'
|
||||
);
|
||||
|
|
|
|||
|
|
@ -9,13 +9,6 @@ SELECT t.* FROM (
|
|||
UNION ALL SELECT 'mlmd.Process', NULL, 0, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'mlmd.Evaluate', NULL, 0, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'mlmd.Deploy', NULL, 0, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.RegisteredModel', NULL, 2, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.ModelVersion', NULL, 2, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.DocArtifact', NULL, 1, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.ModelArtifact', NULL, 1, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.ServingEnvironment', NULL, 2, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.InferenceService', NULL, 2, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.ServeModel', NULL, 0, NULL, NULL, NULL, NULL
|
||||
) t
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM `Type`
|
||||
|
|
|
|||
|
|
@ -1,27 +1 @@
|
|||
DELETE FROM `TypeProperty`
|
||||
WHERE (`type_id`, `name`, `data_type`) IN (
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'description', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'owner', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'state', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'author', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'description', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'model_name', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'state', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'version', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.DocArtifact'), 'description', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'description', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'model_format_name', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'model_format_version', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'service_account_name', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'storage_key', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'storage_path', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ServingEnvironment'), 'description', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'description', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'desired_state', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'model_version_id', 1),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'registered_model_id', 1),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'runtime', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'serving_environment_id', 1),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ServeModel'), 'description', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ServeModel'), 'model_version_id', 1)
|
||||
);
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -1,25 +1 @@
|
|||
INSERT IGNORE INTO `TypeProperty` (`type_id`, `name`, `data_type`)
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'owner', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'state', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'author', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'model_name', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'state', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelVersion'), 'version', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.DocArtifact'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'model_format_name', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'model_format_version', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'service_account_name', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'storage_key', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ModelArtifact'), 'storage_path', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ServingEnvironment'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'desired_state', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'model_version_id', 1 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'registered_model_id', 1 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'runtime', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.InferenceService'), 'serving_environment_id', 1 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ServeModel'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ServeModel'), 'model_version_id', 1;
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -1,13 +1 @@
|
|||
DELETE FROM `TypeProperty` WHERE type_id=(
|
||||
SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'
|
||||
) AND `name` IN (
|
||||
'language',
|
||||
'library_name',
|
||||
'license_link',
|
||||
'license',
|
||||
'logo',
|
||||
'maturity',
|
||||
'provider',
|
||||
'readme',
|
||||
'tasks'
|
||||
);
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -1,10 +1 @@
|
|||
INSERT IGNORE INTO `TypeProperty` (`type_id`, `name`, `data_type`)
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'language', 4 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'library_name', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'license_link', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'license', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'logo', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'maturity', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'provider', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'readme', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.RegisteredModel'), 'tasks', 4;
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -1,8 +1 @@
|
|||
DELETE FROM `Type` WHERE `name` IN (
|
||||
'kf.MetricHistory',
|
||||
'kf.Experiment',
|
||||
'kf.ExperimentRun',
|
||||
'kf.DataSet',
|
||||
'kf.Metric',
|
||||
'kf.Parameter'
|
||||
);
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -1,13 +1 @@
|
|||
INSERT INTO `Type` (`name`, `version`, `type_kind`, `description`, `input_type`, `output_type`, `external_id`)
|
||||
SELECT t.* FROM (
|
||||
SELECT 'kf.MetricHistory' as name, NULL as version, 1 as type_kind, NULL as description, NULL as input_type, NULL as output_type, NULL as external_id
|
||||
UNION ALL SELECT 'kf.Experiment', NULL, 2, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.ExperimentRun', NULL, 2, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.DataSet', NULL, 1, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.Metric', NULL, 1, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.Parameter', NULL, 1, NULL, NULL, NULL, NULL
|
||||
) t
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM `Type`
|
||||
WHERE `name` = t.name AND `version` IS NULL
|
||||
);
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -1,30 +1 @@
|
|||
DELETE FROM `TypeProperty`
|
||||
WHERE (`type_id`, `name`, `data_type`) IN (
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'description', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'value', 5),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'timestamp', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'step', 1),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.Parameter'), 'description', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.Parameter'), 'value', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.Parameter'), 'parameter_type', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'description', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'value', 5),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'timestamp', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'step', 1),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'description', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'digest', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'source_type', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'source', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'schema', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'profile', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.Experiment'), 'description', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.Experiment'), 'owner', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.Experiment'), 'state', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'description', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'owner', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'state', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'status', 3),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'start_time_since_epoch', 1),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'end_time_since_epoch', 1),
|
||||
((SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'experiment_id', 1)
|
||||
);
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -1,28 +1 @@
|
|||
INSERT IGNORE INTO `TypeProperty` (`type_id`, `name`, `data_type`)
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'value', 5 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'timestamp', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Metric'), 'step', 1 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Parameter'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Parameter'), 'value', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Parameter'), 'parameter_type', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'value', 5 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'timestamp', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.MetricHistory'), 'step', 1 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'digest', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'source_type', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'source', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'schema', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.DataSet'), 'profile', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Experiment'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Experiment'), 'owner', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.Experiment'), 'state', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'owner', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'state', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'status', 3 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'start_time_since_epoch', 1 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'end_time_since_epoch', 1 UNION ALL
|
||||
SELECT (SELECT id FROM `Type` WHERE name = 'kf.ExperimentRun'), 'experiment_id', 1;
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -28,19 +28,6 @@ func (Type) TableName() string {
|
|||
return "Type"
|
||||
}
|
||||
|
||||
// TypeProperty represents the TypeProperty table structure
|
||||
type TypeProperty struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
TypeID int64 `gorm:"column:type_id"`
|
||||
Name string `gorm:"column:name"`
|
||||
DataType string `gorm:"column:data_type"`
|
||||
Description string `gorm:"column:description"`
|
||||
}
|
||||
|
||||
func (TypeProperty) TableName() string {
|
||||
return "TypeProperty"
|
||||
}
|
||||
|
||||
// Package-level shared database instance
|
||||
var (
|
||||
sharedDB *gorm.DB
|
||||
|
|
@ -190,11 +177,6 @@ func TestMigrations(t *testing.T) {
|
|||
err = sharedDB.Model(&Type{}).Count(&count).Error
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, count, int64(0))
|
||||
|
||||
// Verify TypeProperty table has expected entries
|
||||
err = sharedDB.Model(&TypeProperty{}).Count(&count).Error
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, count, int64(0))
|
||||
}
|
||||
|
||||
func TestDownMigrations(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -7,12 +7,5 @@ DELETE FROM "Type" WHERE name IN (
|
|||
'mlmd.Transform',
|
||||
'mlmd.Process',
|
||||
'mlmd.Evaluate',
|
||||
'mlmd.Deploy',
|
||||
'kf.RegisteredModel',
|
||||
'kf.ModelVersion',
|
||||
'kf.DocArtifact',
|
||||
'kf.ModelArtifact',
|
||||
'kf.ServingEnvironment',
|
||||
'kf.InferenceService',
|
||||
'kf.ServeModel'
|
||||
);
|
||||
'mlmd.Deploy'
|
||||
);
|
||||
|
|
|
|||
|
|
@ -10,15 +10,8 @@ SELECT t.* FROM (
|
|||
UNION ALL SELECT 'mlmd.Process', NULL, 0, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'mlmd.Evaluate', NULL, 0, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'mlmd.Deploy', NULL, 0, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.RegisteredModel', NULL, 2, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.ModelVersion', NULL, 2, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.DocArtifact', NULL, 1, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.ModelArtifact', NULL, 1, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.ServingEnvironment', NULL, 2, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.InferenceService', NULL, 2, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.ServeModel', NULL, 0, NULL, NULL, NULL, NULL
|
||||
) t
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM "Type"
|
||||
WHERE name = t.name AND version IS NULL
|
||||
);
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,28 +1 @@
|
|||
-- Clear seeded TypeProperty data
|
||||
DELETE FROM "TypeProperty"
|
||||
WHERE (type_id, name, data_type) IN (
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'description', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'owner', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'state', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'author', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'description', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'model_name', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'state', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'version', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.DocArtifact'), 'description', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'description', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'model_format_name', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'model_format_version', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'service_account_name', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'storage_key', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'storage_path', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ServingEnvironment'), 'description', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'description', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'desired_state', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'model_version_id', 1),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'registered_model_id', 1),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'runtime', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'serving_environment_id', 1),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ServeModel'), 'description', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ServeModel'), 'model_version_id', 1)
|
||||
);
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -1,26 +1 @@
|
|||
-- Seed TypeProperty table
|
||||
INSERT INTO "TypeProperty" (type_id, name, data_type)
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'owner', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'state', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'author', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'model_name', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'state', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelVersion'), 'version', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.DocArtifact'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'model_format_name', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'model_format_version', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'service_account_name', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'storage_key', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ModelArtifact'), 'storage_path', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ServingEnvironment'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'desired_state', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'model_version_id', 1 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'registered_model_id', 1 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'runtime', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.InferenceService'), 'serving_environment_id', 1 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ServeModel'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ServeModel'), 'model_version_id', 1;
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -1,13 +1 @@
|
|||
DELETE FROM "TypeProperty" WHERE type_id=(
|
||||
SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'
|
||||
) AND name IN (
|
||||
'language',
|
||||
'library_name',
|
||||
'license_link',
|
||||
'license',
|
||||
'logo',
|
||||
'maturity',
|
||||
'provider',
|
||||
'readme',
|
||||
'tasks'
|
||||
);
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -1,10 +1 @@
|
|||
INSERT INTO "TypeProperty" (type_id, name, data_type)
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'language', 4 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'library_name', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'license_link', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'license', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'logo', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'maturity', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'provider', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'readme', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.RegisteredModel'), 'tasks', 4;
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -1,8 +1 @@
|
|||
DELETE FROM "Type" WHERE name IN (
|
||||
'kf.MetricHistory',
|
||||
'kf.Experiment',
|
||||
'kf.ExperimentRun',
|
||||
'kf.DataSet',
|
||||
'kf.Metric',
|
||||
'kf.Parameter'
|
||||
);
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -1,14 +1 @@
|
|||
-- Add missing types for experiment tracking and metric history
|
||||
INSERT INTO "Type" (name, version, type_kind, description, input_type, output_type, external_id)
|
||||
SELECT t.* FROM (
|
||||
SELECT 'kf.MetricHistory' as name, NULL as version, 1 as type_kind, NULL as description, NULL as input_type, NULL as output_type, NULL as external_id
|
||||
UNION ALL SELECT 'kf.Experiment', NULL, 2, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.ExperimentRun', NULL, 2, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.DataSet', NULL, 1, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.Metric', NULL, 1, NULL, NULL, NULL, NULL
|
||||
UNION ALL SELECT 'kf.Parameter', NULL, 1, NULL, NULL, NULL, NULL
|
||||
) t
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM "Type"
|
||||
WHERE name = t.name AND version IS NULL
|
||||
);
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -1,31 +1 @@
|
|||
-- Remove TypeProperty entries for experiment-related types
|
||||
DELETE FROM "TypeProperty"
|
||||
WHERE (type_id, name, data_type) IN (
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'description', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'value', 5),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'timestamp', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'step', 1),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.Parameter'), 'description', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.Parameter'), 'value', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.Parameter'), 'parameter_type', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'description', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'value', 5),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'timestamp', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'step', 1),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'description', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'digest', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'source_type', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'source', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'schema', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'profile', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.Experiment'), 'description', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.Experiment'), 'owner', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.Experiment'), 'state', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'description', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'owner', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'state', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'status', 3),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'start_time_since_epoch', 1),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'end_time_since_epoch', 1),
|
||||
((SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'experiment_id', 1)
|
||||
);
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -1,29 +1 @@
|
|||
-- Add TypeProperty entries for experiment-related types
|
||||
INSERT INTO "TypeProperty" (type_id, name, data_type)
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'value', 5 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'timestamp', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Metric'), 'step', 1 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Parameter'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Parameter'), 'value', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Parameter'), 'parameter_type', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'value', 5 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'timestamp', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.MetricHistory'), 'step', 1 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'digest', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'source_type', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'source', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'schema', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.DataSet'), 'profile', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Experiment'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Experiment'), 'owner', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.Experiment'), 'state', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'description', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'owner', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'state', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'status', 3 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'start_time_since_epoch', 1 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'end_time_since_epoch', 1 UNION ALL
|
||||
SELECT (SELECT id FROM "Type" WHERE name = 'kf.ExperimentRun'), 'experiment_id', 1;
|
||||
-- Migration removed
|
||||
|
|
|
|||
|
|
@ -0,0 +1,196 @@
|
|||
package embedmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/golang/glog"
|
||||
"github.com/kubeflow/model-registry/internal/datastore"
|
||||
"github.com/kubeflow/model-registry/internal/db/service"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ datastore.RepoSet = (*repoSetImpl)(nil)
|
||||
|
||||
type repoSetImpl struct {
|
||||
db *gorm.DB
|
||||
spec *datastore.Spec
|
||||
nameIDMap map[string]int64
|
||||
repos map[reflect.Type]any
|
||||
}
|
||||
|
||||
func newRepoSet(db *gorm.DB, spec *datastore.Spec) (datastore.RepoSet, error) {
|
||||
typeRepository := service.NewTypeRepository(db)
|
||||
|
||||
glog.Infof("Getting types...")
|
||||
|
||||
types, err := typeRepository.GetAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nameIDMap := make(map[string]int64, len(types))
|
||||
for _, t := range types {
|
||||
nameIDMap[*t.GetAttributes().Name] = int64(*t.GetID())
|
||||
}
|
||||
|
||||
glog.Infof("Types retrieved")
|
||||
|
||||
// Add debug logging to see what types are actually available
|
||||
glog.V(2).Infof("DEBUG: Available types:")
|
||||
for typeName, typeID := range nameIDMap {
|
||||
glog.V(2).Infof(" %s = %d", typeName, typeID)
|
||||
}
|
||||
|
||||
// Validate that all required types are registered
|
||||
requiredTypes := spec.AllNames()
|
||||
for _, requiredType := range requiredTypes {
|
||||
if _, exists := nameIDMap[requiredType]; !exists {
|
||||
return nil, fmt.Errorf("required type '%s' not found in database. Please ensure all migrations have been applied", requiredType)
|
||||
}
|
||||
}
|
||||
|
||||
glog.Infof("All required types validated successfully")
|
||||
|
||||
rs := &repoSetImpl{
|
||||
db: db,
|
||||
spec: spec,
|
||||
nameIDMap: nameIDMap,
|
||||
repos: make(map[reflect.Type]any, len(requiredTypes)+1),
|
||||
}
|
||||
|
||||
artifactTypes := makeTypeMap[datastore.ArtifactTypeMap](spec.ArtifactTypes, nameIDMap)
|
||||
contextTypes := makeTypeMap[datastore.ContextTypeMap](spec.ContextTypes, nameIDMap)
|
||||
executionTypes := makeTypeMap[datastore.ExecutionTypeMap](spec.ExecutionTypes, nameIDMap)
|
||||
|
||||
args := map[reflect.Type]any{
|
||||
reflect.TypeOf(db): db,
|
||||
reflect.TypeOf(artifactTypes): artifactTypes,
|
||||
reflect.TypeOf(contextTypes): contextTypes,
|
||||
reflect.TypeOf(executionTypes): executionTypes,
|
||||
}
|
||||
|
||||
for i, fn := range spec.Others {
|
||||
repo, err := rs.call(fn, args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embedmd: other %d: %w", i, err)
|
||||
}
|
||||
rs.put(repo)
|
||||
}
|
||||
|
||||
for name, specType := range spec.ArtifactTypes {
|
||||
args[reflect.TypeOf(nameIDMap[name])] = nameIDMap[name]
|
||||
|
||||
repo, err := rs.call(specType.InitFn, args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embedmd: %s: %w", name, err)
|
||||
}
|
||||
rs.put(repo)
|
||||
}
|
||||
|
||||
for name, specType := range spec.ContextTypes {
|
||||
args[reflect.TypeOf(nameIDMap[name])] = nameIDMap[name]
|
||||
|
||||
repo, err := rs.call(specType.InitFn, args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embedmd: %s: %w", name, err)
|
||||
}
|
||||
rs.put(repo)
|
||||
}
|
||||
|
||||
for name, specType := range spec.ExecutionTypes {
|
||||
args[reflect.TypeOf(nameIDMap[name])] = nameIDMap[name]
|
||||
repo, err := rs.call(specType.InitFn, args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embedmd: %s: %w", name, err)
|
||||
}
|
||||
rs.put(repo)
|
||||
}
|
||||
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
// call invokes the function pointed to by fn. It matches fn's arguments to the
|
||||
// types in args. fn must return at least one argument, and may optionally
|
||||
// return an error.
|
||||
func (rs *repoSetImpl) call(fn any, args map[reflect.Type]any) (any, error) {
|
||||
t := reflect.TypeOf(fn)
|
||||
if t.Kind() != reflect.Func {
|
||||
return nil, fmt.Errorf("initializer is not a function (got type %T)", fn)
|
||||
}
|
||||
|
||||
switch t.NumOut() {
|
||||
case 0:
|
||||
return nil, fmt.Errorf("initializer has no return value")
|
||||
case 1, 2:
|
||||
// OK
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown initializer type, more than 2 return values")
|
||||
}
|
||||
|
||||
fnArgs := make([]reflect.Value, t.NumIn())
|
||||
for i := range t.NumIn() {
|
||||
v, ok := args[t.In(i)]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no initializer argument for type %v", t.In(i))
|
||||
}
|
||||
fnArgs[i] = reflect.ValueOf(v)
|
||||
}
|
||||
|
||||
out := reflect.ValueOf(fn).Call(fnArgs)
|
||||
|
||||
var err error
|
||||
if len(out) > 1 {
|
||||
ierr := out[1].Interface()
|
||||
if ierr != nil {
|
||||
var ok bool
|
||||
err, ok = ierr.(error)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown return value, expected error, got %T", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out[0].Interface(), err
|
||||
}
|
||||
|
||||
// put adds one repository to the set.
|
||||
func (rs *repoSetImpl) put(repo any) {
|
||||
rs.repos[reflect.TypeOf(repo)] = repo
|
||||
}
|
||||
|
||||
func (rs *repoSetImpl) Repository(t reflect.Type) (any, error) {
|
||||
// First try an exact match for the requested type.
|
||||
repo, ok := rs.repos[t]
|
||||
if ok {
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
// If the attempt above failed and the requested type is an interface,
|
||||
// use the first repo that implements it.
|
||||
if t.Kind() == reflect.Interface {
|
||||
for repoType, repo := range rs.repos {
|
||||
if repoType.Implements(t) {
|
||||
return repo, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown repository type: %s", t.Name())
|
||||
}
|
||||
|
||||
func (rs *repoSetImpl) TypeMap() map[string]int64 {
|
||||
clone := make(map[string]int64, len(rs.nameIDMap))
|
||||
for k, v := range rs.nameIDMap {
|
||||
clone[k] = v
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func makeTypeMap[T ~map[string]int64](specMap map[string]*datastore.SpecType, nameIDMap map[string]int64) T {
|
||||
returnMap := make(T, len(specMap))
|
||||
for k := range specMap {
|
||||
returnMap[k] = nameIDMap[k]
|
||||
}
|
||||
return returnMap
|
||||
}
|
||||
|
|
@ -0,0 +1,403 @@
|
|||
package embedmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/kubeflow/model-registry/internal/datastore"
|
||||
"github.com/kubeflow/model-registry/internal/datastore/embedmd/mysql"
|
||||
"github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/schema"
|
||||
"github.com/kubeflow/model-registry/internal/db/service"
|
||||
"github.com/kubeflow/model-registry/internal/tls"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cont_mysql "github.com/testcontainers/testcontainers-go/modules/mysql"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Mock repository types for testing
|
||||
type mockArtifactRepo struct {
|
||||
typeID int64
|
||||
}
|
||||
|
||||
type mockContextRepo struct {
|
||||
typeID int64
|
||||
}
|
||||
|
||||
type mockExecutionRepo struct {
|
||||
typeID int64
|
||||
}
|
||||
|
||||
type mockOtherRepo struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// Mock initializer functions
|
||||
func newMockArtifactRepo(db *gorm.DB, typeID int64) *mockArtifactRepo {
|
||||
return &mockArtifactRepo{typeID: typeID}
|
||||
}
|
||||
|
||||
func newMockContextRepo(db *gorm.DB, typeID int64) *mockContextRepo {
|
||||
return &mockContextRepo{typeID: typeID}
|
||||
}
|
||||
|
||||
func newMockExecutionRepo(db *gorm.DB, typeID int64) *mockExecutionRepo {
|
||||
return &mockExecutionRepo{typeID: typeID}
|
||||
}
|
||||
|
||||
func newMockOtherRepo(db *gorm.DB) *mockOtherRepo {
|
||||
return &mockOtherRepo{db: db}
|
||||
}
|
||||
|
||||
func newMockOtherRepoWithError(db *gorm.DB) (*mockOtherRepo, error) {
|
||||
return nil, errors.New("mock initialization error")
|
||||
}
|
||||
|
||||
// Test helper to create a test database with types using a simplified MySQL setup
|
||||
func setupTestDB(t *testing.T) (*gorm.DB, func()) {
|
||||
// Create a simple MySQL container without config file
|
||||
ctx := context.Background()
|
||||
|
||||
mysqlContainer, err := cont_mysql.Run(ctx, "mysql:8.3",
|
||||
cont_mysql.WithDatabase("test"),
|
||||
cont_mysql.WithUsername("root"),
|
||||
cont_mysql.WithPassword("root"),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get connection string
|
||||
dsn, err := mysqlContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect to database
|
||||
dbConnector := mysql.NewMySQLDBConnector(dsn, &tls.TLSConfig{})
|
||||
db, err := dbConnector.Connect()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create the required tables
|
||||
err = db.AutoMigrate(&schema.Type{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert test types
|
||||
testTypes := []schema.Type{
|
||||
{ID: 1, Name: "TestArtifact", TypeKind: 1},
|
||||
{ID: 2, Name: "TestDoc", TypeKind: 1},
|
||||
{ID: 3, Name: "TestContext", TypeKind: 2},
|
||||
{ID: 4, Name: "TestModel", TypeKind: 2},
|
||||
{ID: 5, Name: "TestExecution", TypeKind: 3},
|
||||
}
|
||||
|
||||
for _, typ := range testTypes {
|
||||
require.NoError(t, db.Create(&typ).Error)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
sqlDB, err := db.DB()
|
||||
if err == nil {
|
||||
//nolint:errcheck
|
||||
sqlDB.Close()
|
||||
}
|
||||
//nolint:errcheck
|
||||
mysqlContainer.Terminate(ctx)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
}
|
||||
|
||||
func TestNewRepoSet_Success(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
spec := datastore.NewSpec().
|
||||
AddArtifact("TestArtifact", datastore.NewSpecType(newMockArtifactRepo)).
|
||||
AddArtifact("TestDoc", datastore.NewSpecType(newMockArtifactRepo)).
|
||||
AddContext("TestContext", datastore.NewSpecType(newMockContextRepo)).
|
||||
AddContext("TestModel", datastore.NewSpecType(newMockContextRepo)).
|
||||
AddExecution("TestExecution", datastore.NewSpecType(newMockExecutionRepo)).
|
||||
AddOther(newMockOtherRepo)
|
||||
|
||||
repoSet, err := newRepoSet(db, spec)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, repoSet)
|
||||
|
||||
// Verify we can get repositories by type
|
||||
mockArtifact, err := repoSet.Repository(reflect.TypeOf(&mockArtifactRepo{}))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, mockArtifact)
|
||||
|
||||
mockContext, err := repoSet.Repository(reflect.TypeOf(&mockContextRepo{}))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, mockContext)
|
||||
|
||||
mockExecution, err := repoSet.Repository(reflect.TypeOf(&mockExecutionRepo{}))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, mockExecution)
|
||||
|
||||
mockOther, err := repoSet.Repository(reflect.TypeOf(&mockOtherRepo{}))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, mockOther)
|
||||
|
||||
// Verify TypeMap returns correct mappings
|
||||
typeMap := repoSet.TypeMap()
|
||||
assert.Equal(t, int64(1), typeMap["TestArtifact"])
|
||||
assert.Equal(t, int64(2), typeMap["TestDoc"])
|
||||
assert.Equal(t, int64(3), typeMap["TestContext"])
|
||||
assert.Equal(t, int64(4), typeMap["TestModel"])
|
||||
assert.Equal(t, int64(5), typeMap["TestExecution"])
|
||||
}
|
||||
|
||||
func TestNewRepoSet_MissingType(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
spec := datastore.NewSpec().
|
||||
AddArtifact("NonExistentType", datastore.NewSpecType(newMockArtifactRepo))
|
||||
|
||||
repoSet, err := newRepoSet(db, spec)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, repoSet)
|
||||
assert.Contains(t, err.Error(), "required type 'NonExistentType' not found in database")
|
||||
}
|
||||
|
||||
func TestNewRepoSet_InitializerError(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
spec := datastore.NewSpec().AddOther(newMockOtherRepoWithError)
|
||||
|
||||
repoSet, err := newRepoSet(db, spec)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, repoSet)
|
||||
assert.Contains(t, err.Error(), "mock initialization error")
|
||||
}
|
||||
|
||||
// Define a simple interface and implementation for interface testing
|
||||
type TestInterface interface {
|
||||
TestMethod() string
|
||||
}
|
||||
|
||||
type testImpl struct{}
|
||||
|
||||
func (ti *testImpl) TestMethod() string {
|
||||
return "test"
|
||||
}
|
||||
|
||||
func TestRepoSetImpl_Repository_InterfaceMatch(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create initializer that returns the implementation
|
||||
newTestImpl := func(db *gorm.DB) *testImpl {
|
||||
return &testImpl{}
|
||||
}
|
||||
|
||||
spec := datastore.NewSpec().AddOther(newTestImpl)
|
||||
|
||||
repoSet, err := newRepoSet(db, spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be able to get repository by interface type
|
||||
repo, err := repoSet.Repository(reflect.TypeOf((*TestInterface)(nil)).Elem())
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, repo)
|
||||
|
||||
// Verify it's the correct implementation
|
||||
impl, ok := repo.(*testImpl)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test", impl.TestMethod())
|
||||
}
|
||||
|
||||
func TestRepoSetImpl_Repository_UnknownType(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
spec := datastore.NewSpec().
|
||||
AddArtifact("TestArtifact", datastore.NewSpecType(newMockArtifactRepo))
|
||||
|
||||
repoSet, err := newRepoSet(db, spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get a repository type that doesn't exist
|
||||
type unknownType struct{}
|
||||
repo, err := repoSet.Repository(reflect.TypeOf(&unknownType{}))
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, repo)
|
||||
assert.Contains(t, err.Error(), "unknown repository type")
|
||||
}
|
||||
|
||||
func TestRepoSetImpl_Call_InvalidInitializer(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
rs := &repoSetImpl{
|
||||
db: db,
|
||||
}
|
||||
|
||||
args := map[reflect.Type]any{
|
||||
reflect.TypeOf(db): db,
|
||||
}
|
||||
|
||||
// Test with non-function
|
||||
_, err := rs.call("not a function", args)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "initializer is not a function")
|
||||
|
||||
// Test with function that has no return values
|
||||
noReturnFunc := func() {}
|
||||
_, err = rs.call(noReturnFunc, args)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "initializer has no return value")
|
||||
|
||||
// Test with function that has too many return values
|
||||
tooManyReturnsFunc := func() (int, string, error) {
|
||||
return 0, "", nil
|
||||
}
|
||||
_, err = rs.call(tooManyReturnsFunc, args)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "more than 2 return values")
|
||||
|
||||
// Test with missing argument type
|
||||
needsIntFunc := func(i int) string {
|
||||
return "test"
|
||||
}
|
||||
_, err = rs.call(needsIntFunc, args)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no initializer argument for type")
|
||||
}
|
||||
|
||||
func TestRepoSetImpl_Call_ValidInitializers(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
rs := &repoSetImpl{
|
||||
db: db,
|
||||
}
|
||||
|
||||
args := map[reflect.Type]any{
|
||||
reflect.TypeOf(db): db,
|
||||
reflect.TypeOf(int64(0)): int64(42),
|
||||
}
|
||||
|
||||
// Test function with one return value
|
||||
oneReturnFunc := func(db *gorm.DB) *mockOtherRepo {
|
||||
return &mockOtherRepo{db: db}
|
||||
}
|
||||
result, err := rs.call(oneReturnFunc, args)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
mockRepo, ok := result.(*mockOtherRepo)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, db, mockRepo.db)
|
||||
|
||||
// Test function with two return values (success case)
|
||||
twoReturnSuccessFunc := func(db *gorm.DB, id int64) (*mockArtifactRepo, error) {
|
||||
return &mockArtifactRepo{typeID: id}, nil
|
||||
}
|
||||
result, err = rs.call(twoReturnSuccessFunc, args)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
mockArtifact, ok := result.(*mockArtifactRepo)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, int64(42), mockArtifact.typeID)
|
||||
|
||||
// Test function with two return values (error case)
|
||||
twoReturnErrorFunc := func(db *gorm.DB) (*mockOtherRepo, error) {
|
||||
return nil, errors.New("initialization failed")
|
||||
}
|
||||
result, err = rs.call(twoReturnErrorFunc, args)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "initialization failed")
|
||||
}
|
||||
|
||||
func TestMakeTypeMap(t *testing.T) {
|
||||
specMap := map[string]*datastore.SpecType{
|
||||
"type1": datastore.NewSpecType("func1"),
|
||||
"type2": datastore.NewSpecType("func2"),
|
||||
"type3": datastore.NewSpecType("func3"),
|
||||
}
|
||||
|
||||
nameIDMap := map[string]int64{
|
||||
"type1": 10,
|
||||
"type2": 20,
|
||||
"type3": 30,
|
||||
}
|
||||
|
||||
// Test ArtifactTypeMap
|
||||
artifactMap := makeTypeMap[datastore.ArtifactTypeMap](specMap, nameIDMap)
|
||||
assert.Equal(t, datastore.ArtifactTypeMap{"type1": 10, "type2": 20, "type3": 30}, artifactMap)
|
||||
|
||||
// Test ContextTypeMap
|
||||
contextMap := makeTypeMap[datastore.ContextTypeMap](specMap, nameIDMap)
|
||||
assert.Equal(t, datastore.ContextTypeMap{"type1": 10, "type2": 20, "type3": 30}, contextMap)
|
||||
|
||||
// Test ExecutionTypeMap
|
||||
executionMap := makeTypeMap[datastore.ExecutionTypeMap](specMap, nameIDMap)
|
||||
assert.Equal(t, datastore.ExecutionTypeMap{"type1": 10, "type2": 20, "type3": 30}, executionMap)
|
||||
}
|
||||
|
||||
func TestRepoSetImpl_TypeMapCloning(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
spec := datastore.NewSpec().
|
||||
AddArtifact("TestArtifact", datastore.NewSpecType(newMockArtifactRepo))
|
||||
|
||||
repoSet, err := newRepoSet(db, spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
rs := repoSet.(*repoSetImpl)
|
||||
|
||||
// Get the type map
|
||||
typeMap1 := rs.TypeMap()
|
||||
typeMap2 := rs.TypeMap()
|
||||
|
||||
// Verify they have the same values
|
||||
assert.Equal(t, typeMap1, typeMap2)
|
||||
|
||||
// Verify they are different objects (cloned)
|
||||
typeMap1["TestModification"] = 999
|
||||
assert.NotEqual(t, typeMap1, typeMap2)
|
||||
assert.NotContains(t, typeMap2, "TestModification")
|
||||
}
|
||||
|
||||
// Integration test with real service repositories
|
||||
func TestRepoSetImpl_WithRealRepositories(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create artifact types that match what the real repositories expect
|
||||
realArtifactTypes := []schema.Type{
|
||||
{ID: 10, Name: "model-artifact", TypeKind: 1},
|
||||
{ID: 11, Name: "doc-artifact", TypeKind: 1},
|
||||
}
|
||||
|
||||
for _, at := range realArtifactTypes {
|
||||
require.NoError(t, db.Create(&at).Error)
|
||||
}
|
||||
|
||||
spec := datastore.NewSpec().
|
||||
AddArtifact("model-artifact", datastore.NewSpecType(service.NewModelArtifactRepository)).
|
||||
AddArtifact("doc-artifact", datastore.NewSpecType(service.NewDocArtifactRepository)).
|
||||
AddOther(service.NewArtifactRepository)
|
||||
|
||||
repoSet, err := newRepoSet(db, spec)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, repoSet)
|
||||
|
||||
// Verify we can get the real repositories
|
||||
modelRepo, err := repoSet.Repository(reflect.TypeOf((*models.ModelArtifactRepository)(nil)).Elem())
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, modelRepo)
|
||||
|
||||
docRepo, err := repoSet.Repository(reflect.TypeOf((*models.DocArtifactRepository)(nil)).Elem())
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, docRepo)
|
||||
|
||||
artifactRepo, err := repoSet.Repository(reflect.TypeOf((*models.ArtifactRepository)(nil)).Elem())
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, artifactRepo)
|
||||
}
|
||||
|
|
@ -1,41 +1,69 @@
|
|||
package embedmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang/glog"
|
||||
"github.com/kubeflow/model-registry/internal/core"
|
||||
"github.com/kubeflow/model-registry/internal/apiutils"
|
||||
"github.com/kubeflow/model-registry/internal/datastore"
|
||||
"github.com/kubeflow/model-registry/internal/db"
|
||||
"github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/service"
|
||||
"github.com/kubeflow/model-registry/internal/db/types"
|
||||
"github.com/kubeflow/model-registry/internal/defaults"
|
||||
"github.com/kubeflow/model-registry/internal/tls"
|
||||
"github.com/kubeflow/model-registry/pkg/api"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const connectorType = "embedmd"
|
||||
|
||||
func init() {
|
||||
datastore.Register(connectorType, func(cfg any) (datastore.Connector, error) {
|
||||
emdbCfg, ok := cfg.(*EmbedMDConfig)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid EmbedMD config type (%T)", cfg)
|
||||
}
|
||||
|
||||
if err := emdbCfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid EmbedMD config: %w", err)
|
||||
}
|
||||
|
||||
return NewEmbedMDService(cfg.(*EmbedMDConfig))
|
||||
})
|
||||
}
|
||||
|
||||
type EmbedMDConfig struct {
|
||||
DatabaseType string
|
||||
DatabaseDSN string
|
||||
TLSConfig *tls.TLSConfig
|
||||
|
||||
// DB is an already connected database instance that, if provided, will
|
||||
// be used instead of making a new connection.
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
func (c *EmbedMDConfig) Validate() error {
|
||||
if c.DatabaseType != types.DatabaseTypeMySQL && c.DatabaseType != types.DatabaseTypePostgres {
|
||||
return fmt.Errorf("unsupported database type: %s. Supported types: %s, %s", c.DatabaseType, types.DatabaseTypeMySQL, types.DatabaseTypePostgres)
|
||||
if c.DB == nil {
|
||||
if c.DatabaseType != types.DatabaseTypeMySQL && c.DatabaseType != types.DatabaseTypePostgres {
|
||||
return fmt.Errorf("unsupported database type: %s. Supported types: %s, %s", c.DatabaseType, types.DatabaseTypeMySQL, types.DatabaseTypePostgres)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type EmbedMDService struct {
|
||||
*EmbedMDConfig
|
||||
dbConnector db.Connector
|
||||
}
|
||||
|
||||
func NewEmbedMDService(cfg *EmbedMDConfig) (*EmbedMDService, error) {
|
||||
err := db.Init(cfg.DatabaseType, cfg.DatabaseDSN, cfg.TLSConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize database connector: %w", err)
|
||||
if cfg.DB != nil {
|
||||
db.SetDB(cfg.DB)
|
||||
} else {
|
||||
err := db.Init(cfg.DatabaseType, cfg.DatabaseDSN, cfg.TLSConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize database connector: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
dbConnector, ok := db.GetConnector()
|
||||
|
|
@ -44,12 +72,11 @@ func NewEmbedMDService(cfg *EmbedMDConfig) (*EmbedMDService, error) {
|
|||
}
|
||||
|
||||
return &EmbedMDService{
|
||||
EmbedMDConfig: cfg,
|
||||
dbConnector: dbConnector,
|
||||
dbConnector: dbConnector,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *EmbedMDService) Connect() (api.ModelRegistryApi, error) {
|
||||
func (s *EmbedMDService) Connect(spec *datastore.Spec) (datastore.RepoSet, error) {
|
||||
glog.Infof("Connecting to EmbedMD service...")
|
||||
|
||||
connectedDB, err := s.dbConnector.Connect()
|
||||
|
|
@ -59,7 +86,7 @@ func (s *EmbedMDService) Connect() (api.ModelRegistryApi, error) {
|
|||
|
||||
glog.Infof("Connected to EmbedMD service")
|
||||
|
||||
migrator, err := db.NewDBMigrator(s.DatabaseType, connectedDB)
|
||||
migrator, err := db.NewDBMigrator(connectedDB)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -73,93 +100,88 @@ func (s *EmbedMDService) Connect() (api.ModelRegistryApi, error) {
|
|||
|
||||
glog.Infof("Migrations completed")
|
||||
|
||||
typeRepository := service.NewTypeRepository(connectedDB)
|
||||
|
||||
glog.Infof("Getting types...")
|
||||
|
||||
types, err := typeRepository.GetAll()
|
||||
glog.Infof("Syncing types...")
|
||||
err = s.syncTypes(connectedDB, spec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
glog.Infof("Syncing types completed")
|
||||
|
||||
typesMap := make(map[string]int64)
|
||||
return newRepoSet(connectedDB, spec)
|
||||
}
|
||||
|
||||
for _, t := range types {
|
||||
typesMap[*t.GetAttributes().Name] = int64(*t.GetID())
|
||||
func (s EmbedMDService) Type() string {
|
||||
return connectorType
|
||||
}
|
||||
|
||||
const (
|
||||
executionTypeKind int32 = iota
|
||||
artifactTypeKind
|
||||
contextTypeKind
|
||||
)
|
||||
|
||||
func (s *EmbedMDService) syncTypes(conn *gorm.DB, spec *datastore.Spec) error {
|
||||
idMap := make(map[string]int32, len(spec.ExecutionTypes)+len(spec.ArtifactTypes)+len(spec.ContextTypes))
|
||||
var errs []error
|
||||
|
||||
typeRepository := service.NewTypeRepository(conn)
|
||||
errs = append(errs, s.createTypes(typeRepository, spec.ExecutionTypes, executionTypeKind, idMap))
|
||||
errs = append(errs, s.createTypes(typeRepository, spec.ArtifactTypes, artifactTypeKind, idMap))
|
||||
errs = append(errs, s.createTypes(typeRepository, spec.ContextTypes, contextTypeKind, idMap))
|
||||
|
||||
typePropertyRepository := service.NewTypePropertyRepository(conn)
|
||||
errs = append(errs, s.createTypeProperties(typePropertyRepository, spec.ExecutionTypes, idMap))
|
||||
errs = append(errs, s.createTypeProperties(typePropertyRepository, spec.ArtifactTypes, idMap))
|
||||
errs = append(errs, s.createTypeProperties(typePropertyRepository, spec.ContextTypes, idMap))
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (s *EmbedMDService) createTypes(repo models.TypeRepository, types map[string]*datastore.SpecType, kind int32, idMap map[string]int32) error {
|
||||
var errs []error
|
||||
for name := range types {
|
||||
t, err := repo.Save(&models.TypeImpl{
|
||||
Attributes: &models.TypeAttributes{
|
||||
Name: &name,
|
||||
TypeKind: &kind,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("%s: unable to create type: %w", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
id := t.GetID()
|
||||
if id == nil {
|
||||
errs = append(errs, fmt.Errorf("%s: unable to determine type ID", name))
|
||||
continue
|
||||
}
|
||||
idMap[name] = *id
|
||||
}
|
||||
|
||||
glog.Infof("Types retrieved")
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// Add debug logging to see what types are actually available
|
||||
glog.V(2).Infof("DEBUG: Available types in typesMap:")
|
||||
for typeName, typeID := range typesMap {
|
||||
glog.V(2).Infof(" %s = %d", typeName, typeID)
|
||||
}
|
||||
func (s *EmbedMDService) createTypeProperties(repo models.TypePropertyRepository, types map[string]*datastore.SpecType, idMap map[string]int32) error {
|
||||
var errs []error
|
||||
for typeName, typeSpec := range types {
|
||||
typeID := idMap[typeName]
|
||||
if typeID == 0 {
|
||||
errs = append(errs, fmt.Errorf("%s: unknown type", typeName))
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate that all required types are registered
|
||||
requiredTypes := []string{
|
||||
defaults.ModelArtifactTypeName,
|
||||
defaults.DocArtifactTypeName,
|
||||
defaults.DataSetTypeName,
|
||||
defaults.MetricTypeName,
|
||||
defaults.ParameterTypeName,
|
||||
defaults.MetricHistoryTypeName,
|
||||
defaults.RegisteredModelTypeName,
|
||||
defaults.ModelVersionTypeName,
|
||||
defaults.ServingEnvironmentTypeName,
|
||||
defaults.InferenceServiceTypeName,
|
||||
defaults.ServeModelTypeName,
|
||||
defaults.ExperimentTypeName,
|
||||
defaults.ExperimentRunTypeName,
|
||||
}
|
||||
|
||||
for _, requiredType := range requiredTypes {
|
||||
if _, exists := typesMap[requiredType]; !exists {
|
||||
return nil, fmt.Errorf("required type '%s' not found in database. Please ensure all migrations have been applied", requiredType)
|
||||
for name, dataType := range typeSpec.Properties {
|
||||
_, err := repo.Save(&models.TypePropertyImpl{
|
||||
TypeID: typeID,
|
||||
Name: name,
|
||||
DataType: apiutils.Of(int32(dataType)),
|
||||
})
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("%s-%s: %w", typeName, name, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
glog.Infof("All required types validated successfully")
|
||||
|
||||
artifactRepository := service.NewArtifactRepository(connectedDB, typesMap[defaults.ModelArtifactTypeName], typesMap[defaults.DocArtifactTypeName], typesMap[defaults.DataSetTypeName], typesMap[defaults.MetricTypeName], typesMap[defaults.ParameterTypeName], typesMap[defaults.MetricHistoryTypeName])
|
||||
modelArtifactRepository := service.NewModelArtifactRepository(connectedDB, typesMap[defaults.ModelArtifactTypeName])
|
||||
docArtifactRepository := service.NewDocArtifactRepository(connectedDB, typesMap[defaults.DocArtifactTypeName])
|
||||
registeredModelRepository := service.NewRegisteredModelRepository(connectedDB, typesMap[defaults.RegisteredModelTypeName])
|
||||
modelVersionRepository := service.NewModelVersionRepository(connectedDB, typesMap[defaults.ModelVersionTypeName])
|
||||
servingEnvironmentRepository := service.NewServingEnvironmentRepository(connectedDB, typesMap[defaults.ServingEnvironmentTypeName])
|
||||
inferenceServiceRepository := service.NewInferenceServiceRepository(connectedDB, typesMap[defaults.InferenceServiceTypeName])
|
||||
serveModelRepository := service.NewServeModelRepository(connectedDB, typesMap[defaults.ServeModelTypeName])
|
||||
experimentRepository := service.NewExperimentRepository(connectedDB, typesMap[defaults.ExperimentTypeName])
|
||||
experimentRunRepository := service.NewExperimentRunRepository(connectedDB, typesMap[defaults.ExperimentRunTypeName])
|
||||
|
||||
dataSetRepository := service.NewDataSetRepository(connectedDB, typesMap[defaults.DataSetTypeName])
|
||||
metricRepository := service.NewMetricRepository(connectedDB, typesMap[defaults.MetricTypeName])
|
||||
parameterRepository := service.NewParameterRepository(connectedDB, typesMap[defaults.ParameterTypeName])
|
||||
metricHistoryRepository := service.NewMetricHistoryRepository(connectedDB, typesMap[defaults.MetricHistoryTypeName])
|
||||
|
||||
modelRegistryService := core.NewModelRegistryService(
|
||||
artifactRepository,
|
||||
modelArtifactRepository,
|
||||
docArtifactRepository,
|
||||
registeredModelRepository,
|
||||
modelVersionRepository,
|
||||
servingEnvironmentRepository,
|
||||
inferenceServiceRepository,
|
||||
serveModelRepository,
|
||||
experimentRepository,
|
||||
experimentRunRepository,
|
||||
dataSetRepository,
|
||||
metricRepository,
|
||||
parameterRepository,
|
||||
metricHistoryRepository,
|
||||
typesMap,
|
||||
)
|
||||
|
||||
glog.Infof("EmbedMD service connected")
|
||||
|
||||
return modelRegistryService, nil
|
||||
}
|
||||
|
||||
func (s *EmbedMDService) Teardown() error {
|
||||
return nil
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,164 @@
|
|||
package datastore
|
||||
|
||||
import "reflect"
|
||||
|
||||
// Spec is the specification for the datastore.
|
||||
// Each entry
|
||||
type Spec struct {
|
||||
// Maps artifact type names to an initializer.
|
||||
ArtifactTypes map[string]*SpecType
|
||||
|
||||
// Maps context type names to an initializer.
|
||||
ContextTypes map[string]*SpecType
|
||||
|
||||
// Maps execution type names to an initializer.
|
||||
ExecutionTypes map[string]*SpecType
|
||||
|
||||
// Any repo initialization functions that don't map to a type can be
|
||||
// added here.
|
||||
Others []any
|
||||
}
|
||||
|
||||
// NewSpec returns an empty Spec instance.
|
||||
func NewSpec() *Spec {
|
||||
return &Spec{
|
||||
ArtifactTypes: map[string]*SpecType{},
|
||||
ContextTypes: map[string]*SpecType{},
|
||||
ExecutionTypes: map[string]*SpecType{},
|
||||
Others: []any{},
|
||||
}
|
||||
}
|
||||
|
||||
// AddArtifact adds an artifact type to the spec.
|
||||
func (s *Spec) AddArtifact(name string, t *SpecType) *Spec {
|
||||
s.ArtifactTypes[name] = t
|
||||
return s
|
||||
}
|
||||
|
||||
// AddContext adds a context type to the spec.
|
||||
func (s *Spec) AddContext(name string, t *SpecType) *Spec {
|
||||
s.ContextTypes[name] = t
|
||||
return s
|
||||
}
|
||||
|
||||
// AddExecution adds an execution type to the spec.
|
||||
func (s *Spec) AddExecution(name string, t *SpecType) *Spec {
|
||||
s.ExecutionTypes[name] = t
|
||||
return s
|
||||
}
|
||||
|
||||
// AddOther adds a repo initializer to the spec.
|
||||
func (s *Spec) AddOther(initFn any) *Spec {
|
||||
s.Others = append(s.Others, initFn)
|
||||
return s
|
||||
}
|
||||
|
||||
// AllNames returns all the type names in the spec.
|
||||
func (s *Spec) AllNames() []string {
|
||||
names := make([]string, 0, len(s.ArtifactTypes)+len(s.ContextTypes)+len(s.ExecutionTypes))
|
||||
for n := range s.ArtifactTypes {
|
||||
names = append(names, n)
|
||||
}
|
||||
for n := range s.ContextTypes {
|
||||
names = append(names, n)
|
||||
}
|
||||
for n := range s.ExecutionTypes {
|
||||
names = append(names, n)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// PropertyType is the data type of a property's value.
|
||||
type PropertyType int32
|
||||
|
||||
const (
|
||||
PropertyTypeUnknown = iota
|
||||
PropertyTypeInt
|
||||
PropertyTypeDouble
|
||||
PropertyTypeString
|
||||
PropertyTypeStruct
|
||||
PropertyTypeProto
|
||||
PropertyTypeBoolean
|
||||
)
|
||||
|
||||
// SpecType is a single type in the spec.
|
||||
type SpecType struct {
|
||||
// InitFn is a pointer to an initialization function that returns the
|
||||
// repository instance and an optional error.
|
||||
//
|
||||
// Data Store implementations must pass arguments to the initialization
|
||||
// functions that are required by the repository (database handles,
|
||||
// HTTP clients, etc.) and should also provide values for the following
|
||||
// types:
|
||||
//
|
||||
// - int64: type id
|
||||
// - ArtifactTypeMap: Map of all artifact names to type IDs
|
||||
// - ContextTypeMap: Map of all context names to type IDs
|
||||
// - ExecutionTypeMap: Map of all execution names to type IDs
|
||||
InitFn any
|
||||
|
||||
// Defined (non-custom) properties of the type.
|
||||
Properties map[string]PropertyType
|
||||
}
|
||||
|
||||
// NewSpecType creates a SpecType instance.
|
||||
func NewSpecType(initFn any) *SpecType {
|
||||
return &SpecType{
|
||||
InitFn: initFn,
|
||||
Properties: map[string]PropertyType{},
|
||||
}
|
||||
}
|
||||
|
||||
// AddInt adds an int property to the type spec.
|
||||
func (st *SpecType) AddInt(name string) *SpecType {
|
||||
st.Properties[name] = PropertyTypeInt
|
||||
return st
|
||||
}
|
||||
|
||||
// AddDouble adds a double property to the type spec.
|
||||
func (st *SpecType) AddDouble(name string) *SpecType {
|
||||
st.Properties[name] = PropertyTypeDouble
|
||||
return st
|
||||
}
|
||||
|
||||
// AddString adds a string property to the type spec.
|
||||
func (st *SpecType) AddString(name string) *SpecType {
|
||||
st.Properties[name] = PropertyTypeString
|
||||
return st
|
||||
}
|
||||
|
||||
// AddStruct adds a struct property to the type spec.
|
||||
func (st *SpecType) AddStruct(name string) *SpecType {
|
||||
st.Properties[name] = PropertyTypeStruct
|
||||
return st
|
||||
}
|
||||
|
||||
// AddProto adds a proto property to the type spec.
|
||||
func (st *SpecType) AddProto(name string) *SpecType {
|
||||
st.Properties[name] = PropertyTypeProto
|
||||
return st
|
||||
}
|
||||
|
||||
// AddBoolean adds a boolean property to the type spec.
|
||||
func (st *SpecType) AddBoolean(name string) *SpecType {
|
||||
st.Properties[name] = PropertyTypeBoolean
|
||||
return st
|
||||
}
|
||||
|
||||
// RepoSet holds repository implementions.
|
||||
type RepoSet interface {
|
||||
// TypeMap returns a map of type names to IDs
|
||||
TypeMap() map[string]int64
|
||||
|
||||
// Repository returns a repository instance of the specified type.
|
||||
Repository(t reflect.Type) (any, error)
|
||||
}
|
||||
|
||||
// ArtifactTypeMap maps artifact type names to IDs
|
||||
type ArtifactTypeMap map[string]int64
|
||||
|
||||
// ContextTypeMap maps context type names to IDs
|
||||
type ContextTypeMap map[string]int64
|
||||
|
||||
// ExecutionTypeMap maps execution type names to IDs
|
||||
type ExecutionTypeMap map[string]int64
|
||||
|
|
@ -30,9 +30,9 @@ func Init(dbType string, dsn string, tlsConfig *tls.TLSConfig) error {
|
|||
}
|
||||
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
case types.DatabaseTypeMySQL:
|
||||
_connectorInstance = mysql.NewMySQLDBConnector(dsn, tlsConfig)
|
||||
case "postgres":
|
||||
case types.DatabaseTypePostgres:
|
||||
_connectorInstance = postgres.NewPostgresDBConnector(dsn, tlsConfig)
|
||||
default:
|
||||
return fmt.Errorf("unsupported database type: %s. Supported types: %s, %s", dbType, types.DatabaseTypeMySQL, types.DatabaseTypePostgres)
|
||||
|
|
@ -41,6 +41,12 @@ func Init(dbType string, dsn string, tlsConfig *tls.TLSConfig) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func SetDB(connectedDB *gorm.DB) {
|
||||
connectorMutex.Lock()
|
||||
defer connectorMutex.Unlock()
|
||||
_connectorInstance = ConnectedConnector{ConnectedDB: connectedDB}
|
||||
}
|
||||
|
||||
func GetConnector() (Connector, bool) {
|
||||
connectorMutex.RLock()
|
||||
defer connectorMutex.RUnlock()
|
||||
|
|
@ -54,3 +60,17 @@ func ClearConnector() {
|
|||
|
||||
_connectorInstance = nil
|
||||
}
|
||||
|
||||
// ConnectedConnector satifies the connector interface for an already connected
|
||||
// gorm.DB instance.
|
||||
type ConnectedConnector struct {
|
||||
ConnectedDB *gorm.DB
|
||||
}
|
||||
|
||||
func (c ConnectedConnector) Connect() (*gorm.DB, error) {
|
||||
return c.ConnectedDB, nil
|
||||
}
|
||||
|
||||
func (c ConnectedConnector) DB() *gorm.DB {
|
||||
return c.ConnectedDB
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,13 +15,13 @@ type DBMigrator interface {
|
|||
Down(steps *int) error
|
||||
}
|
||||
|
||||
func NewDBMigrator(dbType string, db *gorm.DB) (DBMigrator, error) {
|
||||
switch dbType {
|
||||
func NewDBMigrator(db *gorm.DB) (DBMigrator, error) {
|
||||
switch db.Name() {
|
||||
case types.DatabaseTypeMySQL:
|
||||
return mysql.NewMySQLMigrator(db)
|
||||
case types.DatabaseTypePostgres:
|
||||
return postgres.NewPostgresMigrator(db)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported database type: %s. Supported types: %s, %s", dbType, types.DatabaseTypeMySQL, types.DatabaseTypePostgres)
|
||||
return nil, fmt.Errorf("unsupported database type: %s. Supported types: %s, %s", db.Name(), types.DatabaseTypeMySQL, types.DatabaseTypePostgres)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,5 +29,9 @@ func (t *TypeImpl) GetAttributes() *TypeAttributes {
|
|||
}
|
||||
|
||||
type TypeRepository interface {
|
||||
// GetAll returns every registered type.
|
||||
GetAll() ([]Type, error)
|
||||
|
||||
// Save updates a type, if the definition differs from what's stored.
|
||||
Save(t Type) (Type, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,32 @@
|
|||
package models
|
||||
|
||||
type TypeProperty interface {
|
||||
GetTypeID() int32
|
||||
GetName() string
|
||||
GetDataType() *int32
|
||||
}
|
||||
|
||||
var _ TypeProperty = (*TypePropertyImpl)(nil)
|
||||
|
||||
type TypePropertyImpl struct {
|
||||
TypeID int32
|
||||
Name string
|
||||
DataType *int32
|
||||
}
|
||||
|
||||
func (tp *TypePropertyImpl) GetTypeID() int32 {
|
||||
return tp.TypeID
|
||||
}
|
||||
|
||||
func (tp *TypePropertyImpl) GetName() string {
|
||||
return tp.Name
|
||||
}
|
||||
|
||||
func (tp *TypePropertyImpl) GetDataType() *int32 {
|
||||
return tp.DataType
|
||||
}
|
||||
|
||||
type TypePropertyRepository interface {
|
||||
// Save stores a type property if it doesn't exist.
|
||||
Save(tp TypeProperty) (TypeProperty, error)
|
||||
}
|
||||
|
|
@ -4,10 +4,12 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/kubeflow/model-registry/internal/datastore"
|
||||
"github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/schema"
|
||||
"github.com/kubeflow/model-registry/internal/db/scopes"
|
||||
"github.com/kubeflow/model-registry/internal/db/utils"
|
||||
"github.com/kubeflow/model-registry/internal/defaults"
|
||||
"github.com/kubeflow/model-registry/pkg/api"
|
||||
"github.com/kubeflow/model-registry/pkg/openapi"
|
||||
"gorm.io/gorm"
|
||||
|
|
@ -16,24 +18,21 @@ import (
|
|||
var ErrArtifactNotFound = errors.New("artifact by id not found")
|
||||
|
||||
type ArtifactRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
modelArtifactTypeID int64
|
||||
docArtifactTypeID int64
|
||||
dataSetTypeID int64
|
||||
metricTypeID int64
|
||||
parameterTypeID int64
|
||||
metricHistoryTypeID int64
|
||||
db *gorm.DB
|
||||
idToName map[int64]string
|
||||
nameToID datastore.ArtifactTypeMap
|
||||
}
|
||||
|
||||
func NewArtifactRepository(db *gorm.DB, modelArtifactTypeID int64, docArtifactTypeID int64, dataSetTypeID int64, metricTypeID int64, parameterTypeID int64, metricHistoryTypeID int64) models.ArtifactRepository {
|
||||
func NewArtifactRepository(db *gorm.DB, artifactTypes datastore.ArtifactTypeMap) models.ArtifactRepository {
|
||||
idToName := make(map[int64]string, len(artifactTypes))
|
||||
for name, id := range artifactTypes {
|
||||
idToName[id] = name
|
||||
}
|
||||
|
||||
return &ArtifactRepositoryImpl{
|
||||
db: db,
|
||||
modelArtifactTypeID: modelArtifactTypeID,
|
||||
docArtifactTypeID: docArtifactTypeID,
|
||||
dataSetTypeID: dataSetTypeID,
|
||||
metricTypeID: metricTypeID,
|
||||
parameterTypeID: parameterTypeID,
|
||||
metricHistoryTypeID: metricHistoryTypeID,
|
||||
db: db,
|
||||
nameToID: artifactTypes,
|
||||
idToName: idToName,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -73,7 +72,9 @@ func (r *ArtifactRepositoryImpl) List(listOptions models.ArtifactListOptions) (*
|
|||
query := r.db.Model(&schema.Artifact{})
|
||||
|
||||
// Exclude metric history records - they should only be returned via metric history endpoints
|
||||
query = query.Where("type_id != ?", r.metricHistoryTypeID)
|
||||
if metricHistoryTypeID, ok := r.nameToID[defaults.MetricHistoryTypeName]; ok {
|
||||
query = query.Where("type_id != ?", metricHistoryTypeID)
|
||||
}
|
||||
|
||||
if listOptions.Name != nil {
|
||||
// Name is not prefixed with the parent resource id to allow for filtering by name only
|
||||
|
|
@ -170,17 +171,17 @@ func (r *ArtifactRepositoryImpl) List(listOptions models.ArtifactListOptions) (*
|
|||
|
||||
// getTypeIDFromArtifactType maps artifact type strings to their corresponding type IDs
|
||||
func (r *ArtifactRepositoryImpl) getTypeIDFromArtifactType(artifactType string) (int64, error) {
|
||||
switch artifactType {
|
||||
case string(openapi.ARTIFACTTYPEQUERYPARAM_MODEL_ARTIFACT):
|
||||
return r.modelArtifactTypeID, nil
|
||||
case string(openapi.ARTIFACTTYPEQUERYPARAM_DOC_ARTIFACT):
|
||||
return r.docArtifactTypeID, nil
|
||||
case string(openapi.ARTIFACTTYPEQUERYPARAM_DATASET_ARTIFACT):
|
||||
return r.dataSetTypeID, nil
|
||||
case string(openapi.ARTIFACTTYPEQUERYPARAM_METRIC):
|
||||
return r.metricTypeID, nil
|
||||
case string(openapi.ARTIFACTTYPEQUERYPARAM_PARAMETER):
|
||||
return r.parameterTypeID, nil
|
||||
switch openapi.ArtifactTypeQueryParam(artifactType) {
|
||||
case openapi.ARTIFACTTYPEQUERYPARAM_MODEL_ARTIFACT:
|
||||
return r.nameToID[defaults.ModelArtifactTypeName], nil
|
||||
case openapi.ARTIFACTTYPEQUERYPARAM_DOC_ARTIFACT:
|
||||
return r.nameToID[defaults.DocArtifactTypeName], nil
|
||||
case openapi.ARTIFACTTYPEQUERYPARAM_DATASET_ARTIFACT:
|
||||
return r.nameToID[defaults.DataSetTypeName], nil
|
||||
case openapi.ARTIFACTTYPEQUERYPARAM_METRIC:
|
||||
return r.nameToID[defaults.MetricTypeName], nil
|
||||
case openapi.ARTIFACTTYPEQUERYPARAM_PARAMETER:
|
||||
return r.nameToID[defaults.ParameterTypeName], nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported artifact type: %s: %w", artifactType, api.ErrBadRequest)
|
||||
}
|
||||
|
|
@ -189,25 +190,26 @@ func (r *ArtifactRepositoryImpl) getTypeIDFromArtifactType(artifactType string)
|
|||
func (r *ArtifactRepositoryImpl) mapDataLayerToArtifact(artifact schema.Artifact, properties []schema.ArtifactProperty) (models.Artifact, error) {
|
||||
artToReturn := models.Artifact{}
|
||||
|
||||
switch artifact.TypeID {
|
||||
case int32(r.modelArtifactTypeID):
|
||||
typeName := r.idToName[int64(artifact.TypeID)]
|
||||
|
||||
switch typeName {
|
||||
case defaults.ModelArtifactTypeName:
|
||||
modelArtifact := mapDataLayerToModelArtifact(artifact, properties)
|
||||
artToReturn.ModelArtifact = &modelArtifact
|
||||
case int32(r.docArtifactTypeID):
|
||||
case defaults.DocArtifactTypeName:
|
||||
docArtifact := mapDataLayerToDocArtifact(artifact, properties)
|
||||
artToReturn.DocArtifact = &docArtifact
|
||||
case int32(r.dataSetTypeID):
|
||||
case defaults.DataSetTypeName:
|
||||
dataSet := mapDataLayerToDataSet(artifact, properties)
|
||||
artToReturn.DataSet = &dataSet
|
||||
case int32(r.metricTypeID):
|
||||
case defaults.MetricTypeName:
|
||||
metric := mapDataLayerToMetric(artifact, properties)
|
||||
artToReturn.Metric = &metric
|
||||
case int32(r.parameterTypeID):
|
||||
case defaults.ParameterTypeName:
|
||||
parameter := mapDataLayerToParameter(artifact, properties)
|
||||
artToReturn.Parameter = ¶meter
|
||||
default:
|
||||
return models.Artifact{}, fmt.Errorf("invalid artifact type: %d (expected: modelArtifact=%d, docArtifact=%d, dataSet=%d, metric=%d, parameter=%d, metricHistory=%d [filtered])",
|
||||
artifact.TypeID, r.modelArtifactTypeID, r.docArtifactTypeID, r.dataSetTypeID, r.metricTypeID, r.parameterTypeID, r.metricHistoryTypeID)
|
||||
return models.Artifact{}, fmt.Errorf("invalid artifact type: %s=%d (expected: %v)", typeName, artifact.TypeID, r.idToName)
|
||||
}
|
||||
|
||||
return artToReturn, nil
|
||||
|
|
|
|||
|
|
@ -8,13 +8,14 @@ import (
|
|||
"github.com/kubeflow/model-registry/internal/apiutils"
|
||||
"github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/service"
|
||||
"github.com/kubeflow/model-registry/internal/defaults"
|
||||
"github.com/kubeflow/model-registry/internal/testutils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestArtifactRepository(t *testing.T) {
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
// Get the actual type IDs from the database
|
||||
|
|
@ -24,7 +25,14 @@ func TestArtifactRepository(t *testing.T) {
|
|||
metricTypeID := getMetricTypeID(t, sharedDB)
|
||||
parameterTypeID := getParameterTypeID(t, sharedDB)
|
||||
metricHistoryTypeID := getMetricHistoryTypeID(t, sharedDB)
|
||||
repo := service.NewArtifactRepository(sharedDB, modelArtifactTypeID, docArtifactTypeID, dataSetTypeID, metricTypeID, parameterTypeID, metricHistoryTypeID)
|
||||
repo := service.NewArtifactRepository(sharedDB, map[string]int64{
|
||||
defaults.ModelArtifactTypeName: modelArtifactTypeID,
|
||||
defaults.DocArtifactTypeName: docArtifactTypeID,
|
||||
defaults.DataSetTypeName: dataSetTypeID,
|
||||
defaults.MetricTypeName: metricTypeID,
|
||||
defaults.ParameterTypeName: parameterTypeID,
|
||||
defaults.MetricHistoryTypeName: metricHistoryTypeID,
|
||||
})
|
||||
|
||||
// Also get other type IDs for creating related entities
|
||||
registeredModelTypeID := getRegisteredModelTypeID(t, sharedDB)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/kubeflow/model-registry/internal/db/schema"
|
||||
"github.com/kubeflow/model-registry/internal/db/service"
|
||||
"github.com/kubeflow/model-registry/internal/defaults"
|
||||
"github.com/kubeflow/model-registry/internal/testutils"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -16,7 +17,7 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
|
||||
func setupTestDB(t *testing.T) (*gorm.DB, func()) {
|
||||
db, dbCleanup := testutils.SetupMySQLWithMigrations(t)
|
||||
db, dbCleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
|
||||
// Clean up test data before each test
|
||||
testutils.CleanupTestData(t, db)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import (
|
|||
)
|
||||
|
||||
func TestDocArtifactRepository(t *testing.T) {
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
// Get the actual DocArtifact type ID from the database
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import (
|
|||
)
|
||||
|
||||
func TestInferenceServiceRepository(t *testing.T) {
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
// Get the actual InferenceService type ID from the database
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import (
|
|||
)
|
||||
|
||||
func TestModelArtifactRepository(t *testing.T) {
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
// Get the actual ModelArtifact type ID from the database
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import (
|
|||
)
|
||||
|
||||
func TestModelVersionRepository(t *testing.T) {
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
// Get the actual ModelVersion type ID from the database
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import (
|
|||
)
|
||||
|
||||
func TestRegisteredModelRepository(t *testing.T) {
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
// Get the actual RegisteredModel type ID from the database
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import (
|
|||
)
|
||||
|
||||
func TestServeModelRepository(t *testing.T) {
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
// Get the actual ServeModel type ID from the database
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import (
|
|||
)
|
||||
|
||||
func TestServingEnvironmentRepository(t *testing.T) {
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
// Get the actual ServingEnvironment type ID from the database
|
||||
|
|
|
|||
|
|
@ -0,0 +1,97 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"github.com/kubeflow/model-registry/internal/datastore"
|
||||
"github.com/kubeflow/model-registry/internal/defaults"
|
||||
)
|
||||
|
||||
func DatastoreSpec() *datastore.Spec {
|
||||
return datastore.NewSpec().
|
||||
AddArtifact(defaults.ModelArtifactTypeName, datastore.NewSpecType(NewModelArtifactRepository).
|
||||
AddString("description").
|
||||
AddString("model_format_name").
|
||||
AddString("model_format_version").
|
||||
AddString("service_account_name").
|
||||
AddString("storage_key").
|
||||
AddString("storage_path"),
|
||||
).
|
||||
AddArtifact(defaults.DocArtifactTypeName, datastore.NewSpecType(NewDocArtifactRepository).
|
||||
AddString("description"),
|
||||
).
|
||||
AddArtifact(defaults.DataSetTypeName, datastore.NewSpecType(NewDataSetRepository).
|
||||
AddString("description").
|
||||
AddString("digest").
|
||||
AddString("source_type").
|
||||
AddString("source").
|
||||
AddString("schema").
|
||||
AddString("profile"),
|
||||
).
|
||||
AddArtifact(defaults.MetricTypeName, datastore.NewSpecType(NewMetricRepository).
|
||||
AddString("description").
|
||||
AddProto("value").
|
||||
AddString("timestamp").
|
||||
AddInt("step"),
|
||||
).
|
||||
AddArtifact(defaults.ParameterTypeName, datastore.NewSpecType(NewParameterRepository).
|
||||
AddString("description").
|
||||
AddString("value").
|
||||
AddString("parameter_type"),
|
||||
).
|
||||
AddArtifact(defaults.MetricHistoryTypeName, datastore.NewSpecType(NewMetricHistoryRepository).
|
||||
AddString("description").
|
||||
AddProto("value").
|
||||
AddString("timestamp").
|
||||
AddInt("step"),
|
||||
).
|
||||
AddContext(defaults.RegisteredModelTypeName, datastore.NewSpecType(NewRegisteredModelRepository).
|
||||
AddString("description").
|
||||
AddString("owner").
|
||||
AddString("state").
|
||||
AddStruct("language").
|
||||
AddString("library_name").
|
||||
AddString("license_link").
|
||||
AddString("license").
|
||||
AddString("logo").
|
||||
AddString("maturity").
|
||||
AddString("provider").
|
||||
AddString("readme").
|
||||
AddStruct("tasks"),
|
||||
).
|
||||
AddContext(defaults.ModelVersionTypeName, datastore.NewSpecType(NewModelVersionRepository).
|
||||
AddString("author").
|
||||
AddString("description").
|
||||
AddString("model_name").
|
||||
AddString("state").
|
||||
AddString("version"),
|
||||
).
|
||||
AddContext(defaults.ServingEnvironmentTypeName, datastore.NewSpecType(NewServingEnvironmentRepository).
|
||||
AddString("description"),
|
||||
).
|
||||
AddContext(defaults.InferenceServiceTypeName, datastore.NewSpecType(NewInferenceServiceRepository).
|
||||
AddString("description").
|
||||
AddString("desired_state").
|
||||
AddInt("model_version_id").
|
||||
AddInt("registered_model_id").
|
||||
AddString("runtime").
|
||||
AddInt("serving_environment_id"),
|
||||
).
|
||||
AddContext(defaults.ExperimentTypeName, datastore.NewSpecType(NewExperimentRepository).
|
||||
AddString("description").
|
||||
AddString("owner").
|
||||
AddString("state"),
|
||||
).
|
||||
AddContext(defaults.ExperimentRunTypeName, datastore.NewSpecType(NewExperimentRunRepository).
|
||||
AddString("description").
|
||||
AddString("owner").
|
||||
AddString("state").
|
||||
AddString("status").
|
||||
AddInt("start_time_since_epoch").
|
||||
AddInt("end_time_since_epoch").
|
||||
AddInt("experiment_id"),
|
||||
).
|
||||
AddExecution(defaults.ServeModelTypeName, datastore.NewSpecType(NewServeModelRepository).
|
||||
AddString("description").
|
||||
AddInt("model_version_id"),
|
||||
).
|
||||
AddOther(NewArtifactRepository)
|
||||
}
|
||||
|
|
@ -1,6 +1,9 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/schema"
|
||||
"gorm.io/gorm"
|
||||
|
|
@ -40,3 +43,58 @@ func (r *TypeRepositoryImpl) GetAll() ([]models.Type, error) {
|
|||
|
||||
return typesModels, nil
|
||||
}
|
||||
|
||||
func (r *TypeRepositoryImpl) Save(t models.Type) (models.Type, error) {
|
||||
attr := t.GetAttributes()
|
||||
if attr == nil {
|
||||
return t, errors.New("invalid type: missing attributes")
|
||||
}
|
||||
if attr.Name == nil {
|
||||
return t, errors.New("invalid type: missing name")
|
||||
}
|
||||
if attr.TypeKind == nil {
|
||||
return t, errors.New("invalid type: missing kind")
|
||||
}
|
||||
|
||||
var st schema.Type
|
||||
err := r.db.Where("name = ?", *attr.Name).First(&st).Error
|
||||
|
||||
if err == nil {
|
||||
// Record already exists. We don't support updates, but we can return the full details.
|
||||
|
||||
// Catch this case in particular.
|
||||
if st.TypeKind != *attr.TypeKind {
|
||||
return t, fmt.Errorf("invalid type: kind is %d, cannot change to kind %d", st.TypeKind, *attr.TypeKind)
|
||||
}
|
||||
} else if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// Record doesn't exist, so we'll create it.
|
||||
st = schema.Type{
|
||||
Name: *attr.Name,
|
||||
Version: attr.Version,
|
||||
TypeKind: *attr.TypeKind,
|
||||
Description: attr.Description,
|
||||
InputType: attr.InputType,
|
||||
OutputType: attr.OutputType,
|
||||
ExternalID: attr.ExternalID,
|
||||
}
|
||||
|
||||
if err := r.db.Create(&st).Error; err != nil {
|
||||
return t, err
|
||||
}
|
||||
} else {
|
||||
return t, err
|
||||
}
|
||||
|
||||
return &models.TypeImpl{
|
||||
ID: &st.ID,
|
||||
Attributes: &models.TypeAttributes{
|
||||
Name: &st.Name,
|
||||
Version: st.Version,
|
||||
TypeKind: &st.TypeKind,
|
||||
Description: st.Description,
|
||||
InputType: st.InputType,
|
||||
OutputType: st.OutputType,
|
||||
ExternalID: st.ExternalID,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,53 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/schema"
|
||||
"golang.org/x/exp/constraints"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type typePropertyRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewTypePropertyRepository(db *gorm.DB) models.TypePropertyRepository {
|
||||
return &typePropertyRepositoryImpl{db: db}
|
||||
}
|
||||
|
||||
func (r *typePropertyRepositoryImpl) Save(tp models.TypeProperty) (models.TypeProperty, error) {
|
||||
var stp schema.TypeProperty
|
||||
err := r.db.Where("type_id=? AND name=?", tp.GetTypeID(), tp.GetName()).First(&stp).Error
|
||||
if err == nil {
|
||||
oldType := intPointerString(stp.DataType)
|
||||
newType := intPointerString(tp.GetDataType())
|
||||
if oldType != newType {
|
||||
return tp, fmt.Errorf("invalid property type: data type is %s, cannot change to %s", oldType, newType)
|
||||
}
|
||||
} else if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
stp.TypeID = tp.GetTypeID()
|
||||
stp.Name = tp.GetName()
|
||||
stp.DataType = tp.GetDataType()
|
||||
|
||||
if err := r.db.Create(&stp).Error; err != nil {
|
||||
return tp, err
|
||||
}
|
||||
}
|
||||
|
||||
return &models.TypePropertyImpl{
|
||||
TypeID: stp.TypeID,
|
||||
Name: stp.Name,
|
||||
DataType: stp.DataType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func intPointerString[T constraints.Integer](v *T) string {
|
||||
if v == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
return strconv.Itoa(int(*v))
|
||||
}
|
||||
|
|
@ -0,0 +1,238 @@
|
|||
package service_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/service"
|
||||
"github.com/kubeflow/model-registry/internal/testutils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTypePropertyRepository(t *testing.T) {
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
repo := service.NewTypePropertyRepository(sharedDB)
|
||||
typeRepo := service.NewTypeRepository(sharedDB)
|
||||
|
||||
// Create a test type to use for properties
|
||||
testTypeName := "test-type-for-properties"
|
||||
testTypeKind := int32(1)
|
||||
testType := &models.TypeImpl{
|
||||
Attributes: &models.TypeAttributes{
|
||||
Name: &testTypeName,
|
||||
TypeKind: &testTypeKind,
|
||||
},
|
||||
}
|
||||
|
||||
savedType, err := typeRepo.Save(testType)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedType)
|
||||
typeID := *savedType.GetID()
|
||||
|
||||
t.Run("TestSave", func(t *testing.T) {
|
||||
// Test saving a new type property
|
||||
propertyName := "test-property"
|
||||
dataType := int32(1)
|
||||
|
||||
property := &models.TypePropertyImpl{
|
||||
TypeID: typeID,
|
||||
Name: propertyName,
|
||||
DataType: &dataType,
|
||||
}
|
||||
|
||||
savedProperty, err := repo.Save(property)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedProperty)
|
||||
|
||||
// Verify the saved property
|
||||
assert.Equal(t, typeID, savedProperty.GetTypeID())
|
||||
assert.Equal(t, propertyName, savedProperty.GetName())
|
||||
assert.Equal(t, dataType, *savedProperty.GetDataType())
|
||||
})
|
||||
|
||||
t.Run("TestSaveExisting", func(t *testing.T) {
|
||||
// Test saving a property that already exists with same data type
|
||||
propertyName := "test-existing-property"
|
||||
dataType := int32(2)
|
||||
|
||||
// First, save the property
|
||||
firstProperty := &models.TypePropertyImpl{
|
||||
TypeID: typeID,
|
||||
Name: propertyName,
|
||||
DataType: &dataType,
|
||||
}
|
||||
|
||||
savedProperty1, err := repo.Save(firstProperty)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedProperty1)
|
||||
|
||||
// Now try to save the same property again
|
||||
secondProperty := &models.TypePropertyImpl{
|
||||
TypeID: typeID,
|
||||
Name: propertyName,
|
||||
DataType: &dataType,
|
||||
}
|
||||
|
||||
savedProperty2, err := repo.Save(secondProperty)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedProperty2)
|
||||
|
||||
// Should return the existing property
|
||||
assert.Equal(t, savedProperty1.GetTypeID(), savedProperty2.GetTypeID())
|
||||
assert.Equal(t, savedProperty1.GetName(), savedProperty2.GetName())
|
||||
assert.Equal(t, *savedProperty1.GetDataType(), *savedProperty2.GetDataType())
|
||||
})
|
||||
|
||||
t.Run("TestSaveInvalidDataTypeChange", func(t *testing.T) {
|
||||
// Test trying to save a property with different data type than existing
|
||||
propertyName := "test-datatype-change-property"
|
||||
originalDataType := int32(1)
|
||||
newDataType := int32(2)
|
||||
|
||||
// First, save the property with original data type
|
||||
originalProperty := &models.TypePropertyImpl{
|
||||
TypeID: typeID,
|
||||
Name: propertyName,
|
||||
DataType: &originalDataType,
|
||||
}
|
||||
|
||||
_, err := repo.Save(originalProperty)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now try to save with different data type - should fail
|
||||
changedProperty := &models.TypePropertyImpl{
|
||||
TypeID: typeID,
|
||||
Name: propertyName,
|
||||
DataType: &newDataType,
|
||||
}
|
||||
|
||||
_, err = repo.Save(changedProperty)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cannot change to")
|
||||
})
|
||||
|
||||
t.Run("TestSaveMultiplePropertiesForSameType", func(t *testing.T) {
|
||||
// Test saving multiple properties for the same type
|
||||
properties := []struct {
|
||||
name string
|
||||
dataType int32
|
||||
}{
|
||||
{"property1", 1},
|
||||
{"property2", 2},
|
||||
{"property3", 3},
|
||||
}
|
||||
|
||||
for _, prop := range properties {
|
||||
property := &models.TypePropertyImpl{
|
||||
TypeID: typeID,
|
||||
Name: prop.name,
|
||||
DataType: &prop.dataType,
|
||||
}
|
||||
|
||||
savedProperty, err := repo.Save(property)
|
||||
require.NoError(t, err, "Failed to save property %s", prop.name)
|
||||
require.NotNil(t, savedProperty)
|
||||
|
||||
assert.Equal(t, typeID, savedProperty.GetTypeID())
|
||||
assert.Equal(t, prop.name, savedProperty.GetName())
|
||||
assert.Equal(t, prop.dataType, *savedProperty.GetDataType())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TestSaveWithDifferentTypes", func(t *testing.T) {
|
||||
// Create another test type
|
||||
anotherTypeName := "another-test-type"
|
||||
anotherTypeKind := int32(2)
|
||||
anotherType := &models.TypeImpl{
|
||||
Attributes: &models.TypeAttributes{
|
||||
Name: &anotherTypeName,
|
||||
TypeKind: &anotherTypeKind,
|
||||
},
|
||||
}
|
||||
|
||||
savedAnotherType, err := typeRepo.Save(anotherType)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedAnotherType)
|
||||
anotherTypeID := *savedAnotherType.GetID()
|
||||
|
||||
// Test saving properties with same name for different types
|
||||
propertyName := "shared-property-name"
|
||||
dataType := int32(1)
|
||||
|
||||
// Save property for first type
|
||||
property1 := &models.TypePropertyImpl{
|
||||
TypeID: typeID,
|
||||
Name: propertyName,
|
||||
DataType: &dataType,
|
||||
}
|
||||
|
||||
savedProperty1, err := repo.Save(property1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedProperty1)
|
||||
|
||||
// Save property with same name for second type - should work
|
||||
property2 := &models.TypePropertyImpl{
|
||||
TypeID: anotherTypeID,
|
||||
Name: propertyName,
|
||||
DataType: &dataType,
|
||||
}
|
||||
|
||||
savedProperty2, err := repo.Save(property2)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedProperty2)
|
||||
|
||||
// Verify both properties exist with different type IDs
|
||||
assert.Equal(t, typeID, savedProperty1.GetTypeID())
|
||||
assert.Equal(t, anotherTypeID, savedProperty2.GetTypeID())
|
||||
assert.Equal(t, propertyName, savedProperty1.GetName())
|
||||
assert.Equal(t, propertyName, savedProperty2.GetName())
|
||||
assert.Equal(t, dataType, *savedProperty1.GetDataType())
|
||||
assert.Equal(t, dataType, *savedProperty2.GetDataType())
|
||||
})
|
||||
|
||||
t.Run("TestSaveNilDataType", func(t *testing.T) {
|
||||
// Test saving a property with nil data type for a new property
|
||||
propertyName := "test-nil-datatype-property"
|
||||
|
||||
property := &models.TypePropertyImpl{
|
||||
TypeID: typeID,
|
||||
Name: propertyName,
|
||||
DataType: nil,
|
||||
}
|
||||
|
||||
savedProperty, err := repo.Save(property)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedProperty)
|
||||
|
||||
// Verify the saved property
|
||||
assert.Equal(t, typeID, savedProperty.GetTypeID())
|
||||
assert.Equal(t, propertyName, savedProperty.GetName())
|
||||
assert.Nil(t, savedProperty.GetDataType())
|
||||
})
|
||||
|
||||
t.Run("TestSaveValidDataTypes", func(t *testing.T) {
|
||||
// Test saving properties with various valid data types
|
||||
validDataTypes := []int32{0, 1, 2, 3, 4, 5, 10, 100}
|
||||
|
||||
for i, dataType := range validDataTypes {
|
||||
propertyName := fmt.Sprintf("test-datatype-%d-property", i)
|
||||
property := &models.TypePropertyImpl{
|
||||
TypeID: typeID,
|
||||
Name: propertyName,
|
||||
DataType: &dataType,
|
||||
}
|
||||
|
||||
savedProperty, err := repo.Save(property)
|
||||
require.NoError(t, err, "Failed to save property with data type %d", dataType)
|
||||
require.NotNil(t, savedProperty)
|
||||
|
||||
assert.Equal(t, typeID, savedProperty.GetTypeID())
|
||||
assert.Equal(t, propertyName, savedProperty.GetName())
|
||||
assert.Equal(t, dataType, *savedProperty.GetDataType())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -3,6 +3,7 @@ package service_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/kubeflow/model-registry/internal/apiutils"
|
||||
"github.com/kubeflow/model-registry/internal/db/models"
|
||||
"github.com/kubeflow/model-registry/internal/db/service"
|
||||
"github.com/kubeflow/model-registry/internal/defaults"
|
||||
|
|
@ -12,7 +13,7 @@ import (
|
|||
)
|
||||
|
||||
func TestTypeRepository(t *testing.T) {
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
repo := service.NewTypeRepository(sharedDB)
|
||||
|
|
@ -188,4 +189,168 @@ func TestTypeRepository(t *testing.T) {
|
|||
assert.NotEmpty(t, types, "Migrated database should have default types")
|
||||
assert.GreaterOrEqual(t, len(types), 1, "Should have at least one type after migration")
|
||||
})
|
||||
|
||||
t.Run("TestSave", func(t *testing.T) {
|
||||
// Test saving a new type
|
||||
newTypeName := "test-custom-type"
|
||||
newTypeKind := int32(1)
|
||||
newVersion := "1.0.0"
|
||||
newDescription := "Test custom type description"
|
||||
newInputType := "application/json"
|
||||
newOutputType := "application/json"
|
||||
newExternalID := "external-123"
|
||||
|
||||
newType := &models.TypeImpl{
|
||||
Attributes: &models.TypeAttributes{
|
||||
Name: &newTypeName,
|
||||
TypeKind: &newTypeKind,
|
||||
Version: &newVersion,
|
||||
Description: &newDescription,
|
||||
InputType: &newInputType,
|
||||
OutputType: &newOutputType,
|
||||
ExternalID: &newExternalID,
|
||||
},
|
||||
}
|
||||
|
||||
savedType, err := repo.Save(newType)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedType)
|
||||
|
||||
// Verify the saved type has an ID
|
||||
assert.NotNil(t, savedType.GetID())
|
||||
assert.Greater(t, *savedType.GetID(), int32(0))
|
||||
|
||||
// Verify all attributes are preserved
|
||||
attrs := savedType.GetAttributes()
|
||||
assert.Equal(t, newTypeName, *attrs.Name)
|
||||
assert.Equal(t, newTypeKind, *attrs.TypeKind)
|
||||
assert.Equal(t, newVersion, *attrs.Version)
|
||||
assert.Equal(t, newDescription, *attrs.Description)
|
||||
assert.Equal(t, newInputType, *attrs.InputType)
|
||||
assert.Equal(t, newOutputType, *attrs.OutputType)
|
||||
assert.Equal(t, newExternalID, *attrs.ExternalID)
|
||||
})
|
||||
|
||||
t.Run("TestSaveExisting", func(t *testing.T) {
|
||||
// Test saving a type that already exists
|
||||
existingTypeName := "test-existing-type"
|
||||
existingTypeKind := int32(2)
|
||||
|
||||
// First, save the type
|
||||
firstType := &models.TypeImpl{
|
||||
Attributes: &models.TypeAttributes{
|
||||
Name: &existingTypeName,
|
||||
TypeKind: &existingTypeKind,
|
||||
},
|
||||
}
|
||||
|
||||
savedType1, err := repo.Save(firstType)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedType1)
|
||||
|
||||
// Now try to save the same type again
|
||||
secondType := &models.TypeImpl{
|
||||
Attributes: &models.TypeAttributes{
|
||||
Name: &existingTypeName,
|
||||
TypeKind: &existingTypeKind,
|
||||
},
|
||||
}
|
||||
|
||||
savedType2, err := repo.Save(secondType)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedType2)
|
||||
|
||||
// Should return the existing type with same ID
|
||||
assert.Equal(t, *savedType1.GetID(), *savedType2.GetID())
|
||||
assert.Equal(t, *savedType1.GetAttributes().Name, *savedType2.GetAttributes().Name)
|
||||
assert.Equal(t, *savedType1.GetAttributes().TypeKind, *savedType2.GetAttributes().TypeKind)
|
||||
})
|
||||
|
||||
t.Run("TestSaveInvalidTypeKindChange", func(t *testing.T) {
|
||||
// Test trying to save a type with different kind than existing
|
||||
typeName := "test-kind-change-type"
|
||||
originalKind := int32(1)
|
||||
newKind := int32(2)
|
||||
|
||||
// First, save the type with original kind
|
||||
originalType := &models.TypeImpl{
|
||||
Attributes: &models.TypeAttributes{
|
||||
Name: &typeName,
|
||||
TypeKind: &originalKind,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := repo.Save(originalType)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now try to save with different kind - should fail
|
||||
changedType := &models.TypeImpl{
|
||||
Attributes: &models.TypeAttributes{
|
||||
Name: &typeName,
|
||||
TypeKind: &newKind,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = repo.Save(changedType)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cannot change to kind")
|
||||
})
|
||||
|
||||
t.Run("TestSaveValidationErrors", func(t *testing.T) {
|
||||
// Test saving type without attributes
|
||||
emptyType := &models.TypeImpl{}
|
||||
_, err := repo.Save(emptyType)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "missing attributes")
|
||||
|
||||
// Test saving type without name
|
||||
noNameType := &models.TypeImpl{
|
||||
Attributes: &models.TypeAttributes{
|
||||
TypeKind: apiutils.Of(int32(1)),
|
||||
},
|
||||
}
|
||||
_, err = repo.Save(noNameType)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "missing name")
|
||||
|
||||
// Test saving type without kind
|
||||
noKindType := &models.TypeImpl{
|
||||
Attributes: &models.TypeAttributes{
|
||||
Name: apiutils.Of("test-no-kind"),
|
||||
},
|
||||
}
|
||||
_, err = repo.Save(noKindType)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "missing kind")
|
||||
})
|
||||
|
||||
t.Run("TestSaveMinimalType", func(t *testing.T) {
|
||||
// Test saving type with only required fields
|
||||
typeName := "test-minimal-type"
|
||||
typeKind := int32(3)
|
||||
|
||||
minimalType := &models.TypeImpl{
|
||||
Attributes: &models.TypeAttributes{
|
||||
Name: &typeName,
|
||||
TypeKind: &typeKind,
|
||||
},
|
||||
}
|
||||
|
||||
savedType, err := repo.Save(minimalType)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedType)
|
||||
|
||||
// Verify required fields
|
||||
assert.NotNil(t, savedType.GetID())
|
||||
assert.Equal(t, typeName, *savedType.GetAttributes().Name)
|
||||
assert.Equal(t, typeKind, *savedType.GetAttributes().TypeKind)
|
||||
|
||||
// Optional fields should be nil
|
||||
attrs := savedType.GetAttributes()
|
||||
assert.Nil(t, attrs.Version)
|
||||
assert.Nil(t, attrs.Description)
|
||||
assert.Nil(t, attrs.InputType)
|
||||
assert.Nil(t, attrs.OutputType)
|
||||
assert.Nil(t, attrs.ExternalID)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/kubeflow/model-registry/internal/datastore"
|
||||
"github.com/kubeflow/model-registry/internal/db"
|
||||
"github.com/kubeflow/model-registry/pkg/api"
|
||||
)
|
||||
|
|
@ -72,13 +71,10 @@ type HealthChecker interface {
|
|||
|
||||
// DatabaseHealthChecker checks database connectivity and schema state
|
||||
type DatabaseHealthChecker struct {
|
||||
datastore datastore.Datastore
|
||||
}
|
||||
|
||||
func NewDatabaseHealthChecker(datastore datastore.Datastore) *DatabaseHealthChecker {
|
||||
return &DatabaseHealthChecker{
|
||||
datastore: datastore,
|
||||
}
|
||||
func NewDatabaseHealthChecker() *DatabaseHealthChecker {
|
||||
return &DatabaseHealthChecker{}
|
||||
}
|
||||
|
||||
func (d *DatabaseHealthChecker) Check() HealthCheck {
|
||||
|
|
@ -87,22 +83,6 @@ func (d *DatabaseHealthChecker) Check() HealthCheck {
|
|||
Details: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Skip embedmd check for mlmd datastore
|
||||
if d.datastore.Type != "embedmd" {
|
||||
check.Status = StatusPass
|
||||
check.Message = "MLMD datastore - skipping database check"
|
||||
check.Details[detailDatastoreType] = d.datastore.Type
|
||||
return check
|
||||
}
|
||||
|
||||
// Check DSN configuration
|
||||
dsn := d.datastore.EmbedMD.DatabaseDSN
|
||||
if dsn == "" {
|
||||
check.Status = StatusFail
|
||||
check.Message = "database DSN not configured"
|
||||
return check
|
||||
}
|
||||
|
||||
// Check database connector
|
||||
dbConnector, ok := db.GetConnector()
|
||||
if !ok {
|
||||
|
|
@ -259,7 +239,7 @@ func (m *ModelRegistryHealthChecker) Check() HealthCheck {
|
|||
}
|
||||
|
||||
// GeneralReadinessHandler creates a general readiness handler with configurable health checks
|
||||
func GeneralReadinessHandler(datastore datastore.Datastore, additionalCheckers ...HealthChecker) http.Handler {
|
||||
func GeneralReadinessHandler(additionalCheckers ...HealthChecker) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/kubeflow/model-registry/internal/core"
|
||||
"github.com/kubeflow/model-registry/internal/datastore"
|
||||
"github.com/kubeflow/model-registry/internal/datastore/embedmd"
|
||||
"github.com/kubeflow/model-registry/internal/db"
|
||||
"github.com/kubeflow/model-registry/internal/db/schema"
|
||||
"github.com/kubeflow/model-registry/internal/db/service"
|
||||
|
|
@ -29,7 +27,7 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
|
||||
func setupTestDB(t *testing.T) (*gorm.DB, string, api.ModelRegistryApi, func()) {
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t)
|
||||
sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
dsn := testutils.GetSharedMySQLDSN(t)
|
||||
svc := setupModelRegistryService(sharedDB)
|
||||
|
||||
|
|
@ -78,7 +76,14 @@ func setupModelRegistryService(sharedDB *gorm.DB) api.ModelRegistryApi {
|
|||
typesMap := getTypeIDs(sharedDB)
|
||||
|
||||
// Create all repositories
|
||||
artifactRepo := service.NewArtifactRepository(sharedDB, typesMap[defaults.ModelArtifactTypeName], typesMap[defaults.DocArtifactTypeName], typesMap[defaults.DataSetTypeName], typesMap[defaults.MetricTypeName], typesMap[defaults.ParameterTypeName], typesMap[defaults.MetricHistoryTypeName])
|
||||
artifactRepo := service.NewArtifactRepository(sharedDB, map[string]int64{
|
||||
defaults.ModelArtifactTypeName: typesMap[defaults.ModelArtifactTypeName],
|
||||
defaults.DocArtifactTypeName: typesMap[defaults.DocArtifactTypeName],
|
||||
defaults.DataSetTypeName: typesMap[defaults.DataSetTypeName],
|
||||
defaults.MetricTypeName: typesMap[defaults.MetricTypeName],
|
||||
defaults.ParameterTypeName: typesMap[defaults.ParameterTypeName],
|
||||
defaults.MetricHistoryTypeName: typesMap[defaults.MetricHistoryTypeName],
|
||||
})
|
||||
modelArtifactRepo := service.NewModelArtifactRepository(sharedDB, typesMap[defaults.ModelArtifactTypeName])
|
||||
docArtifactRepo := service.NewDocArtifactRepository(sharedDB, typesMap[defaults.DocArtifactTypeName])
|
||||
registeredModelRepo := service.NewRegisteredModelRepository(sharedDB, typesMap[defaults.RegisteredModelTypeName])
|
||||
|
|
@ -129,28 +134,15 @@ func setDirtySchemaState(t *testing.T, sharedDB *gorm.DB) {
|
|||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// createTestDatastore creates a datastore config for testing
|
||||
func createTestDatastore(sharedDSN string) datastore.Datastore {
|
||||
return datastore.Datastore{
|
||||
Type: "embedmd",
|
||||
EmbedMD: embedmd.EmbedMDConfig{
|
||||
DatabaseType: "mysql",
|
||||
DatabaseDSN: sharedDSN,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadinessHandler_EmbedMD_Success(t *testing.T) {
|
||||
// Ensure clean state before test
|
||||
sharedDB, sharedDSN, _, cleanup := setupTestDB(t)
|
||||
sharedDB, _, _, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
cleanupSchemaState(t, sharedDB)
|
||||
|
||||
ds := createTestDatastore(sharedDSN)
|
||||
|
||||
dbHealthChecker := NewDatabaseHealthChecker(ds)
|
||||
handler := GeneralReadinessHandler(ds, dbHealthChecker)
|
||||
dbHealthChecker := NewDatabaseHealthChecker()
|
||||
handler := GeneralReadinessHandler(dbHealthChecker)
|
||||
req, err := http.NewRequest("GET", "/readyz/isDirty", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -163,16 +155,14 @@ func TestReadinessHandler_EmbedMD_Success(t *testing.T) {
|
|||
|
||||
func TestReadinessHandler_EmbedMD_Dirty(t *testing.T) {
|
||||
// Set dirty state for this test
|
||||
sharedDB, sharedDSN, _, cleanup := setupTestDB(t)
|
||||
sharedDB, _, _, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
setDirtySchemaState(t, sharedDB)
|
||||
defer cleanupSchemaState(t, sharedDB)
|
||||
|
||||
ds := createTestDatastore(sharedDSN)
|
||||
|
||||
dbHealthChecker := NewDatabaseHealthChecker(ds)
|
||||
handler := GeneralReadinessHandler(ds, dbHealthChecker)
|
||||
dbHealthChecker := NewDatabaseHealthChecker()
|
||||
handler := GeneralReadinessHandler(dbHealthChecker)
|
||||
req, err := http.NewRequest("GET", "/readyz/isDirty", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -208,17 +198,15 @@ func TestReadinessHandler_EmbedMD_Dirty(t *testing.T) {
|
|||
|
||||
func TestGeneralReadinessHandler_WithModelRegistry_Success(t *testing.T) {
|
||||
// Ensure clean state before test
|
||||
sharedDB, sharedDSN, sharedModelRegistryService, cleanup := setupTestDB(t)
|
||||
sharedDB, _, sharedModelRegistryService, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
cleanupSchemaState(t, sharedDB)
|
||||
|
||||
ds := createTestDatastore(sharedDSN)
|
||||
|
||||
// Create both health checkers
|
||||
dbHealthChecker := NewDatabaseHealthChecker(ds)
|
||||
dbHealthChecker := NewDatabaseHealthChecker()
|
||||
mrHealthChecker := NewModelRegistryHealthChecker(sharedModelRegistryService)
|
||||
handler := GeneralReadinessHandler(ds, dbHealthChecker, mrHealthChecker)
|
||||
handler := GeneralReadinessHandler(dbHealthChecker, mrHealthChecker)
|
||||
|
||||
req, err := http.NewRequest("GET", "/readyz/health", nil)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -232,17 +220,15 @@ func TestGeneralReadinessHandler_WithModelRegistry_Success(t *testing.T) {
|
|||
|
||||
func TestGeneralReadinessHandler_WithModelRegistry_JSONFormat(t *testing.T) {
|
||||
// Ensure clean state before test
|
||||
sharedDB, sharedDSN, sharedModelRegistryService, cleanup := setupTestDB(t)
|
||||
sharedDB, _, sharedModelRegistryService, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
cleanupSchemaState(t, sharedDB)
|
||||
|
||||
ds := createTestDatastore(sharedDSN)
|
||||
|
||||
// Create both health checkers
|
||||
dbHealthChecker := NewDatabaseHealthChecker(ds)
|
||||
dbHealthChecker := NewDatabaseHealthChecker()
|
||||
mrHealthChecker := NewModelRegistryHealthChecker(sharedModelRegistryService)
|
||||
handler := GeneralReadinessHandler(ds, dbHealthChecker, mrHealthChecker)
|
||||
handler := GeneralReadinessHandler(dbHealthChecker, mrHealthChecker)
|
||||
|
||||
req, err := http.NewRequest("GET", "/readyz/health?format=json", nil)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -279,18 +265,16 @@ func TestGeneralReadinessHandler_WithModelRegistry_JSONFormat(t *testing.T) {
|
|||
|
||||
func TestGeneralReadinessHandler_WithModelRegistry_DatabaseFail(t *testing.T) {
|
||||
// Set dirty state to make database check fail
|
||||
sharedDB, sharedDSN, sharedModelRegistryService, cleanup := setupTestDB(t)
|
||||
sharedDB, _, sharedModelRegistryService, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
setDirtySchemaState(t, sharedDB)
|
||||
defer cleanupSchemaState(t, sharedDB)
|
||||
|
||||
ds := createTestDatastore(sharedDSN)
|
||||
|
||||
// Create both health checkers
|
||||
dbHealthChecker := NewDatabaseHealthChecker(ds)
|
||||
dbHealthChecker := NewDatabaseHealthChecker()
|
||||
mrHealthChecker := NewModelRegistryHealthChecker(sharedModelRegistryService)
|
||||
handler := GeneralReadinessHandler(ds, dbHealthChecker, mrHealthChecker)
|
||||
handler := GeneralReadinessHandler(dbHealthChecker, mrHealthChecker)
|
||||
|
||||
req, err := http.NewRequest("GET", "/readyz/health?format=json", nil)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -319,17 +303,15 @@ func TestGeneralReadinessHandler_WithModelRegistry_DatabaseFail(t *testing.T) {
|
|||
|
||||
func TestGeneralReadinessHandler_WithModelRegistry_ModelRegistryNil(t *testing.T) {
|
||||
// Ensure clean state before test
|
||||
sharedDB, sharedDSN, _, cleanup := setupTestDB(t)
|
||||
sharedDB, _, _, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
cleanupSchemaState(t, sharedDB)
|
||||
|
||||
ds := createTestDatastore(sharedDSN)
|
||||
|
||||
// Create health checkers - with nil model registry service
|
||||
dbHealthChecker := NewDatabaseHealthChecker(ds)
|
||||
dbHealthChecker := NewDatabaseHealthChecker()
|
||||
mrHealthChecker := NewModelRegistryHealthChecker(nil)
|
||||
handler := GeneralReadinessHandler(ds, dbHealthChecker, mrHealthChecker)
|
||||
handler := GeneralReadinessHandler(dbHealthChecker, mrHealthChecker)
|
||||
|
||||
req, err := http.NewRequest("GET", "/readyz/health?format=json", nil)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -358,18 +340,16 @@ func TestGeneralReadinessHandler_WithModelRegistry_ModelRegistryNil(t *testing.T
|
|||
|
||||
func TestGeneralReadinessHandler_SimpleTextResponse_Failure(t *testing.T) {
|
||||
// Set dirty state to make database check fail
|
||||
sharedDB, sharedDSN, sharedModelRegistryService, cleanup := setupTestDB(t)
|
||||
sharedDB, _, sharedModelRegistryService, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
setDirtySchemaState(t, sharedDB)
|
||||
defer cleanupSchemaState(t, sharedDB)
|
||||
|
||||
ds := createTestDatastore(sharedDSN)
|
||||
|
||||
// Create both health checkers
|
||||
dbHealthChecker := NewDatabaseHealthChecker(ds)
|
||||
dbHealthChecker := NewDatabaseHealthChecker()
|
||||
mrHealthChecker := NewModelRegistryHealthChecker(sharedModelRegistryService)
|
||||
handler := GeneralReadinessHandler(ds, dbHealthChecker, mrHealthChecker)
|
||||
handler := GeneralReadinessHandler(dbHealthChecker, mrHealthChecker)
|
||||
|
||||
req, err := http.NewRequest("GET", "/readyz/health", nil)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -382,36 +362,17 @@ func TestGeneralReadinessHandler_SimpleTextResponse_Failure(t *testing.T) {
|
|||
assert.Contains(t, rr.Body.String(), "database schema is in dirty state")
|
||||
}
|
||||
|
||||
func TestDatabaseHealthChecker_EmptyDSN(t *testing.T) {
|
||||
ds := datastore.Datastore{
|
||||
Type: "embedmd",
|
||||
EmbedMD: embedmd.EmbedMDConfig{
|
||||
DatabaseType: "mysql",
|
||||
DatabaseDSN: "", // Empty DSN
|
||||
},
|
||||
}
|
||||
|
||||
checker := NewDatabaseHealthChecker(ds)
|
||||
result := checker.Check()
|
||||
|
||||
assert.Equal(t, HealthCheckDatabase, result.Name)
|
||||
assert.Equal(t, StatusFail, result.Status)
|
||||
assert.Equal(t, "database DSN not configured", result.Message)
|
||||
}
|
||||
|
||||
func TestGeneralReadinessHandler_MultipleFailures(t *testing.T) {
|
||||
// Test with both database and model registry failing
|
||||
sharedDB, sharedDSN, _, cleanup := setupTestDB(t)
|
||||
sharedDB, _, _, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
setDirtySchemaState(t, sharedDB)
|
||||
defer cleanupSchemaState(t, sharedDB)
|
||||
|
||||
ds := createTestDatastore(sharedDSN)
|
||||
|
||||
dbHealthChecker := NewDatabaseHealthChecker(ds)
|
||||
dbHealthChecker := NewDatabaseHealthChecker()
|
||||
mrHealthChecker := NewModelRegistryHealthChecker(nil) // Nil service to make it fail
|
||||
handler := GeneralReadinessHandler(ds, dbHealthChecker, mrHealthChecker)
|
||||
handler := GeneralReadinessHandler(dbHealthChecker, mrHealthChecker)
|
||||
|
||||
req, err := http.NewRequest("GET", "/readyz/health?format=json", nil)
|
||||
require.NoError(t, err)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import (
|
|||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/kubeflow/model-registry/internal/datastore"
|
||||
"github.com/kubeflow/model-registry/internal/datastore/embedmd"
|
||||
"github.com/kubeflow/model-registry/internal/datastore/embedmd/mysql"
|
||||
"github.com/kubeflow/model-registry/internal/tls"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -102,14 +104,18 @@ func CleanupSharedMySQL() {
|
|||
}
|
||||
|
||||
// SetupMySQLWithMigrations returns a migrated MySQL database connection
|
||||
func SetupMySQLWithMigrations(t *testing.T) (*gorm.DB, func()) {
|
||||
func SetupMySQLWithMigrations(t *testing.T, spec *datastore.Spec) (*gorm.DB, func()) {
|
||||
db, cleanup := GetSharedMySQLDB(t)
|
||||
|
||||
// Run migrations
|
||||
migrator, err := mysql.NewMySQLMigrator(db)
|
||||
require.NoError(t, err)
|
||||
err = migrator.Migrate()
|
||||
require.NoError(t, err)
|
||||
ds, err := datastore.NewConnector("embedmd", &embedmd.EmbedMDConfig{DB: db})
|
||||
if err != nil {
|
||||
t.Fatalf("unable get datastore connector: %v", err)
|
||||
}
|
||||
|
||||
_, err = ds.Connect(spec)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to connect to datastore: %v", err)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/kubeflow/model-registry/internal/db/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
|
@ -32,7 +33,7 @@ func TestSharedMySQLUtility(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("SetupMySQLWithMigrations", func(t *testing.T) {
|
||||
db, cleanup := SetupMySQLWithMigrations(t)
|
||||
db, cleanup := SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
// Test that migrations were applied
|
||||
|
|
@ -43,7 +44,7 @@ func TestSharedMySQLUtility(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("CleanupTestData", func(t *testing.T) {
|
||||
db, cleanup := SetupMySQLWithMigrations(t)
|
||||
db, cleanup := SetupMySQLWithMigrations(t, service.DatastoreSpec())
|
||||
defer cleanup()
|
||||
|
||||
// Insert some test data
|
||||
|
|
|
|||
|
|
@ -4,10 +4,16 @@ This directory contains manifests for deploying the Model Catalog using Kustomiz
|
|||
|
||||
## Deployment
|
||||
|
||||
To deploy the Model Catalog to your Kubernetes cluster, run the following command in the `base` directory:
|
||||
The model catalog manifests deploy a PostgreSQL database, set `POSTGRES_PASSWORD` in `postgres.env` before deploying the manifests. On Linux, you can generate a random password with:
|
||||
|
||||
```sh
|
||||
kubectl apply -k . -n <your-namespace>
|
||||
(echo POSTGRES_USER=postgres ; echo -n POSTGRES_PASSWORD=; dd if=/dev/random of=/dev/stdout bs=15 count=1 status=none | base64) >base/postgres.env
|
||||
```
|
||||
|
||||
To deploy the Model Catalog to your Kubernetes cluster (without Kubeflow--see below for Istio support), run the following command from this directory:
|
||||
|
||||
```sh
|
||||
kubectl apply -k base -n <your-namespace>
|
||||
```
|
||||
|
||||
Replace `<your-namespace>` with the Kubernetes namespace where you want to deploy the catalog.
|
||||
|
|
@ -16,11 +22,13 @@ This command will create:
|
|||
* A `Deployment` to run the Model Catalog server.
|
||||
* A `Service` to expose the Model Catalog server.
|
||||
* A `ConfigMap` named `model-catalog-sources` containing the configuration for the catalog sources.
|
||||
* A `StatefulSet` with a PostgreSQL database
|
||||
* A `PersistentVolumeClaim` for PostgreSQL
|
||||
|
||||
For deployment in a Kubeflow environment with Istio support, use the overlay, running the following command in the `overlay` directory::
|
||||
For deployment in a Kubeflow environment with Istio support, use the `overlay` directory instead:
|
||||
|
||||
```sh
|
||||
kubectl apply -k . -n <your-namespace>
|
||||
kubectl apply -k overlay -n <your-namespace>
|
||||
```
|
||||
|
||||
## Configuring Catalog Sources
|
||||
|
|
@ -118,4 +126,4 @@ catalogs:
|
|||
|
||||
### Sample Catalog File
|
||||
|
||||
You can refer to `sample-catalog.yaml` in this directory for an example of how to structure your model definitions file.
|
||||
You can refer to `sample-catalog.yaml` in this directory for an example of how to structure your model definitions file.
|
||||
|
|
|
|||
|
|
@ -3,17 +3,23 @@ kind: Deployment
|
|||
metadata:
|
||||
name: model-catalog-server
|
||||
labels:
|
||||
component: model-catalog-server
|
||||
app.kubernetes.io/name: model-catalog
|
||||
app.kubernetes.io/part-of: model-catalog
|
||||
app.kubernetes.io/component: server
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
component: model-catalog-server
|
||||
app.kubernetes.io/name: model-catalog
|
||||
app.kubernetes.io/part-of: model-catalog
|
||||
app.kubernetes.io/component: server
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
sidecar.istio.io/inject: "true"
|
||||
component: model-catalog-server
|
||||
app.kubernetes.io/name: model-catalog
|
||||
app.kubernetes.io/part-of: model-catalog
|
||||
app.kubernetes.io/component: server
|
||||
spec:
|
||||
securityContext:
|
||||
seccompProfile:
|
||||
|
|
@ -25,6 +31,19 @@ spec:
|
|||
name: model-catalog-sources
|
||||
containers:
|
||||
- name: catalog
|
||||
env:
|
||||
- name: PGHOST
|
||||
value: model-catalog-postgres
|
||||
- name: PGUSER
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: model-catalog-postgres
|
||||
key: POSTGRES_USER
|
||||
- name: PGPASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: model-catalog-postgres
|
||||
key: POSTGRES_PASSWORD
|
||||
command:
|
||||
- /model-registry
|
||||
- catalog
|
||||
|
|
|
|||
|
|
@ -4,6 +4,9 @@ kind: Kustomization
|
|||
resources:
|
||||
- deployment.yaml
|
||||
- service.yaml
|
||||
- postgres-statefulset.yaml
|
||||
- postgres-pvc.yaml
|
||||
- postgres-service.yaml
|
||||
|
||||
configMapGenerator:
|
||||
- behavior: create
|
||||
|
|
@ -13,3 +16,10 @@ configMapGenerator:
|
|||
name: model-catalog-sources
|
||||
options:
|
||||
disableNameSuffixHash: true
|
||||
|
||||
secretGenerator:
|
||||
- name: postgres
|
||||
envs:
|
||||
- postgres.env
|
||||
options:
|
||||
disableNameSuffixHash: true
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: model-catalog-postgres
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
resources:
|
||||
requests:
|
||||
storage: 5Gi
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
kind: Service
|
||||
apiVersion: v1
|
||||
metadata:
|
||||
labels:
|
||||
app.kubernetes.io/name: model-catalog
|
||||
app.kubernetes.io/part-of: model-catalog
|
||||
app.kubernetes.io/component: database
|
||||
name: model-catalog-postgres
|
||||
spec:
|
||||
selector:
|
||||
app.kubernetes.io/part-of: model-catalog
|
||||
app.kubernetes.io/name: postgres
|
||||
app.kubernetes.io/component: database
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: 5432
|
||||
protocol: TCP
|
||||
name: postgres
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
apiVersion: apps/v1
|
||||
kind: StatefulSet
|
||||
metadata:
|
||||
name: model-catalog-postgres
|
||||
labels:
|
||||
component: model-catalog-server
|
||||
spec:
|
||||
selector:
|
||||
matchLabels:
|
||||
app.kubernetes.io/name: postgres
|
||||
app.kubernetes.io/part-of: model-catalog
|
||||
replicas: 1
|
||||
template:
|
||||
metadata:
|
||||
name: postgres
|
||||
labels:
|
||||
app.kubernetes.io/name: postgres
|
||||
app.kubernetes.io/part-of: model-catalog
|
||||
app.kubernetes.io/component: database
|
||||
sidecar.istio.io/inject: "false"
|
||||
spec:
|
||||
securityContext:
|
||||
seccompProfile:
|
||||
type: RuntimeDefault
|
||||
runAsNonRoot: true
|
||||
fsGroup: 70
|
||||
containers:
|
||||
- name: postgres
|
||||
image: postgres:17.6
|
||||
env:
|
||||
- name: PGDATA
|
||||
value: /var/lib/postgresql/data/pgdata
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: model-catalog-postgres
|
||||
ports:
|
||||
- name: postgres
|
||||
containerPort: 5432
|
||||
volumeMounts:
|
||||
- name: model-catalog-postgres
|
||||
mountPath: /var/lib/postgresql/data
|
||||
securityContext:
|
||||
runAsUser: 70
|
||||
runAsGroup: 70
|
||||
allowPrivilegeEscalation: false
|
||||
capabilities:
|
||||
drop:
|
||||
- ALL
|
||||
volumes:
|
||||
- name: model-catalog-postgres
|
||||
persistentVolumeClaim:
|
||||
claimName: model-catalog-postgres
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=postgres
|
||||
|
|
@ -2,16 +2,14 @@ kind: Service
|
|||
apiVersion: v1
|
||||
metadata:
|
||||
labels:
|
||||
app: model-catalog-service
|
||||
app.kubernetes.io/component: model-catalog
|
||||
app.kubernetes.io/instance: model-catalog-service
|
||||
app.kubernetes.io/name: model-catalog
|
||||
app.kubernetes.io/component: server
|
||||
app.kubernetes.io/part-of: model-catalog
|
||||
component: model-catalog
|
||||
name: model-catalog
|
||||
spec:
|
||||
selector:
|
||||
component: model-catalog-server
|
||||
app.kubernetes.io/part-of: model-catalog
|
||||
app.kubernetes.io/component: server
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: 8080
|
||||
|
|
|
|||
Loading…
Reference in New Issue