Improve core layer testing (#85)

* Improve core layer testing

* Treat ids as string on service layer

* Moved testutils inside internal package

* Adapt test to name prefix implementation
This commit is contained in:
Andrea Lamparelli 2023-10-30 08:52:43 +01:00 committed by GitHub
parent 270521ddcc
commit a309537e8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1395 additions and 207 deletions

View File

@ -167,6 +167,10 @@ test: gen
test-nocache: gen test-nocache: gen
go test ./internal/... -count=1 go test ./internal/... -count=1
.PHONY: test-cover
test-cover: gen
go test ./internal/... -cover -count=1
.PHONY: run/migrate .PHONY: run/migrate
run/migrate: gen run/migrate: gen
go run main.go migrate --logtostderr=true -m config/metadata-library go run main.go migrate --logtostderr=true -m config/metadata-library

View File

@ -19,7 +19,7 @@ type ModelRegistryApi interface {
// approach used by MLMD gRPC api. If Id is provided update the entity otherwise create a new one. // approach used by MLMD gRPC api. If Id is provided update the entity otherwise create a new one.
UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error) UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error)
GetRegisteredModelById(id *BaseResourceId) (*openapi.RegisteredModel, error) GetRegisteredModelById(id string) (*openapi.RegisteredModel, error)
GetRegisteredModelByParams(name *string, externalId *string) (*openapi.RegisteredModel, error) GetRegisteredModelByParams(name *string, externalId *string) (*openapi.RegisteredModel, error)
GetRegisteredModels(listOptions ListOptions) (*openapi.RegisteredModelList, error) GetRegisteredModels(listOptions ListOptions) (*openapi.RegisteredModelList, error)
@ -27,19 +27,19 @@ type ModelRegistryApi interface {
// Create a new Model Version // Create a new Model Version
// or update a Model Version associated to a specific RegisteredModel identified by parentResourceId parameter // or update a Model Version associated to a specific RegisteredModel identified by parentResourceId parameter
UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *BaseResourceId) (*openapi.ModelVersion, error) UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *string) (*openapi.ModelVersion, error)
GetModelVersionById(id *BaseResourceId) (*openapi.ModelVersion, error) GetModelVersionById(id string) (*openapi.ModelVersion, error)
GetModelVersionByParams(versionName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelVersion, error) GetModelVersionByParams(versionName *string, parentResourceId *string, externalId *string) (*openapi.ModelVersion, error)
GetModelVersions(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelVersionList, error) GetModelVersions(listOptions ListOptions, parentResourceId *string) (*openapi.ModelVersionList, error)
// MODEL ARTIFACT // MODEL ARTIFACT
// Create a new Artifact // Create a new Artifact
// or update an Artifact associated to a specific ModelVersion identified by parentResourceId parameter // or update an Artifact associated to a specific ModelVersion identified by parentResourceId parameter
UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *BaseResourceId) (*openapi.ModelArtifact, error) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *string) (*openapi.ModelArtifact, error)
GetModelArtifactById(id *BaseResourceId) (*openapi.ModelArtifact, error) GetModelArtifactById(id string) (*openapi.ModelArtifact, error)
GetModelArtifactByParams(artifactName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelArtifact, error) GetModelArtifactByParams(artifactName *string, parentResourceId *string, externalId *string) (*openapi.ModelArtifact, error)
GetModelArtifacts(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelArtifactList, error) GetModelArtifacts(listOptions ListOptions, parentResourceId *string) (*openapi.ModelArtifactList, error)
} }

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"strconv"
"github.com/opendatahub-io/model-registry/internal/core/mapper" "github.com/opendatahub-io/model-registry/internal/core/mapper"
"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto" "github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
@ -80,7 +79,11 @@ func NewModelRegistryService(cc grpc.ClientConnInterface) (ModelRegistryApi, err
// REGISTERED MODELS // REGISTERED MODELS
func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error) { func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error) {
log.Printf("Creating or updating registered model for %s", *registeredModel.Name) if registeredModel.Id == nil {
log.Printf("Creating registered model for %s", *registeredModel.Name)
} else {
log.Printf("Updating registered model %s for %s", *registeredModel.Id, *registeredModel.Name)
}
modelCtx, err := serv.mapper.MapFromRegisteredModel(registeredModel) modelCtx, err := serv.mapper.MapFromRegisteredModel(registeredModel)
if err != nil { if err != nil {
@ -96,8 +99,8 @@ func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi
return nil, err return nil, err
} }
modelId := &modelCtxResp.ContextIds[0] idAsString := mapper.IdToString(modelCtxResp.ContextIds[0])
model, err := serv.GetRegisteredModelById((*BaseResourceId)(modelId)) model, err := serv.GetRegisteredModelById(*idAsString)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -105,18 +108,23 @@ func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi
return model, nil return model, nil
} }
func (serv *modelRegistryService) GetRegisteredModelById(id *BaseResourceId) (*openapi.RegisteredModel, error) { func (serv *modelRegistryService) GetRegisteredModelById(id string) (*openapi.RegisteredModel, error) {
log.Printf("Getting registered model %d", *id) log.Printf("Getting registered model %s", id)
idAsInt, err := mapper.IdToInt64(id)
if err != nil {
return nil, err
}
getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{
ContextIds: []int64{int64(*id)}, ContextIds: []int64{int64(*idAsInt)},
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(getByIdResp.Contexts) != 1 { if len(getByIdResp.Contexts) != 1 {
return nil, fmt.Errorf("multiple registered models found for id %d", *id) return nil, fmt.Errorf("multiple registered models found for id %s", id)
} }
regModel, err := serv.mapper.MapToRegisteredModel(getByIdResp.Contexts[0]) regModel, err := serv.mapper.MapToRegisteredModel(getByIdResp.Contexts[0])
@ -191,10 +199,20 @@ func (serv *modelRegistryService) GetRegisteredModels(listOptions ListOptions) (
// MODEL VERSIONS // MODEL VERSIONS
func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *BaseResourceId) (*openapi.ModelVersion, error) { func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *string) (*openapi.ModelVersion, error) {
registeredModel, err := serv.GetRegisteredModelById(parentResourceId) if modelVersion.Id == nil {
log.Printf("Creating model version")
} else {
log.Printf("Updating model version %s", *modelVersion.Id)
}
if parentResourceId == nil {
return nil, fmt.Errorf("missing registered model id, cannot create model version without registered model")
}
registeredModel, err := serv.GetRegisteredModelById(*parentResourceId)
if err != nil { if err != nil {
return nil, fmt.Errorf("not a valid registered model id: %d", *parentResourceId) return nil, fmt.Errorf("not a valid registered model id: %s", *parentResourceId)
} }
registeredModelIdCtxID, err := mapper.IdToInt64(*registeredModel.Id) registeredModelIdCtxID, err := mapper.IdToInt64(*registeredModel.Id)
if err != nil { if err != nil {
@ -216,17 +234,20 @@ func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.Model
} }
modelId := &modelCtxResp.ContextIds[0] modelId := &modelCtxResp.ContextIds[0]
_, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{ if modelVersion.Id == nil {
ParentContexts: []*proto.ParentContext{{ _, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{
ChildId: modelId, ParentContexts: []*proto.ParentContext{{
ParentId: registeredModelIdCtxID}}, ChildId: modelId,
TransactionOptions: &proto.TransactionOptions{}, ParentId: registeredModelIdCtxID}},
}) TransactionOptions: &proto.TransactionOptions{},
if err != nil { })
return nil, err if err != nil {
return nil, err
}
} }
model, err := serv.GetModelVersionById((*BaseResourceId)(modelId)) idAsString := mapper.IdToString(*modelId)
model, err := serv.GetModelVersionById(*idAsString)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -234,16 +255,21 @@ func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.Model
return model, nil return model, nil
} }
func (serv *modelRegistryService) GetModelVersionById(id *BaseResourceId) (*openapi.ModelVersion, error) { func (serv *modelRegistryService) GetModelVersionById(id string) (*openapi.ModelVersion, error) {
idAsInt, err := mapper.IdToInt64(id)
if err != nil {
return nil, err
}
getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{ getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{
ContextIds: []int64{int64(*id)}, ContextIds: []int64{int64(*idAsInt)},
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(getByIdResp.Contexts) != 1 { if len(getByIdResp.Contexts) != 1 {
return nil, fmt.Errorf("multiple model versions found for id %d", *id) return nil, fmt.Errorf("multiple model versions found for id %s", id)
} }
modelVer, err := serv.mapper.MapToModelVersion(getByIdResp.Contexts[0]) modelVer, err := serv.mapper.MapToModelVersion(getByIdResp.Contexts[0])
@ -254,10 +280,14 @@ func (serv *modelRegistryService) GetModelVersionById(id *BaseResourceId) (*open
return modelVer, nil return modelVer, nil
} }
func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelVersion, error) { func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, parentResourceId *string, externalId *string) (*openapi.ModelVersion, error) {
filterQuery := "" filterQuery := ""
if versionName != nil && parentResourceId != nil { if versionName != nil && parentResourceId != nil {
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned((*int64)(parentResourceId), *versionName)) idAsInt, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned(idAsInt, *versionName))
} else if externalId != nil { } else if externalId != nil {
filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId)
} }
@ -273,7 +303,7 @@ func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, p
} }
if len(getByParamsResp.Contexts) != 1 { if len(getByParamsResp.Contexts) != 1 {
return nil, fmt.Errorf("multiple registered models found for versionName=%v, parentResourceId=%v, externalId=%v", zeroIfNil(versionName), zeroIfNil(parentResourceId), zeroIfNil(externalId)) return nil, fmt.Errorf("multiple model versions found for versionName=%v, parentResourceId=%v, externalId=%v", zeroIfNil(versionName), zeroIfNil(parentResourceId), zeroIfNil(externalId))
} }
modelVer, err := serv.mapper.MapToModelVersion(getByParamsResp.Contexts[0]) modelVer, err := serv.mapper.MapToModelVersion(getByParamsResp.Contexts[0])
@ -283,14 +313,14 @@ func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, p
return modelVer, nil return modelVer, nil
} }
func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelVersionList, error) { func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, parentResourceId *string) (*openapi.ModelVersionList, error) {
listOperationOptions, err := BuildListOperationOptions(listOptions) listOperationOptions, err := BuildListOperationOptions(listOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if parentResourceId != nil { if parentResourceId != nil {
queryParentCtxId := fmt.Sprintf("parent_contexts_a.type = %d", *parentResourceId) queryParentCtxId := fmt.Sprintf("parent_contexts_a.id = %s", *parentResourceId)
listOperationOptions.FilterQuery = &queryParentCtxId listOperationOptions.FilterQuery = &queryParentCtxId
} }
@ -322,8 +352,18 @@ func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, pare
// MODEL ARTIFACTS // MODEL ARTIFACTS
func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *BaseResourceId) (*openapi.ModelArtifact, error) { func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *string) (*openapi.ModelArtifact, error) {
artifact := serv.mapper.MapFromModelArtifact(*modelArtifact, (*int64)(parentResourceId)) if modelArtifact.Id == nil {
log.Printf("Creating model artifact")
} else {
log.Printf("Updating model artifact %s", *modelArtifact.Id)
}
idAsInt, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
artifact := serv.mapper.MapFromModelArtifact(*modelArtifact, idAsInt)
artifactsResp, err := serv.mlmdClient.PutArtifacts(context.Background(), &proto.PutArtifactsRequest{ artifactsResp, err := serv.mlmdClient.PutArtifacts(context.Background(), &proto.PutArtifactsRequest{
Artifacts: []*proto.Artifact{artifact}, Artifacts: []*proto.Artifact{artifact},
@ -331,16 +371,17 @@ func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.Mod
if err != nil { if err != nil {
return nil, err return nil, err
} }
idString := strconv.FormatInt(artifactsResp.ArtifactIds[0], 10)
modelArtifact.Id = &idString
// add explicit association between artifacts and model version // add explicit association between artifacts and model version
if parentResourceId != nil { if parentResourceId != nil && modelArtifact.Id == nil {
modelVersionIdCtx := int64(*parentResourceId) modelVersionIdCtx, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
attributions := []*proto.Attribution{} attributions := []*proto.Attribution{}
for _, a := range artifactsResp.ArtifactIds { for _, a := range artifactsResp.ArtifactIds {
attributions = append(attributions, &proto.Attribution{ attributions = append(attributions, &proto.Attribution{
ContextId: &modelVersionIdCtx, ContextId: modelVersionIdCtx,
ArtifactId: &a, ArtifactId: &a,
}) })
} }
@ -353,12 +394,22 @@ func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.Mod
} }
} }
return modelArtifact, nil idAsString := mapper.IdToString(artifactsResp.ArtifactIds[0])
mapped, err := serv.GetModelArtifactById(*idAsString)
if err != nil {
return nil, err
}
return mapped, nil
} }
func (serv *modelRegistryService) GetModelArtifactById(id *BaseResourceId) (*openapi.ModelArtifact, error) { func (serv *modelRegistryService) GetModelArtifactById(id string) (*openapi.ModelArtifact, error) {
idAsInt, err := mapper.IdToInt64(id)
if err != nil {
return nil, err
}
artifactsResp, err := serv.mlmdClient.GetArtifactsByID(context.Background(), &proto.GetArtifactsByIDRequest{ artifactsResp, err := serv.mlmdClient.GetArtifactsByID(context.Background(), &proto.GetArtifactsByIDRequest{
ArtifactIds: []int64{int64(*id)}, ArtifactIds: []int64{int64(*idAsInt)},
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -372,14 +423,18 @@ func (serv *modelRegistryService) GetModelArtifactById(id *BaseResourceId) (*ope
return result, nil return result, nil
} }
func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelArtifact, error) { func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string, parentResourceId *string, externalId *string) (*openapi.ModelArtifact, error) {
var artifact0 *proto.Artifact var artifact0 *proto.Artifact
filterQuery := "" filterQuery := ""
if externalId != nil { if externalId != nil {
filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId) filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId)
} else if artifactName != nil && parentResourceId != nil { } else if artifactName != nil && parentResourceId != nil {
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned((*int64)(parentResourceId), *artifactName)) idAsInt, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned(idAsInt, *artifactName))
} else { } else {
return nil, fmt.Errorf("invalid parameters call, supply either (artifactName and parentResourceId), or externalId") return nil, fmt.Errorf("invalid parameters call, supply either (artifactName and parentResourceId), or externalId")
} }
@ -406,7 +461,7 @@ func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string,
return result, nil return result, nil
} }
func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelArtifactList, error) { func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, parentResourceId *string) (*openapi.ModelArtifactList, error) {
listOperationOptions, err := BuildListOperationOptions(listOptions) listOperationOptions, err := BuildListOperationOptions(listOptions)
if err != nil { if err != nil {
return nil, err return nil, err
@ -415,9 +470,12 @@ func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, par
var artifacts []*proto.Artifact var artifacts []*proto.Artifact
var nextPageToken *string var nextPageToken *string
if parentResourceId != nil { if parentResourceId != nil {
ctxId := int64(*parentResourceId) ctxId, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{ artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{
ContextId: &ctxId, ContextId: ctxId,
Options: listOperationOptions, Options: listOperationOptions,
}) })
if err != nil { if err != nil {

File diff suppressed because it is too large Load Diff

View File

@ -98,6 +98,15 @@ func (m *Mapper) MapToProperties(data map[string]openapi.MetadataValue) (map[str
return props, nil return props, nil
} }
func (m *Mapper) MapToArtifactState(oapiState *openapi.ArtifactState) *proto.Artifact_State {
if oapiState == nil {
return nil
}
state := (proto.Artifact_State)(proto.Artifact_State_value[string(*oapiState)])
return &state
}
func (m *Mapper) MapFromRegisteredModel(registeredModel *openapi.RegisteredModel) (*proto.Context, error) { func (m *Mapper) MapFromRegisteredModel(registeredModel *openapi.RegisteredModel) (*proto.Context, error) {
var idInt *int64 var idInt *int64
@ -129,9 +138,20 @@ func (m *Mapper) MapFromModelVersion(modelVersion *openapi.ModelVersion, registe
if modelVersion.CustomProperties != nil { if modelVersion.CustomProperties != nil {
customProps, _ = m.MapToProperties(*modelVersion.CustomProperties) customProps, _ = m.MapToProperties(*modelVersion.CustomProperties)
} }
var idAsInt *int64
if modelVersion.Id != nil {
var err error
idAsInt, err = IdToInt64(*modelVersion.Id)
if err != nil {
return nil, err
}
}
ctx := &proto.Context{ ctx := &proto.Context{
Name: &fullName, Id: idAsInt,
TypeId: &m.ModelVersionTypeId, Name: &fullName,
TypeId: &m.ModelVersionTypeId,
ExternalId: modelVersion.ExternalID,
Properties: map[string]*proto.Value{ Properties: map[string]*proto.Value{
"model_name": { "model_name": {
Value: &proto.Value_StringValue{ Value: &proto.Value_StringValue{
@ -170,10 +190,25 @@ func (m *Mapper) MapFromModelArtifact(modelArtifact openapi.ModelArtifact, model
} }
// build fullName for mlmd storage // build fullName for mlmd storage
fullName := PrefixWhenOwned(modelVersionId, artifactName) fullName := PrefixWhenOwned(modelVersionId, artifactName)
customProps := make(map[string]*proto.Value)
if modelArtifact.CustomProperties != nil {
customProps, _ = m.MapToProperties(*modelArtifact.CustomProperties)
}
var idAsInt *int64
if modelArtifact.Id != nil {
idAsInt, _ = IdToInt64(*modelArtifact.Id)
}
return &proto.Artifact{ return &proto.Artifact{
TypeId: &m.ModelArtifactTypeId, Id: idAsInt,
Name: &fullName, TypeId: &m.ModelArtifactTypeId,
Uri: modelArtifact.Uri, Name: &fullName,
Uri: modelArtifact.Uri,
ExternalId: modelArtifact.ExternalID,
State: m.MapToArtifactState(modelArtifact.State),
CustomProperties: customProps,
} }
} }
@ -236,12 +271,21 @@ func (m *Mapper) MapFromProperties(props map[string]*proto.Value) (map[string]op
return data, nil return data, nil
} }
func (m *Mapper) MapFromArtifactState(mlmdState *proto.Artifact_State) *openapi.ArtifactState {
if mlmdState == nil {
return nil
}
state := mlmdState.String()
return (*openapi.ArtifactState)(&state)
}
func (m *Mapper) MapToRegisteredModel(ctx *proto.Context) (*openapi.RegisteredModel, error) { func (m *Mapper) MapToRegisteredModel(ctx *proto.Context) (*openapi.RegisteredModel, error) {
if ctx.GetTypeId() != m.RegisteredModelTypeId { if ctx.GetTypeId() != m.RegisteredModelTypeId {
return nil, fmt.Errorf("invalid TypeId, exptected %d but received %d", m.RegisteredModelTypeId, ctx.GetTypeId()) return nil, fmt.Errorf("invalid TypeId, exptected %d but received %d", m.RegisteredModelTypeId, ctx.GetTypeId())
} }
_, err := m.MapFromProperties(ctx.CustomProperties) customProps, err := m.MapFromProperties(ctx.CustomProperties)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -249,9 +293,10 @@ func (m *Mapper) MapToRegisteredModel(ctx *proto.Context) (*openapi.RegisteredMo
idString := strconv.FormatInt(*ctx.Id, 10) idString := strconv.FormatInt(*ctx.Id, 10)
model := &openapi.RegisteredModel{ model := &openapi.RegisteredModel{
Id: &idString, Id: &idString,
Name: ctx.Name, Name: ctx.Name,
ExternalID: ctx.ExternalId, ExternalID: ctx.ExternalId,
CustomProperties: &customProps,
} }
return model, nil return model, nil
@ -276,8 +321,9 @@ func (m *Mapper) MapToModelVersion(ctx *proto.Context) (*openapi.ModelVersion, e
name := NameFromOwned(*ctx.Name) name := NameFromOwned(*ctx.Name)
modelVersion := &openapi.ModelVersion{ modelVersion := &openapi.ModelVersion{
// ModelName: &modelName, // ModelName: &modelName,
Id: &idString, Id: &idString,
Name: &name, Name: &name,
ExternalID: ctx.ExternalId,
// Author: &author, // Author: &author,
CustomProperties: &metadata, CustomProperties: &metadata,
} }
@ -290,7 +336,7 @@ func (m *Mapper) MapToModelArtifact(artifact *proto.Artifact) (*openapi.ModelArt
return nil, fmt.Errorf("invalid TypeId, exptected %d but received %d", m.ModelArtifactTypeId, artifact.GetTypeId()) return nil, fmt.Errorf("invalid TypeId, exptected %d but received %d", m.ModelArtifactTypeId, artifact.GetTypeId())
} }
_, err := m.MapFromProperties(artifact.CustomProperties) customProps, err := m.MapFromProperties(artifact.CustomProperties)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -302,8 +348,12 @@ func (m *Mapper) MapToModelArtifact(artifact *proto.Artifact) (*openapi.ModelArt
name := NameFromOwned(*artifact.Name) name := NameFromOwned(*artifact.Name)
modelArtifact := &openapi.ModelArtifact{ modelArtifact := &openapi.ModelArtifact{
Uri: artifact.Uri, Id: IdToString(*artifact.Id),
Name: &name, Uri: artifact.Uri,
Name: &name,
ExternalID: artifact.ExternalId,
State: m.MapFromArtifactState(artifact.State),
CustomProperties: &customProps,
} }
return modelArtifact, nil return modelArtifact, nil

View File

@ -0,0 +1,106 @@
package testutils
import (
"context"
"fmt"
"os"
"testing"
"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
const (
useProvider = testcontainers.ProviderDefault // or explicit to testcontainers.ProviderPodman if needed
mlmdImage = "gcr.io/tfx-oss-public/ml_metadata_store_server:1.14.0"
sqliteFile = "metadata.sqlite.db"
testConfigFolder = "test/config/ml-metadata"
)
func clearMetadataSqliteDB(wd string) error {
if err := os.Remove(fmt.Sprintf("%s/%s", wd, sqliteFile)); err != nil {
return fmt.Errorf("expected to clear sqlite file but didn't find: %v", err)
}
return nil
}
// SetupMLMDTestContainer creates a MLMD gRPC test container
// Returns
// - gRPC client connection to the test container
// - ml-metadata client used to double check the database
// - teardown function
func SetupMLMDTestContainer(t *testing.T) (*grpc.ClientConn, proto.MetadataStoreServiceClient, func(t *testing.T)) {
ctx := context.Background()
wd, err := os.Getwd()
if err != nil {
t.Errorf("error getting working directory: %v", err)
}
wd = fmt.Sprintf("%s/../../%s", wd, testConfigFolder)
t.Logf("using working directory: %s", wd)
req := testcontainers.ContainerRequest{
Image: mlmdImage,
ExposedPorts: []string{"8080/tcp"},
Env: map[string]string{
"METADATA_STORE_SERVER_CONFIG_FILE": "/tmp/shared/conn_config.pb",
},
Mounts: testcontainers.ContainerMounts{
testcontainers.ContainerMount{
Source: testcontainers.GenericBindMountSource{
HostPath: wd,
},
Target: "/tmp/shared",
},
},
WaitingFor: wait.ForLog("Server listening on"),
}
mlmdgrpc, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ProviderType: useProvider,
ContainerRequest: req,
Started: true,
})
if err != nil {
t.Errorf("error setting up mlmd grpc container: %v", err)
}
mappedHost, err := mlmdgrpc.Host(ctx)
if err != nil {
t.Error(err)
}
mappedPort, err := mlmdgrpc.MappedPort(ctx, "8080")
if err != nil {
t.Error(err)
}
mlmdAddr := fmt.Sprintf("%s:%s", mappedHost, mappedPort.Port())
t.Log("MLMD test container setup at: ", mlmdAddr)
// setup grpc connection
conn, err := grpc.DialContext(
context.Background(),
mlmdAddr,
grpc.WithReturnConnectionError(),
grpc.WithBlock(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
t.Errorf("error dialing connection to mlmd server %s: %v", mlmdAddr, err)
}
mlmdClient := proto.NewMetadataStoreServiceClient(conn)
return conn, mlmdClient, func(t *testing.T) {
if err := conn.Close(); err != nil {
t.Error(err)
}
if err := mlmdgrpc.Terminate(ctx); err != nil {
t.Error(err)
}
if err := clearMetadataSqliteDB(wd); err != nil {
t.Error(err)
}
}
}

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"log" "log"
"os" "os"
"strconv"
"testing" "testing"
"github.com/cucumber/godog" "github.com/cucumber/godog"
@ -80,38 +79,19 @@ func iStoreARegisteredModelWithNameAndAChildModelVersionWithNameAndAChildArtifac
if err != nil { if err != nil {
return err return err
} }
registeredModelId, err := idToInt64(*registeredModel.Id)
if err != nil {
return err
}
var modelVersion *openapi.ModelVersion var modelVersion *openapi.ModelVersion
if modelVersion, err = service.UpsertModelVersion(&openapi.ModelVersion{Name: &modelVersionName}, (*core.BaseResourceId)(registeredModelId)); err != nil { if modelVersion, err = service.UpsertModelVersion(&openapi.ModelVersion{Name: &modelVersionName}, registeredModel.Id); err != nil {
return err
}
modelVersionId, err := idToInt64(*modelVersion.Id)
if err != nil {
return err return err
} }
if _, err = service.UpsertModelArtifact(&openapi.ModelArtifact{Uri: &artifactURI}, (*core.BaseResourceId)(modelVersionId)); err != nil { if _, err = service.UpsertModelArtifact(&openapi.ModelArtifact{Uri: &artifactURI}, modelVersion.Id); err != nil {
return err return err
} }
return nil return nil
} }
func idToInt64(idString string) (*int64, error) {
idInt, err := strconv.Atoi(idString)
if err != nil {
return nil, err
}
idInt64 := int64(idInt)
return &idInt64, nil
}
func thereShouldBeAMlmdContextOfTypeNamed(ctx context.Context, arg1, arg2 string) error { func thereShouldBeAMlmdContextOfTypeNamed(ctx context.Context, arg1, arg2 string) error {
conn := ctx.Value(connCtxKey{}).(*grpc.ClientConn) conn := ctx.Value(connCtxKey{}).(*grpc.ClientConn)
client := proto.NewMetadataStoreServiceClient(conn) client := proto.NewMetadataStoreServiceClient(conn)
@ -183,7 +163,7 @@ func InitializeScenario(ctx *godog.ScenarioContext) {
return ctx, err return ctx, err
} }
wd := ctx.Value(wdCtxKey{}).(string) wd := ctx.Value(wdCtxKey{}).(string)
clearMetadataSqliteDB(wd) _ = clearMetadataSqliteDB(wd)
return ctx, nil return ctx, nil
}) })
} }