Improve core layer testing (#85)

* Improve core layer testing

* Treat ids as string on service layer

* Moved testutils inside internal package

* Adapt test to name prefix implementation
This commit is contained in:
Andrea Lamparelli 2023-10-30 08:52:43 +01:00 committed by GitHub
parent 270521ddcc
commit a309537e8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1395 additions and 207 deletions

View File

@ -167,6 +167,10 @@ test: gen
test-nocache: gen
go test ./internal/... -count=1
.PHONY: test-cover
test-cover: gen
go test ./internal/... -cover -count=1
.PHONY: run/migrate
run/migrate: gen
go run main.go migrate --logtostderr=true -m config/metadata-library

View File

@ -19,7 +19,7 @@ type ModelRegistryApi interface {
// approach used by MLMD gRPC api. If Id is provided update the entity otherwise create a new one.
UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error)
GetRegisteredModelById(id *BaseResourceId) (*openapi.RegisteredModel, error)
GetRegisteredModelById(id string) (*openapi.RegisteredModel, error)
GetRegisteredModelByParams(name *string, externalId *string) (*openapi.RegisteredModel, error)
GetRegisteredModels(listOptions ListOptions) (*openapi.RegisteredModelList, error)
@ -27,19 +27,19 @@ type ModelRegistryApi interface {
// Create a new Model Version
// or update a Model Version associated to a specific RegisteredModel identified by parentResourceId parameter
UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *BaseResourceId) (*openapi.ModelVersion, error)
UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *string) (*openapi.ModelVersion, error)
GetModelVersionById(id *BaseResourceId) (*openapi.ModelVersion, error)
GetModelVersionByParams(versionName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelVersion, error)
GetModelVersions(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelVersionList, error)
GetModelVersionById(id string) (*openapi.ModelVersion, error)
GetModelVersionByParams(versionName *string, parentResourceId *string, externalId *string) (*openapi.ModelVersion, error)
GetModelVersions(listOptions ListOptions, parentResourceId *string) (*openapi.ModelVersionList, error)
// MODEL ARTIFACT
// Create a new Artifact
// or update an Artifact associated to a specific ModelVersion identified by parentResourceId parameter
UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *BaseResourceId) (*openapi.ModelArtifact, error)
UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *string) (*openapi.ModelArtifact, error)
GetModelArtifactById(id *BaseResourceId) (*openapi.ModelArtifact, error)
GetModelArtifactByParams(artifactName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelArtifact, error)
GetModelArtifacts(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelArtifactList, error)
GetModelArtifactById(id string) (*openapi.ModelArtifact, error)
GetModelArtifactByParams(artifactName *string, parentResourceId *string, externalId *string) (*openapi.ModelArtifact, error)
GetModelArtifacts(listOptions ListOptions, parentResourceId *string) (*openapi.ModelArtifactList, error)
}

View File

@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log"
"strconv"
"github.com/opendatahub-io/model-registry/internal/core/mapper"
"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
@ -80,7 +79,11 @@ func NewModelRegistryService(cc grpc.ClientConnInterface) (ModelRegistryApi, err
// REGISTERED MODELS
func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error) {
log.Printf("Creating or updating registered model for %s", *registeredModel.Name)
if registeredModel.Id == nil {
log.Printf("Creating registered model for %s", *registeredModel.Name)
} else {
log.Printf("Updating registered model %s for %s", *registeredModel.Id, *registeredModel.Name)
}
modelCtx, err := serv.mapper.MapFromRegisteredModel(registeredModel)
if err != nil {
@ -96,8 +99,8 @@ func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi
return nil, err
}
modelId := &modelCtxResp.ContextIds[0]
model, err := serv.GetRegisteredModelById((*BaseResourceId)(modelId))
idAsString := mapper.IdToString(modelCtxResp.ContextIds[0])
model, err := serv.GetRegisteredModelById(*idAsString)
if err != nil {
return nil, err
}
@ -105,18 +108,23 @@ func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi
return model, nil
}
func (serv *modelRegistryService) GetRegisteredModelById(id *BaseResourceId) (*openapi.RegisteredModel, error) {
log.Printf("Getting registered model %d", *id)
func (serv *modelRegistryService) GetRegisteredModelById(id string) (*openapi.RegisteredModel, error) {
log.Printf("Getting registered model %s", id)
idAsInt, err := mapper.IdToInt64(id)
if err != nil {
return nil, err
}
getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{
ContextIds: []int64{int64(*id)},
ContextIds: []int64{int64(*idAsInt)},
})
if err != nil {
return nil, err
}
if len(getByIdResp.Contexts) != 1 {
return nil, fmt.Errorf("multiple registered models found for id %d", *id)
return nil, fmt.Errorf("multiple registered models found for id %s", id)
}
regModel, err := serv.mapper.MapToRegisteredModel(getByIdResp.Contexts[0])
@ -191,10 +199,20 @@ func (serv *modelRegistryService) GetRegisteredModels(listOptions ListOptions) (
// MODEL VERSIONS
func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *BaseResourceId) (*openapi.ModelVersion, error) {
registeredModel, err := serv.GetRegisteredModelById(parentResourceId)
func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *string) (*openapi.ModelVersion, error) {
if modelVersion.Id == nil {
log.Printf("Creating model version")
} else {
log.Printf("Updating model version %s", *modelVersion.Id)
}
if parentResourceId == nil {
return nil, fmt.Errorf("missing registered model id, cannot create model version without registered model")
}
registeredModel, err := serv.GetRegisteredModelById(*parentResourceId)
if err != nil {
return nil, fmt.Errorf("not a valid registered model id: %d", *parentResourceId)
return nil, fmt.Errorf("not a valid registered model id: %s", *parentResourceId)
}
registeredModelIdCtxID, err := mapper.IdToInt64(*registeredModel.Id)
if err != nil {
@ -216,17 +234,20 @@ func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.Model
}
modelId := &modelCtxResp.ContextIds[0]
_, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{
ParentContexts: []*proto.ParentContext{{
ChildId: modelId,
ParentId: registeredModelIdCtxID}},
TransactionOptions: &proto.TransactionOptions{},
})
if err != nil {
return nil, err
if modelVersion.Id == nil {
_, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{
ParentContexts: []*proto.ParentContext{{
ChildId: modelId,
ParentId: registeredModelIdCtxID}},
TransactionOptions: &proto.TransactionOptions{},
})
if err != nil {
return nil, err
}
}
model, err := serv.GetModelVersionById((*BaseResourceId)(modelId))
idAsString := mapper.IdToString(*modelId)
model, err := serv.GetModelVersionById(*idAsString)
if err != nil {
return nil, err
}
@ -234,16 +255,21 @@ func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.Model
return model, nil
}
func (serv *modelRegistryService) GetModelVersionById(id *BaseResourceId) (*openapi.ModelVersion, error) {
func (serv *modelRegistryService) GetModelVersionById(id string) (*openapi.ModelVersion, error) {
idAsInt, err := mapper.IdToInt64(id)
if err != nil {
return nil, err
}
getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{
ContextIds: []int64{int64(*id)},
ContextIds: []int64{int64(*idAsInt)},
})
if err != nil {
return nil, err
}
if len(getByIdResp.Contexts) != 1 {
return nil, fmt.Errorf("multiple model versions found for id %d", *id)
return nil, fmt.Errorf("multiple model versions found for id %s", id)
}
modelVer, err := serv.mapper.MapToModelVersion(getByIdResp.Contexts[0])
@ -254,10 +280,14 @@ func (serv *modelRegistryService) GetModelVersionById(id *BaseResourceId) (*open
return modelVer, nil
}
func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelVersion, error) {
func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, parentResourceId *string, externalId *string) (*openapi.ModelVersion, error) {
filterQuery := ""
if versionName != nil && parentResourceId != nil {
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned((*int64)(parentResourceId), *versionName))
idAsInt, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned(idAsInt, *versionName))
} else if externalId != nil {
filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId)
}
@ -273,7 +303,7 @@ func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, p
}
if len(getByParamsResp.Contexts) != 1 {
return nil, fmt.Errorf("multiple registered models found for versionName=%v, parentResourceId=%v, externalId=%v", zeroIfNil(versionName), zeroIfNil(parentResourceId), zeroIfNil(externalId))
return nil, fmt.Errorf("multiple model versions found for versionName=%v, parentResourceId=%v, externalId=%v", zeroIfNil(versionName), zeroIfNil(parentResourceId), zeroIfNil(externalId))
}
modelVer, err := serv.mapper.MapToModelVersion(getByParamsResp.Contexts[0])
@ -283,14 +313,14 @@ func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, p
return modelVer, nil
}
func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelVersionList, error) {
func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, parentResourceId *string) (*openapi.ModelVersionList, error) {
listOperationOptions, err := BuildListOperationOptions(listOptions)
if err != nil {
return nil, err
}
if parentResourceId != nil {
queryParentCtxId := fmt.Sprintf("parent_contexts_a.type = %d", *parentResourceId)
queryParentCtxId := fmt.Sprintf("parent_contexts_a.id = %s", *parentResourceId)
listOperationOptions.FilterQuery = &queryParentCtxId
}
@ -322,8 +352,18 @@ func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, pare
// MODEL ARTIFACTS
func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *BaseResourceId) (*openapi.ModelArtifact, error) {
artifact := serv.mapper.MapFromModelArtifact(*modelArtifact, (*int64)(parentResourceId))
func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *string) (*openapi.ModelArtifact, error) {
if modelArtifact.Id == nil {
log.Printf("Creating model artifact")
} else {
log.Printf("Updating model artifact %s", *modelArtifact.Id)
}
idAsInt, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
artifact := serv.mapper.MapFromModelArtifact(*modelArtifact, idAsInt)
artifactsResp, err := serv.mlmdClient.PutArtifacts(context.Background(), &proto.PutArtifactsRequest{
Artifacts: []*proto.Artifact{artifact},
@ -331,16 +371,17 @@ func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.Mod
if err != nil {
return nil, err
}
idString := strconv.FormatInt(artifactsResp.ArtifactIds[0], 10)
modelArtifact.Id = &idString
// add explicit association between artifacts and model version
if parentResourceId != nil {
modelVersionIdCtx := int64(*parentResourceId)
if parentResourceId != nil && modelArtifact.Id == nil {
modelVersionIdCtx, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
attributions := []*proto.Attribution{}
for _, a := range artifactsResp.ArtifactIds {
attributions = append(attributions, &proto.Attribution{
ContextId: &modelVersionIdCtx,
ContextId: modelVersionIdCtx,
ArtifactId: &a,
})
}
@ -353,12 +394,22 @@ func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.Mod
}
}
return modelArtifact, nil
idAsString := mapper.IdToString(artifactsResp.ArtifactIds[0])
mapped, err := serv.GetModelArtifactById(*idAsString)
if err != nil {
return nil, err
}
return mapped, nil
}
func (serv *modelRegistryService) GetModelArtifactById(id *BaseResourceId) (*openapi.ModelArtifact, error) {
func (serv *modelRegistryService) GetModelArtifactById(id string) (*openapi.ModelArtifact, error) {
idAsInt, err := mapper.IdToInt64(id)
if err != nil {
return nil, err
}
artifactsResp, err := serv.mlmdClient.GetArtifactsByID(context.Background(), &proto.GetArtifactsByIDRequest{
ArtifactIds: []int64{int64(*id)},
ArtifactIds: []int64{int64(*idAsInt)},
})
if err != nil {
return nil, err
@ -372,14 +423,18 @@ func (serv *modelRegistryService) GetModelArtifactById(id *BaseResourceId) (*ope
return result, nil
}
func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelArtifact, error) {
func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string, parentResourceId *string, externalId *string) (*openapi.ModelArtifact, error) {
var artifact0 *proto.Artifact
filterQuery := ""
if externalId != nil {
filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId)
} else if artifactName != nil && parentResourceId != nil {
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned((*int64)(parentResourceId), *artifactName))
idAsInt, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned(idAsInt, *artifactName))
} else {
return nil, fmt.Errorf("invalid parameters call, supply either (artifactName and parentResourceId), or externalId")
}
@ -406,7 +461,7 @@ func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string,
return result, nil
}
func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelArtifactList, error) {
func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, parentResourceId *string) (*openapi.ModelArtifactList, error) {
listOperationOptions, err := BuildListOperationOptions(listOptions)
if err != nil {
return nil, err
@ -415,9 +470,12 @@ func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, par
var artifacts []*proto.Artifact
var nextPageToken *string
if parentResourceId != nil {
ctxId := int64(*parentResourceId)
ctxId, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{
ContextId: &ctxId,
ContextId: ctxId,
Options: listOperationOptions,
})
if err != nil {

File diff suppressed because it is too large Load Diff

View File

@ -98,6 +98,15 @@ func (m *Mapper) MapToProperties(data map[string]openapi.MetadataValue) (map[str
return props, nil
}
func (m *Mapper) MapToArtifactState(oapiState *openapi.ArtifactState) *proto.Artifact_State {
if oapiState == nil {
return nil
}
state := (proto.Artifact_State)(proto.Artifact_State_value[string(*oapiState)])
return &state
}
func (m *Mapper) MapFromRegisteredModel(registeredModel *openapi.RegisteredModel) (*proto.Context, error) {
var idInt *int64
@ -129,9 +138,20 @@ func (m *Mapper) MapFromModelVersion(modelVersion *openapi.ModelVersion, registe
if modelVersion.CustomProperties != nil {
customProps, _ = m.MapToProperties(*modelVersion.CustomProperties)
}
var idAsInt *int64
if modelVersion.Id != nil {
var err error
idAsInt, err = IdToInt64(*modelVersion.Id)
if err != nil {
return nil, err
}
}
ctx := &proto.Context{
Name: &fullName,
TypeId: &m.ModelVersionTypeId,
Id: idAsInt,
Name: &fullName,
TypeId: &m.ModelVersionTypeId,
ExternalId: modelVersion.ExternalID,
Properties: map[string]*proto.Value{
"model_name": {
Value: &proto.Value_StringValue{
@ -170,10 +190,25 @@ func (m *Mapper) MapFromModelArtifact(modelArtifact openapi.ModelArtifact, model
}
// build fullName for mlmd storage
fullName := PrefixWhenOwned(modelVersionId, artifactName)
customProps := make(map[string]*proto.Value)
if modelArtifact.CustomProperties != nil {
customProps, _ = m.MapToProperties(*modelArtifact.CustomProperties)
}
var idAsInt *int64
if modelArtifact.Id != nil {
idAsInt, _ = IdToInt64(*modelArtifact.Id)
}
return &proto.Artifact{
TypeId: &m.ModelArtifactTypeId,
Name: &fullName,
Uri: modelArtifact.Uri,
Id: idAsInt,
TypeId: &m.ModelArtifactTypeId,
Name: &fullName,
Uri: modelArtifact.Uri,
ExternalId: modelArtifact.ExternalID,
State: m.MapToArtifactState(modelArtifact.State),
CustomProperties: customProps,
}
}
@ -236,12 +271,21 @@ func (m *Mapper) MapFromProperties(props map[string]*proto.Value) (map[string]op
return data, nil
}
func (m *Mapper) MapFromArtifactState(mlmdState *proto.Artifact_State) *openapi.ArtifactState {
if mlmdState == nil {
return nil
}
state := mlmdState.String()
return (*openapi.ArtifactState)(&state)
}
func (m *Mapper) MapToRegisteredModel(ctx *proto.Context) (*openapi.RegisteredModel, error) {
if ctx.GetTypeId() != m.RegisteredModelTypeId {
return nil, fmt.Errorf("invalid TypeId, exptected %d but received %d", m.RegisteredModelTypeId, ctx.GetTypeId())
}
_, err := m.MapFromProperties(ctx.CustomProperties)
customProps, err := m.MapFromProperties(ctx.CustomProperties)
if err != nil {
return nil, err
}
@ -249,9 +293,10 @@ func (m *Mapper) MapToRegisteredModel(ctx *proto.Context) (*openapi.RegisteredMo
idString := strconv.FormatInt(*ctx.Id, 10)
model := &openapi.RegisteredModel{
Id: &idString,
Name: ctx.Name,
ExternalID: ctx.ExternalId,
Id: &idString,
Name: ctx.Name,
ExternalID: ctx.ExternalId,
CustomProperties: &customProps,
}
return model, nil
@ -276,8 +321,9 @@ func (m *Mapper) MapToModelVersion(ctx *proto.Context) (*openapi.ModelVersion, e
name := NameFromOwned(*ctx.Name)
modelVersion := &openapi.ModelVersion{
// ModelName: &modelName,
Id: &idString,
Name: &name,
Id: &idString,
Name: &name,
ExternalID: ctx.ExternalId,
// Author: &author,
CustomProperties: &metadata,
}
@ -290,7 +336,7 @@ func (m *Mapper) MapToModelArtifact(artifact *proto.Artifact) (*openapi.ModelArt
return nil, fmt.Errorf("invalid TypeId, exptected %d but received %d", m.ModelArtifactTypeId, artifact.GetTypeId())
}
_, err := m.MapFromProperties(artifact.CustomProperties)
customProps, err := m.MapFromProperties(artifact.CustomProperties)
if err != nil {
return nil, err
}
@ -302,8 +348,12 @@ func (m *Mapper) MapToModelArtifact(artifact *proto.Artifact) (*openapi.ModelArt
name := NameFromOwned(*artifact.Name)
modelArtifact := &openapi.ModelArtifact{
Uri: artifact.Uri,
Name: &name,
Id: IdToString(*artifact.Id),
Uri: artifact.Uri,
Name: &name,
ExternalID: artifact.ExternalId,
State: m.MapFromArtifactState(artifact.State),
CustomProperties: &customProps,
}
return modelArtifact, nil

View File

@ -0,0 +1,106 @@
package testutils
import (
"context"
"fmt"
"os"
"testing"
"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
const (
useProvider = testcontainers.ProviderDefault // or explicit to testcontainers.ProviderPodman if needed
mlmdImage = "gcr.io/tfx-oss-public/ml_metadata_store_server:1.14.0"
sqliteFile = "metadata.sqlite.db"
testConfigFolder = "test/config/ml-metadata"
)
func clearMetadataSqliteDB(wd string) error {
if err := os.Remove(fmt.Sprintf("%s/%s", wd, sqliteFile)); err != nil {
return fmt.Errorf("expected to clear sqlite file but didn't find: %v", err)
}
return nil
}
// SetupMLMDTestContainer creates a MLMD gRPC test container
// Returns
// - gRPC client connection to the test container
// - ml-metadata client used to double check the database
// - teardown function
func SetupMLMDTestContainer(t *testing.T) (*grpc.ClientConn, proto.MetadataStoreServiceClient, func(t *testing.T)) {
ctx := context.Background()
wd, err := os.Getwd()
if err != nil {
t.Errorf("error getting working directory: %v", err)
}
wd = fmt.Sprintf("%s/../../%s", wd, testConfigFolder)
t.Logf("using working directory: %s", wd)
req := testcontainers.ContainerRequest{
Image: mlmdImage,
ExposedPorts: []string{"8080/tcp"},
Env: map[string]string{
"METADATA_STORE_SERVER_CONFIG_FILE": "/tmp/shared/conn_config.pb",
},
Mounts: testcontainers.ContainerMounts{
testcontainers.ContainerMount{
Source: testcontainers.GenericBindMountSource{
HostPath: wd,
},
Target: "/tmp/shared",
},
},
WaitingFor: wait.ForLog("Server listening on"),
}
mlmdgrpc, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ProviderType: useProvider,
ContainerRequest: req,
Started: true,
})
if err != nil {
t.Errorf("error setting up mlmd grpc container: %v", err)
}
mappedHost, err := mlmdgrpc.Host(ctx)
if err != nil {
t.Error(err)
}
mappedPort, err := mlmdgrpc.MappedPort(ctx, "8080")
if err != nil {
t.Error(err)
}
mlmdAddr := fmt.Sprintf("%s:%s", mappedHost, mappedPort.Port())
t.Log("MLMD test container setup at: ", mlmdAddr)
// setup grpc connection
conn, err := grpc.DialContext(
context.Background(),
mlmdAddr,
grpc.WithReturnConnectionError(),
grpc.WithBlock(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
t.Errorf("error dialing connection to mlmd server %s: %v", mlmdAddr, err)
}
mlmdClient := proto.NewMetadataStoreServiceClient(conn)
return conn, mlmdClient, func(t *testing.T) {
if err := conn.Close(); err != nil {
t.Error(err)
}
if err := mlmdgrpc.Terminate(ctx); err != nil {
t.Error(err)
}
if err := clearMetadataSqliteDB(wd); err != nil {
t.Error(err)
}
}
}

View File

@ -5,7 +5,6 @@ import (
"fmt"
"log"
"os"
"strconv"
"testing"
"github.com/cucumber/godog"
@ -80,38 +79,19 @@ func iStoreARegisteredModelWithNameAndAChildModelVersionWithNameAndAChildArtifac
if err != nil {
return err
}
registeredModelId, err := idToInt64(*registeredModel.Id)
if err != nil {
return err
}
var modelVersion *openapi.ModelVersion
if modelVersion, err = service.UpsertModelVersion(&openapi.ModelVersion{Name: &modelVersionName}, (*core.BaseResourceId)(registeredModelId)); err != nil {
return err
}
modelVersionId, err := idToInt64(*modelVersion.Id)
if err != nil {
if modelVersion, err = service.UpsertModelVersion(&openapi.ModelVersion{Name: &modelVersionName}, registeredModel.Id); err != nil {
return err
}
if _, err = service.UpsertModelArtifact(&openapi.ModelArtifact{Uri: &artifactURI}, (*core.BaseResourceId)(modelVersionId)); err != nil {
if _, err = service.UpsertModelArtifact(&openapi.ModelArtifact{Uri: &artifactURI}, modelVersion.Id); err != nil {
return err
}
return nil
}
func idToInt64(idString string) (*int64, error) {
idInt, err := strconv.Atoi(idString)
if err != nil {
return nil, err
}
idInt64 := int64(idInt)
return &idInt64, nil
}
func thereShouldBeAMlmdContextOfTypeNamed(ctx context.Context, arg1, arg2 string) error {
conn := ctx.Value(connCtxKey{}).(*grpc.ClientConn)
client := proto.NewMetadataStoreServiceClient(conn)
@ -183,7 +163,7 @@ func InitializeScenario(ctx *godog.ScenarioContext) {
return ctx, err
}
wd := ctx.Value(wdCtxKey{}).(string)
clearMetadataSqliteDB(wd)
_ = clearMetadataSqliteDB(wd)
return ctx, nil
})
}