Improve core layer testing (#85)
* Improve core layer testing * Treat ids as string on service layer * Moved testutils inside internal package * Adapt test to name prefix implementation
This commit is contained in:
parent
270521ddcc
commit
a309537e8b
4
Makefile
4
Makefile
|
|
@ -167,6 +167,10 @@ test: gen
|
||||||
test-nocache: gen
|
test-nocache: gen
|
||||||
go test ./internal/... -count=1
|
go test ./internal/... -count=1
|
||||||
|
|
||||||
|
.PHONY: test-cover
|
||||||
|
test-cover: gen
|
||||||
|
go test ./internal/... -cover -count=1
|
||||||
|
|
||||||
.PHONY: run/migrate
|
.PHONY: run/migrate
|
||||||
run/migrate: gen
|
run/migrate: gen
|
||||||
go run main.go migrate --logtostderr=true -m config/metadata-library
|
go run main.go migrate --logtostderr=true -m config/metadata-library
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ type ModelRegistryApi interface {
|
||||||
// approach used by MLMD gRPC api. If Id is provided update the entity otherwise create a new one.
|
// approach used by MLMD gRPC api. If Id is provided update the entity otherwise create a new one.
|
||||||
UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error)
|
UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error)
|
||||||
|
|
||||||
GetRegisteredModelById(id *BaseResourceId) (*openapi.RegisteredModel, error)
|
GetRegisteredModelById(id string) (*openapi.RegisteredModel, error)
|
||||||
GetRegisteredModelByParams(name *string, externalId *string) (*openapi.RegisteredModel, error)
|
GetRegisteredModelByParams(name *string, externalId *string) (*openapi.RegisteredModel, error)
|
||||||
GetRegisteredModels(listOptions ListOptions) (*openapi.RegisteredModelList, error)
|
GetRegisteredModels(listOptions ListOptions) (*openapi.RegisteredModelList, error)
|
||||||
|
|
||||||
|
|
@ -27,19 +27,19 @@ type ModelRegistryApi interface {
|
||||||
|
|
||||||
// Create a new Model Version
|
// Create a new Model Version
|
||||||
// or update a Model Version associated to a specific RegisteredModel identified by parentResourceId parameter
|
// or update a Model Version associated to a specific RegisteredModel identified by parentResourceId parameter
|
||||||
UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *BaseResourceId) (*openapi.ModelVersion, error)
|
UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *string) (*openapi.ModelVersion, error)
|
||||||
|
|
||||||
GetModelVersionById(id *BaseResourceId) (*openapi.ModelVersion, error)
|
GetModelVersionById(id string) (*openapi.ModelVersion, error)
|
||||||
GetModelVersionByParams(versionName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelVersion, error)
|
GetModelVersionByParams(versionName *string, parentResourceId *string, externalId *string) (*openapi.ModelVersion, error)
|
||||||
GetModelVersions(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelVersionList, error)
|
GetModelVersions(listOptions ListOptions, parentResourceId *string) (*openapi.ModelVersionList, error)
|
||||||
|
|
||||||
// MODEL ARTIFACT
|
// MODEL ARTIFACT
|
||||||
|
|
||||||
// Create a new Artifact
|
// Create a new Artifact
|
||||||
// or update an Artifact associated to a specific ModelVersion identified by parentResourceId parameter
|
// or update an Artifact associated to a specific ModelVersion identified by parentResourceId parameter
|
||||||
UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *BaseResourceId) (*openapi.ModelArtifact, error)
|
UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *string) (*openapi.ModelArtifact, error)
|
||||||
|
|
||||||
GetModelArtifactById(id *BaseResourceId) (*openapi.ModelArtifact, error)
|
GetModelArtifactById(id string) (*openapi.ModelArtifact, error)
|
||||||
GetModelArtifactByParams(artifactName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelArtifact, error)
|
GetModelArtifactByParams(artifactName *string, parentResourceId *string, externalId *string) (*openapi.ModelArtifact, error)
|
||||||
GetModelArtifacts(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelArtifactList, error)
|
GetModelArtifacts(listOptions ListOptions, parentResourceId *string) (*openapi.ModelArtifactList, error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/opendatahub-io/model-registry/internal/core/mapper"
|
"github.com/opendatahub-io/model-registry/internal/core/mapper"
|
||||||
"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
|
"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
|
||||||
|
|
@ -80,7 +79,11 @@ func NewModelRegistryService(cc grpc.ClientConnInterface) (ModelRegistryApi, err
|
||||||
// REGISTERED MODELS
|
// REGISTERED MODELS
|
||||||
|
|
||||||
func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error) {
|
func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error) {
|
||||||
log.Printf("Creating or updating registered model for %s", *registeredModel.Name)
|
if registeredModel.Id == nil {
|
||||||
|
log.Printf("Creating registered model for %s", *registeredModel.Name)
|
||||||
|
} else {
|
||||||
|
log.Printf("Updating registered model %s for %s", *registeredModel.Id, *registeredModel.Name)
|
||||||
|
}
|
||||||
|
|
||||||
modelCtx, err := serv.mapper.MapFromRegisteredModel(registeredModel)
|
modelCtx, err := serv.mapper.MapFromRegisteredModel(registeredModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -96,8 +99,8 @@ func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
modelId := &modelCtxResp.ContextIds[0]
|
idAsString := mapper.IdToString(modelCtxResp.ContextIds[0])
|
||||||
model, err := serv.GetRegisteredModelById((*BaseResourceId)(modelId))
|
model, err := serv.GetRegisteredModelById(*idAsString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -105,18 +108,23 @@ func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi
|
||||||
return model, nil
|
return model, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (serv *modelRegistryService) GetRegisteredModelById(id *BaseResourceId) (*openapi.RegisteredModel, error) {
|
func (serv *modelRegistryService) GetRegisteredModelById(id string) (*openapi.RegisteredModel, error) {
|
||||||
log.Printf("Getting registered model %d", *id)
|
log.Printf("Getting registered model %s", id)
|
||||||
|
|
||||||
|
idAsInt, err := mapper.IdToInt64(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{
|
getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{
|
||||||
ContextIds: []int64{int64(*id)},
|
ContextIds: []int64{int64(*idAsInt)},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(getByIdResp.Contexts) != 1 {
|
if len(getByIdResp.Contexts) != 1 {
|
||||||
return nil, fmt.Errorf("multiple registered models found for id %d", *id)
|
return nil, fmt.Errorf("multiple registered models found for id %s", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
regModel, err := serv.mapper.MapToRegisteredModel(getByIdResp.Contexts[0])
|
regModel, err := serv.mapper.MapToRegisteredModel(getByIdResp.Contexts[0])
|
||||||
|
|
@ -191,10 +199,20 @@ func (serv *modelRegistryService) GetRegisteredModels(listOptions ListOptions) (
|
||||||
|
|
||||||
// MODEL VERSIONS
|
// MODEL VERSIONS
|
||||||
|
|
||||||
func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *BaseResourceId) (*openapi.ModelVersion, error) {
|
func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *string) (*openapi.ModelVersion, error) {
|
||||||
registeredModel, err := serv.GetRegisteredModelById(parentResourceId)
|
if modelVersion.Id == nil {
|
||||||
|
log.Printf("Creating model version")
|
||||||
|
} else {
|
||||||
|
log.Printf("Updating model version %s", *modelVersion.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if parentResourceId == nil {
|
||||||
|
return nil, fmt.Errorf("missing registered model id, cannot create model version without registered model")
|
||||||
|
}
|
||||||
|
|
||||||
|
registeredModel, err := serv.GetRegisteredModelById(*parentResourceId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("not a valid registered model id: %d", *parentResourceId)
|
return nil, fmt.Errorf("not a valid registered model id: %s", *parentResourceId)
|
||||||
}
|
}
|
||||||
registeredModelIdCtxID, err := mapper.IdToInt64(*registeredModel.Id)
|
registeredModelIdCtxID, err := mapper.IdToInt64(*registeredModel.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -216,17 +234,20 @@ func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.Model
|
||||||
}
|
}
|
||||||
|
|
||||||
modelId := &modelCtxResp.ContextIds[0]
|
modelId := &modelCtxResp.ContextIds[0]
|
||||||
_, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{
|
if modelVersion.Id == nil {
|
||||||
ParentContexts: []*proto.ParentContext{{
|
_, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{
|
||||||
ChildId: modelId,
|
ParentContexts: []*proto.ParentContext{{
|
||||||
ParentId: registeredModelIdCtxID}},
|
ChildId: modelId,
|
||||||
TransactionOptions: &proto.TransactionOptions{},
|
ParentId: registeredModelIdCtxID}},
|
||||||
})
|
TransactionOptions: &proto.TransactionOptions{},
|
||||||
if err != nil {
|
})
|
||||||
return nil, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := serv.GetModelVersionById((*BaseResourceId)(modelId))
|
idAsString := mapper.IdToString(*modelId)
|
||||||
|
model, err := serv.GetModelVersionById(*idAsString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -234,16 +255,21 @@ func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.Model
|
||||||
return model, nil
|
return model, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (serv *modelRegistryService) GetModelVersionById(id *BaseResourceId) (*openapi.ModelVersion, error) {
|
func (serv *modelRegistryService) GetModelVersionById(id string) (*openapi.ModelVersion, error) {
|
||||||
|
idAsInt, err := mapper.IdToInt64(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{
|
getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{
|
||||||
ContextIds: []int64{int64(*id)},
|
ContextIds: []int64{int64(*idAsInt)},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(getByIdResp.Contexts) != 1 {
|
if len(getByIdResp.Contexts) != 1 {
|
||||||
return nil, fmt.Errorf("multiple model versions found for id %d", *id)
|
return nil, fmt.Errorf("multiple model versions found for id %s", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
modelVer, err := serv.mapper.MapToModelVersion(getByIdResp.Contexts[0])
|
modelVer, err := serv.mapper.MapToModelVersion(getByIdResp.Contexts[0])
|
||||||
|
|
@ -254,10 +280,14 @@ func (serv *modelRegistryService) GetModelVersionById(id *BaseResourceId) (*open
|
||||||
return modelVer, nil
|
return modelVer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelVersion, error) {
|
func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, parentResourceId *string, externalId *string) (*openapi.ModelVersion, error) {
|
||||||
filterQuery := ""
|
filterQuery := ""
|
||||||
if versionName != nil && parentResourceId != nil {
|
if versionName != nil && parentResourceId != nil {
|
||||||
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned((*int64)(parentResourceId), *versionName))
|
idAsInt, err := mapper.IdToInt64(*parentResourceId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned(idAsInt, *versionName))
|
||||||
} else if externalId != nil {
|
} else if externalId != nil {
|
||||||
filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId)
|
filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId)
|
||||||
}
|
}
|
||||||
|
|
@ -273,7 +303,7 @@ func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, p
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(getByParamsResp.Contexts) != 1 {
|
if len(getByParamsResp.Contexts) != 1 {
|
||||||
return nil, fmt.Errorf("multiple registered models found for versionName=%v, parentResourceId=%v, externalId=%v", zeroIfNil(versionName), zeroIfNil(parentResourceId), zeroIfNil(externalId))
|
return nil, fmt.Errorf("multiple model versions found for versionName=%v, parentResourceId=%v, externalId=%v", zeroIfNil(versionName), zeroIfNil(parentResourceId), zeroIfNil(externalId))
|
||||||
}
|
}
|
||||||
|
|
||||||
modelVer, err := serv.mapper.MapToModelVersion(getByParamsResp.Contexts[0])
|
modelVer, err := serv.mapper.MapToModelVersion(getByParamsResp.Contexts[0])
|
||||||
|
|
@ -283,14 +313,14 @@ func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, p
|
||||||
return modelVer, nil
|
return modelVer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelVersionList, error) {
|
func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, parentResourceId *string) (*openapi.ModelVersionList, error) {
|
||||||
listOperationOptions, err := BuildListOperationOptions(listOptions)
|
listOperationOptions, err := BuildListOperationOptions(listOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if parentResourceId != nil {
|
if parentResourceId != nil {
|
||||||
queryParentCtxId := fmt.Sprintf("parent_contexts_a.type = %d", *parentResourceId)
|
queryParentCtxId := fmt.Sprintf("parent_contexts_a.id = %s", *parentResourceId)
|
||||||
listOperationOptions.FilterQuery = &queryParentCtxId
|
listOperationOptions.FilterQuery = &queryParentCtxId
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -322,8 +352,18 @@ func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, pare
|
||||||
|
|
||||||
// MODEL ARTIFACTS
|
// MODEL ARTIFACTS
|
||||||
|
|
||||||
func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *BaseResourceId) (*openapi.ModelArtifact, error) {
|
func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *string) (*openapi.ModelArtifact, error) {
|
||||||
artifact := serv.mapper.MapFromModelArtifact(*modelArtifact, (*int64)(parentResourceId))
|
if modelArtifact.Id == nil {
|
||||||
|
log.Printf("Creating model artifact")
|
||||||
|
} else {
|
||||||
|
log.Printf("Updating model artifact %s", *modelArtifact.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
idAsInt, err := mapper.IdToInt64(*parentResourceId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
artifact := serv.mapper.MapFromModelArtifact(*modelArtifact, idAsInt)
|
||||||
|
|
||||||
artifactsResp, err := serv.mlmdClient.PutArtifacts(context.Background(), &proto.PutArtifactsRequest{
|
artifactsResp, err := serv.mlmdClient.PutArtifacts(context.Background(), &proto.PutArtifactsRequest{
|
||||||
Artifacts: []*proto.Artifact{artifact},
|
Artifacts: []*proto.Artifact{artifact},
|
||||||
|
|
@ -331,16 +371,17 @@ func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.Mod
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
idString := strconv.FormatInt(artifactsResp.ArtifactIds[0], 10)
|
|
||||||
modelArtifact.Id = &idString
|
|
||||||
|
|
||||||
// add explicit association between artifacts and model version
|
// add explicit association between artifacts and model version
|
||||||
if parentResourceId != nil {
|
if parentResourceId != nil && modelArtifact.Id == nil {
|
||||||
modelVersionIdCtx := int64(*parentResourceId)
|
modelVersionIdCtx, err := mapper.IdToInt64(*parentResourceId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
attributions := []*proto.Attribution{}
|
attributions := []*proto.Attribution{}
|
||||||
for _, a := range artifactsResp.ArtifactIds {
|
for _, a := range artifactsResp.ArtifactIds {
|
||||||
attributions = append(attributions, &proto.Attribution{
|
attributions = append(attributions, &proto.Attribution{
|
||||||
ContextId: &modelVersionIdCtx,
|
ContextId: modelVersionIdCtx,
|
||||||
ArtifactId: &a,
|
ArtifactId: &a,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -353,12 +394,22 @@ func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.Mod
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return modelArtifact, nil
|
idAsString := mapper.IdToString(artifactsResp.ArtifactIds[0])
|
||||||
|
mapped, err := serv.GetModelArtifactById(*idAsString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return mapped, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (serv *modelRegistryService) GetModelArtifactById(id *BaseResourceId) (*openapi.ModelArtifact, error) {
|
func (serv *modelRegistryService) GetModelArtifactById(id string) (*openapi.ModelArtifact, error) {
|
||||||
|
idAsInt, err := mapper.IdToInt64(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
artifactsResp, err := serv.mlmdClient.GetArtifactsByID(context.Background(), &proto.GetArtifactsByIDRequest{
|
artifactsResp, err := serv.mlmdClient.GetArtifactsByID(context.Background(), &proto.GetArtifactsByIDRequest{
|
||||||
ArtifactIds: []int64{int64(*id)},
|
ArtifactIds: []int64{int64(*idAsInt)},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -372,14 +423,18 @@ func (serv *modelRegistryService) GetModelArtifactById(id *BaseResourceId) (*ope
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelArtifact, error) {
|
func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string, parentResourceId *string, externalId *string) (*openapi.ModelArtifact, error) {
|
||||||
var artifact0 *proto.Artifact
|
var artifact0 *proto.Artifact
|
||||||
|
|
||||||
filterQuery := ""
|
filterQuery := ""
|
||||||
if externalId != nil {
|
if externalId != nil {
|
||||||
filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId)
|
filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId)
|
||||||
} else if artifactName != nil && parentResourceId != nil {
|
} else if artifactName != nil && parentResourceId != nil {
|
||||||
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned((*int64)(parentResourceId), *artifactName))
|
idAsInt, err := mapper.IdToInt64(*parentResourceId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned(idAsInt, *artifactName))
|
||||||
} else {
|
} else {
|
||||||
return nil, fmt.Errorf("invalid parameters call, supply either (artifactName and parentResourceId), or externalId")
|
return nil, fmt.Errorf("invalid parameters call, supply either (artifactName and parentResourceId), or externalId")
|
||||||
}
|
}
|
||||||
|
|
@ -406,7 +461,7 @@ func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string,
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelArtifactList, error) {
|
func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, parentResourceId *string) (*openapi.ModelArtifactList, error) {
|
||||||
listOperationOptions, err := BuildListOperationOptions(listOptions)
|
listOperationOptions, err := BuildListOperationOptions(listOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -415,9 +470,12 @@ func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, par
|
||||||
var artifacts []*proto.Artifact
|
var artifacts []*proto.Artifact
|
||||||
var nextPageToken *string
|
var nextPageToken *string
|
||||||
if parentResourceId != nil {
|
if parentResourceId != nil {
|
||||||
ctxId := int64(*parentResourceId)
|
ctxId, err := mapper.IdToInt64(*parentResourceId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{
|
artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{
|
||||||
ContextId: &ctxId,
|
ContextId: ctxId,
|
||||||
Options: listOperationOptions,
|
Options: listOperationOptions,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -98,6 +98,15 @@ func (m *Mapper) MapToProperties(data map[string]openapi.MetadataValue) (map[str
|
||||||
return props, nil
|
return props, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Mapper) MapToArtifactState(oapiState *openapi.ArtifactState) *proto.Artifact_State {
|
||||||
|
if oapiState == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
state := (proto.Artifact_State)(proto.Artifact_State_value[string(*oapiState)])
|
||||||
|
return &state
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Mapper) MapFromRegisteredModel(registeredModel *openapi.RegisteredModel) (*proto.Context, error) {
|
func (m *Mapper) MapFromRegisteredModel(registeredModel *openapi.RegisteredModel) (*proto.Context, error) {
|
||||||
|
|
||||||
var idInt *int64
|
var idInt *int64
|
||||||
|
|
@ -129,9 +138,20 @@ func (m *Mapper) MapFromModelVersion(modelVersion *openapi.ModelVersion, registe
|
||||||
if modelVersion.CustomProperties != nil {
|
if modelVersion.CustomProperties != nil {
|
||||||
customProps, _ = m.MapToProperties(*modelVersion.CustomProperties)
|
customProps, _ = m.MapToProperties(*modelVersion.CustomProperties)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var idAsInt *int64
|
||||||
|
if modelVersion.Id != nil {
|
||||||
|
var err error
|
||||||
|
idAsInt, err = IdToInt64(*modelVersion.Id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
ctx := &proto.Context{
|
ctx := &proto.Context{
|
||||||
Name: &fullName,
|
Id: idAsInt,
|
||||||
TypeId: &m.ModelVersionTypeId,
|
Name: &fullName,
|
||||||
|
TypeId: &m.ModelVersionTypeId,
|
||||||
|
ExternalId: modelVersion.ExternalID,
|
||||||
Properties: map[string]*proto.Value{
|
Properties: map[string]*proto.Value{
|
||||||
"model_name": {
|
"model_name": {
|
||||||
Value: &proto.Value_StringValue{
|
Value: &proto.Value_StringValue{
|
||||||
|
|
@ -170,10 +190,25 @@ func (m *Mapper) MapFromModelArtifact(modelArtifact openapi.ModelArtifact, model
|
||||||
}
|
}
|
||||||
// build fullName for mlmd storage
|
// build fullName for mlmd storage
|
||||||
fullName := PrefixWhenOwned(modelVersionId, artifactName)
|
fullName := PrefixWhenOwned(modelVersionId, artifactName)
|
||||||
|
|
||||||
|
customProps := make(map[string]*proto.Value)
|
||||||
|
if modelArtifact.CustomProperties != nil {
|
||||||
|
customProps, _ = m.MapToProperties(*modelArtifact.CustomProperties)
|
||||||
|
}
|
||||||
|
|
||||||
|
var idAsInt *int64
|
||||||
|
if modelArtifact.Id != nil {
|
||||||
|
idAsInt, _ = IdToInt64(*modelArtifact.Id)
|
||||||
|
}
|
||||||
|
|
||||||
return &proto.Artifact{
|
return &proto.Artifact{
|
||||||
TypeId: &m.ModelArtifactTypeId,
|
Id: idAsInt,
|
||||||
Name: &fullName,
|
TypeId: &m.ModelArtifactTypeId,
|
||||||
Uri: modelArtifact.Uri,
|
Name: &fullName,
|
||||||
|
Uri: modelArtifact.Uri,
|
||||||
|
ExternalId: modelArtifact.ExternalID,
|
||||||
|
State: m.MapToArtifactState(modelArtifact.State),
|
||||||
|
CustomProperties: customProps,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -236,12 +271,21 @@ func (m *Mapper) MapFromProperties(props map[string]*proto.Value) (map[string]op
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Mapper) MapFromArtifactState(mlmdState *proto.Artifact_State) *openapi.ArtifactState {
|
||||||
|
if mlmdState == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
state := mlmdState.String()
|
||||||
|
return (*openapi.ArtifactState)(&state)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Mapper) MapToRegisteredModel(ctx *proto.Context) (*openapi.RegisteredModel, error) {
|
func (m *Mapper) MapToRegisteredModel(ctx *proto.Context) (*openapi.RegisteredModel, error) {
|
||||||
if ctx.GetTypeId() != m.RegisteredModelTypeId {
|
if ctx.GetTypeId() != m.RegisteredModelTypeId {
|
||||||
return nil, fmt.Errorf("invalid TypeId, exptected %d but received %d", m.RegisteredModelTypeId, ctx.GetTypeId())
|
return nil, fmt.Errorf("invalid TypeId, exptected %d but received %d", m.RegisteredModelTypeId, ctx.GetTypeId())
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := m.MapFromProperties(ctx.CustomProperties)
|
customProps, err := m.MapFromProperties(ctx.CustomProperties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -249,9 +293,10 @@ func (m *Mapper) MapToRegisteredModel(ctx *proto.Context) (*openapi.RegisteredMo
|
||||||
idString := strconv.FormatInt(*ctx.Id, 10)
|
idString := strconv.FormatInt(*ctx.Id, 10)
|
||||||
|
|
||||||
model := &openapi.RegisteredModel{
|
model := &openapi.RegisteredModel{
|
||||||
Id: &idString,
|
Id: &idString,
|
||||||
Name: ctx.Name,
|
Name: ctx.Name,
|
||||||
ExternalID: ctx.ExternalId,
|
ExternalID: ctx.ExternalId,
|
||||||
|
CustomProperties: &customProps,
|
||||||
}
|
}
|
||||||
|
|
||||||
return model, nil
|
return model, nil
|
||||||
|
|
@ -276,8 +321,9 @@ func (m *Mapper) MapToModelVersion(ctx *proto.Context) (*openapi.ModelVersion, e
|
||||||
name := NameFromOwned(*ctx.Name)
|
name := NameFromOwned(*ctx.Name)
|
||||||
modelVersion := &openapi.ModelVersion{
|
modelVersion := &openapi.ModelVersion{
|
||||||
// ModelName: &modelName,
|
// ModelName: &modelName,
|
||||||
Id: &idString,
|
Id: &idString,
|
||||||
Name: &name,
|
Name: &name,
|
||||||
|
ExternalID: ctx.ExternalId,
|
||||||
// Author: &author,
|
// Author: &author,
|
||||||
CustomProperties: &metadata,
|
CustomProperties: &metadata,
|
||||||
}
|
}
|
||||||
|
|
@ -290,7 +336,7 @@ func (m *Mapper) MapToModelArtifact(artifact *proto.Artifact) (*openapi.ModelArt
|
||||||
return nil, fmt.Errorf("invalid TypeId, exptected %d but received %d", m.ModelArtifactTypeId, artifact.GetTypeId())
|
return nil, fmt.Errorf("invalid TypeId, exptected %d but received %d", m.ModelArtifactTypeId, artifact.GetTypeId())
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := m.MapFromProperties(artifact.CustomProperties)
|
customProps, err := m.MapFromProperties(artifact.CustomProperties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -302,8 +348,12 @@ func (m *Mapper) MapToModelArtifact(artifact *proto.Artifact) (*openapi.ModelArt
|
||||||
|
|
||||||
name := NameFromOwned(*artifact.Name)
|
name := NameFromOwned(*artifact.Name)
|
||||||
modelArtifact := &openapi.ModelArtifact{
|
modelArtifact := &openapi.ModelArtifact{
|
||||||
Uri: artifact.Uri,
|
Id: IdToString(*artifact.Id),
|
||||||
Name: &name,
|
Uri: artifact.Uri,
|
||||||
|
Name: &name,
|
||||||
|
ExternalID: artifact.ExternalId,
|
||||||
|
State: m.MapFromArtifactState(artifact.State),
|
||||||
|
CustomProperties: &customProps,
|
||||||
}
|
}
|
||||||
|
|
||||||
return modelArtifact, nil
|
return modelArtifact, nil
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,106 @@
|
||||||
|
package testutils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
|
||||||
|
"github.com/testcontainers/testcontainers-go"
|
||||||
|
"github.com/testcontainers/testcontainers-go/wait"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
useProvider = testcontainers.ProviderDefault // or explicit to testcontainers.ProviderPodman if needed
|
||||||
|
mlmdImage = "gcr.io/tfx-oss-public/ml_metadata_store_server:1.14.0"
|
||||||
|
sqliteFile = "metadata.sqlite.db"
|
||||||
|
testConfigFolder = "test/config/ml-metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
func clearMetadataSqliteDB(wd string) error {
|
||||||
|
if err := os.Remove(fmt.Sprintf("%s/%s", wd, sqliteFile)); err != nil {
|
||||||
|
return fmt.Errorf("expected to clear sqlite file but didn't find: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupMLMDTestContainer creates a MLMD gRPC test container
|
||||||
|
// Returns
|
||||||
|
// - gRPC client connection to the test container
|
||||||
|
// - ml-metadata client used to double check the database
|
||||||
|
// - teardown function
|
||||||
|
func SetupMLMDTestContainer(t *testing.T) (*grpc.ClientConn, proto.MetadataStoreServiceClient, func(t *testing.T)) {
|
||||||
|
ctx := context.Background()
|
||||||
|
wd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error getting working directory: %v", err)
|
||||||
|
}
|
||||||
|
wd = fmt.Sprintf("%s/../../%s", wd, testConfigFolder)
|
||||||
|
t.Logf("using working directory: %s", wd)
|
||||||
|
|
||||||
|
req := testcontainers.ContainerRequest{
|
||||||
|
Image: mlmdImage,
|
||||||
|
ExposedPorts: []string{"8080/tcp"},
|
||||||
|
Env: map[string]string{
|
||||||
|
"METADATA_STORE_SERVER_CONFIG_FILE": "/tmp/shared/conn_config.pb",
|
||||||
|
},
|
||||||
|
Mounts: testcontainers.ContainerMounts{
|
||||||
|
testcontainers.ContainerMount{
|
||||||
|
Source: testcontainers.GenericBindMountSource{
|
||||||
|
HostPath: wd,
|
||||||
|
},
|
||||||
|
Target: "/tmp/shared",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WaitingFor: wait.ForLog("Server listening on"),
|
||||||
|
}
|
||||||
|
|
||||||
|
mlmdgrpc, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
||||||
|
ProviderType: useProvider,
|
||||||
|
ContainerRequest: req,
|
||||||
|
Started: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error setting up mlmd grpc container: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mappedHost, err := mlmdgrpc.Host(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
mappedPort, err := mlmdgrpc.MappedPort(ctx, "8080")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
mlmdAddr := fmt.Sprintf("%s:%s", mappedHost, mappedPort.Port())
|
||||||
|
t.Log("MLMD test container setup at: ", mlmdAddr)
|
||||||
|
|
||||||
|
// setup grpc connection
|
||||||
|
conn, err := grpc.DialContext(
|
||||||
|
context.Background(),
|
||||||
|
mlmdAddr,
|
||||||
|
grpc.WithReturnConnectionError(),
|
||||||
|
grpc.WithBlock(),
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error dialing connection to mlmd server %s: %v", mlmdAddr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mlmdClient := proto.NewMetadataStoreServiceClient(conn)
|
||||||
|
|
||||||
|
return conn, mlmdClient, func(t *testing.T) {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if err := mlmdgrpc.Terminate(ctx); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if err := clearMetadataSqliteDB(wd); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/cucumber/godog"
|
"github.com/cucumber/godog"
|
||||||
|
|
@ -80,38 +79,19 @@ func iStoreARegisteredModelWithNameAndAChildModelVersionWithNameAndAChildArtifac
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
registeredModelId, err := idToInt64(*registeredModel.Id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var modelVersion *openapi.ModelVersion
|
var modelVersion *openapi.ModelVersion
|
||||||
if modelVersion, err = service.UpsertModelVersion(&openapi.ModelVersion{Name: &modelVersionName}, (*core.BaseResourceId)(registeredModelId)); err != nil {
|
if modelVersion, err = service.UpsertModelVersion(&openapi.ModelVersion{Name: &modelVersionName}, registeredModel.Id); err != nil {
|
||||||
return err
|
|
||||||
}
|
|
||||||
modelVersionId, err := idToInt64(*modelVersion.Id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = service.UpsertModelArtifact(&openapi.ModelArtifact{Uri: &artifactURI}, (*core.BaseResourceId)(modelVersionId)); err != nil {
|
if _, err = service.UpsertModelArtifact(&openapi.ModelArtifact{Uri: &artifactURI}, modelVersion.Id); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func idToInt64(idString string) (*int64, error) {
|
|
||||||
idInt, err := strconv.Atoi(idString)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
idInt64 := int64(idInt)
|
|
||||||
|
|
||||||
return &idInt64, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func thereShouldBeAMlmdContextOfTypeNamed(ctx context.Context, arg1, arg2 string) error {
|
func thereShouldBeAMlmdContextOfTypeNamed(ctx context.Context, arg1, arg2 string) error {
|
||||||
conn := ctx.Value(connCtxKey{}).(*grpc.ClientConn)
|
conn := ctx.Value(connCtxKey{}).(*grpc.ClientConn)
|
||||||
client := proto.NewMetadataStoreServiceClient(conn)
|
client := proto.NewMetadataStoreServiceClient(conn)
|
||||||
|
|
@ -183,7 +163,7 @@ func InitializeScenario(ctx *godog.ScenarioContext) {
|
||||||
return ctx, err
|
return ctx, err
|
||||||
}
|
}
|
||||||
wd := ctx.Value(wdCtxKey{}).(string)
|
wd := ctx.Value(wdCtxKey{}).(string)
|
||||||
clearMetadataSqliteDB(wd)
|
_ = clearMetadataSqliteDB(wd)
|
||||||
return ctx, nil
|
return ctx, nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue