217 lines
6.5 KiB
Go
217 lines
6.5 KiB
Go
package core
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
|
|
"github.com/golang/glog"
|
|
"github.com/kubeflow/model-registry/internal/apiutils"
|
|
"github.com/kubeflow/model-registry/internal/converter"
|
|
"github.com/kubeflow/model-registry/internal/db/models"
|
|
"github.com/kubeflow/model-registry/pkg/api"
|
|
"github.com/kubeflow/model-registry/pkg/openapi"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
func (b *ModelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVersion, registeredModelId *string) (*openapi.ModelVersion, error) {
|
|
if modelVersion == nil {
|
|
return nil, fmt.Errorf("invalid model version pointer, cannot be nil: %w", api.ErrBadRequest)
|
|
}
|
|
|
|
if modelVersion.Id != nil {
|
|
existing, err := b.GetModelVersionById(*modelVersion.Id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
withNotEditable, err := b.mapper.UpdateExistingModelVersion(converter.NewOpenapiUpdateWrapper(existing, modelVersion))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
|
|
}
|
|
modelVersion = &withNotEditable
|
|
}
|
|
|
|
if registeredModelId != nil {
|
|
modelVersion.RegisteredModelId = *registeredModelId
|
|
}
|
|
|
|
model, err := b.mapper.MapFromModelVersion(modelVersion, registeredModelId)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
|
|
}
|
|
|
|
savedModel, err := b.modelVersionRepository.Save(model)
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
|
return nil, fmt.Errorf("model version with name %s already exists: %w", modelVersion.Name, api.ErrConflict)
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
toReturn, err := b.mapper.MapToModelVersion(savedModel)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
|
|
}
|
|
|
|
return toReturn, nil
|
|
}
|
|
|
|
func (b *ModelRegistryService) GetModelVersionById(id string) (*openapi.ModelVersion, error) {
|
|
glog.Infof("Getting ModelVersion by id %s", id)
|
|
|
|
convertedId, err := apiutils.ValidateIDAsInt32(id, "model version")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
model, err := b.modelVersionRepository.GetByID(convertedId)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("no model version found for id %s: %w", id, api.ErrNotFound)
|
|
}
|
|
|
|
toReturn, err := b.mapper.MapToModelVersion(model)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
|
|
}
|
|
|
|
return toReturn, nil
|
|
}
|
|
|
|
func (b *ModelRegistryService) GetModelVersionByInferenceService(inferenceServiceId string) (*openapi.ModelVersion, error) {
|
|
convertedId, err := apiutils.ValidateIDAsInt32(inferenceServiceId, "inference service")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
infSvc, err := b.inferenceServiceRepository.GetByID(convertedId)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("no inference service found for id %s: %w", inferenceServiceId, api.ErrNotFound)
|
|
}
|
|
|
|
infSvcProps := infSvc.GetProperties()
|
|
|
|
if infSvcProps == nil {
|
|
return nil, fmt.Errorf("no registered model found for inference service")
|
|
}
|
|
|
|
modelVersionID := int32(0)
|
|
|
|
for _, prop := range *infSvcProps {
|
|
if prop.Name == "model_version_id" {
|
|
modelVersionID = *prop.IntValue
|
|
break
|
|
}
|
|
}
|
|
|
|
if modelVersionID != 0 {
|
|
return b.GetModelVersionById(strconv.Itoa(int(modelVersionID)))
|
|
}
|
|
|
|
registeredModelID := ""
|
|
|
|
for _, prop := range *infSvcProps {
|
|
if prop.Name == "registered_model_id" {
|
|
registeredModelID = strconv.Itoa(int(*prop.IntValue))
|
|
break
|
|
}
|
|
}
|
|
// modelVersionId: ID of the ModelVersion to serve. If it's unspecified, then the latest ModelVersion by creation order will be served.
|
|
orderByCreateTime := "CREATE_TIME"
|
|
sortOrderDesc := "DESC"
|
|
versions, err := b.GetModelVersions(api.ListOptions{OrderBy: &orderByCreateTime, SortOrder: &sortOrderDesc}, ®isteredModelID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(versions.Items) == 0 {
|
|
return nil, fmt.Errorf("no model versions found for id %s: %w", inferenceServiceId, api.ErrNotFound)
|
|
}
|
|
|
|
return &versions.Items[0], nil
|
|
}
|
|
|
|
func (b *ModelRegistryService) GetModelVersionByParams(versionName *string, registeredModelId *string, externalId *string) (*openapi.ModelVersion, error) {
|
|
if (versionName == nil || registeredModelId == nil) && externalId == nil {
|
|
return nil, fmt.Errorf("invalid parameters call, supply either (versionName and registeredModelId), or externalId: %w", api.ErrBadRequest)
|
|
}
|
|
|
|
var parentResourceID *int32
|
|
if registeredModelId != nil {
|
|
var err error
|
|
parentResourceID, err = apiutils.ValidateIDAsInt32Ptr(registeredModelId, "registered model")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
versionsList, err := b.modelVersionRepository.List(models.ModelVersionListOptions{
|
|
Name: versionName,
|
|
ExternalID: externalId,
|
|
ParentResourceID: parentResourceID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(versionsList.Items) > 1 {
|
|
return nil, fmt.Errorf("multiple model versions found for versionName=%v, registeredModelId=%v, externalId=%v: %w", apiutils.ZeroIfNil(versionName), apiutils.ZeroIfNil(registeredModelId), apiutils.ZeroIfNil(externalId), api.ErrNotFound)
|
|
}
|
|
|
|
if len(versionsList.Items) == 0 {
|
|
return nil, fmt.Errorf("no model versions found for versionName=%v, registeredModelId=%v, externalId=%v: %w", apiutils.ZeroIfNil(versionName), apiutils.ZeroIfNil(registeredModelId), apiutils.ZeroIfNil(externalId), api.ErrNotFound)
|
|
}
|
|
|
|
toReturn, err := b.mapper.MapToModelVersion(versionsList.Items[0])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
|
|
}
|
|
|
|
return toReturn, nil
|
|
}
|
|
|
|
func (b *ModelRegistryService) GetModelVersions(listOptions api.ListOptions, registeredModelId *string) (*openapi.ModelVersionList, error) {
|
|
var parentResourceID *int32
|
|
|
|
if registeredModelId != nil {
|
|
var err error
|
|
parentResourceID, err = apiutils.ValidateIDAsInt32Ptr(registeredModelId, "registered model")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
versionsList, err := b.modelVersionRepository.List(models.ModelVersionListOptions{
|
|
Pagination: models.Pagination{
|
|
PageSize: listOptions.PageSize,
|
|
OrderBy: listOptions.OrderBy,
|
|
SortOrder: listOptions.SortOrder,
|
|
NextPageToken: listOptions.NextPageToken,
|
|
FilterQuery: listOptions.FilterQuery,
|
|
},
|
|
ParentResourceID: parentResourceID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
modelVersionList := &openapi.ModelVersionList{
|
|
Items: []openapi.ModelVersion{},
|
|
}
|
|
|
|
for _, model := range versionsList.Items {
|
|
modelVersion, err := b.mapper.MapToModelVersion(model)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
|
|
}
|
|
modelVersionList.Items = append(modelVersionList.Items, *modelVersion)
|
|
}
|
|
|
|
modelVersionList.NextPageToken = versionsList.NextPageToken
|
|
modelVersionList.PageSize = versionsList.PageSize
|
|
modelVersionList.Size = int32(versionsList.Size)
|
|
|
|
return modelVersionList, nil
|
|
}
|