model-registry/cmd/proxy.go

252 lines
8.2 KiB
Go

package cmd
import (
"fmt"
"net/http"
"reflect"
"strings"
"sync"
"github.com/golang/glog"
"github.com/kubeflow/model-registry/internal/core"
"github.com/kubeflow/model-registry/internal/datastore"
"github.com/kubeflow/model-registry/internal/datastore/embedmd"
"github.com/kubeflow/model-registry/internal/db/models"
"github.com/kubeflow/model-registry/internal/db/service"
"github.com/kubeflow/model-registry/internal/proxy"
"github.com/kubeflow/model-registry/internal/server/middleware"
"github.com/kubeflow/model-registry/internal/server/openapi"
"github.com/kubeflow/model-registry/internal/tls"
"github.com/kubeflow/model-registry/pkg/api"
"github.com/spf13/cobra"
)
type ProxyConfig struct {
EmbedMD embedmd.EmbedMDConfig
DatastoreType string
}
const (
// datastoreUnavailableMessage is the message returned when the datastore service is down or unavailable.
datastoreUnavailableMessage = "Datastore service is down or unavailable. Please check that the database is reachable and try again later."
)
var (
proxyCfg = ProxyConfig{
DatastoreType: "embedmd",
EmbedMD: embedmd.EmbedMDConfig{
TLSConfig: &tls.TLSConfig{},
},
}
// proxyCmd represents the proxy command
proxyCmd = &cobra.Command{
Use: "proxy",
Short: "Starts the go OpenAPI proxy server to connect to a metadata store",
Long: `This command launches the go OpenAPI proxy server.
The server connects to a metadata store, currently only the internal store is supported. It supports options to customize the
hostname and port where it listens.`,
RunE: runProxyServer,
}
)
// ModelRegistryServiceHolder safely holds the model registry service
type ModelRegistryServiceHolder struct {
mu sync.RWMutex
service api.ModelRegistryApi
}
func (h *ModelRegistryServiceHolder) Set(service api.ModelRegistryApi) {
h.mu.Lock()
defer h.mu.Unlock()
h.service = service
}
func (h *ModelRegistryServiceHolder) Get() api.ModelRegistryApi {
h.mu.RLock()
defer h.mu.RUnlock()
return h.service
}
// ConditionalModelRegistryHealthChecker checks model registry health only if service is available
type ConditionalModelRegistryHealthChecker struct {
holder *ModelRegistryServiceHolder
}
func (c *ConditionalModelRegistryHealthChecker) Check() proxy.HealthCheck {
service := c.holder.Get()
if service == nil {
return proxy.HealthCheck{
Name: proxy.HealthCheckModelRegistry,
Status: proxy.StatusFail,
Message: "model registry service not yet initialized",
Details: map[string]interface{}{
"service_ready": false,
},
}
}
checker := proxy.NewModelRegistryHealthChecker(service)
return checker.Check()
}
func runProxyServer(cmd *cobra.Command, args []string) error {
var (
ds datastore.Connector
wg sync.WaitGroup
)
serviceHolder := &ModelRegistryServiceHolder{}
router := proxy.NewDynamicRouter()
router.SetRouter(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, datastoreUnavailableMessage, http.StatusServiceUnavailable)
}))
readyChecks := []proxy.HealthChecker{}
generalChecks := []proxy.HealthChecker{
&ConditionalModelRegistryHealthChecker{holder: serviceHolder},
}
if proxyCfg.DatastoreType == "embedmd" {
dbHealthChecker := proxy.NewDatabaseHealthChecker()
readyChecks = append(readyChecks, dbHealthChecker)
generalChecks = append(generalChecks, dbHealthChecker)
}
generalReadinessHandler := proxy.GeneralReadinessHandler(generalChecks...)
readinessHandler := proxy.GeneralReadinessHandler(readyChecks...)
// route health endpoints appropriately
mainHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/readyz/isDirty") {
readinessHandler.ServeHTTP(w, r)
return
}
if strings.HasSuffix(r.URL.Path, "/readyz/health") {
generalReadinessHandler.ServeHTTP(w, r)
return
}
router.ServeHTTP(w, r)
})
errChan := make(chan error, 1)
wg.Add(2)
go func() {
defer close(errChan)
wg.Wait()
}()
// Start the connection to the Datastore server in a separate goroutine, so that
// we can start the proxy server and start serving requests while we wait
// for the connection to be established.
go func() {
var (
err error
)
defer wg.Done()
ds, err = datastore.NewConnector(proxyCfg.DatastoreType, &proxyCfg.EmbedMD)
if err != nil {
errChan <- fmt.Errorf("error creating datastore: %w", err)
return
}
conn, err := newModelRegistryService(ds)
if err != nil {
// {{ALERT}} is used to identify this error in pod logs, DO NOT REMOVE
errChan <- fmt.Errorf("{{ALERT}} error connecting to datastore: %w", err)
return
}
// Set the model registry service in the holder for health checks
serviceHolder.Set(conn)
ModelRegistryServiceAPIService := openapi.NewModelRegistryServiceAPIService(conn)
ModelRegistryServiceAPIController := openapi.NewModelRegistryServiceAPIController(ModelRegistryServiceAPIService)
router.SetRouter(middleware.WrapWithValidation(ModelRegistryServiceAPIController))
}()
// Start the proxy server in a separate goroutine so that we can handle
// errors from both the proxy server and the connection to the Datastore server.
go func() {
defer wg.Done()
glog.Infof("Proxy server started at %s:%v", cfg.Hostname, cfg.Port)
err := http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), mainHandler)
if err != nil {
errChan <- fmt.Errorf("error starting proxy server: %w", err)
}
}()
// Wait for either the Datastore server connection or the proxy server to return an error
// or for both to finish successfully.
return <-errChan
}
func newModelRegistryService(ds datastore.Connector) (api.ModelRegistryApi, error) {
repoSet, err := ds.Connect(service.DatastoreSpec())
if err != nil {
return nil, err
}
modelRegistryService := core.NewModelRegistryService(
getRepo[models.ArtifactRepository](repoSet),
getRepo[models.ModelArtifactRepository](repoSet),
getRepo[models.DocArtifactRepository](repoSet),
getRepo[models.RegisteredModelRepository](repoSet),
getRepo[models.ModelVersionRepository](repoSet),
getRepo[models.ServingEnvironmentRepository](repoSet),
getRepo[models.InferenceServiceRepository](repoSet),
getRepo[models.ServeModelRepository](repoSet),
getRepo[models.ExperimentRepository](repoSet),
getRepo[models.ExperimentRunRepository](repoSet),
getRepo[models.DataSetRepository](repoSet),
getRepo[models.MetricRepository](repoSet),
getRepo[models.ParameterRepository](repoSet),
getRepo[models.MetricHistoryRepository](repoSet),
repoSet.TypeMap(),
)
glog.Infof("EmbedMD service connected")
return modelRegistryService, nil
}
func getRepo[T any](repoSet datastore.RepoSet) T {
repo, err := repoSet.Repository(reflect.TypeFor[T]())
if err != nil {
panic(fmt.Sprintf("unable to get repository: %v", err))
}
return repo.(T)
}
func init() {
rootCmd.AddCommand(proxyCmd)
proxyCmd.Flags().StringVarP(&cfg.Hostname, "hostname", "n", cfg.Hostname, "Proxy server listen hostname")
proxyCmd.Flags().IntVarP(&cfg.Port, "port", "p", cfg.Port, "Proxy server listen port")
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.DatabaseType, "embedmd-database-type", "mysql", "EmbedMD database type")
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.DatabaseDSN, "embedmd-database-dsn", "", "EmbedMD database DSN")
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.CertPath, "embedmd-database-ssl-cert", "", "EmbedMD SSL cert path")
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.KeyPath, "embedmd-database-ssl-key", "", "EmbedMD SSL key path")
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.RootCertPath, "embedmd-database-ssl-root-cert", "", "EmbedMD SSL root cert path")
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.CAPath, "embedmd-database-ssl-ca", "", "EmbedMD SSL CA path")
proxyCmd.Flags().StringVar(&proxyCfg.EmbedMD.TLSConfig.Cipher, "embedmd-database-ssl-cipher", "", "Colon-separated list of allowed TLS ciphers for the EmbedMD database connection. Values are from the list at https://pkg.go.dev/crypto/tls#pkg-constants e.g. 'TLS_AES_128_GCM_SHA256:TLS_CHACHA20_POLY1305_SHA256'")
proxyCmd.Flags().BoolVar(&proxyCfg.EmbedMD.TLSConfig.VerifyServerCert, "embedmd-database-ssl-verify-server-cert", false, "EmbedMD SSL verify server cert")
proxyCmd.Flags().StringVar(&proxyCfg.DatastoreType, "datastore-type", proxyCfg.DatastoreType, "Datastore type")
}