Improve core layer testing (#85)
* Improve core layer testing * Treat ids as string on service layer * Moved testutils inside internal package * Adapt test to name prefix implementation
This commit is contained in:
parent
270521ddcc
commit
a309537e8b
4
Makefile
4
Makefile
|
|
@ -167,6 +167,10 @@ test: gen
|
|||
test-nocache: gen
|
||||
go test ./internal/... -count=1
|
||||
|
||||
.PHONY: test-cover
|
||||
test-cover: gen
|
||||
go test ./internal/... -cover -count=1
|
||||
|
||||
.PHONY: run/migrate
|
||||
run/migrate: gen
|
||||
go run main.go migrate --logtostderr=true -m config/metadata-library
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ type ModelRegistryApi interface {
|
|||
// approach used by MLMD gRPC api. If Id is provided update the entity otherwise create a new one.
|
||||
UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error)
|
||||
|
||||
GetRegisteredModelById(id *BaseResourceId) (*openapi.RegisteredModel, error)
|
||||
GetRegisteredModelById(id string) (*openapi.RegisteredModel, error)
|
||||
GetRegisteredModelByParams(name *string, externalId *string) (*openapi.RegisteredModel, error)
|
||||
GetRegisteredModels(listOptions ListOptions) (*openapi.RegisteredModelList, error)
|
||||
|
||||
|
|
@ -27,19 +27,19 @@ type ModelRegistryApi interface {
|
|||
|
||||
// Create a new Model Version
|
||||
// or update a Model Version associated to a specific RegisteredModel identified by parentResourceId parameter
|
||||
UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *BaseResourceId) (*openapi.ModelVersion, error)
|
||||
UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *string) (*openapi.ModelVersion, error)
|
||||
|
||||
GetModelVersionById(id *BaseResourceId) (*openapi.ModelVersion, error)
|
||||
GetModelVersionByParams(versionName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelVersion, error)
|
||||
GetModelVersions(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelVersionList, error)
|
||||
GetModelVersionById(id string) (*openapi.ModelVersion, error)
|
||||
GetModelVersionByParams(versionName *string, parentResourceId *string, externalId *string) (*openapi.ModelVersion, error)
|
||||
GetModelVersions(listOptions ListOptions, parentResourceId *string) (*openapi.ModelVersionList, error)
|
||||
|
||||
// MODEL ARTIFACT
|
||||
|
||||
// Create a new Artifact
|
||||
// or update an Artifact associated to a specific ModelVersion identified by parentResourceId parameter
|
||||
UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *BaseResourceId) (*openapi.ModelArtifact, error)
|
||||
UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *string) (*openapi.ModelArtifact, error)
|
||||
|
||||
GetModelArtifactById(id *BaseResourceId) (*openapi.ModelArtifact, error)
|
||||
GetModelArtifactByParams(artifactName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelArtifact, error)
|
||||
GetModelArtifacts(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelArtifactList, error)
|
||||
GetModelArtifactById(id string) (*openapi.ModelArtifact, error)
|
||||
GetModelArtifactByParams(artifactName *string, parentResourceId *string, externalId *string) (*openapi.ModelArtifact, error)
|
||||
GetModelArtifacts(listOptions ListOptions, parentResourceId *string) (*openapi.ModelArtifactList, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
|
||||
"github.com/opendatahub-io/model-registry/internal/core/mapper"
|
||||
"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
|
||||
|
|
@ -80,7 +79,11 @@ func NewModelRegistryService(cc grpc.ClientConnInterface) (ModelRegistryApi, err
|
|||
// REGISTERED MODELS
|
||||
|
||||
func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error) {
|
||||
log.Printf("Creating or updating registered model for %s", *registeredModel.Name)
|
||||
if registeredModel.Id == nil {
|
||||
log.Printf("Creating registered model for %s", *registeredModel.Name)
|
||||
} else {
|
||||
log.Printf("Updating registered model %s for %s", *registeredModel.Id, *registeredModel.Name)
|
||||
}
|
||||
|
||||
modelCtx, err := serv.mapper.MapFromRegisteredModel(registeredModel)
|
||||
if err != nil {
|
||||
|
|
@ -96,8 +99,8 @@ func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi
|
|||
return nil, err
|
||||
}
|
||||
|
||||
modelId := &modelCtxResp.ContextIds[0]
|
||||
model, err := serv.GetRegisteredModelById((*BaseResourceId)(modelId))
|
||||
idAsString := mapper.IdToString(modelCtxResp.ContextIds[0])
|
||||
model, err := serv.GetRegisteredModelById(*idAsString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -105,18 +108,23 @@ func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi
|
|||
return model, nil
|
||||
}
|
||||
|
||||
func (serv *modelRegistryService) GetRegisteredModelById(id *BaseResourceId) (*openapi.RegisteredModel, error) {
|
||||
log.Printf("Getting registered model %d", *id)
|
||||
func (serv *modelRegistryService) GetRegisteredModelById(id string) (*openapi.RegisteredModel, error) {
|
||||
log.Printf("Getting registered model %s", id)
|
||||
|
||||
idAsInt, err := mapper.IdToInt64(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{
|
||||
ContextIds: []int64{int64(*id)},
|
||||
ContextIds: []int64{int64(*idAsInt)},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(getByIdResp.Contexts) != 1 {
|
||||
return nil, fmt.Errorf("multiple registered models found for id %d", *id)
|
||||
return nil, fmt.Errorf("multiple registered models found for id %s", id)
|
||||
}
|
||||
|
||||
regModel, err := serv.mapper.MapToRegisteredModel(getByIdResp.Contexts[0])
|
||||
|
|
@ -191,10 +199,20 @@ func (serv *modelRegistryService) GetRegisteredModels(listOptions ListOptions) (
|
|||
|
||||
// MODEL VERSIONS
|
||||
|
||||
func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *BaseResourceId) (*openapi.ModelVersion, error) {
|
||||
registeredModel, err := serv.GetRegisteredModelById(parentResourceId)
|
||||
func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *string) (*openapi.ModelVersion, error) {
|
||||
if modelVersion.Id == nil {
|
||||
log.Printf("Creating model version")
|
||||
} else {
|
||||
log.Printf("Updating model version %s", *modelVersion.Id)
|
||||
}
|
||||
|
||||
if parentResourceId == nil {
|
||||
return nil, fmt.Errorf("missing registered model id, cannot create model version without registered model")
|
||||
}
|
||||
|
||||
registeredModel, err := serv.GetRegisteredModelById(*parentResourceId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("not a valid registered model id: %d", *parentResourceId)
|
||||
return nil, fmt.Errorf("not a valid registered model id: %s", *parentResourceId)
|
||||
}
|
||||
registeredModelIdCtxID, err := mapper.IdToInt64(*registeredModel.Id)
|
||||
if err != nil {
|
||||
|
|
@ -216,17 +234,20 @@ func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.Model
|
|||
}
|
||||
|
||||
modelId := &modelCtxResp.ContextIds[0]
|
||||
_, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{
|
||||
ParentContexts: []*proto.ParentContext{{
|
||||
ChildId: modelId,
|
||||
ParentId: registeredModelIdCtxID}},
|
||||
TransactionOptions: &proto.TransactionOptions{},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if modelVersion.Id == nil {
|
||||
_, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{
|
||||
ParentContexts: []*proto.ParentContext{{
|
||||
ChildId: modelId,
|
||||
ParentId: registeredModelIdCtxID}},
|
||||
TransactionOptions: &proto.TransactionOptions{},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
model, err := serv.GetModelVersionById((*BaseResourceId)(modelId))
|
||||
idAsString := mapper.IdToString(*modelId)
|
||||
model, err := serv.GetModelVersionById(*idAsString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -234,16 +255,21 @@ func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.Model
|
|||
return model, nil
|
||||
}
|
||||
|
||||
func (serv *modelRegistryService) GetModelVersionById(id *BaseResourceId) (*openapi.ModelVersion, error) {
|
||||
func (serv *modelRegistryService) GetModelVersionById(id string) (*openapi.ModelVersion, error) {
|
||||
idAsInt, err := mapper.IdToInt64(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{
|
||||
ContextIds: []int64{int64(*id)},
|
||||
ContextIds: []int64{int64(*idAsInt)},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(getByIdResp.Contexts) != 1 {
|
||||
return nil, fmt.Errorf("multiple model versions found for id %d", *id)
|
||||
return nil, fmt.Errorf("multiple model versions found for id %s", id)
|
||||
}
|
||||
|
||||
modelVer, err := serv.mapper.MapToModelVersion(getByIdResp.Contexts[0])
|
||||
|
|
@ -254,10 +280,14 @@ func (serv *modelRegistryService) GetModelVersionById(id *BaseResourceId) (*open
|
|||
return modelVer, nil
|
||||
}
|
||||
|
||||
func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelVersion, error) {
|
||||
func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, parentResourceId *string, externalId *string) (*openapi.ModelVersion, error) {
|
||||
filterQuery := ""
|
||||
if versionName != nil && parentResourceId != nil {
|
||||
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned((*int64)(parentResourceId), *versionName))
|
||||
idAsInt, err := mapper.IdToInt64(*parentResourceId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned(idAsInt, *versionName))
|
||||
} else if externalId != nil {
|
||||
filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId)
|
||||
}
|
||||
|
|
@ -273,7 +303,7 @@ func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, p
|
|||
}
|
||||
|
||||
if len(getByParamsResp.Contexts) != 1 {
|
||||
return nil, fmt.Errorf("multiple registered models found for versionName=%v, parentResourceId=%v, externalId=%v", zeroIfNil(versionName), zeroIfNil(parentResourceId), zeroIfNil(externalId))
|
||||
return nil, fmt.Errorf("multiple model versions found for versionName=%v, parentResourceId=%v, externalId=%v", zeroIfNil(versionName), zeroIfNil(parentResourceId), zeroIfNil(externalId))
|
||||
}
|
||||
|
||||
modelVer, err := serv.mapper.MapToModelVersion(getByParamsResp.Contexts[0])
|
||||
|
|
@ -283,14 +313,14 @@ func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, p
|
|||
return modelVer, nil
|
||||
}
|
||||
|
||||
func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelVersionList, error) {
|
||||
func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, parentResourceId *string) (*openapi.ModelVersionList, error) {
|
||||
listOperationOptions, err := BuildListOperationOptions(listOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if parentResourceId != nil {
|
||||
queryParentCtxId := fmt.Sprintf("parent_contexts_a.type = %d", *parentResourceId)
|
||||
queryParentCtxId := fmt.Sprintf("parent_contexts_a.id = %s", *parentResourceId)
|
||||
listOperationOptions.FilterQuery = &queryParentCtxId
|
||||
}
|
||||
|
||||
|
|
@ -322,8 +352,18 @@ func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, pare
|
|||
|
||||
// MODEL ARTIFACTS
|
||||
|
||||
func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *BaseResourceId) (*openapi.ModelArtifact, error) {
|
||||
artifact := serv.mapper.MapFromModelArtifact(*modelArtifact, (*int64)(parentResourceId))
|
||||
func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *string) (*openapi.ModelArtifact, error) {
|
||||
if modelArtifact.Id == nil {
|
||||
log.Printf("Creating model artifact")
|
||||
} else {
|
||||
log.Printf("Updating model artifact %s", *modelArtifact.Id)
|
||||
}
|
||||
|
||||
idAsInt, err := mapper.IdToInt64(*parentResourceId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
artifact := serv.mapper.MapFromModelArtifact(*modelArtifact, idAsInt)
|
||||
|
||||
artifactsResp, err := serv.mlmdClient.PutArtifacts(context.Background(), &proto.PutArtifactsRequest{
|
||||
Artifacts: []*proto.Artifact{artifact},
|
||||
|
|
@ -331,16 +371,17 @@ func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.Mod
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idString := strconv.FormatInt(artifactsResp.ArtifactIds[0], 10)
|
||||
modelArtifact.Id = &idString
|
||||
|
||||
// add explicit association between artifacts and model version
|
||||
if parentResourceId != nil {
|
||||
modelVersionIdCtx := int64(*parentResourceId)
|
||||
if parentResourceId != nil && modelArtifact.Id == nil {
|
||||
modelVersionIdCtx, err := mapper.IdToInt64(*parentResourceId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attributions := []*proto.Attribution{}
|
||||
for _, a := range artifactsResp.ArtifactIds {
|
||||
attributions = append(attributions, &proto.Attribution{
|
||||
ContextId: &modelVersionIdCtx,
|
||||
ContextId: modelVersionIdCtx,
|
||||
ArtifactId: &a,
|
||||
})
|
||||
}
|
||||
|
|
@ -353,12 +394,22 @@ func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.Mod
|
|||
}
|
||||
}
|
||||
|
||||
return modelArtifact, nil
|
||||
idAsString := mapper.IdToString(artifactsResp.ArtifactIds[0])
|
||||
mapped, err := serv.GetModelArtifactById(*idAsString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mapped, nil
|
||||
}
|
||||
|
||||
func (serv *modelRegistryService) GetModelArtifactById(id *BaseResourceId) (*openapi.ModelArtifact, error) {
|
||||
func (serv *modelRegistryService) GetModelArtifactById(id string) (*openapi.ModelArtifact, error) {
|
||||
idAsInt, err := mapper.IdToInt64(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
artifactsResp, err := serv.mlmdClient.GetArtifactsByID(context.Background(), &proto.GetArtifactsByIDRequest{
|
||||
ArtifactIds: []int64{int64(*id)},
|
||||
ArtifactIds: []int64{int64(*idAsInt)},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -372,14 +423,18 @@ func (serv *modelRegistryService) GetModelArtifactById(id *BaseResourceId) (*ope
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelArtifact, error) {
|
||||
func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string, parentResourceId *string, externalId *string) (*openapi.ModelArtifact, error) {
|
||||
var artifact0 *proto.Artifact
|
||||
|
||||
filterQuery := ""
|
||||
if externalId != nil {
|
||||
filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId)
|
||||
} else if artifactName != nil && parentResourceId != nil {
|
||||
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned((*int64)(parentResourceId), *artifactName))
|
||||
idAsInt, err := mapper.IdToInt64(*parentResourceId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned(idAsInt, *artifactName))
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid parameters call, supply either (artifactName and parentResourceId), or externalId")
|
||||
}
|
||||
|
|
@ -406,7 +461,7 @@ func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string,
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelArtifactList, error) {
|
||||
func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, parentResourceId *string) (*openapi.ModelArtifactList, error) {
|
||||
listOperationOptions, err := BuildListOperationOptions(listOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -415,9 +470,12 @@ func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, par
|
|||
var artifacts []*proto.Artifact
|
||||
var nextPageToken *string
|
||||
if parentResourceId != nil {
|
||||
ctxId := int64(*parentResourceId)
|
||||
ctxId, err := mapper.IdToInt64(*parentResourceId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{
|
||||
ContextId: &ctxId,
|
||||
ContextId: ctxId,
|
||||
Options: listOperationOptions,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -98,6 +98,15 @@ func (m *Mapper) MapToProperties(data map[string]openapi.MetadataValue) (map[str
|
|||
return props, nil
|
||||
}
|
||||
|
||||
func (m *Mapper) MapToArtifactState(oapiState *openapi.ArtifactState) *proto.Artifact_State {
|
||||
if oapiState == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
state := (proto.Artifact_State)(proto.Artifact_State_value[string(*oapiState)])
|
||||
return &state
|
||||
}
|
||||
|
||||
func (m *Mapper) MapFromRegisteredModel(registeredModel *openapi.RegisteredModel) (*proto.Context, error) {
|
||||
|
||||
var idInt *int64
|
||||
|
|
@ -129,9 +138,20 @@ func (m *Mapper) MapFromModelVersion(modelVersion *openapi.ModelVersion, registe
|
|||
if modelVersion.CustomProperties != nil {
|
||||
customProps, _ = m.MapToProperties(*modelVersion.CustomProperties)
|
||||
}
|
||||
|
||||
var idAsInt *int64
|
||||
if modelVersion.Id != nil {
|
||||
var err error
|
||||
idAsInt, err = IdToInt64(*modelVersion.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
ctx := &proto.Context{
|
||||
Name: &fullName,
|
||||
TypeId: &m.ModelVersionTypeId,
|
||||
Id: idAsInt,
|
||||
Name: &fullName,
|
||||
TypeId: &m.ModelVersionTypeId,
|
||||
ExternalId: modelVersion.ExternalID,
|
||||
Properties: map[string]*proto.Value{
|
||||
"model_name": {
|
||||
Value: &proto.Value_StringValue{
|
||||
|
|
@ -170,10 +190,25 @@ func (m *Mapper) MapFromModelArtifact(modelArtifact openapi.ModelArtifact, model
|
|||
}
|
||||
// build fullName for mlmd storage
|
||||
fullName := PrefixWhenOwned(modelVersionId, artifactName)
|
||||
|
||||
customProps := make(map[string]*proto.Value)
|
||||
if modelArtifact.CustomProperties != nil {
|
||||
customProps, _ = m.MapToProperties(*modelArtifact.CustomProperties)
|
||||
}
|
||||
|
||||
var idAsInt *int64
|
||||
if modelArtifact.Id != nil {
|
||||
idAsInt, _ = IdToInt64(*modelArtifact.Id)
|
||||
}
|
||||
|
||||
return &proto.Artifact{
|
||||
TypeId: &m.ModelArtifactTypeId,
|
||||
Name: &fullName,
|
||||
Uri: modelArtifact.Uri,
|
||||
Id: idAsInt,
|
||||
TypeId: &m.ModelArtifactTypeId,
|
||||
Name: &fullName,
|
||||
Uri: modelArtifact.Uri,
|
||||
ExternalId: modelArtifact.ExternalID,
|
||||
State: m.MapToArtifactState(modelArtifact.State),
|
||||
CustomProperties: customProps,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -236,12 +271,21 @@ func (m *Mapper) MapFromProperties(props map[string]*proto.Value) (map[string]op
|
|||
return data, nil
|
||||
}
|
||||
|
||||
func (m *Mapper) MapFromArtifactState(mlmdState *proto.Artifact_State) *openapi.ArtifactState {
|
||||
if mlmdState == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
state := mlmdState.String()
|
||||
return (*openapi.ArtifactState)(&state)
|
||||
}
|
||||
|
||||
func (m *Mapper) MapToRegisteredModel(ctx *proto.Context) (*openapi.RegisteredModel, error) {
|
||||
if ctx.GetTypeId() != m.RegisteredModelTypeId {
|
||||
return nil, fmt.Errorf("invalid TypeId, exptected %d but received %d", m.RegisteredModelTypeId, ctx.GetTypeId())
|
||||
}
|
||||
|
||||
_, err := m.MapFromProperties(ctx.CustomProperties)
|
||||
customProps, err := m.MapFromProperties(ctx.CustomProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -249,9 +293,10 @@ func (m *Mapper) MapToRegisteredModel(ctx *proto.Context) (*openapi.RegisteredMo
|
|||
idString := strconv.FormatInt(*ctx.Id, 10)
|
||||
|
||||
model := &openapi.RegisteredModel{
|
||||
Id: &idString,
|
||||
Name: ctx.Name,
|
||||
ExternalID: ctx.ExternalId,
|
||||
Id: &idString,
|
||||
Name: ctx.Name,
|
||||
ExternalID: ctx.ExternalId,
|
||||
CustomProperties: &customProps,
|
||||
}
|
||||
|
||||
return model, nil
|
||||
|
|
@ -276,8 +321,9 @@ func (m *Mapper) MapToModelVersion(ctx *proto.Context) (*openapi.ModelVersion, e
|
|||
name := NameFromOwned(*ctx.Name)
|
||||
modelVersion := &openapi.ModelVersion{
|
||||
// ModelName: &modelName,
|
||||
Id: &idString,
|
||||
Name: &name,
|
||||
Id: &idString,
|
||||
Name: &name,
|
||||
ExternalID: ctx.ExternalId,
|
||||
// Author: &author,
|
||||
CustomProperties: &metadata,
|
||||
}
|
||||
|
|
@ -290,7 +336,7 @@ func (m *Mapper) MapToModelArtifact(artifact *proto.Artifact) (*openapi.ModelArt
|
|||
return nil, fmt.Errorf("invalid TypeId, exptected %d but received %d", m.ModelArtifactTypeId, artifact.GetTypeId())
|
||||
}
|
||||
|
||||
_, err := m.MapFromProperties(artifact.CustomProperties)
|
||||
customProps, err := m.MapFromProperties(artifact.CustomProperties)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -302,8 +348,12 @@ func (m *Mapper) MapToModelArtifact(artifact *proto.Artifact) (*openapi.ModelArt
|
|||
|
||||
name := NameFromOwned(*artifact.Name)
|
||||
modelArtifact := &openapi.ModelArtifact{
|
||||
Uri: artifact.Uri,
|
||||
Name: &name,
|
||||
Id: IdToString(*artifact.Id),
|
||||
Uri: artifact.Uri,
|
||||
Name: &name,
|
||||
ExternalID: artifact.ExternalId,
|
||||
State: m.MapFromArtifactState(artifact.State),
|
||||
CustomProperties: &customProps,
|
||||
}
|
||||
|
||||
return modelArtifact, nil
|
||||
|
|
|
|||
|
|
@ -0,0 +1,106 @@
|
|||
package testutils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
const (
|
||||
useProvider = testcontainers.ProviderDefault // or explicit to testcontainers.ProviderPodman if needed
|
||||
mlmdImage = "gcr.io/tfx-oss-public/ml_metadata_store_server:1.14.0"
|
||||
sqliteFile = "metadata.sqlite.db"
|
||||
testConfigFolder = "test/config/ml-metadata"
|
||||
)
|
||||
|
||||
func clearMetadataSqliteDB(wd string) error {
|
||||
if err := os.Remove(fmt.Sprintf("%s/%s", wd, sqliteFile)); err != nil {
|
||||
return fmt.Errorf("expected to clear sqlite file but didn't find: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetupMLMDTestContainer creates a MLMD gRPC test container
|
||||
// Returns
|
||||
// - gRPC client connection to the test container
|
||||
// - ml-metadata client used to double check the database
|
||||
// - teardown function
|
||||
func SetupMLMDTestContainer(t *testing.T) (*grpc.ClientConn, proto.MetadataStoreServiceClient, func(t *testing.T)) {
|
||||
ctx := context.Background()
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Errorf("error getting working directory: %v", err)
|
||||
}
|
||||
wd = fmt.Sprintf("%s/../../%s", wd, testConfigFolder)
|
||||
t.Logf("using working directory: %s", wd)
|
||||
|
||||
req := testcontainers.ContainerRequest{
|
||||
Image: mlmdImage,
|
||||
ExposedPorts: []string{"8080/tcp"},
|
||||
Env: map[string]string{
|
||||
"METADATA_STORE_SERVER_CONFIG_FILE": "/tmp/shared/conn_config.pb",
|
||||
},
|
||||
Mounts: testcontainers.ContainerMounts{
|
||||
testcontainers.ContainerMount{
|
||||
Source: testcontainers.GenericBindMountSource{
|
||||
HostPath: wd,
|
||||
},
|
||||
Target: "/tmp/shared",
|
||||
},
|
||||
},
|
||||
WaitingFor: wait.ForLog("Server listening on"),
|
||||
}
|
||||
|
||||
mlmdgrpc, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
||||
ProviderType: useProvider,
|
||||
ContainerRequest: req,
|
||||
Started: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("error setting up mlmd grpc container: %v", err)
|
||||
}
|
||||
|
||||
mappedHost, err := mlmdgrpc.Host(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
mappedPort, err := mlmdgrpc.MappedPort(ctx, "8080")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
mlmdAddr := fmt.Sprintf("%s:%s", mappedHost, mappedPort.Port())
|
||||
t.Log("MLMD test container setup at: ", mlmdAddr)
|
||||
|
||||
// setup grpc connection
|
||||
conn, err := grpc.DialContext(
|
||||
context.Background(),
|
||||
mlmdAddr,
|
||||
grpc.WithReturnConnectionError(),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("error dialing connection to mlmd server %s: %v", mlmdAddr, err)
|
||||
}
|
||||
|
||||
mlmdClient := proto.NewMetadataStoreServiceClient(conn)
|
||||
|
||||
return conn, mlmdClient, func(t *testing.T) {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := mlmdgrpc.Terminate(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := clearMetadataSqliteDB(wd); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/cucumber/godog"
|
||||
|
|
@ -80,38 +79,19 @@ func iStoreARegisteredModelWithNameAndAChildModelVersionWithNameAndAChildArtifac
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
registeredModelId, err := idToInt64(*registeredModel.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var modelVersion *openapi.ModelVersion
|
||||
if modelVersion, err = service.UpsertModelVersion(&openapi.ModelVersion{Name: &modelVersionName}, (*core.BaseResourceId)(registeredModelId)); err != nil {
|
||||
return err
|
||||
}
|
||||
modelVersionId, err := idToInt64(*modelVersion.Id)
|
||||
if err != nil {
|
||||
if modelVersion, err = service.UpsertModelVersion(&openapi.ModelVersion{Name: &modelVersionName}, registeredModel.Id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = service.UpsertModelArtifact(&openapi.ModelArtifact{Uri: &artifactURI}, (*core.BaseResourceId)(modelVersionId)); err != nil {
|
||||
if _, err = service.UpsertModelArtifact(&openapi.ModelArtifact{Uri: &artifactURI}, modelVersion.Id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func idToInt64(idString string) (*int64, error) {
|
||||
idInt, err := strconv.Atoi(idString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idInt64 := int64(idInt)
|
||||
|
||||
return &idInt64, nil
|
||||
}
|
||||
|
||||
func thereShouldBeAMlmdContextOfTypeNamed(ctx context.Context, arg1, arg2 string) error {
|
||||
conn := ctx.Value(connCtxKey{}).(*grpc.ClientConn)
|
||||
client := proto.NewMetadataStoreServiceClient(conn)
|
||||
|
|
@ -183,7 +163,7 @@ func InitializeScenario(ctx *godog.ScenarioContext) {
|
|||
return ctx, err
|
||||
}
|
||||
wd := ctx.Value(wdCtxKey{}).(string)
|
||||
clearMetadataSqliteDB(wd)
|
||||
_ = clearMetadataSqliteDB(wd)
|
||||
return ctx, nil
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue