188 lines
4.9 KiB
Go
188 lines
4.9 KiB
Go
package embedmd
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/golang/glog"
|
|
"github.com/kubeflow/model-registry/internal/apiutils"
|
|
"github.com/kubeflow/model-registry/internal/datastore"
|
|
"github.com/kubeflow/model-registry/internal/db"
|
|
"github.com/kubeflow/model-registry/internal/db/models"
|
|
"github.com/kubeflow/model-registry/internal/db/service"
|
|
"github.com/kubeflow/model-registry/internal/db/types"
|
|
"github.com/kubeflow/model-registry/internal/tls"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const connectorType = "embedmd"
|
|
|
|
func init() {
|
|
datastore.Register(connectorType, func(cfg any) (datastore.Connector, error) {
|
|
emdbCfg, ok := cfg.(*EmbedMDConfig)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid EmbedMD config type (%T)", cfg)
|
|
}
|
|
|
|
if err := emdbCfg.Validate(); err != nil {
|
|
return nil, fmt.Errorf("invalid EmbedMD config: %w", err)
|
|
}
|
|
|
|
return NewEmbedMDService(cfg.(*EmbedMDConfig))
|
|
})
|
|
}
|
|
|
|
type EmbedMDConfig struct {
|
|
DatabaseType string
|
|
DatabaseDSN string
|
|
TLSConfig *tls.TLSConfig
|
|
|
|
// DB is an already connected database instance that, if provided, will
|
|
// be used instead of making a new connection.
|
|
DB *gorm.DB
|
|
}
|
|
|
|
func (c *EmbedMDConfig) Validate() error {
|
|
if c.DB == nil {
|
|
if c.DatabaseType != types.DatabaseTypeMySQL && c.DatabaseType != types.DatabaseTypePostgres {
|
|
return fmt.Errorf("unsupported database type: %s. Supported types: %s, %s", c.DatabaseType, types.DatabaseTypeMySQL, types.DatabaseTypePostgres)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type EmbedMDService struct {
|
|
dbConnector db.Connector
|
|
}
|
|
|
|
func NewEmbedMDService(cfg *EmbedMDConfig) (*EmbedMDService, error) {
|
|
if cfg.DB != nil {
|
|
db.SetDB(cfg.DB)
|
|
} else {
|
|
err := db.Init(cfg.DatabaseType, cfg.DatabaseDSN, cfg.TLSConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to initialize database connector: %w", err)
|
|
}
|
|
}
|
|
|
|
dbConnector, ok := db.GetConnector()
|
|
if !ok {
|
|
return nil, fmt.Errorf("database connector not initialized")
|
|
}
|
|
|
|
return &EmbedMDService{
|
|
dbConnector: dbConnector,
|
|
}, nil
|
|
}
|
|
|
|
func (s *EmbedMDService) Connect(spec *datastore.Spec) (datastore.RepoSet, error) {
|
|
glog.Infof("Connecting to EmbedMD service...")
|
|
|
|
connectedDB, err := s.dbConnector.Connect()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
glog.Infof("Connected to EmbedMD service")
|
|
|
|
migrator, err := db.NewDBMigrator(connectedDB)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
glog.Infof("Running migrations...")
|
|
|
|
err = migrator.Migrate()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
glog.Infof("Migrations completed")
|
|
|
|
glog.Infof("Syncing types...")
|
|
err = s.syncTypes(connectedDB, spec)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
glog.Infof("Syncing types completed")
|
|
|
|
return newRepoSet(connectedDB, spec)
|
|
}
|
|
|
|
func (s EmbedMDService) Type() string {
|
|
return connectorType
|
|
}
|
|
|
|
const (
|
|
executionTypeKind int32 = iota
|
|
artifactTypeKind
|
|
contextTypeKind
|
|
)
|
|
|
|
func (s *EmbedMDService) syncTypes(conn *gorm.DB, spec *datastore.Spec) error {
|
|
idMap := make(map[string]int32, len(spec.ExecutionTypes)+len(spec.ArtifactTypes)+len(spec.ContextTypes))
|
|
var errs []error
|
|
|
|
typeRepository := service.NewTypeRepository(conn)
|
|
errs = append(errs, s.createTypes(typeRepository, spec.ExecutionTypes, executionTypeKind, idMap))
|
|
errs = append(errs, s.createTypes(typeRepository, spec.ArtifactTypes, artifactTypeKind, idMap))
|
|
errs = append(errs, s.createTypes(typeRepository, spec.ContextTypes, contextTypeKind, idMap))
|
|
|
|
typePropertyRepository := service.NewTypePropertyRepository(conn)
|
|
errs = append(errs, s.createTypeProperties(typePropertyRepository, spec.ExecutionTypes, idMap))
|
|
errs = append(errs, s.createTypeProperties(typePropertyRepository, spec.ArtifactTypes, idMap))
|
|
errs = append(errs, s.createTypeProperties(typePropertyRepository, spec.ContextTypes, idMap))
|
|
|
|
return errors.Join(errs...)
|
|
}
|
|
|
|
func (s *EmbedMDService) createTypes(repo models.TypeRepository, types map[string]*datastore.SpecType, kind int32, idMap map[string]int32) error {
|
|
var errs []error
|
|
for name := range types {
|
|
t, err := repo.Save(&models.TypeImpl{
|
|
Attributes: &models.TypeAttributes{
|
|
Name: &name,
|
|
TypeKind: &kind,
|
|
},
|
|
})
|
|
if err != nil {
|
|
errs = append(errs, fmt.Errorf("%s: unable to create type: %w", name, err))
|
|
continue
|
|
}
|
|
|
|
id := t.GetID()
|
|
if id == nil {
|
|
errs = append(errs, fmt.Errorf("%s: unable to determine type ID", name))
|
|
continue
|
|
}
|
|
idMap[name] = *id
|
|
}
|
|
|
|
return errors.Join(errs...)
|
|
}
|
|
|
|
func (s *EmbedMDService) createTypeProperties(repo models.TypePropertyRepository, types map[string]*datastore.SpecType, idMap map[string]int32) error {
|
|
var errs []error
|
|
for typeName, typeSpec := range types {
|
|
typeID := idMap[typeName]
|
|
if typeID == 0 {
|
|
errs = append(errs, fmt.Errorf("%s: unknown type", typeName))
|
|
continue
|
|
}
|
|
|
|
for name, dataType := range typeSpec.Properties {
|
|
_, err := repo.Save(&models.TypePropertyImpl{
|
|
TypeID: typeID,
|
|
Name: name,
|
|
DataType: apiutils.Of(int32(dataType)),
|
|
})
|
|
if err != nil {
|
|
errs = append(errs, fmt.Errorf("%s-%s: %w", typeName, name, err))
|
|
}
|
|
}
|
|
}
|
|
|
|
return errors.Join(errs...)
|
|
}
|