From ad0d5f857b7f79ddbfc33fdc63eb823e848bb064 Mon Sep 17 00:00:00 2001 From: Alessio Pragliola <83355398+Al-Pragliola@users.noreply.github.com> Date: Wed, 1 Oct 2025 13:58:33 +0200 Subject: [PATCH] feat: use db as source in model catalog (#1667) * feat: switch to db usage WIP Signed-off-by: Alessio Pragliola * fix: partially fix model_catalog_service tests Signed-off-by: Alessio Pragliola * fix: tests Signed-off-by: Alessio Pragliola * fix: change mustache template to support array of primitive type plus reflect removal Signed-off-by: Alessio Pragliola * fix: /models endpoint not working Signed-off-by: Alessio Pragliola * fix: artifacts route 500 status code Signed-off-by: Alessio Pragliola * feat: make artifacts run on a unified repository Signed-off-by: Alessio Pragliola * feat: add custom properties to catalog artifacts Signed-off-by: Alessio Pragliola * feat: add tests files to db_catalog and ctalog_artifact Signed-off-by: Alessio Pragliola * fix: pagination issues Signed-off-by: Alessio Pragliola * fix: return 404 is the model does not exist in the Get model Signed-off-by: Alessio Pragliola * fix: 500 error code on bad user input Signed-off-by: Alessio Pragliola * fix: q parameter not working Signed-off-by: Alessio Pragliola * chore: better function/alias naming Signed-off-by: Alessio Pragliola Co-authored-by: Paul Boyd --------- Signed-off-by: Alessio Pragliola Co-authored-by: Paul Boyd --- api/openapi/catalog.yaml | 53 +- api/openapi/src/catalog.yaml | 24 +- api/openapi/src/lib/common.yaml | 29 + api/openapi/src/model-registry.yaml | 29 - catalog/cmd/catalog.go | 18 +- catalog/internal/catalog/catalog.go | 21 +- catalog/internal/catalog/db_catalog.go | 415 +++++++++++ catalog/internal/catalog/db_catalog_test.go | 704 ++++++++++++++++++ catalog/internal/catalog/hf_catalog.go | 6 +- catalog/internal/catalog/hf_catalog_test.go | 6 +- catalog/internal/catalog/testdata/testdb.cnf | 5 + catalog/internal/catalog/yaml_catalog.go | 12 +- catalog/internal/catalog/yaml_catalog_test.go | 14 +- .../internal/db/models/catalog_artifact.go | 41 + .../db/models/catalog_metrics_artifact.go | 6 +- catalog/internal/db/models/catalog_model.go | 2 + .../db/models/catalog_model_artifact.go | 5 +- .../internal/db/service/catalog_artifact.go | 217 ++++++ .../db/service/catalog_artifact_test.go | 378 ++++++++++ .../db/service/catalog_metrics_artifact.go | 1 + catalog/internal/db/service/catalog_model.go | 52 +- .../db/service/catalog_model_artifact.go | 3 +- catalog/internal/db/service/spec.go | 3 +- .../server/openapi/.openapi-generator/FILES | 1 + catalog/internal/server/openapi/api.go | 4 +- .../openapi/api_model_catalog_service.go | 9 +- .../api_model_catalog_service_service.go | 107 ++- .../api_model_catalog_service_service_test.go | 167 ++++- .../internal/server/openapi/type_asserts.go | 10 + catalog/pkg/openapi/.openapi-generator/FILES | 1 + .../pkg/openapi/api_model_catalog_service.go | 70 +- catalog/pkg/openapi/model_base_resource.go | 347 +++++++++ .../openapi/model_catalog_metrics_artifact.go | 222 +++++- catalog/pkg/openapi/model_catalog_model.go | 224 ++++-- .../openapi/model_catalog_model_artifact.go | 222 +++++- go.mod | 2 +- templates/go-server/api.mustache | 5 +- 37 files changed, 3110 insertions(+), 325 deletions(-) create mode 100644 catalog/internal/catalog/db_catalog.go create mode 100644 catalog/internal/catalog/db_catalog_test.go create mode 100644 catalog/internal/catalog/testdata/testdb.cnf create mode 100644 catalog/internal/db/models/catalog_artifact.go create mode 100644 catalog/internal/db/service/catalog_artifact.go create mode 100644 catalog/internal/db/service/catalog_artifact_test.go create mode 100644 catalog/pkg/openapi/model_base_resource.go diff --git a/api/openapi/catalog.yaml b/api/openapi/catalog.yaml index 89772c18..313682fe 100644 --- a/api/openapi/catalog.yaml +++ b/api/openapi/catalog.yaml @@ -20,12 +20,18 @@ paths: parameters: - name: source description: |- - Filter models by source. This parameter is currently required and - may only be specified once. + Filter models by source. This parameter can be specified multiple times + to filter by multiple sources (OR logic). For example: + ?source=huggingface&source=local will return models from either + huggingface OR local sources. schema: - type: string + type: array + items: + type: string + style: form + explode: true in: query - required: true + required: false - name: q description: Free-form keyword search used to filter the response. schema: @@ -135,6 +141,10 @@ paths: type: string in: path required: true + - $ref: "#/components/parameters/pageSize" + - $ref: "#/components/parameters/orderBy" + - $ref: "#/components/parameters/sortOrder" + - $ref: "#/components/parameters/nextPageToken" components: schemas: ArtifactTypeQueryParam: @@ -202,6 +212,35 @@ components: type: object additionalProperties: $ref: "#/components/schemas/MetadataValue" + BaseResource: + allOf: + - type: object + properties: + customProperties: + description: User provided custom properties which are not defined by its type. + type: object + additionalProperties: + $ref: "#/components/schemas/MetadataValue" + description: + description: |- + An optional description about the resource. + type: string + externalId: + description: |- + The external id that come from the clients’ system. This field is optional. + If set, it must be unique among all resources within a database instance. + type: string + name: + description: |- + The client provided name of the artifact. This field is optional. If set, + it must be unique among all the artifacts of the same artifact type within + a database instance and cannot be changed once set. + type: string + id: + format: int64 + description: The unique server generated id of the resource. + type: string + - $ref: "#/components/schemas/BaseResourceDates" BaseResourceDates: description: Common timestamp fields for resources type: object @@ -278,7 +317,7 @@ components: type: object additionalProperties: $ref: "#/components/schemas/MetadataValue" - - $ref: "#/components/schemas/BaseResourceDates" + - $ref: "#/components/schemas/BaseResource" CatalogModel: description: A model in the model catalog. allOf: @@ -293,8 +332,8 @@ components: source_id: type: string description: ID of the source this model belongs to. - - $ref: "#/components/schemas/BaseResourceDates" - $ref: "#/components/schemas/BaseModel" + - $ref: "#/components/schemas/BaseResource" CatalogModelArtifact: description: A Catalog Model Artifact Entity. allOf: @@ -315,7 +354,7 @@ components: type: object additionalProperties: $ref: "#/components/schemas/MetadataValue" - - $ref: "#/components/schemas/BaseResourceDates" + - $ref: "#/components/schemas/BaseResource" CatalogModelList: description: List of CatalogModel entities. allOf: diff --git a/api/openapi/src/catalog.yaml b/api/openapi/src/catalog.yaml index 88d2afb6..85fe8828 100644 --- a/api/openapi/src/catalog.yaml +++ b/api/openapi/src/catalog.yaml @@ -20,12 +20,18 @@ paths: parameters: - name: source description: |- - Filter models by source. This parameter is currently required and - may only be specified once. + Filter models by source. This parameter can be specified multiple times + to filter by multiple sources (OR logic). For example: + ?source=huggingface&source=local will return models from either + huggingface OR local sources. schema: - type: string + type: array + items: + type: string + style: form + explode: true in: query - required: true + required: false - name: q description: Free-form keyword search used to filter the response. schema: @@ -135,6 +141,10 @@ paths: type: string in: path required: true + - $ref: "#/components/parameters/pageSize" + - $ref: "#/components/parameters/orderBy" + - $ref: "#/components/parameters/sortOrder" + - $ref: "#/components/parameters/nextPageToken" components: schemas: CatalogArtifact: @@ -181,7 +191,7 @@ components: type: object additionalProperties: $ref: "#/components/schemas/MetadataValue" - - $ref: "#/components/schemas/BaseResourceDates" + - $ref: "#/components/schemas/BaseResource" CatalogModel: description: A model in the model catalog. allOf: @@ -196,8 +206,8 @@ components: source_id: type: string description: ID of the source this model belongs to. - - $ref: "#/components/schemas/BaseResourceDates" - $ref: "#/components/schemas/BaseModel" + - $ref: "#/components/schemas/BaseResource" CatalogModelArtifact: description: A Catalog Model Artifact Entity. allOf: @@ -218,7 +228,7 @@ components: type: object additionalProperties: $ref: "#/components/schemas/MetadataValue" - - $ref: "#/components/schemas/BaseResourceDates" + - $ref: "#/components/schemas/BaseResource" CatalogModelList: description: List of CatalogModel entities. allOf: diff --git a/api/openapi/src/lib/common.yaml b/api/openapi/src/lib/common.yaml index 0e4e32e8..f0a34b68 100644 --- a/api/openapi/src/lib/common.yaml +++ b/api/openapi/src/lib/common.yaml @@ -79,6 +79,35 @@ components: description: Output only. Last update time of the resource since epoch in millisecond since epoch. type: string readOnly: true + BaseResource: + allOf: + - type: object + properties: + customProperties: + description: User provided custom properties which are not defined by its type. + type: object + additionalProperties: + $ref: "#/components/schemas/MetadataValue" + description: + description: |- + An optional description about the resource. + type: string + externalId: + description: |- + The external id that come from the clients’ system. This field is optional. + If set, it must be unique among all resources within a database instance. + type: string + name: + description: |- + The client provided name of the artifact. This field is optional. If set, + it must be unique among all the artifacts of the same artifact type within + a database instance and cannot be changed once set. + type: string + id: + format: int64 + description: The unique server generated id of the resource. + type: string + - $ref: "#/components/schemas/BaseResourceDates" BaseResourceList: required: - nextPageToken diff --git a/api/openapi/src/model-registry.yaml b/api/openapi/src/model-registry.yaml index dd780ed1..158dac7a 100644 --- a/api/openapi/src/model-registry.yaml +++ b/api/openapi/src/model-registry.yaml @@ -1676,35 +1676,6 @@ components: dataset-artifact: "#/components/schemas/DataSetUpdate" metric: "#/components/schemas/MetricUpdate" parameter: "#/components/schemas/ParameterUpdate" - BaseResource: - allOf: - - type: object - properties: - customProperties: - description: User provided custom properties which are not defined by its type. - type: object - additionalProperties: - $ref: "#/components/schemas/MetadataValue" - description: - description: |- - An optional description about the resource. - type: string - externalId: - description: |- - The external id that come from the clients’ system. This field is optional. - If set, it must be unique among all resources within a database instance. - type: string - name: - description: |- - The client provided name of the artifact. This field is optional. If set, - it must be unique among all the artifacts of the same artifact type within - a database instance and cannot be changed once set. - type: string - id: - format: int64 - description: The unique server generated id of the resource. - type: string - - $ref: "#/components/schemas/BaseResourceDates" BaseResourceCreate: type: object properties: diff --git a/catalog/cmd/catalog.go b/catalog/cmd/catalog.go index acc452d1..2ec80bcc 100644 --- a/catalog/cmd/catalog.go +++ b/catalog/cmd/catalog.go @@ -3,9 +3,11 @@ package cmd import ( "fmt" "net/http" + "reflect" "github.com/golang/glog" "github.com/kubeflow/model-registry/catalog/internal/catalog" + "github.com/kubeflow/model-registry/catalog/internal/db/models" "github.com/kubeflow/model-registry/catalog/internal/db/service" "github.com/kubeflow/model-registry/catalog/internal/server/openapi" "github.com/kubeflow/model-registry/internal/datastore" @@ -46,7 +48,7 @@ func runCatalogServer(cmd *cobra.Command, args []string) error { return fmt.Errorf("error creating datastore: %w", err) } - _, err = ds.Connect(service.DatastoreSpec()) + repoSet, err := ds.Connect(service.DatastoreSpec()) if err != nil { return fmt.Errorf("error initializing datastore: %v", err) } @@ -56,9 +58,21 @@ func runCatalogServer(cmd *cobra.Command, args []string) error { return fmt.Errorf("error loading catalog sources: %v", err) } - svc := openapi.NewModelCatalogServiceAPIService(sources) + svc := openapi.NewModelCatalogServiceAPIService(catalog.NewDBCatalog( + getRepo[models.CatalogModelRepository](repoSet), + getRepo[models.CatalogArtifactRepository](repoSet), + ), sources) ctrl := openapi.NewModelCatalogServiceAPIController(svc) glog.Infof("Catalog API server listening on %s", catalogCfg.ListenAddress) return http.ListenAndServe(catalogCfg.ListenAddress, openapi.NewRouter(ctrl)) } + +func getRepo[T any](repoSet datastore.RepoSet) T { + repo, err := repoSet.Repository(reflect.TypeFor[T]()) + if err != nil { + panic(fmt.Sprintf("unable to get repository: %v", err)) + } + + return repo.(T) +} diff --git a/catalog/internal/catalog/catalog.go b/catalog/internal/catalog/catalog.go index 6f4d9d88..6e611e83 100644 --- a/catalog/internal/catalog/catalog.go +++ b/catalog/internal/catalog/catalog.go @@ -14,9 +14,19 @@ import ( ) type ListModelsParams struct { - Query string - OrderBy model.OrderByField - SortOrder model.SortOrder + Query string + SourceIDs []string + PageSize int32 + OrderBy model.OrderByField + SortOrder model.SortOrder + NextPageToken *string +} + +type ListArtifactsParams struct { + PageSize int32 + OrderBy model.OrderByField + SortOrder model.SortOrder + NextPageToken *string } // CatalogSourceProvider is implemented by catalog source types, e.g. YamlCatalog @@ -24,16 +34,17 @@ type CatalogSourceProvider interface { // GetModel returns model metadata for a single model by its name. If // nothing is found with the name provided it returns nil, without an // error. - GetModel(ctx context.Context, name string) (*model.CatalogModel, error) + GetModel(ctx context.Context, modelName string, sourceID string) (*model.CatalogModel, error) // ListModels returns all models according to the parameters. If // nothing suitable is found, it returns an empty list. + // If sourceIDs is provided, filter models by source IDs. If not provided, return all models. ListModels(ctx context.Context, params ListModelsParams) (model.CatalogModelList, error) // GetArtifacts returns all artifacts for a particular model. If no // model is found with that name, it returns nil. If the model is // found, but has no artifacts, an empty list is returned. - GetArtifacts(ctx context.Context, name string) (*model.CatalogArtifactList, error) + GetArtifacts(ctx context.Context, modelName string, sourceID string, params ListArtifactsParams) (model.CatalogArtifactList, error) } // CatalogSourceConfig is a single entry from the catalog sources YAML file. diff --git a/catalog/internal/catalog/db_catalog.go b/catalog/internal/catalog/db_catalog.go new file mode 100644 index 00000000..b0b01aad --- /dev/null +++ b/catalog/internal/catalog/db_catalog.go @@ -0,0 +1,415 @@ +package catalog + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + + dbmodels "github.com/kubeflow/model-registry/catalog/internal/db/models" + apimodels "github.com/kubeflow/model-registry/catalog/pkg/openapi" + "github.com/kubeflow/model-registry/internal/converter" + mrmodels "github.com/kubeflow/model-registry/internal/db/models" + "github.com/kubeflow/model-registry/pkg/api" + "github.com/kubeflow/model-registry/pkg/openapi" +) + +type dbCatalogImpl struct { + catalogModelRepository dbmodels.CatalogModelRepository + catalogArtifactRepository dbmodels.CatalogArtifactRepository +} + +func NewDBCatalog( + catalogModelRepository dbmodels.CatalogModelRepository, + catalogArtifactRepository dbmodels.CatalogArtifactRepository, +) CatalogSourceProvider { + return &dbCatalogImpl{ + catalogModelRepository: catalogModelRepository, + catalogArtifactRepository: catalogArtifactRepository, + } +} + +func (d *dbCatalogImpl) GetModel(ctx context.Context, modelName string, sourceID string) (*apimodels.CatalogModel, error) { + modelsList, err := d.catalogModelRepository.List(dbmodels.CatalogModelListOptions{ + Name: &modelName, + SourceIDs: &[]string{sourceID}, + }) + if err != nil { + return nil, err + } + + if len(modelsList.Items) == 0 { + return nil, fmt.Errorf("no models found for name=%v: %w", modelName, api.ErrNotFound) + } + + if len(modelsList.Items) > 1 { + return nil, fmt.Errorf("multiple models found for name=%v: %w", modelName, api.ErrNotFound) + } + + model := mapDBModelToAPIModel(modelsList.Items[0]) + + return &model, nil +} + +func (d *dbCatalogImpl) ListModels(ctx context.Context, params ListModelsParams) (apimodels.CatalogModelList, error) { + pageSize := int32(params.PageSize) + + // Use consistent defaults to match pagination logic + orderBy := string(params.OrderBy) + if orderBy == "" { + orderBy = mrmodels.DefaultOrderBy + } + + sortOrder := string(params.SortOrder) + if sortOrder == "" { + sortOrder = mrmodels.DefaultSortOrder + } + + nextPageToken := params.NextPageToken + + var queryPtr *string + if params.Query != "" { + queryPtr = ¶ms.Query + } + + modelsList, err := d.catalogModelRepository.List(dbmodels.CatalogModelListOptions{ + SourceIDs: ¶ms.SourceIDs, + Query: queryPtr, + Pagination: mrmodels.Pagination{ + PageSize: &pageSize, + OrderBy: &orderBy, + SortOrder: &sortOrder, + NextPageToken: nextPageToken, + }, + }) + if err != nil { + return apimodels.CatalogModelList{}, err + } + + modelList := &apimodels.CatalogModelList{ + Items: make([]apimodels.CatalogModel, 0), + } + + for _, model := range modelsList.Items { + modelList.Items = append(modelList.Items, mapDBModelToAPIModel(model)) + } + + modelList.NextPageToken = modelsList.NextPageToken + modelList.PageSize = pageSize + modelList.Size = int32(len(modelsList.Items)) + + return *modelList, nil +} + +func (d *dbCatalogImpl) GetArtifacts(ctx context.Context, modelName string, sourceID string, params ListArtifactsParams) (apimodels.CatalogArtifactList, error) { + pageSize := int32(params.PageSize) + + // Use consistent defaults to match pagination logic + orderBy := string(params.OrderBy) + if orderBy == "" { + orderBy = mrmodels.DefaultOrderBy + } + + sortOrder := string(params.SortOrder) + if sortOrder == "" { + sortOrder = mrmodels.DefaultSortOrder + } + + nextPageToken := params.NextPageToken + + m, err := d.GetModel(ctx, modelName, sourceID) + if err != nil { + if errors.Is(err, api.ErrNotFound) { + return apimodels.CatalogArtifactList{}, fmt.Errorf("invalid model name '%s' for source '%s': %w", modelName, sourceID, api.ErrBadRequest) + } + return apimodels.CatalogArtifactList{}, err + } + + parentResourceID, err := strconv.ParseInt(*m.Id, 10, 32) + if err != nil { + return apimodels.CatalogArtifactList{}, err + } + + parentResourceID32 := int32(parentResourceID) + + artifactsList, err := d.catalogArtifactRepository.List(dbmodels.CatalogArtifactListOptions{ + ParentResourceID: &parentResourceID32, + Pagination: mrmodels.Pagination{ + PageSize: &pageSize, + OrderBy: &orderBy, + SortOrder: &sortOrder, + NextPageToken: nextPageToken, + }, + }) + if err != nil { + return apimodels.CatalogArtifactList{}, err + } + + artifactList := &apimodels.CatalogArtifactList{ + Items: make([]apimodels.CatalogArtifact, 0), + } + + for _, artifact := range artifactsList.Items { + mappedArtifact, err := mapDBArtifactToAPIArtifact(artifact) + if err != nil { + return apimodels.CatalogArtifactList{}, err + } + artifactList.Items = append(artifactList.Items, mappedArtifact) + } + + artifactList.NextPageToken = artifactsList.NextPageToken + artifactList.PageSize = pageSize + artifactList.Size = int32(len(artifactList.Items)) + + return *artifactList, nil +} + +func mapDBModelToAPIModel(m dbmodels.CatalogModel) apimodels.CatalogModel { + res := apimodels.CatalogModel{} + + id := strconv.FormatInt(int64(*m.GetID()), 10) + res.Id = &id + + if m.GetAttributes() != nil { + res.Name = *m.GetAttributes().Name + res.ExternalId = m.GetAttributes().ExternalID + + if m.GetAttributes().CreateTimeSinceEpoch != nil { + createTimeSinceEpoch := strconv.FormatInt(*m.GetAttributes().CreateTimeSinceEpoch, 10) + res.CreateTimeSinceEpoch = &createTimeSinceEpoch + } + if m.GetAttributes().LastUpdateTimeSinceEpoch != nil { + lastUpdateTimeSinceEpoch := strconv.FormatInt(*m.GetAttributes().LastUpdateTimeSinceEpoch, 10) + res.LastUpdateTimeSinceEpoch = &lastUpdateTimeSinceEpoch + } + } + + if m.GetProperties() != nil { + for _, prop := range *m.GetProperties() { + switch prop.Name { + case "source_id": + if prop.StringValue != nil { + res.SourceId = prop.StringValue + } + case "description": + if prop.StringValue != nil { + res.Description = prop.StringValue + } + case "library_name": + if prop.StringValue != nil { + res.LibraryName = prop.StringValue + } + case "license_link": + if prop.StringValue != nil { + res.LicenseLink = prop.StringValue + } + case "license": + if prop.StringValue != nil { + res.License = prop.StringValue + } + case "logo": + if prop.StringValue != nil { + res.Logo = prop.StringValue + } + case "maturity": + if prop.StringValue != nil { + res.Maturity = prop.StringValue + } + case "provider": + if prop.StringValue != nil { + res.Provider = prop.StringValue + } + case "readme": + if prop.StringValue != nil { + res.Readme = prop.StringValue + } + case "language": + if prop.StringValue != nil { + var languages []string + if err := json.Unmarshal([]byte(*prop.StringValue), &languages); err == nil { + res.Language = languages + } + } + case "tasks": + if prop.StringValue != nil { + var tasks []string + if err := json.Unmarshal([]byte(*prop.StringValue), &tasks); err == nil { + res.Tasks = tasks + } + } + } + } + } + + return res +} + +func mapDBArtifactToAPIArtifact(a dbmodels.CatalogArtifact) (apimodels.CatalogArtifact, error) { + if a.CatalogModelArtifact != nil { + return mapToModelArtifact(*a.CatalogModelArtifact) + } else if a.CatalogMetricsArtifact != nil { + metricsTypeValue := string((*a.CatalogMetricsArtifact).GetAttributes().MetricsType) + return mapToMetricsArtifact(*a.CatalogMetricsArtifact, metricsTypeValue) + } + + return apimodels.CatalogArtifact{}, fmt.Errorf("invalid catalog artifact type: %v", a) +} + +func mapToModelArtifact(a dbmodels.CatalogModelArtifact) (apimodels.CatalogArtifact, error) { + catalogModelArtifact := &apimodels.CatalogModelArtifact{ + ArtifactType: dbmodels.CatalogModelArtifactType, + } + + if a.GetID() != nil { + id := strconv.FormatInt(int64(*a.GetID()), 10) + catalogModelArtifact.Id = &id + } + + if a.GetAttributes() != nil { + attrs := a.GetAttributes() + + catalogModelArtifact.Name = attrs.Name + catalogModelArtifact.ExternalId = attrs.ExternalID + + if attrs.URI != nil { + catalogModelArtifact.Uri = *attrs.URI + } + + if attrs.CreateTimeSinceEpoch != nil { + createTime := strconv.FormatInt(*attrs.CreateTimeSinceEpoch, 10) + catalogModelArtifact.CreateTimeSinceEpoch = &createTime + } + + if attrs.LastUpdateTimeSinceEpoch != nil { + updateTime := strconv.FormatInt(*attrs.LastUpdateTimeSinceEpoch, 10) + catalogModelArtifact.LastUpdateTimeSinceEpoch = &updateTime + } + } + + if a.GetProperties() != nil { + for _, prop := range *a.GetProperties() { + switch prop.Name { + case "description": + if prop.StringValue != nil { + catalogModelArtifact.Description = prop.StringValue + } + case "artifactType": + if prop.StringValue != nil { + catalogModelArtifact.ArtifactType = *prop.StringValue + } + } + } + } + + // Map custom properties + if a.GetCustomProperties() != nil && len(*a.GetCustomProperties()) > 0 { + customPropsMap, err := converter.MapEmbedMDCustomProperties(*a.GetCustomProperties()) + if err != nil { + return apimodels.CatalogArtifact{}, fmt.Errorf("error mapping custom properties: %w", err) + } + + catalogCustomProps := convertMetadataValueMap(customPropsMap) + catalogModelArtifact.CustomProperties = &catalogCustomProps + } + + return apimodels.CatalogArtifact{ + CatalogModelArtifact: catalogModelArtifact, + }, nil +} + +func mapToMetricsArtifact(a dbmodels.CatalogMetricsArtifact, metricsType string) (apimodels.CatalogArtifact, error) { + catalogMetricsArtifact := &apimodels.CatalogMetricsArtifact{ + ArtifactType: dbmodels.CatalogMetricsArtifactType, + MetricsType: metricsType, + } + + if a.GetID() != nil { + id := strconv.FormatInt(int64(*a.GetID()), 10) + catalogMetricsArtifact.Id = &id + } + + if a.GetAttributes() != nil { + attrs := a.GetAttributes() + + catalogMetricsArtifact.Name = attrs.Name + catalogMetricsArtifact.ExternalId = attrs.ExternalID + + if attrs.CreateTimeSinceEpoch != nil { + createTime := strconv.FormatInt(*attrs.CreateTimeSinceEpoch, 10) + catalogMetricsArtifact.CreateTimeSinceEpoch = &createTime + } + + if attrs.LastUpdateTimeSinceEpoch != nil { + updateTime := strconv.FormatInt(*attrs.LastUpdateTimeSinceEpoch, 10) + catalogMetricsArtifact.LastUpdateTimeSinceEpoch = &updateTime + } + } + + if a.GetProperties() != nil { + for _, prop := range *a.GetProperties() { + switch prop.Name { + case "description": + if prop.StringValue != nil { + catalogMetricsArtifact.Description = prop.StringValue + } + } + } + } + + // Map custom properties + if a.GetCustomProperties() != nil && len(*a.GetCustomProperties()) > 0 { + customPropsMap, err := converter.MapEmbedMDCustomProperties(*a.GetCustomProperties()) + if err != nil { + return apimodels.CatalogArtifact{}, fmt.Errorf("error mapping custom properties: %w", err) + } + + catalogCustomProps := convertMetadataValueMap(customPropsMap) + catalogMetricsArtifact.CustomProperties = &catalogCustomProps + + } + + return apimodels.CatalogArtifact{ + CatalogMetricsArtifact: catalogMetricsArtifact, + }, nil +} + +// convertMetadataValueMap converts from pkg/openapi.MetadataValue to catalog/pkg/openapi.MetadataValue +func convertMetadataValueMap(source map[string]openapi.MetadataValue) map[string]apimodels.MetadataValue { + result := make(map[string]apimodels.MetadataValue) + + for key, value := range source { + catalogValue := apimodels.MetadataValue{} + + if value.MetadataStringValue != nil { + catalogValue.MetadataStringValue = &apimodels.MetadataStringValue{ + StringValue: value.MetadataStringValue.StringValue, + MetadataType: value.MetadataStringValue.MetadataType, + } + } else if value.MetadataIntValue != nil { + catalogValue.MetadataIntValue = &apimodels.MetadataIntValue{ + IntValue: value.MetadataIntValue.IntValue, + MetadataType: value.MetadataIntValue.MetadataType, + } + } else if value.MetadataDoubleValue != nil { + catalogValue.MetadataDoubleValue = &apimodels.MetadataDoubleValue{ + DoubleValue: value.MetadataDoubleValue.DoubleValue, + MetadataType: value.MetadataDoubleValue.MetadataType, + } + } else if value.MetadataBoolValue != nil { + catalogValue.MetadataBoolValue = &apimodels.MetadataBoolValue{ + BoolValue: value.MetadataBoolValue.BoolValue, + MetadataType: value.MetadataBoolValue.MetadataType, + } + } else if value.MetadataStructValue != nil { + catalogValue.MetadataStructValue = &apimodels.MetadataStructValue{ + StructValue: value.MetadataStructValue.StructValue, + MetadataType: value.MetadataStructValue.MetadataType, + } + } + + result[key] = catalogValue + } + + return result +} diff --git a/catalog/internal/catalog/db_catalog_test.go b/catalog/internal/catalog/db_catalog_test.go new file mode 100644 index 00000000..ceb82f99 --- /dev/null +++ b/catalog/internal/catalog/db_catalog_test.go @@ -0,0 +1,704 @@ +package catalog + +import ( + "context" + "fmt" + "os" + "strconv" + "testing" + "time" + + "github.com/kubeflow/model-registry/catalog/internal/db/models" + "github.com/kubeflow/model-registry/catalog/internal/db/service" + model "github.com/kubeflow/model-registry/catalog/pkg/openapi" + "github.com/kubeflow/model-registry/internal/apiutils" + mr_models "github.com/kubeflow/model-registry/internal/db/models" + "github.com/kubeflow/model-registry/internal/db/schema" + "github.com/kubeflow/model-registry/internal/testutils" + "github.com/kubeflow/model-registry/pkg/api" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func TestMain(m *testing.M) { + os.Exit(testutils.TestMainHelper(m)) +} + +func TestDBCatalog(t *testing.T) { + // Setup test database + sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec()) + defer cleanup() + + // Get type IDs + catalogModelTypeID := getCatalogModelTypeIDForDBTest(t, sharedDB) + modelArtifactTypeID := getCatalogModelArtifactTypeIDForDBTest(t, sharedDB) + metricsArtifactTypeID := getCatalogMetricsArtifactTypeIDForDBTest(t, sharedDB) + + // Create repositories + catalogModelRepo := service.NewCatalogModelRepository(sharedDB, catalogModelTypeID) + catalogArtifactRepo := service.NewCatalogArtifactRepository(sharedDB, map[string]int64{ + service.CatalogModelArtifactTypeName: modelArtifactTypeID, + service.CatalogMetricsArtifactTypeName: metricsArtifactTypeID, + }) + modelArtifactRepo := service.NewCatalogModelArtifactRepository(sharedDB, modelArtifactTypeID) + metricsArtifactRepo := service.NewCatalogMetricsArtifactRepository(sharedDB, metricsArtifactTypeID) + + // Create DB catalog instance + dbCatalog := NewDBCatalog(catalogModelRepo, catalogArtifactRepo) + ctx := context.Background() + + t.Run("TestNewDBCatalog", func(t *testing.T) { + catalog := NewDBCatalog(catalogModelRepo, catalogArtifactRepo) + require.NotNil(t, catalog) + + // Verify it implements the interface + var _ CatalogSourceProvider = catalog + }) + + t.Run("TestGetModel_Success", func(t *testing.T) { + // Create test model + testModel := &models.CatalogModelImpl{ + TypeID: apiutils.Of(int32(catalogModelTypeID)), + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of("test-get-model"), + ExternalID: apiutils.Of("test-get-model-ext"), + }, + Properties: &[]mr_models.Properties{ + {Name: "source_id", StringValue: apiutils.Of("test-source-id")}, + {Name: "description", StringValue: apiutils.Of("Test model description")}, + }, + } + + savedModel, err := catalogModelRepo.Save(testModel) + require.NoError(t, err) + + // Test GetModel + retrievedModel, err := dbCatalog.GetModel(ctx, "test-get-model", "test-source-id") + require.NoError(t, err) + require.NotNil(t, retrievedModel) + + assert.Equal(t, "test-get-model", retrievedModel.Name) + assert.Equal(t, strconv.FormatInt(int64(*savedModel.GetID()), 10), *retrievedModel.Id) + assert.Equal(t, "test-get-model-ext", *retrievedModel.ExternalId) + assert.Equal(t, "test-source-id", *retrievedModel.SourceId) + assert.Equal(t, "Test model description", *retrievedModel.Description) + }) + + t.Run("TestGetModel_NotFound", func(t *testing.T) { + // Test with non-existent model + _, err := dbCatalog.GetModel(ctx, "non-existent-model", "test-source-id") + require.Error(t, err) + assert.Contains(t, err.Error(), "no models found") + assert.ErrorIs(t, err, api.ErrNotFound) + }) + + t.Run("TestGetModel_DatabaseConstraints", func(t *testing.T) { + // Test database constraint behavior - attempting to create duplicate models should fail + // Using timestamp to ensure uniqueness across test runs + timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10) + modelName := "constraint-test-model-" + timestamp + sourceID := "constraint-test-source-" + timestamp + + model1 := &models.CatalogModelImpl{ + TypeID: apiutils.Of(int32(catalogModelTypeID)), + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of(modelName), + ExternalID: apiutils.Of("constraint-test-1-" + timestamp), + }, + Properties: &[]mr_models.Properties{ + {Name: "source_id", StringValue: apiutils.Of(sourceID)}, + }, + } + + // First model should save successfully + savedModel1, err := catalogModelRepo.Save(model1) + require.NoError(t, err) + require.NotNil(t, savedModel1) + + // Now test that GetModel works correctly with the single saved model + retrievedModel, err := dbCatalog.GetModel(ctx, modelName, sourceID) + require.NoError(t, err) + require.NotNil(t, retrievedModel) + assert.Equal(t, modelName, retrievedModel.Name) + + // Test attempting to create a duplicate with same name but different external ID + // This should fail due to database constraints (which is expected behavior) + model2 := &models.CatalogModelImpl{ + TypeID: apiutils.Of(int32(catalogModelTypeID)), + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of(modelName), // Same name + ExternalID: apiutils.Of("constraint-test-2-" + timestamp), // Different external ID + }, + Properties: &[]mr_models.Properties{ + {Name: "source_id", StringValue: apiutils.Of(sourceID)}, + }, + } + + _, err = catalogModelRepo.Save(model2) + // This should fail due to database constraints preventing duplicate names + assert.Error(t, err) + assert.Contains(t, err.Error(), "duplicated key") + }) + + t.Run("TestListModels_Success", func(t *testing.T) { + // Create test models + sourceIDs := []string{"list-test-source"} + + model1 := &models.CatalogModelImpl{ + TypeID: apiutils.Of(int32(catalogModelTypeID)), + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of("list-test-model-1"), + ExternalID: apiutils.Of("list-test-1"), + }, + Properties: &[]mr_models.Properties{ + {Name: "source_id", StringValue: apiutils.Of("list-test-source")}, + {Name: "description", StringValue: apiutils.Of("First test model")}, + }, + } + + model2 := &models.CatalogModelImpl{ + TypeID: apiutils.Of(int32(catalogModelTypeID)), + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of("list-test-model-2"), + ExternalID: apiutils.Of("list-test-2"), + }, + Properties: &[]mr_models.Properties{ + {Name: "source_id", StringValue: apiutils.Of("list-test-source")}, + {Name: "description", StringValue: apiutils.Of("Second test model")}, + }, + } + + _, err := catalogModelRepo.Save(model1) + require.NoError(t, err) + _, err = catalogModelRepo.Save(model2) + require.NoError(t, err) + + // Test ListModels + params := ListModelsParams{ + SourceIDs: sourceIDs, + PageSize: 10, + OrderBy: model.ORDERBYFIELD_CREATE_TIME, + SortOrder: model.SORTORDER_ASC, + NextPageToken: apiutils.Of(""), + } + + result, err := dbCatalog.ListModels(ctx, params) + require.NoError(t, err) + + assert.GreaterOrEqual(t, len(result.Items), 2, "Should return at least 2 models") + assert.Equal(t, int32(10), result.PageSize) + assert.GreaterOrEqual(t, result.Size, int32(2)) + + // Verify models are properly mapped + modelNames := make(map[string]bool) + for _, model := range result.Items { + modelNames[model.Name] = true + // Verify required fields are present + assert.NotEmpty(t, *model.Id) + assert.NotEmpty(t, *model.SourceId) + } + + // Should contain our test models + foundCount := 0 + if modelNames["list-test-model-1"] { + foundCount++ + } + if modelNames["list-test-model-2"] { + foundCount++ + } + assert.GreaterOrEqual(t, foundCount, 2, "Should find our test models") + }) + + t.Run("TestListModels_WithPagination", func(t *testing.T) { + // Test pagination + sourceIDs := []string{"pagination-test-source"} + + // Create multiple models + for i := 0; i < 5; i++ { + model := &models.CatalogModelImpl{ + TypeID: apiutils.Of(int32(catalogModelTypeID)), + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of(fmt.Sprintf("pagination-test-model-%d", i)), + ExternalID: apiutils.Of(fmt.Sprintf("pagination-test-%d", i)), + }, + Properties: &[]mr_models.Properties{ + {Name: "source_id", StringValue: apiutils.Of("pagination-test-source")}, + }, + } + _, err := catalogModelRepo.Save(model) + require.NoError(t, err) + } + + params := ListModelsParams{ + SourceIDs: sourceIDs, + PageSize: 3, + OrderBy: model.ORDERBYFIELD_CREATE_TIME, + SortOrder: model.SORTORDER_ASC, + NextPageToken: apiutils.Of(""), + } + + result, err := dbCatalog.ListModels(ctx, params) + require.NoError(t, err) + + assert.LessOrEqual(t, len(result.Items), 3, "Should respect page size") + assert.Equal(t, int32(3), result.PageSize) + }) + + t.Run("TestListModels_WithQuery", func(t *testing.T) { + // Create test models with different properties for query filtering + sourceIDs := []string{"query-test-source"} + + model1 := &models.CatalogModelImpl{ + TypeID: apiutils.Of(int32(catalogModelTypeID)), + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of("BERT-base-model"), + ExternalID: apiutils.Of("bert-base-1"), + }, + Properties: &[]mr_models.Properties{ + {Name: "source_id", StringValue: apiutils.Of("query-test-source")}, + {Name: "description", StringValue: apiutils.Of("BERT base model for NLP tasks")}, + {Name: "provider", StringValue: apiutils.Of("Hugging Face")}, + {Name: "tasks", StringValue: apiutils.Of(`["text-classification", "question-answering"]`)}, + }, + } + + model2 := &models.CatalogModelImpl{ + TypeID: apiutils.Of(int32(catalogModelTypeID)), + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of("GPT-3.5-turbo"), + ExternalID: apiutils.Of("gpt-35-turbo-1"), + }, + Properties: &[]mr_models.Properties{ + {Name: "source_id", StringValue: apiutils.Of("query-test-source")}, + {Name: "description", StringValue: apiutils.Of("OpenAI GPT model for text generation")}, + {Name: "provider", StringValue: apiutils.Of("OpenAI")}, + {Name: "tasks", StringValue: apiutils.Of(`["text-generation", "conversational"]`)}, + }, + } + + model3 := &models.CatalogModelImpl{ + TypeID: apiutils.Of(int32(catalogModelTypeID)), + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of("ResNet-50-image"), + ExternalID: apiutils.Of("resnet-50-1"), + }, + Properties: &[]mr_models.Properties{ + {Name: "source_id", StringValue: apiutils.Of("query-test-source")}, + {Name: "description", StringValue: apiutils.Of("Deep learning model for image classification")}, + {Name: "provider", StringValue: apiutils.Of("PyTorch")}, + {Name: "tasks", StringValue: apiutils.Of(`["image-classification", "computer-vision"]`)}, + }, + } + + _, err := catalogModelRepo.Save(model1) + require.NoError(t, err) + _, err = catalogModelRepo.Save(model2) + require.NoError(t, err) + _, err = catalogModelRepo.Save(model3) + require.NoError(t, err) + + // Test query filtering by name + params := ListModelsParams{ + Query: "BERT", + SourceIDs: sourceIDs, + PageSize: 10, + OrderBy: model.ORDERBYFIELD_CREATE_TIME, + SortOrder: model.SORTORDER_ASC, + NextPageToken: apiutils.Of(""), + } + + result, err := dbCatalog.ListModels(ctx, params) + require.NoError(t, err) + + assert.Equal(t, int32(1), result.Size, "Should return 1 model matching 'BERT'") + assert.Contains(t, result.Items[0].Name, "BERT", "Should contain BERT model") + + // Test query filtering by description + params.Query = "NLP" + result, err = dbCatalog.ListModels(ctx, params) + require.NoError(t, err) + + assert.Equal(t, int32(1), result.Size, "Should return 1 model with 'NLP' in description") + assert.Contains(t, result.Items[0].Name, "BERT", "Should contain BERT model") + + // Test query filtering by provider + params.Query = "OpenAI" + result, err = dbCatalog.ListModels(ctx, params) + require.NoError(t, err) + + assert.Equal(t, int32(1), result.Size, "Should return 1 model from 'OpenAI' provider") + assert.Contains(t, result.Items[0].Name, "GPT", "Should contain GPT model") + + // Test query filtering that should match multiple models + params.Query = "model" + result, err = dbCatalog.ListModels(ctx, params) + require.NoError(t, err) + + assert.GreaterOrEqual(t, result.Size, int32(3), "Should return at least 3 models matching 'model'") + + // Test query that should return no results + params.Query = "nonexistent" + result, err = dbCatalog.ListModels(ctx, params) + require.NoError(t, err) + + assert.Equal(t, int32(0), result.Size, "Should return 0 models for nonexistent query") + + // Test query filtering by tasks - text-classification + params.Query = "text-classification" + result, err = dbCatalog.ListModels(ctx, params) + require.NoError(t, err) + + assert.Equal(t, int32(1), result.Size, "Should return 1 model with 'text-classification' task") + assert.Contains(t, result.Items[0].Name, "BERT", "Should contain BERT model") + + // Test query filtering by tasks - image-classification + params.Query = "image-classification" + result, err = dbCatalog.ListModels(ctx, params) + require.NoError(t, err) + + assert.Equal(t, int32(1), result.Size, "Should return 1 model with 'image-classification' task") + assert.Contains(t, result.Items[0].Name, "ResNet", "Should contain ResNet model") + + // Test query filtering by tasks - conversational + params.Query = "conversational" + result, err = dbCatalog.ListModels(ctx, params) + require.NoError(t, err) + + assert.Equal(t, int32(1), result.Size, "Should return 1 model with 'conversational' task") + assert.Contains(t, result.Items[0].Name, "GPT", "Should contain GPT model") + + // Test query filtering by tasks - partial match on "classification" + params.Query = "classification" + result, err = dbCatalog.ListModels(ctx, params) + require.NoError(t, err) + + assert.Equal(t, int32(2), result.Size, "Should return 2 models with 'classification' in their tasks") + + // Test query filtering by tasks - computer-vision + params.Query = "computer-vision" + result, err = dbCatalog.ListModels(ctx, params) + require.NoError(t, err) + + assert.Equal(t, int32(1), result.Size, "Should return 1 model with 'computer-vision' task") + assert.Contains(t, result.Items[0].Name, "ResNet", "Should contain ResNet model") + }) + + t.Run("TestGetArtifacts_Success", func(t *testing.T) { + // Create test model + testModel := &models.CatalogModelImpl{ + TypeID: apiutils.Of(int32(catalogModelTypeID)), + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of("artifact-test-model"), + ExternalID: apiutils.Of("artifact-test-model-ext"), + }, + Properties: &[]mr_models.Properties{ + {Name: "source_id", StringValue: apiutils.Of("artifact-test-source")}, + }, + } + + savedModel, err := catalogModelRepo.Save(testModel) + require.NoError(t, err) + + // Create test artifacts + modelArtifact := &models.CatalogModelArtifactImpl{ + TypeID: apiutils.Of(int32(modelArtifactTypeID)), + Attributes: &models.CatalogModelArtifactAttributes{ + Name: apiutils.Of("test-model-artifact"), + ExternalID: apiutils.Of("test-model-artifact-ext"), + URI: apiutils.Of("s3://test/model.bin"), + ArtifactType: apiutils.Of(models.CatalogModelArtifactType), + }, + } + + metricsArtifact := &models.CatalogMetricsArtifactImpl{ + TypeID: apiutils.Of(int32(metricsArtifactTypeID)), + Attributes: &models.CatalogMetricsArtifactAttributes{ + Name: apiutils.Of("test-metrics-artifact"), + ExternalID: apiutils.Of("test-metrics-artifact-ext"), + MetricsType: models.MetricsTypeAccuracy, + ArtifactType: apiutils.Of("metrics-artifact"), + }, + } + + savedModelArt, err := modelArtifactRepo.Save(modelArtifact, savedModel.GetID()) + require.NoError(t, err) + savedMetricsArt, err := metricsArtifactRepo.Save(metricsArtifact, savedModel.GetID()) + require.NoError(t, err) + + // Test GetArtifacts + params := ListArtifactsParams{ + PageSize: 10, + OrderBy: model.ORDERBYFIELD_CREATE_TIME, + SortOrder: model.SORTORDER_ASC, + NextPageToken: apiutils.Of(""), + } + + result, err := dbCatalog.GetArtifacts(ctx, "artifact-test-model", "artifact-test-source", params) + require.NoError(t, err) + + assert.GreaterOrEqual(t, len(result.Items), 2, "Should return at least 2 artifacts") + assert.Equal(t, int32(10), result.PageSize) + + // Verify both types of artifacts are returned + var modelArtifactFound, metricsArtifactFound bool + artifactIDs := make(map[string]bool) + + for _, artifact := range result.Items { + if artifact.CatalogModelArtifact != nil { + modelArtifactFound = true + artifactIDs[*artifact.CatalogModelArtifact.Id] = true + assert.Equal(t, "model-artifact", artifact.CatalogModelArtifact.ArtifactType) + } + if artifact.CatalogMetricsArtifact != nil { + metricsArtifactFound = true + artifactIDs[*artifact.CatalogMetricsArtifact.Id] = true + assert.Equal(t, "metrics-artifact", artifact.CatalogMetricsArtifact.ArtifactType) + } + } + + assert.True(t, modelArtifactFound, "Should find model artifact") + assert.True(t, metricsArtifactFound, "Should find metrics artifact") + + // Verify our specific artifacts are in the results + modelArtifactIDStr := strconv.FormatInt(int64(*savedModelArt.GetID()), 10) + metricsArtifactIDStr := strconv.FormatInt(int64(*savedMetricsArt.GetID()), 10) + assert.True(t, artifactIDs[modelArtifactIDStr], "Should contain our model artifact") + assert.True(t, artifactIDs[metricsArtifactIDStr], "Should contain our metrics artifact") + }) + + t.Run("TestGetArtifacts_ModelNotFound", func(t *testing.T) { + // Test with non-existent model + params := ListArtifactsParams{ + PageSize: 10, + OrderBy: model.ORDERBYFIELD_CREATE_TIME, + SortOrder: model.SORTORDER_ASC, + NextPageToken: apiutils.Of(""), + } + + _, err := dbCatalog.GetArtifacts(ctx, "non-existent-model", "test-source", params) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid model name") + }) + + t.Run("TestGetArtifacts_WithCustomProperties", func(t *testing.T) { + // Create model + testModel := &models.CatalogModelImpl{ + TypeID: apiutils.Of(int32(catalogModelTypeID)), + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of("custom-props-model"), + ExternalID: apiutils.Of("custom-props-model-ext"), + }, + Properties: &[]mr_models.Properties{ + {Name: "source_id", StringValue: apiutils.Of("custom-props-source")}, + }, + } + + savedModel, err := catalogModelRepo.Save(testModel) + require.NoError(t, err) + + // Create artifact with custom properties + customProps := []mr_models.Properties{ + {Name: "custom_prop_1", StringValue: apiutils.Of("value_1")}, + {Name: "custom_prop_2", StringValue: apiutils.Of("value_2")}, + } + + artifactWithProps := &models.CatalogModelArtifactImpl{ + TypeID: apiutils.Of(int32(modelArtifactTypeID)), + Attributes: &models.CatalogModelArtifactAttributes{ + Name: apiutils.Of("artifact-with-props"), + ExternalID: apiutils.Of("artifact-with-props-ext"), + URI: apiutils.Of("s3://test/props.bin"), + ArtifactType: apiutils.Of(models.CatalogModelArtifactType), + }, + CustomProperties: &customProps, + } + + _, err = modelArtifactRepo.Save(artifactWithProps, savedModel.GetID()) + require.NoError(t, err) + + // Get artifacts and verify custom properties + params := ListArtifactsParams{ + PageSize: 10, + OrderBy: model.ORDERBYFIELD_CREATE_TIME, + SortOrder: model.SORTORDER_ASC, + NextPageToken: apiutils.Of(""), + } + + result, err := dbCatalog.GetArtifacts(ctx, "custom-props-model", "custom-props-source", params) + require.NoError(t, err) + + // Find our artifact and check custom properties + found := false + for _, artifact := range result.Items { + if artifact.CatalogModelArtifact != nil && + artifact.CatalogModelArtifact.Name != nil && + *artifact.CatalogModelArtifact.Name == "artifact-with-props" { + + found = true + assert.NotNil(t, artifact.CatalogModelArtifact.CustomProperties) + + // Verify custom properties are present and properly converted + customPropsMap := *artifact.CatalogModelArtifact.CustomProperties + assert.Contains(t, customPropsMap, "custom_prop_1") + assert.Contains(t, customPropsMap, "custom_prop_2") + + // Verify the values are properly converted to MetadataValue + prop1 := customPropsMap["custom_prop_1"] + assert.NotNil(t, prop1.MetadataStringValue) + assert.Equal(t, "value_1", prop1.MetadataStringValue.StringValue) + + break + } + } + assert.True(t, found, "Should find artifact with custom properties") + }) + + t.Run("TestMappingFunctions", func(t *testing.T) { + t.Run("TestMapCatalogModelToCatalogModel", func(t *testing.T) { + // Create a catalog model with various properties + catalogModel := &models.CatalogModelImpl{ + ID: apiutils.Of(int32(123)), + TypeID: apiutils.Of(int32(catalogModelTypeID)), + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of("mapping-test-model"), + ExternalID: apiutils.Of("mapping-test-ext"), + CreateTimeSinceEpoch: apiutils.Of(int64(1234567890)), + LastUpdateTimeSinceEpoch: apiutils.Of(int64(1234567891)), + }, + Properties: &[]mr_models.Properties{ + {Name: "source_id", StringValue: apiutils.Of("test-source")}, + {Name: "description", StringValue: apiutils.Of("Test description")}, + {Name: "library_name", StringValue: apiutils.Of("pytorch")}, + {Name: "language", StringValue: apiutils.Of("[\"python\", \"go\"]")}, + {Name: "tasks", StringValue: apiutils.Of("[\"classification\", \"regression\"]")}, + }, + } + + result := mapDBModelToAPIModel(catalogModel) + + assert.Equal(t, "123", *result.Id) + assert.Equal(t, "mapping-test-model", result.Name) + assert.Equal(t, "mapping-test-ext", *result.ExternalId) + assert.Equal(t, "test-source", *result.SourceId) + assert.Equal(t, "Test description", *result.Description) + assert.Equal(t, "pytorch", *result.LibraryName) + assert.Equal(t, "1234567890", *result.CreateTimeSinceEpoch) + assert.Equal(t, "1234567891", *result.LastUpdateTimeSinceEpoch) + + // Verify JSON arrays are properly parsed + assert.Equal(t, []string{"python", "go"}, result.Language) + assert.Equal(t, []string{"classification", "regression"}, result.Tasks) + }) + + t.Run("TestMapCatalogArtifactToCatalogArtifact", func(t *testing.T) { + // Test model artifact mapping + var catalogModelArtifact models.CatalogModelArtifact = &models.CatalogModelArtifactImpl{ + ID: apiutils.Of(int32(456)), + TypeID: apiutils.Of(int32(modelArtifactTypeID)), + Attributes: &models.CatalogModelArtifactAttributes{ + Name: apiutils.Of("test-model-artifact"), + ExternalID: apiutils.Of("test-model-artifact-ext"), + URI: apiutils.Of("s3://test/model.bin"), + }, + } + + catalogArtifact := models.CatalogArtifact{ + CatalogModelArtifact: &catalogModelArtifact, + } + + result, err := mapDBArtifactToAPIArtifact(catalogArtifact) + require.NoError(t, err) + + assert.NotNil(t, result.CatalogModelArtifact) + assert.Nil(t, result.CatalogMetricsArtifact) + assert.Equal(t, "456", *result.CatalogModelArtifact.Id) + assert.Equal(t, "test-model-artifact", *result.CatalogModelArtifact.Name) + assert.Equal(t, "s3://test/model.bin", result.CatalogModelArtifact.Uri) + + // Test metrics artifact mapping + var catalogMetricsArtifact models.CatalogMetricsArtifact = &models.CatalogMetricsArtifactImpl{ + ID: apiutils.Of(int32(789)), + TypeID: apiutils.Of(int32(metricsArtifactTypeID)), + Attributes: &models.CatalogMetricsArtifactAttributes{ + Name: apiutils.Of("test-metrics-artifact"), + ExternalID: apiutils.Of("test-metrics-artifact-ext"), + MetricsType: models.MetricsTypePerformance, + }, + } + + catalogArtifact2 := models.CatalogArtifact{ + CatalogMetricsArtifact: &catalogMetricsArtifact, + } + + result2, err := mapDBArtifactToAPIArtifact(catalogArtifact2) + require.NoError(t, err) + + assert.Nil(t, result2.CatalogModelArtifact) + assert.NotNil(t, result2.CatalogMetricsArtifact) + assert.Equal(t, "789", *result2.CatalogMetricsArtifact.Id) + assert.Equal(t, "test-metrics-artifact", *result2.CatalogMetricsArtifact.Name) + assert.Equal(t, "performance-metrics", result2.CatalogMetricsArtifact.MetricsType) + }) + + t.Run("TestMapCatalogArtifact_EmptyArtifact", func(t *testing.T) { + // Test with empty catalog artifact + emptyCatalogArtifact := models.CatalogArtifact{} + + _, err := mapDBArtifactToAPIArtifact(emptyCatalogArtifact) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid catalog artifact type") + }) + }) + + t.Run("TestErrorHandling", func(t *testing.T) { + t.Run("TestGetArtifacts_InvalidModelID", func(t *testing.T) { + // Create a model with invalid ID format for testing + // This would be an edge case where the ID isn't a valid integer + + // We can't easily test this directly since IDs are generated as integers + // But we can test the error case by mocking a scenario + + // For now, let's test a scenario where the model exists but has some issue + params := ListArtifactsParams{ + PageSize: 10, + OrderBy: model.ORDERBYFIELD_CREATE_TIME, + SortOrder: model.SORTORDER_ASC, + NextPageToken: apiutils.Of(""), + } + + _, err := dbCatalog.GetArtifacts(ctx, "non-existent-model", "test-source", params) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid model name") + }) + }) +} + +// Helper functions to get type IDs from database + +func getCatalogModelTypeIDForDBTest(t *testing.T, db *gorm.DB) int64 { + var typeRecord schema.Type + err := db.Where("name = ?", service.CatalogModelTypeName).First(&typeRecord).Error + if err != nil { + require.NoError(t, err, "Failed to query CatalogModel type") + } + return int64(typeRecord.ID) +} + +func getCatalogModelArtifactTypeIDForDBTest(t *testing.T, db *gorm.DB) int64 { + var typeRecord schema.Type + err := db.Where("name = ?", service.CatalogModelArtifactTypeName).First(&typeRecord).Error + if err != nil { + require.NoError(t, err, "Failed to query CatalogModelArtifact type") + } + return int64(typeRecord.ID) +} + +func getCatalogMetricsArtifactTypeIDForDBTest(t *testing.T, db *gorm.DB) int64 { + var typeRecord schema.Type + err := db.Where("name = ?", service.CatalogMetricsArtifactTypeName).First(&typeRecord).Error + if err != nil { + require.NoError(t, err, "Failed to query CatalogMetricsArtifact type") + } + return int64(typeRecord.ID) +} diff --git a/catalog/internal/catalog/hf_catalog.go b/catalog/internal/catalog/hf_catalog.go index 20219d4a..6a638d2b 100644 --- a/catalog/internal/catalog/hf_catalog.go +++ b/catalog/internal/catalog/hf_catalog.go @@ -24,7 +24,7 @@ const ( defaultHuggingFaceURL = "https://huggingface.co" ) -func (h *hfCatalogImpl) GetModel(ctx context.Context, name string) (*openapi.CatalogModel, error) { +func (h *hfCatalogImpl) GetModel(ctx context.Context, modelName string, sourceID string) (*openapi.CatalogModel, error) { // TODO: Implement HuggingFace model retrieval return nil, fmt.Errorf("HuggingFace model retrieval not yet implemented") } @@ -39,10 +39,10 @@ func (h *hfCatalogImpl) ListModels(ctx context.Context, params ListModelsParams) }, nil } -func (h *hfCatalogImpl) GetArtifacts(ctx context.Context, name string) (*openapi.CatalogArtifactList, error) { +func (h *hfCatalogImpl) GetArtifacts(ctx context.Context, modelName string, sourceID string, params ListArtifactsParams) (openapi.CatalogArtifactList, error) { // TODO: Implement HuggingFace model artifacts retrieval // For now, return empty list to satisfy interface - return &openapi.CatalogArtifactList{ + return openapi.CatalogArtifactList{ Items: []openapi.CatalogArtifact{}, PageSize: 0, Size: 0, diff --git a/catalog/internal/catalog/hf_catalog_test.go b/catalog/internal/catalog/hf_catalog_test.go index 5be51037..e3d956b9 100644 --- a/catalog/internal/catalog/hf_catalog_test.go +++ b/catalog/internal/catalog/hf_catalog_test.go @@ -76,7 +76,7 @@ func TestNewHfCatalog_WithValidCredentials(t *testing.T) { ctx := context.Background() // Test GetModel - should return not implemented error - model, err := hfCatalog.GetModel(ctx, "test-model") + model, err := hfCatalog.GetModel(ctx, "test-model", "") if err == nil { t.Fatal("Expected not implemented error, got nil") } @@ -99,11 +99,11 @@ func TestNewHfCatalog_WithValidCredentials(t *testing.T) { } // Test GetArtifacts - should return empty list - artifacts, err := hfCatalog.GetArtifacts(ctx, "test-model") + artifacts, err := hfCatalog.GetArtifacts(ctx, "test-model", "", ListArtifactsParams{}) if err != nil { t.Fatalf("Failed to get artifacts: %v", err) } - if artifacts == nil { + if artifacts.Items == nil { t.Fatal("Expected artifacts list, got nil") } if len(artifacts.Items) != 0 { diff --git a/catalog/internal/catalog/testdata/testdb.cnf b/catalog/internal/catalog/testdata/testdb.cnf new file mode 100644 index 00000000..873958cd --- /dev/null +++ b/catalog/internal/catalog/testdata/testdb.cnf @@ -0,0 +1,5 @@ +[mysqld] +character-set-server = utf8mb4 +collation-server = utf8mb4_general_ci + +!includedir /etc/mysql/conf.d/ diff --git a/catalog/internal/catalog/yaml_catalog.go b/catalog/internal/catalog/yaml_catalog.go index 5c05830b..b6f5d008 100644 --- a/catalog/internal/catalog/yaml_catalog.go +++ b/catalog/internal/catalog/yaml_catalog.go @@ -68,11 +68,11 @@ type yamlCatalogImpl struct { var _ CatalogSourceProvider = &yamlCatalogImpl{} -func (y *yamlCatalogImpl) GetModel(ctx context.Context, name string) (*model.CatalogModel, error) { +func (y *yamlCatalogImpl) GetModel(ctx context.Context, modelName string, sourceID string) (*model.CatalogModel, error) { y.modelsLock.RLock() defer y.modelsLock.RUnlock() - ym := y.models[name] + ym := y.models[modelName] if ym == nil { return nil, nil } @@ -159,13 +159,13 @@ func (y *yamlCatalogImpl) ListModels(ctx context.Context, params ListModelsParam return list, nil // Return the struct value directly } -func (y *yamlCatalogImpl) GetArtifacts(ctx context.Context, name string) (*model.CatalogArtifactList, error) { +func (y *yamlCatalogImpl) GetArtifacts(ctx context.Context, modelName string, sourceID string, params ListArtifactsParams) (model.CatalogArtifactList, error) { y.modelsLock.RLock() defer y.modelsLock.RUnlock() - ym := y.models[name] + ym := y.models[modelName] if ym == nil { - return nil, nil + return model.CatalogArtifactList{}, nil } count := len(ym.Artifacts) @@ -181,7 +181,7 @@ func (y *yamlCatalogImpl) GetArtifacts(ctx context.Context, name string) (*model for i := range list.Items { list.Items[i] = ym.Artifacts[i].CatalogArtifact } - return &list, nil + return list, nil } func isModelExcluded(modelName string, patterns []string) bool { diff --git a/catalog/internal/catalog/yaml_catalog_test.go b/catalog/internal/catalog/yaml_catalog_test.go index e77199a3..e01f0c29 100644 --- a/catalog/internal/catalog/yaml_catalog_test.go +++ b/catalog/internal/catalog/yaml_catalog_test.go @@ -12,20 +12,20 @@ func TestYAMLCatalogGetModel(t *testing.T) { assert := assert.New(t) provider := testYAMLProvider(t, "testdata/test-yaml-catalog.yaml") - model, err := provider.GetModel(context.Background(), "rhelai1/granite-8b-code-base") + model, err := provider.GetModel(context.Background(), "rhelai1/granite-8b-code-base", "") if assert.NoError(err) { assert.Equal("rhelai1/granite-8b-code-base", model.Name) newLogo := "foobar" model.Logo = &newLogo - model2, err := provider.GetModel(context.Background(), "rhelai1/granite-8b-code-base") + model2, err := provider.GetModel(context.Background(), "rhelai1/granite-8b-code-base", "") if assert.NoError(err) { assert.NotEqual(model2.Logo, model.Logo, "changes to one returned object should not affect other return values") } } - notFound, err := provider.GetModel(context.Background(), "foo") + notFound, err := provider.GetModel(context.Background(), "foo", "") assert.NoError(err) assert.Nil(notFound) } @@ -35,7 +35,7 @@ func TestYAMLCatalogGetArtifacts(t *testing.T) { provider := testYAMLProvider(t, "testdata/test-yaml-catalog.yaml") // Test case 1: Model with artifacts - artifacts, err := provider.GetArtifacts(context.Background(), "rhelai1/granite-8b-code-base") + artifacts, err := provider.GetArtifacts(context.Background(), "rhelai1/granite-8b-code-base", "", ListArtifactsParams{}) if assert.NoError(err) { assert.NotNil(artifacts) assert.Equal(int32(2), artifacts.Size) @@ -52,7 +52,7 @@ func TestYAMLCatalogGetArtifacts(t *testing.T) { } // Test case 2: Model with no artifacts - noArtifactsModel, err := provider.GetArtifacts(context.Background(), "model-with-no-artifacts") + noArtifactsModel, err := provider.GetArtifacts(context.Background(), "model-with-no-artifacts", "", ListArtifactsParams{}) if assert.NoError(err) { assert.NotNil(noArtifactsModel) assert.Equal(int32(0), noArtifactsModel.Size) @@ -61,9 +61,9 @@ func TestYAMLCatalogGetArtifacts(t *testing.T) { } // Test case 3: Model not found - notFoundArtifacts, err := provider.GetArtifacts(context.Background(), "non-existent-model") + notFoundArtifacts, err := provider.GetArtifacts(context.Background(), "non-existent-model", "", ListArtifactsParams{}) assert.NoError(err) - assert.Nil(notFoundArtifacts) + assert.Equal(int32(0), notFoundArtifacts.Size) } func TestYAMLCatalogListModels(t *testing.T) { diff --git a/catalog/internal/db/models/catalog_artifact.go b/catalog/internal/db/models/catalog_artifact.go new file mode 100644 index 00000000..34f5e3cc --- /dev/null +++ b/catalog/internal/db/models/catalog_artifact.go @@ -0,0 +1,41 @@ +package models + +import ( + "github.com/kubeflow/model-registry/internal/db/filter" + "github.com/kubeflow/model-registry/internal/db/models" +) + +type CatalogArtifactListOptions struct { + models.Pagination + Name *string + ExternalID *string + ParentResourceID *int32 + ArtifactType *string +} + +// GetRestEntityType implements the FilterApplier interface +// This enables advanced filtering support for catalog artifacts +func (c *CatalogArtifactListOptions) GetRestEntityType() filter.RestEntityType { + // Determine the appropriate REST entity type based on artifact type + if c.ArtifactType != nil { + switch *c.ArtifactType { + case "model-artifact": + return filter.RestEntityModelArtifact + case "metrics-artifact": + return filter.RestEntityModelArtifact // Reusing existing filter type + } + } + // Default to ModelArtifact if no specific type is provided + return filter.RestEntityModelArtifact +} + +// CatalogArtifact is a discriminated union that can hold different catalog artifact types +type CatalogArtifact struct { + CatalogModelArtifact *CatalogModelArtifact + CatalogMetricsArtifact *CatalogMetricsArtifact +} + +type CatalogArtifactRepository interface { + GetByID(id int32) (CatalogArtifact, error) + List(listOptions CatalogArtifactListOptions) (*models.ListWrapper[CatalogArtifact], error) +} diff --git a/catalog/internal/db/models/catalog_metrics_artifact.go b/catalog/internal/db/models/catalog_metrics_artifact.go index a9443e54..4ba0ba33 100644 --- a/catalog/internal/db/models/catalog_metrics_artifact.go +++ b/catalog/internal/db/models/catalog_metrics_artifact.go @@ -8,8 +8,9 @@ import ( type MetricsType string const ( - MetricsTypePerformance MetricsType = "performance-metrics" - MetricsTypeAccuracy MetricsType = "accuracy-metrics" + MetricsTypePerformance MetricsType = "performance-metrics" + MetricsTypeAccuracy MetricsType = "accuracy-metrics" + CatalogMetricsArtifactType = "metrics-artifact" ) type CatalogMetricsArtifactListOptions struct { @@ -26,6 +27,7 @@ func (c *CatalogMetricsArtifactListOptions) GetRestEntityType() filter.RestEntit type CatalogMetricsArtifactAttributes struct { Name *string + ArtifactType *string MetricsType MetricsType ExternalID *string CreateTimeSinceEpoch *int64 diff --git a/catalog/internal/db/models/catalog_model.go b/catalog/internal/db/models/catalog_model.go index 855b5dc4..f7268f97 100644 --- a/catalog/internal/db/models/catalog_model.go +++ b/catalog/internal/db/models/catalog_model.go @@ -9,6 +9,8 @@ type CatalogModelListOptions struct { models.Pagination Name *string ExternalID *string + SourceIDs *[]string + Query *string } // GetRestEntityType implements the FilterApplier interface diff --git a/catalog/internal/db/models/catalog_model_artifact.go b/catalog/internal/db/models/catalog_model_artifact.go index 38c3e9b4..c39b3b24 100644 --- a/catalog/internal/db/models/catalog_model_artifact.go +++ b/catalog/internal/db/models/catalog_model_artifact.go @@ -5,7 +5,7 @@ import ( "github.com/kubeflow/model-registry/internal/db/models" ) -const CatalogModelArtifactType = "catalog-model-artifact" +const CatalogModelArtifactType = "model-artifact" type CatalogModelArtifactListOptions struct { models.Pagination @@ -22,6 +22,7 @@ func (c *CatalogModelArtifactListOptions) GetRestEntityType() filter.RestEntityT type CatalogModelArtifactAttributes struct { Name *string URI *string + ArtifactType *string ExternalID *string CreateTimeSinceEpoch *int64 LastUpdateTimeSinceEpoch *int64 @@ -37,4 +38,4 @@ type CatalogModelArtifactRepository interface { GetByID(id int32) (CatalogModelArtifact, error) List(listOptions CatalogModelArtifactListOptions) (*models.ListWrapper[CatalogModelArtifact], error) Save(modelArtifact CatalogModelArtifact, parentResourceID *int32) (CatalogModelArtifact, error) -} \ No newline at end of file +} diff --git a/catalog/internal/db/service/catalog_artifact.go b/catalog/internal/db/service/catalog_artifact.go new file mode 100644 index 00000000..a5f6980b --- /dev/null +++ b/catalog/internal/db/service/catalog_artifact.go @@ -0,0 +1,217 @@ +package service + +import ( + "errors" + "fmt" + + "github.com/kubeflow/model-registry/catalog/internal/db/models" + "github.com/kubeflow/model-registry/internal/datastore" + dbmodels "github.com/kubeflow/model-registry/internal/db/models" + "github.com/kubeflow/model-registry/internal/db/schema" + "github.com/kubeflow/model-registry/internal/db/scopes" + "github.com/kubeflow/model-registry/internal/db/utils" + "gorm.io/gorm" +) + +var ErrCatalogArtifactNotFound = errors.New("catalog artifact by id not found") + +type CatalogArtifactRepositoryImpl struct { + db *gorm.DB + idToName map[int64]string + nameToID datastore.ArtifactTypeMap +} + +func NewCatalogArtifactRepository(db *gorm.DB, artifactTypes datastore.ArtifactTypeMap) models.CatalogArtifactRepository { + idToName := make(map[int64]string, len(artifactTypes)) + for name, id := range artifactTypes { + idToName[id] = name + } + + return &CatalogArtifactRepositoryImpl{ + db: db, + nameToID: artifactTypes, + idToName: idToName, + } +} + +func (r *CatalogArtifactRepositoryImpl) GetByID(id int32) (models.CatalogArtifact, error) { + artifact := &schema.Artifact{} + properties := []schema.ArtifactProperty{} + + if err := r.db.Where("id = ?", id).First(artifact).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return models.CatalogArtifact{}, fmt.Errorf("%w: %v", ErrCatalogArtifactNotFound, err) + } + return models.CatalogArtifact{}, fmt.Errorf("error getting catalog artifact by id: %w", err) + } + + if err := r.db.Where("artifact_id = ?", artifact.ID).Find(&properties).Error; err != nil { + return models.CatalogArtifact{}, fmt.Errorf("error getting properties by artifact id: %w", err) + } + + // Use the same logic as mapDataLayerToCatalogArtifact to handle artifact types + mappedArtifact, err := r.mapDataLayerToCatalogArtifact(*artifact, properties) + if err != nil { + return models.CatalogArtifact{}, fmt.Errorf("error mapping catalog artifact: %w", err) + } + + return mappedArtifact, nil +} + +func (r *CatalogArtifactRepositoryImpl) List(listOptions models.CatalogArtifactListOptions) (*dbmodels.ListWrapper[models.CatalogArtifact], error) { + list := dbmodels.ListWrapper[models.CatalogArtifact]{ + PageSize: listOptions.GetPageSize(), + } + + artifacts := []models.CatalogArtifact{} + artifactsArt := []schema.Artifact{} + + query := r.db.Model(&schema.Artifact{}) + + // Apply filters similar to the internal artifact service + if listOptions.Name != nil { + // Name is not prefixed with the parent resource id to allow for filtering by name only + // Parent resource Id is used later to filter by Attribution.context_id + query = query.Where("name LIKE ?", fmt.Sprintf("%%:%s", *listOptions.Name)) + } else if listOptions.ExternalID != nil { + query = query.Where("external_id = ?", listOptions.ExternalID) + } + + // Filter by artifact type if specified + if listOptions.ArtifactType != nil { + typeID, err := r.getTypeIDFromArtifactType(*listOptions.ArtifactType) + if err != nil { + return nil, fmt.Errorf("invalid catalog artifact type %s: %w", *listOptions.ArtifactType, err) + } + query = query.Where("type_id = ?", typeID) + } else { + // Only include catalog artifact types + catalogTypeIDs := []int64{} + for _, typeID := range r.nameToID { + catalogTypeIDs = append(catalogTypeIDs, typeID) + } + query = query.Where("type_id IN ?", catalogTypeIDs) + } + + // Apply parent resource filtering if specified + if listOptions.ParentResourceID != nil { + // Proper GORM JOIN: Use helper that respects naming strategy + query = query.Joins(utils.BuildAttributionJoin(query)). + Where(utils.GetColumnRef(query, &schema.Attribution{}, "context_id")+" = ?", listOptions.ParentResourceID). + Select(utils.GetTableName(query, &schema.Artifact{}) + ".*") // Explicitly select from Artifact table to avoid ambiguity + } + + orderBy := listOptions.GetOrderBy() + sortOrder := listOptions.GetSortOrder() + nextPageToken := listOptions.GetNextPageToken() + pageSize := listOptions.GetPageSize() + + pagination := &dbmodels.Pagination{ + PageSize: &pageSize, + OrderBy: &orderBy, + SortOrder: &sortOrder, + NextPageToken: &nextPageToken, + } + + query = query.Scopes(scopes.PaginateWithTablePrefix(artifactsArt, pagination, r.db, "Artifact")) + + if err := query.Find(&artifactsArt).Error; err != nil { + return nil, fmt.Errorf("error listing catalog artifacts: %w", err) + } + + hasMore := false + if pageSize > 0 { + hasMore = len(artifactsArt) > int(pageSize) + if hasMore { + artifactsArt = artifactsArt[:len(artifactsArt)-1] // Remove the extra item used for hasMore detection + } + } + + // Map each artifact with its properties + for _, artifactArt := range artifactsArt { + properties := []schema.ArtifactProperty{} + if err := r.db.Where("artifact_id = ?", artifactArt.ID).Find(&properties).Error; err != nil { + return nil, fmt.Errorf("error getting properties by artifact id: %w", err) + } + + artifact, err := r.mapDataLayerToCatalogArtifact(artifactArt, properties) + if err != nil { + return nil, fmt.Errorf("error mapping catalog artifact: %w", err) + } + artifacts = append(artifacts, artifact) + } + + // Handle pagination token - generate token when there are more pages + if hasMore && len(artifactsArt) > 0 { + // Use the last artifact to generate pagination token + lastArtifact := artifactsArt[len(artifactsArt)-1] + nextToken := r.createPaginationToken(lastArtifact, listOptions) + listOptions.NextPageToken = &nextToken + } else { + listOptions.NextPageToken = nil + } + + list.Items = artifacts + list.NextPageToken = listOptions.GetNextPageToken() + list.Size = int32(len(artifacts)) + + return &list, nil +} + +// getTypeIDFromArtifactType maps catalog artifact type strings to their corresponding type IDs +func (r *CatalogArtifactRepositoryImpl) getTypeIDFromArtifactType(artifactType string) (int64, error) { + switch artifactType { + case "model-artifact": + return r.nameToID[CatalogModelArtifactTypeName], nil + case "metrics-artifact": + return r.nameToID[CatalogMetricsArtifactTypeName], nil + default: + return 0, fmt.Errorf("unsupported catalog artifact type: %s", artifactType) + } +} + +func (r *CatalogArtifactRepositoryImpl) mapDataLayerToCatalogArtifact(artifact schema.Artifact, properties []schema.ArtifactProperty) (models.CatalogArtifact, error) { + artToReturn := models.CatalogArtifact{} + + typeName := r.idToName[int64(artifact.TypeID)] + + switch typeName { + case CatalogModelArtifactTypeName: + modelArtifact := mapDataLayerToCatalogModelArtifact(artifact, properties) + artToReturn.CatalogModelArtifact = &modelArtifact + case CatalogMetricsArtifactTypeName: + metricsArtifact := mapDataLayerToCatalogMetricsArtifact(artifact, properties) + artToReturn.CatalogMetricsArtifact = &metricsArtifact + default: + return models.CatalogArtifact{}, fmt.Errorf("invalid catalog artifact type: %s=%d (expected: %v)", typeName, artifact.TypeID, r.idToName) + } + + return artToReturn, nil +} + +// createPaginationToken generates a pagination token based on the last artifact and ordering +func (r *CatalogArtifactRepositoryImpl) createPaginationToken(artifact schema.Artifact, listOptions models.CatalogArtifactListOptions) string { + orderBy := listOptions.GetOrderBy() + value := "" + + // Generate token value based on ordering field + switch orderBy { + case "ID": + value = fmt.Sprintf("%d", artifact.ID) + case "CREATE_TIME": + value = fmt.Sprintf("%d", artifact.CreateTimeSinceEpoch) + case "LAST_UPDATE_TIME": + value = fmt.Sprintf("%d", artifact.LastUpdateTimeSinceEpoch) + case "NAME": + if artifact.Name != nil { + value = *artifact.Name + } else { + value = fmt.Sprintf("%d", artifact.ID) // Fallback to ID if name is nil + } + default: + // Default to ID ordering + value = fmt.Sprintf("%d", artifact.ID) + } + + return scopes.CreateNextPageToken(artifact.ID, value) +} diff --git a/catalog/internal/db/service/catalog_artifact_test.go b/catalog/internal/db/service/catalog_artifact_test.go new file mode 100644 index 00000000..a02c73ec --- /dev/null +++ b/catalog/internal/db/service/catalog_artifact_test.go @@ -0,0 +1,378 @@ +package service_test + +import ( + "fmt" + "testing" + + "github.com/kubeflow/model-registry/catalog/internal/db/models" + "github.com/kubeflow/model-registry/catalog/internal/db/service" + "github.com/kubeflow/model-registry/internal/apiutils" + dbmodels "github.com/kubeflow/model-registry/internal/db/models" + "github.com/kubeflow/model-registry/internal/testutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCatalogArtifactRepository(t *testing.T) { + sharedDB, cleanup := testutils.SetupMySQLWithMigrations(t, service.DatastoreSpec()) + defer cleanup() + + // Get the catalog artifact type IDs + modelArtifactTypeID := getCatalogModelArtifactTypeID(t, sharedDB) + metricsArtifactTypeID := getCatalogMetricsArtifactTypeID(t, sharedDB) + + // Create unified artifact repository with both types + artifactTypeMap := map[string]int64{ + service.CatalogModelArtifactTypeName: modelArtifactTypeID, + service.CatalogMetricsArtifactTypeName: metricsArtifactTypeID, + } + repo := service.NewCatalogArtifactRepository(sharedDB, artifactTypeMap) + + // Also get CatalogModel type ID for creating parent entities + catalogModelTypeID := getCatalogModelTypeID(t, sharedDB) + catalogModelRepo := service.NewCatalogModelRepository(sharedDB, catalogModelTypeID) + modelArtifactRepo := service.NewCatalogModelArtifactRepository(sharedDB, modelArtifactTypeID) + metricsArtifactRepo := service.NewCatalogMetricsArtifactRepository(sharedDB, metricsArtifactTypeID) + + // Create shared test data + catalogModel := &models.CatalogModelImpl{ + TypeID: apiutils.Of(int32(catalogModelTypeID)), + Attributes: &models.CatalogModelAttributes{ + Name: apiutils.Of("test-catalog-model-for-artifacts"), + ExternalID: apiutils.Of("catalog-model-artifacts-ext-123"), + }, + } + savedCatalogModel, err := catalogModelRepo.Save(catalogModel) + require.NoError(t, err) + + t.Run("GetByID_ModelArtifact", func(t *testing.T) { + // Create a model artifact using the specific repository + modelArtifact := &models.CatalogModelArtifactImpl{ + TypeID: apiutils.Of(int32(modelArtifactTypeID)), + Attributes: &models.CatalogModelArtifactAttributes{ + Name: apiutils.Of("test-model-artifact-getbyid"), + ExternalID: apiutils.Of("model-art-getbyid-ext-123"), + URI: apiutils.Of("s3://test-bucket/model.bin"), + ArtifactType: apiutils.Of(models.CatalogModelArtifactType), + }, + } + savedModelArtifact, err := modelArtifactRepo.Save(modelArtifact, savedCatalogModel.GetID()) + require.NoError(t, err) + + // Retrieve using unified repository + retrieved, err := repo.GetByID(*savedModelArtifact.GetID()) + require.NoError(t, err) + + // Verify it's a model artifact + assert.NotNil(t, retrieved.CatalogModelArtifact) + assert.Nil(t, retrieved.CatalogMetricsArtifact) + assert.Equal(t, "test-model-artifact-getbyid", *(*retrieved.CatalogModelArtifact).GetAttributes().Name) + assert.Equal(t, "model-art-getbyid-ext-123", *(*retrieved.CatalogModelArtifact).GetAttributes().ExternalID) + assert.Equal(t, "s3://test-bucket/model.bin", *(*retrieved.CatalogModelArtifact).GetAttributes().URI) + }) + + t.Run("GetByID_MetricsArtifact", func(t *testing.T) { + // Create a metrics artifact using the specific repository + metricsArtifact := &models.CatalogMetricsArtifactImpl{ + TypeID: apiutils.Of(int32(metricsArtifactTypeID)), + Attributes: &models.CatalogMetricsArtifactAttributes{ + Name: apiutils.Of("test-metrics-artifact-getbyid"), + ExternalID: apiutils.Of("metrics-art-getbyid-ext-123"), + MetricsType: models.MetricsTypeAccuracy, + ArtifactType: apiutils.Of("metrics-artifact"), + }, + } + savedMetricsArtifact, err := metricsArtifactRepo.Save(metricsArtifact, savedCatalogModel.GetID()) + require.NoError(t, err) + + // Retrieve using unified repository + retrieved, err := repo.GetByID(*savedMetricsArtifact.GetID()) + require.NoError(t, err) + + // Verify it's a metrics artifact + assert.Nil(t, retrieved.CatalogModelArtifact) + assert.NotNil(t, retrieved.CatalogMetricsArtifact) + assert.Equal(t, "test-metrics-artifact-getbyid", *(*retrieved.CatalogMetricsArtifact).GetAttributes().Name) + assert.Equal(t, "metrics-art-getbyid-ext-123", *(*retrieved.CatalogMetricsArtifact).GetAttributes().ExternalID) + assert.Equal(t, models.MetricsTypeAccuracy, (*retrieved.CatalogMetricsArtifact).GetAttributes().MetricsType) + }) + + t.Run("GetByID_NotFound", func(t *testing.T) { + nonExistentID := int32(99999) + _, err := repo.GetByID(nonExistentID) + require.Error(t, err) + assert.Contains(t, err.Error(), "catalog artifact by id not found") + }) + + t.Run("List_AllArtifacts", func(t *testing.T) { + // Create test artifacts of both types + modelArtifact1 := &models.CatalogModelArtifactImpl{ + TypeID: apiutils.Of(int32(modelArtifactTypeID)), + Attributes: &models.CatalogModelArtifactAttributes{ + Name: apiutils.Of("test-model-artifact-list-1"), + ExternalID: apiutils.Of("model-list-1-ext"), + URI: apiutils.Of("s3://test/model1.bin"), + ArtifactType: apiutils.Of(models.CatalogModelArtifactType), + }, + } + + modelArtifact2 := &models.CatalogModelArtifactImpl{ + TypeID: apiutils.Of(int32(modelArtifactTypeID)), + Attributes: &models.CatalogModelArtifactAttributes{ + Name: apiutils.Of("test-model-artifact-list-2"), + ExternalID: apiutils.Of("model-list-2-ext"), + URI: apiutils.Of("s3://test/model2.bin"), + ArtifactType: apiutils.Of(models.CatalogModelArtifactType), + }, + } + + metricsArtifact1 := &models.CatalogMetricsArtifactImpl{ + TypeID: apiutils.Of(int32(metricsArtifactTypeID)), + Attributes: &models.CatalogMetricsArtifactAttributes{ + Name: apiutils.Of("test-metrics-artifact-list-1"), + ExternalID: apiutils.Of("metrics-list-1-ext"), + MetricsType: models.MetricsTypeAccuracy, + ArtifactType: apiutils.Of("metrics-artifact"), + }, + } + + // Save artifacts + savedModelArt1, err := modelArtifactRepo.Save(modelArtifact1, savedCatalogModel.GetID()) + require.NoError(t, err) + savedModelArt2, err := modelArtifactRepo.Save(modelArtifact2, savedCatalogModel.GetID()) + require.NoError(t, err) + savedMetricsArt1, err := metricsArtifactRepo.Save(metricsArtifact1, savedCatalogModel.GetID()) + require.NoError(t, err) + + // List all artifacts for the parent resource + listOptions := models.CatalogArtifactListOptions{ + ParentResourceID: savedCatalogModel.GetID(), + } + + result, err := repo.List(listOptions) + require.NoError(t, err) + require.NotNil(t, result) + + // Should return all 3 artifacts (2 model + 1 metrics) + assert.GreaterOrEqual(t, len(result.Items), 3, "Should return at least the 3 artifacts we created") + + // Verify we got both types + var modelArtifactCount, metricsArtifactCount int + artifactIDs := make(map[int32]bool) + + for _, artifact := range result.Items { + if artifact.CatalogModelArtifact != nil { + modelArtifactCount++ + artifactIDs[*(*artifact.CatalogModelArtifact).GetID()] = true + } else if artifact.CatalogMetricsArtifact != nil { + metricsArtifactCount++ + artifactIDs[*(*artifact.CatalogMetricsArtifact).GetID()] = true + } + } + + assert.GreaterOrEqual(t, modelArtifactCount, 2, "Should have at least 2 model artifacts") + assert.GreaterOrEqual(t, metricsArtifactCount, 1, "Should have at least 1 metrics artifact") + + // Verify our specific artifacts are in the results + assert.True(t, artifactIDs[*savedModelArt1.GetID()], "Should contain first model artifact") + assert.True(t, artifactIDs[*savedModelArt2.GetID()], "Should contain second model artifact") + assert.True(t, artifactIDs[*savedMetricsArt1.GetID()], "Should contain metrics artifact") + }) + + t.Run("List_FilterByArtifactType_ModelArtifact", func(t *testing.T) { + // Filter by model artifact type only + artifactType := "model-artifact" + listOptions := models.CatalogArtifactListOptions{ + ParentResourceID: savedCatalogModel.GetID(), + ArtifactType: &artifactType, + } + + result, err := repo.List(listOptions) + require.NoError(t, err) + require.NotNil(t, result) + + // All results should be model artifacts + for _, artifact := range result.Items { + assert.NotNil(t, artifact.CatalogModelArtifact, "Should only return model artifacts") + assert.Nil(t, artifact.CatalogMetricsArtifact, "Should not return metrics artifacts") + } + }) + + t.Run("List_FilterByArtifactType_MetricsArtifact", func(t *testing.T) { + // Filter by metrics artifact type only + artifactType := "metrics-artifact" + listOptions := models.CatalogArtifactListOptions{ + ParentResourceID: savedCatalogModel.GetID(), + ArtifactType: &artifactType, + } + + result, err := repo.List(listOptions) + require.NoError(t, err) + require.NotNil(t, result) + + // All results should be metrics artifacts + for _, artifact := range result.Items { + assert.Nil(t, artifact.CatalogModelArtifact, "Should not return model artifacts") + assert.NotNil(t, artifact.CatalogMetricsArtifact, "Should only return metrics artifacts") + } + }) + + t.Run("List_FilterByExternalID", func(t *testing.T) { + // Create artifact with specific external ID for filtering + testArtifact := &models.CatalogMetricsArtifactImpl{ + TypeID: apiutils.Of(int32(metricsArtifactTypeID)), + Attributes: &models.CatalogMetricsArtifactAttributes{ + Name: apiutils.Of("external-id-filter-test"), + ExternalID: apiutils.Of("unique-external-id-123"), + MetricsType: models.MetricsTypePerformance, + ArtifactType: apiutils.Of("metrics-artifact"), + }, + } + savedArtifact, err := metricsArtifactRepo.Save(testArtifact, savedCatalogModel.GetID()) + require.NoError(t, err) + + // Filter by external ID + externalID := "unique-external-id-123" + listOptions := models.CatalogArtifactListOptions{ + ExternalID: &externalID, + } + + result, err := repo.List(listOptions) + require.NoError(t, err) + require.NotNil(t, result) + assert.Len(t, result.Items, 1, "Should find exactly one artifact with the external ID") + + // Verify it's the correct artifact + artifact := result.Items[0] + assert.NotNil(t, artifact.CatalogMetricsArtifact) + assert.Equal(t, *savedArtifact.GetID(), *(*artifact.CatalogMetricsArtifact).GetID()) + assert.Equal(t, "unique-external-id-123", *(*artifact.CatalogMetricsArtifact).GetAttributes().ExternalID) + }) + + t.Run("List_WithPagination", func(t *testing.T) { + // Create multiple artifacts for pagination testing + for i := 0; i < 5; i++ { + artifact := &models.CatalogModelArtifactImpl{ + TypeID: apiutils.Of(int32(modelArtifactTypeID)), + Attributes: &models.CatalogModelArtifactAttributes{ + Name: apiutils.Of(fmt.Sprintf("pagination-test-%d", i)), + ExternalID: apiutils.Of(fmt.Sprintf("pagination-ext-%d", i)), + URI: apiutils.Of(fmt.Sprintf("s3://test/pagination-%d.bin", i)), + ArtifactType: apiutils.Of(models.CatalogModelArtifactType), + }, + } + _, err := modelArtifactRepo.Save(artifact, savedCatalogModel.GetID()) + require.NoError(t, err) + } + + // Test pagination + pageSize := int32(3) + listOptions := models.CatalogArtifactListOptions{ + ParentResourceID: savedCatalogModel.GetID(), + Pagination: dbmodels.Pagination{ + PageSize: &pageSize, + OrderBy: apiutils.Of("ID"), + }, + } + + result, err := repo.List(listOptions) + require.NoError(t, err) + require.NotNil(t, result) + assert.LessOrEqual(t, len(result.Items), 3, "Should respect page size limit") + assert.GreaterOrEqual(t, len(result.Items), 1, "Should return at least one item") + }) + + t.Run("List_InvalidArtifactType", func(t *testing.T) { + // Test with invalid artifact type + invalidType := "invalid-artifact-type" + listOptions := models.CatalogArtifactListOptions{ + ParentResourceID: savedCatalogModel.GetID(), + ArtifactType: &invalidType, + } + + _, err := repo.List(listOptions) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid catalog artifact type") + assert.Contains(t, err.Error(), "invalid-artifact-type") + }) + + t.Run("List_WithCustomProperties", func(t *testing.T) { + // Create artifacts with custom properties + customProps := []dbmodels.Properties{ + { + Name: "custom_prop_1", + StringValue: apiutils.Of("custom_value_1"), + }, + { + Name: "custom_prop_2", + StringValue: apiutils.Of("custom_value_2"), + }, + } + + artifactWithCustomProps := &models.CatalogModelArtifactImpl{ + TypeID: apiutils.Of(int32(modelArtifactTypeID)), + Attributes: &models.CatalogModelArtifactAttributes{ + Name: apiutils.Of("artifact-with-custom-props"), + ExternalID: apiutils.Of("custom-props-ext"), + URI: apiutils.Of("s3://test/custom-props.bin"), + ArtifactType: apiutils.Of(models.CatalogModelArtifactType), + }, + CustomProperties: &customProps, + } + + savedArtifact, err := modelArtifactRepo.Save(artifactWithCustomProps, savedCatalogModel.GetID()) + require.NoError(t, err) + + // Retrieve using unified repository + retrieved, err := repo.GetByID(*savedArtifact.GetID()) + require.NoError(t, err) + + // Verify custom properties are preserved + assert.NotNil(t, retrieved.CatalogModelArtifact) + assert.NotNil(t, (*retrieved.CatalogModelArtifact).GetCustomProperties()) + + customPropsMap := make(map[string]string) + for _, prop := range *(*retrieved.CatalogModelArtifact).GetCustomProperties() { + if prop.StringValue != nil { + customPropsMap[prop.Name] = *prop.StringValue + } + } + + assert.Equal(t, "custom_value_1", customPropsMap["custom_prop_1"]) + assert.Equal(t, "custom_value_2", customPropsMap["custom_prop_2"]) + }) + + t.Run("MappingErrors", func(t *testing.T) { + // Test error handling for invalid type mapping + // This would typically happen if there's data inconsistency in the database + + // We can't easily test this without directly manipulating the database + // but we can test the GetByID with an artifact that has an unknown type + // by temporarily modifying the repository's type mapping + + // Create a repository with incomplete type mapping + incompleteTypeMap := map[string]int64{ + service.CatalogModelArtifactTypeName: modelArtifactTypeID, + // Missing CatalogMetricsArtifactTypeName intentionally + } + incompleteRepo := service.NewCatalogArtifactRepository(sharedDB, incompleteTypeMap) + + // Create a metrics artifact first using the complete repo + metricsArtifact := &models.CatalogMetricsArtifactImpl{ + TypeID: apiutils.Of(int32(metricsArtifactTypeID)), + Attributes: &models.CatalogMetricsArtifactAttributes{ + Name: apiutils.Of("test-mapping-error"), + ExternalID: apiutils.Of("mapping-error-ext"), + MetricsType: models.MetricsTypeAccuracy, + ArtifactType: apiutils.Of("metrics-artifact"), + }, + } + savedMetricsArtifact, err := metricsArtifactRepo.Save(metricsArtifact, savedCatalogModel.GetID()) + require.NoError(t, err) + + // Try to retrieve using incomplete repo - should get mapping error + _, err = incompleteRepo.GetByID(*savedMetricsArtifact.GetID()) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid catalog artifact type") + }) +} diff --git a/catalog/internal/db/service/catalog_metrics_artifact.go b/catalog/internal/db/service/catalog_metrics_artifact.go index 6eb4f4e6..369ba110 100644 --- a/catalog/internal/db/service/catalog_metrics_artifact.go +++ b/catalog/internal/db/service/catalog_metrics_artifact.go @@ -131,6 +131,7 @@ func mapDataLayerToCatalogMetricsArtifact(artifact schema.Artifact, artPropertie TypeID: &artifact.TypeID, Attributes: &models.CatalogMetricsArtifactAttributes{ Name: artifact.Name, + ArtifactType: apiutils.Of(models.CatalogMetricsArtifactType), ExternalID: artifact.ExternalID, CreateTimeSinceEpoch: &artifact.CreateTimeSinceEpoch, LastUpdateTimeSinceEpoch: &artifact.LastUpdateTimeSinceEpoch, diff --git a/catalog/internal/db/service/catalog_model.go b/catalog/internal/db/service/catalog_model.go index 9c261a81..2b5b7ac3 100644 --- a/catalog/internal/db/service/catalog_model.go +++ b/catalog/internal/db/service/catalog_model.go @@ -2,11 +2,14 @@ package service import ( "errors" + "fmt" + "strings" "github.com/kubeflow/model-registry/catalog/internal/db/models" dbmodels "github.com/kubeflow/model-registry/internal/db/models" "github.com/kubeflow/model-registry/internal/db/schema" "github.com/kubeflow/model-registry/internal/db/service" + "github.com/kubeflow/model-registry/internal/db/utils" "gorm.io/gorm" ) @@ -45,11 +48,54 @@ func (r *CatalogModelRepositoryImpl) List(listOptions models.CatalogModelListOpt } func applyCatalogModelListFilters(query *gorm.DB, listOptions *models.CatalogModelListOptions) *gorm.DB { + contextTable := utils.GetTableName(query.Statement.DB, &schema.Context{}) + if listOptions.Name != nil { - query = query.Where("name LIKE ?", listOptions.Name) + query = query.Where(fmt.Sprintf("%s.name LIKE ?", contextTable), listOptions.Name) } else if listOptions.ExternalID != nil { - query = query.Where("external_id = ?", listOptions.ExternalID) + query = query.Where(fmt.Sprintf("%s.external_id = ?", contextTable), listOptions.ExternalID) } + + if listOptions.Query != nil && *listOptions.Query != "" { + queryPattern := fmt.Sprintf("%%%s%%", strings.ToLower(*listOptions.Query)) + propertyTable := utils.GetTableName(query.Statement.DB, &schema.ContextProperty{}) + + // Search in name (context table) + nameCondition := fmt.Sprintf("LOWER(%s.name) LIKE ?", contextTable) + + // Search in description, provider, libraryName properties + propertyCondition := fmt.Sprintf("EXISTS (SELECT 1 FROM %s cp WHERE cp.context_id = %s.id AND cp.name IN (?, ?, ?) AND LOWER(cp.string_value) LIKE ?)", + propertyTable, contextTable) + + // Search in tasks (assuming tasks are stored as comma-separated or multiple properties) + tasksCondition := fmt.Sprintf("EXISTS (SELECT 1 FROM %s cp WHERE cp.context_id = %s.id AND cp.name = ? AND LOWER(cp.string_value) LIKE ?)", + propertyTable, contextTable) + + query = query.Where(fmt.Sprintf("(%s OR %s OR %s)", nameCondition, propertyCondition, tasksCondition), + queryPattern, // for name + "description", "provider", "libraryName", queryPattern, // for properties + "tasks", queryPattern, // for tasks + ) + } + + // Filter out empty strings from SourceIDs, for some reason it's passed if no sources are specified + var nonEmptySourceIDs []string + if listOptions.SourceIDs != nil { + for _, sourceID := range *listOptions.SourceIDs { + if sourceID != "" { + nonEmptySourceIDs = append(nonEmptySourceIDs, sourceID) + } + } + } + + if len(nonEmptySourceIDs) > 0 { + propertyTable := utils.GetTableName(query.Statement.DB, &schema.ContextProperty{}) + + joinClause := fmt.Sprintf("JOIN %s cp ON cp.context_id = %s.id", propertyTable, contextTable) + query = query.Joins(joinClause). + Where("cp.name = ? AND cp.string_value IN ?", "source_id", nonEmptySourceIDs) + } + return query } @@ -126,4 +172,4 @@ func mapDataLayerToCatalogModel(modelCtx schema.Context, propertiesCtx []schema. catalogModel.CustomProperties = &customProperties return catalogModel -} \ No newline at end of file +} diff --git a/catalog/internal/db/service/catalog_model_artifact.go b/catalog/internal/db/service/catalog_model_artifact.go index cc9ed38f..1dec9431 100644 --- a/catalog/internal/db/service/catalog_model_artifact.go +++ b/catalog/internal/db/service/catalog_model_artifact.go @@ -108,6 +108,7 @@ func mapDataLayerToCatalogModelArtifact(artifact schema.Artifact, artProperties Attributes: &models.CatalogModelArtifactAttributes{ Name: artifact.Name, URI: artifact.URI, + ArtifactType: apiutils.Of(models.CatalogModelArtifactType), ExternalID: artifact.ExternalID, CreateTimeSinceEpoch: &artifact.CreateTimeSinceEpoch, LastUpdateTimeSinceEpoch: &artifact.LastUpdateTimeSinceEpoch, @@ -129,4 +130,4 @@ func mapDataLayerToCatalogModelArtifact(artifact schema.Artifact, artProperties catalogModelArtifact.Properties = &properties return &catalogModelArtifact -} \ No newline at end of file +} diff --git a/catalog/internal/db/service/spec.go b/catalog/internal/db/service/spec.go index 7207459e..50195786 100644 --- a/catalog/internal/db/service/spec.go +++ b/catalog/internal/db/service/spec.go @@ -32,5 +32,6 @@ func DatastoreSpec() *datastore.Spec { ). AddArtifact(CatalogMetricsArtifactTypeName, datastore.NewSpecType(NewCatalogMetricsArtifactRepository). AddString("metricsType"), - ) + ). + AddOther(NewCatalogArtifactRepository) } diff --git a/catalog/internal/server/openapi/.openapi-generator/FILES b/catalog/internal/server/openapi/.openapi-generator/FILES index 003202ad..7130db6b 100644 --- a/catalog/internal/server/openapi/.openapi-generator/FILES +++ b/catalog/internal/server/openapi/.openapi-generator/FILES @@ -6,6 +6,7 @@ impl.go logger.go model_artifact_type_query_param.go model_base_model.go +model_base_resource.go model_base_resource_dates.go model_base_resource_list.go model_catalog_artifact.go diff --git a/catalog/internal/server/openapi/api.go b/catalog/internal/server/openapi/api.go index 5a58ecb9..5dc53ff8 100644 --- a/catalog/internal/server/openapi/api.go +++ b/catalog/internal/server/openapi/api.go @@ -32,8 +32,8 @@ type ModelCatalogServiceAPIRouter interface { // while the service implementation can be ignored with the .openapi-generator-ignore file // and updated with the logic required for the API. type ModelCatalogServiceAPIServicer interface { - FindModels(context.Context, string, string, string, model.OrderByField, model.SortOrder, string) (ImplResponse, error) + FindModels(context.Context, []string, string, string, model.OrderByField, model.SortOrder, string) (ImplResponse, error) FindSources(context.Context, string, string, model.OrderByField, model.SortOrder, string) (ImplResponse, error) GetModel(context.Context, string, string) (ImplResponse, error) - GetAllModelArtifacts(context.Context, string, string) (ImplResponse, error) + GetAllModelArtifacts(context.Context, string, string, string, model.OrderByField, model.SortOrder, string) (ImplResponse, error) } diff --git a/catalog/internal/server/openapi/api_model_catalog_service.go b/catalog/internal/server/openapi/api_model_catalog_service.go index 0f2e3365..39e80d0e 100644 --- a/catalog/internal/server/openapi/api_model_catalog_service.go +++ b/catalog/internal/server/openapi/api_model_catalog_service.go @@ -78,7 +78,7 @@ func (c *ModelCatalogServiceAPIController) Routes() Routes { // FindModels - Search catalog models across sources. func (c *ModelCatalogServiceAPIController) FindModels(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() - sourceParam := query.Get("source") + sourceParam := strings.Split(query.Get("source"), ",") qParam := query.Get("q") pageSizeParam := query.Get("pageSize") orderByParam := query.Get("orderBy") @@ -128,9 +128,14 @@ func (c *ModelCatalogServiceAPIController) GetModel(w http.ResponseWriter, r *ht // GetAllModelArtifacts - List CatalogArtifacts. func (c *ModelCatalogServiceAPIController) GetAllModelArtifacts(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() sourceIdParam := chi.URLParam(r, "source_id") modelNameParam := chi.URLParam(r, "model_name") - result, err := c.service.GetAllModelArtifacts(r.Context(), sourceIdParam, modelNameParam) + pageSizeParam := query.Get("pageSize") + orderByParam := query.Get("orderBy") + sortOrderParam := query.Get("sortOrder") + nextPageTokenParam := query.Get("nextPageToken") + result, err := c.service.GetAllModelArtifacts(r.Context(), sourceIdParam, modelNameParam, pageSizeParam, model.OrderByField(orderByParam), model.SortOrder(sortOrderParam), nextPageTokenParam) // If an error occurred, encode the error with the status code if err != nil { c.errorHandler(w, r, err, &result) diff --git a/catalog/internal/server/openapi/api_model_catalog_service_service.go b/catalog/internal/server/openapi/api_model_catalog_service_service.go index 37f2cc52..9c632f6b 100644 --- a/catalog/internal/server/openapi/api_model_catalog_service_service.go +++ b/catalog/internal/server/openapi/api_model_catalog_service_service.go @@ -8,87 +8,111 @@ import ( "net/http" "net/url" "slices" + "strconv" "strings" "github.com/kubeflow/model-registry/catalog/internal/catalog" model "github.com/kubeflow/model-registry/catalog/pkg/openapi" + "github.com/kubeflow/model-registry/pkg/api" ) // ModelCatalogServiceAPIService is a service that implements the logic for the ModelCatalogServiceAPIServicer // This service should implement the business logic for every endpoint for the ModelCatalogServiceAPI s.coreApi. // Include any external packages or services that will be required by this service. type ModelCatalogServiceAPIService struct { - sources *catalog.SourceCollection + provider catalog.CatalogSourceProvider + sources *catalog.SourceCollection } // GetAllModelArtifacts retrieves all model artifacts for a given model from the specified source. -func (m *ModelCatalogServiceAPIService) GetAllModelArtifacts(ctx context.Context, sourceID string, name string) (ImplResponse, error) { - source, ok := m.sources.Get(sourceID) - if !ok { - return notFound("Unknown source"), nil +func (m *ModelCatalogServiceAPIService) GetAllModelArtifacts(ctx context.Context, sourceID string, modelName string, pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (ImplResponse, error) { + if newName, err := url.PathUnescape(modelName); err == nil { + modelName = newName } - if newName, err := url.PathUnescape(name); err == nil { - name = newName + var err error + pageSizeInt := int32(10) + + if pageSize != "" { + parsed, err := strconv.ParseInt(pageSize, 10, 32) + if err != nil { + return Response(http.StatusBadRequest, err), err + } + pageSizeInt = int32(parsed) } - artifacts, err := source.Provider.GetArtifacts(ctx, name) + artifacts, err := m.provider.GetArtifacts(ctx, modelName, sourceID, catalog.ListArtifactsParams{ + PageSize: pageSizeInt, + OrderBy: orderBy, + SortOrder: sortOrder, + NextPageToken: &nextPageToken, + }) if err != nil { - return Response(http.StatusInternalServerError, err), err + statusCode := api.ErrToStatus(err) + var errorMsg string + if errors.Is(err, api.ErrBadRequest) { + errorMsg = fmt.Sprintf("Invalid model name '%s' for source '%s'", modelName, sourceID) + } else if errors.Is(err, api.ErrNotFound) { + errorMsg = fmt.Sprintf("No model found '%s' in source '%s'", modelName, sourceID) + } else { + errorMsg = err.Error() + } + return ErrorResponse(statusCode, errors.New(errorMsg)), err } return Response(http.StatusOK, artifacts), nil } -func (m *ModelCatalogServiceAPIService) FindModels(ctx context.Context, sourceID string, q string, pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (ImplResponse, error) { - source, ok := m.sources.Get(sourceID) - if !ok { - return notFound("Unknown source"), errors.New("Unknown source") - } +func (m *ModelCatalogServiceAPIService) FindModels(ctx context.Context, sourceIDs []string, q string, pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (ImplResponse, error) { + var err error + pageSizeInt := int32(10) - p, err := newPaginator[model.CatalogModel](pageSize, orderBy, sortOrder, nextPageToken) - if err != nil { - return ErrorResponse(http.StatusBadRequest, err), err + if pageSize != "" { + parsed, err := strconv.ParseInt(pageSize, 10, 32) + if err != nil { + return Response(http.StatusBadRequest, err), err + } + pageSizeInt = int32(parsed) } listModelsParams := catalog.ListModelsParams{ - Query: q, - OrderBy: p.OrderBy, - SortOrder: p.SortOrder, + Query: q, + SourceIDs: sourceIDs, + PageSize: pageSizeInt, + OrderBy: orderBy, + SortOrder: sortOrder, + NextPageToken: &nextPageToken, } - models, err := source.Provider.ListModels(ctx, listModelsParams) + models, err := m.provider.ListModels(ctx, listModelsParams) if err != nil { return ErrorResponse(http.StatusInternalServerError, err), err } - page, next := p.Paginate(models.Items) - - models.Items = page - models.PageSize = p.PageSize - models.NextPageToken = next.Token() - return Response(http.StatusOK, models), nil } -func (m *ModelCatalogServiceAPIService) GetModel(ctx context.Context, sourceID string, name string) (ImplResponse, error) { - if name, ok := strings.CutSuffix(name, "/artifacts"); ok { - return m.GetAllModelArtifacts(ctx, sourceID, name) +func (m *ModelCatalogServiceAPIService) GetModel(ctx context.Context, sourceID string, modelName string) (ImplResponse, error) { + if name, ok := strings.CutSuffix(modelName, "/artifacts"); ok { + return m.GetAllModelArtifacts(ctx, sourceID, name, "10", model.OrderByField(model.ORDERBYFIELD_CREATE_TIME), model.SortOrder(model.SORTORDER_ASC), "") } - source, ok := m.sources.Get(sourceID) - if !ok { - return notFound("Unknown source"), nil + if newName, err := url.PathUnescape(modelName); err == nil { + modelName = newName } - if newName, err := url.PathUnescape(name); err == nil { - name = newName - } - - model, err := source.Provider.GetModel(ctx, name) + model, err := m.provider.GetModel(ctx, modelName, sourceID) if err != nil { - return Response(http.StatusInternalServerError, err), err + statusCode := api.ErrToStatus(err) + var errorMsg string + if errors.Is(err, api.ErrNotFound) { + errorMsg = fmt.Sprintf("No model found '%s' in source '%s'", modelName, sourceID) + } else { + errorMsg = err.Error() + } + return ErrorResponse(statusCode, errors.New(errorMsg)), err } + if model == nil { return notFound("Unknown model or version"), nil } @@ -167,9 +191,10 @@ func genCatalogCmpFunc(orderBy model.OrderByField, sortOrder model.SortOrder) (f var _ ModelCatalogServiceAPIServicer = &ModelCatalogServiceAPIService{} // NewModelCatalogServiceAPIService creates a default api service -func NewModelCatalogServiceAPIService(sources *catalog.SourceCollection) ModelCatalogServiceAPIServicer { +func NewModelCatalogServiceAPIService(provider catalog.CatalogSourceProvider, sources *catalog.SourceCollection) ModelCatalogServiceAPIServicer { return &ModelCatalogServiceAPIService{ - sources: sources, + provider: provider, + sources: sources, } } diff --git a/catalog/internal/server/openapi/api_model_catalog_service_service_test.go b/catalog/internal/server/openapi/api_model_catalog_service_service_test.go index ac577fea..4b84ede9 100644 --- a/catalog/internal/server/openapi/api_model_catalog_service_service_test.go +++ b/catalog/internal/server/openapi/api_model_catalog_service_service_test.go @@ -181,15 +181,20 @@ func TestFindModels(t *testing.T) { }, }, { - name: "Invalid source ID", - sourceID: "unknown-source", - mockModels: map[string]*model.CatalogModel{}, - q: "", - pageSize: "10", - orderBy: model.ORDERBYFIELD_ID, - sortOrder: model.SORTORDER_ASC, - expectedStatus: http.StatusNotFound, - expectedModelList: nil, + name: "Invalid source ID", + sourceID: "unknown-source", + mockModels: map[string]*model.CatalogModel{}, + q: "", + pageSize: "10", + orderBy: model.ORDERBYFIELD_ID, + sortOrder: model.SORTORDER_ASC, + expectedStatus: http.StatusOK, // Changed from http.StatusNotFound to http.StatusOK with an empty list -- now the source ID is just a field in the CatalogModel + expectedModelList: &model.CatalogModelList{ + Items: []model.CatalogModel{}, + Size: 0, + PageSize: 10, + NextPageToken: "", + }, }, { name: "Invalid pageSize string", @@ -210,12 +215,19 @@ func TestFindModels(t *testing.T) { mockModels: map[string]*model.CatalogModel{ "modelA": modelA, }, - q: "", - pageSize: "10", - orderBy: "UNSUPPORTED_FIELD", - sortOrder: model.SORTORDER_ASC, - expectedStatus: http.StatusBadRequest, - expectedModelList: nil, + q: "", + pageSize: "10", + orderBy: "UNSUPPORTED_FIELD", + sortOrder: model.SORTORDER_ASC, + expectedStatus: http.StatusOK, // Changed from http.StatusBadRequest to http.StatusOK -- in model registry we fallback to ID if the order by field is unsupported + expectedModelList: &model.CatalogModelList{ + Items: []model.CatalogModel{ + *modelA, + }, + Size: 1, + PageSize: 10, + NextPageToken: "", + }, }, { name: "Unsupported sortOrder field", @@ -223,12 +235,19 @@ func TestFindModels(t *testing.T) { mockModels: map[string]*model.CatalogModel{ "modelA": modelA, }, - q: "", - pageSize: "10", - orderBy: model.ORDERBYFIELD_ID, - sortOrder: "UNSUPPORTED_ORDER", - expectedStatus: http.StatusBadRequest, - expectedModelList: nil, + q: "", + pageSize: "10", + orderBy: model.ORDERBYFIELD_ID, + sortOrder: "UNSUPPORTED_ORDER", + expectedStatus: http.StatusOK, // Changed from http.StatusBadRequest to http.StatusOK -- in model registry we fallback to ASC if the sort order field is unsupported + expectedModelList: &model.CatalogModelList{ + Items: []model.CatalogModel{ + *modelA, + }, + Size: 1, + PageSize: 10, + NextPageToken: "", + }, }, { name: "Empty models in source", @@ -277,11 +296,16 @@ func TestFindModels(t *testing.T) { }, }, }) - service := NewModelCatalogServiceAPIService(sources) + + provider := &mockModelProvider{ + models: tc.mockModels, + } + + service := NewModelCatalogServiceAPIService(provider, sources) resp, err := service.FindModels( context.Background(), - tc.sourceID, + []string{tc.sourceID}, tc.q, tc.pageSize, tc.orderBy, @@ -635,7 +659,7 @@ func TestFindSources(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Create service with test catalogs - service := NewModelCatalogServiceAPIService(catalog.NewSourceCollection(tc.catalogs)) + service := NewModelCatalogServiceAPIService(&mockModelProvider{}, catalog.NewSourceCollection(tc.catalogs)) // Call FindSources resp, err := service.FindSources( @@ -729,7 +753,7 @@ type mockModelProvider struct { } // Implement GetModel method for the mock provider -func (m *mockModelProvider) GetModel(ctx context.Context, name string) (*model.CatalogModel, error) { +func (m *mockModelProvider) GetModel(ctx context.Context, name string, sourceID string) (*model.CatalogModel, error) { model, exists := m.models[name] if !exists { return nil, nil @@ -771,30 +795,49 @@ func (m *mockModelProvider) ListModels(ctx context.Context, params catalog.ListM return cmp < 0 }) - items := make([]model.CatalogModel, len(filteredModels)) - for i, mdl := range filteredModels { + totalSize := int32(len(filteredModels)) + pageSize := params.PageSize + if pageSize <= 0 { + pageSize = 10 + } + + // Apply pagination - limit items to page size + endIndex := int(pageSize) + if endIndex > len(filteredModels) { + endIndex = len(filteredModels) + } + + pagedModels := filteredModels[:endIndex] + items := make([]model.CatalogModel, len(pagedModels)) + for i, mdl := range pagedModels { items[i] = *mdl } + nextPageToken := "" + if len(filteredModels) > int(pageSize) { + lastItem := pagedModels[len(pagedModels)-1] + nextPageToken = (&stringCursor{Value: lastItem.Name, ID: lastItem.Name}).String() + } + return model.CatalogModelList{ Items: items, - Size: int32(len(items)), - PageSize: int32(len(items)), // Mock returns all filtered items as one "page" - NextPageToken: "", + Size: totalSize, + PageSize: pageSize, + NextPageToken: nextPageToken, }, nil } -func (m *mockModelProvider) GetArtifacts(ctx context.Context, name string) (*model.CatalogArtifactList, error) { +func (m *mockModelProvider) GetArtifacts(ctx context.Context, name string, sourceID string, params catalog.ListArtifactsParams) (model.CatalogArtifactList, error) { artifacts, exists := m.artifacts[name] if !exists { - return &model.CatalogArtifactList{ + return model.CatalogArtifactList{ Items: []model.CatalogArtifact{}, Size: 0, PageSize: 0, // Or a default page size if applicable NextPageToken: "", }, nil } - return &model.CatalogArtifactList{ + return model.CatalogArtifactList{ Items: artifacts, Size: int32(len(artifacts)), PageSize: int32(len(artifacts)), @@ -810,6 +853,7 @@ func TestGetModel(t *testing.T) { modelName string expectedStatus int expectedModel *model.CatalogModel + provider catalog.CatalogSourceProvider }{ { name: "Existing model in source", @@ -825,6 +869,13 @@ func TestGetModel(t *testing.T) { }, }, }, + provider: &mockModelProvider{ + models: map[string]*model.CatalogModel{ + "test-model": { + Name: "test-model", + }, + }, + }, sourceID: "source1", modelName: "test-model", expectedStatus: http.StatusOK, @@ -839,6 +890,9 @@ func TestGetModel(t *testing.T) { Metadata: model.CatalogSource{Id: "source1", Name: "Test Source"}, }, }, + provider: &mockModelProvider{ + models: map[string]*model.CatalogModel{}, + }, sourceID: "source2", modelName: "test-model", expectedStatus: http.StatusNotFound, @@ -854,6 +908,9 @@ func TestGetModel(t *testing.T) { }, }, }, + provider: &mockModelProvider{ + models: map[string]*model.CatalogModel{}, + }, sourceID: "source1", modelName: "test-model", expectedStatus: http.StatusNotFound, @@ -873,6 +930,13 @@ func TestGetModel(t *testing.T) { }, }, }, + provider: &mockModelProvider{ + models: map[string]*model.CatalogModel{ + "some/model:v1.0.0": { + Name: "some/model:v1.0.0", + }, + }, + }, sourceID: "source1", modelName: "some%2Fmodel%3Av1.0.0", expectedStatus: http.StatusOK, @@ -885,7 +949,7 @@ func TestGetModel(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Create service with test sources - service := NewModelCatalogServiceAPIService(catalog.NewSourceCollection(tc.sources)) + service := NewModelCatalogServiceAPIService(tc.provider, catalog.NewSourceCollection(tc.sources)) // Call GetModel resp, _ := service.GetModel( @@ -923,6 +987,7 @@ func TestGetAllModelArtifacts(t *testing.T) { modelName string expectedStatus int expectedArtifacts []model.CatalogArtifact + provider catalog.CatalogSourceProvider }{ { name: "Existing artifacts for model in source", @@ -947,6 +1012,22 @@ func TestGetAllModelArtifacts(t *testing.T) { }, }, }, + provider: &mockModelProvider{ + artifacts: map[string][]model.CatalogArtifact{ + "test-model": { + { + CatalogModelArtifact: &model.CatalogModelArtifact{ + Uri: "s3://bucket/artifact1", + }, + }, + { + CatalogModelArtifact: &model.CatalogModelArtifact{ + Uri: "s3://bucket/artifact2", + }, + }, + }, + }, + }, sourceID: "source1", modelName: "test-model", expectedStatus: http.StatusOK, @@ -970,10 +1051,13 @@ func TestGetAllModelArtifacts(t *testing.T) { Metadata: model.CatalogSource{Id: "source1", Name: "Test Source"}, }, }, + provider: &mockModelProvider{ + artifacts: map[string][]model.CatalogArtifact{}, + }, sourceID: "source2", modelName: "test-model", - expectedStatus: http.StatusNotFound, - expectedArtifacts: nil, + expectedStatus: http.StatusOK, // Changed from http.StatusNotFound to http.StatusOK -- having the same behavior as the model registry + expectedArtifacts: []model.CatalogArtifact{}, }, { name: "Existing source, no artifacts for model", @@ -985,6 +1069,9 @@ func TestGetAllModelArtifacts(t *testing.T) { }, }, }, + provider: &mockModelProvider{ + artifacts: map[string][]model.CatalogArtifact{}, + }, sourceID: "source1", modelName: "test-model", expectedStatus: http.StatusOK, @@ -995,13 +1082,17 @@ func TestGetAllModelArtifacts(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Create service with test sources - service := NewModelCatalogServiceAPIService(catalog.NewSourceCollection(tc.sources)) + service := NewModelCatalogServiceAPIService(tc.provider, catalog.NewSourceCollection(tc.sources)) // Call GetAllModelArtifacts resp, _ := service.GetAllModelArtifacts( context.Background(), tc.sourceID, tc.modelName, + "10", + model.ORDERBYFIELD_CREATE_TIME, + model.SORTORDER_ASC, + "", ) // Check response status @@ -1016,7 +1107,7 @@ func TestGetAllModelArtifacts(t *testing.T) { require.NotNil(t, resp.Body) // Type assertion to access the list of artifacts - artifactList, ok := resp.Body.(*model.CatalogArtifactList) + artifactList, ok := resp.Body.(model.CatalogArtifactList) require.True(t, ok, "Response body should be a CatalogArtifactList") // Check the artifacts diff --git a/catalog/internal/server/openapi/type_asserts.go b/catalog/internal/server/openapi/type_asserts.go index ec6b68b1..b54eae4e 100644 --- a/catalog/internal/server/openapi/type_asserts.go +++ b/catalog/internal/server/openapi/type_asserts.go @@ -36,6 +36,11 @@ func AssertBaseModelRequired(obj model.BaseModel) error { return nil } +// AssertBaseResourceConstraints checks if the values respects the defined constraints +func AssertBaseResourceConstraints(obj model.BaseResource) error { + return nil +} + // AssertBaseResourceDatesConstraints checks if the values respects the defined constraints func AssertBaseResourceDatesConstraints(obj model.BaseResourceDates) error { return nil @@ -67,6 +72,11 @@ func AssertBaseResourceListRequired(obj model.BaseResourceList) error { return nil } +// AssertBaseResourceRequired checks if the required fields are not zero-ed +func AssertBaseResourceRequired(obj model.BaseResource) error { + return nil +} + // AssertCatalogArtifactListConstraints checks if the values respects the defined constraints func AssertCatalogArtifactListConstraints(obj model.CatalogArtifactList) error { return nil diff --git a/catalog/pkg/openapi/.openapi-generator/FILES b/catalog/pkg/openapi/.openapi-generator/FILES index db65885f..7291aeab 100644 --- a/catalog/pkg/openapi/.openapi-generator/FILES +++ b/catalog/pkg/openapi/.openapi-generator/FILES @@ -3,6 +3,7 @@ client.go configuration.go model_artifact_type_query_param.go model_base_model.go +model_base_resource.go model_base_resource_dates.go model_base_resource_list.go model_catalog_artifact.go diff --git a/catalog/pkg/openapi/api_model_catalog_service.go b/catalog/pkg/openapi/api_model_catalog_service.go index d6bd1c19..9bb15d09 100644 --- a/catalog/pkg/openapi/api_model_catalog_service.go +++ b/catalog/pkg/openapi/api_model_catalog_service.go @@ -16,6 +16,7 @@ import ( "io" "net/http" "net/url" + "reflect" "strings" ) @@ -25,7 +26,7 @@ type ModelCatalogServiceAPIService service type ApiFindModelsRequest struct { ctx context.Context ApiService *ModelCatalogServiceAPIService - source *string + source *[]string q *string pageSize *string orderBy *OrderByField @@ -33,8 +34,8 @@ type ApiFindModelsRequest struct { nextPageToken *string } -// Filter models by source. This parameter is currently required and may only be specified once. -func (r ApiFindModelsRequest) Source(source string) ApiFindModelsRequest { +// Filter models by source. This parameter can be specified multiple times to filter by multiple sources (OR logic). For example: ?source=huggingface&source=local will return models from either huggingface OR local sources. +func (r ApiFindModelsRequest) Source(source []string) ApiFindModelsRequest { r.source = &source return r } @@ -107,11 +108,18 @@ func (a *ModelCatalogServiceAPIService) FindModelsExecute(r ApiFindModelsRequest localVarHeaderParams := make(map[string]string) localVarQueryParams := url.Values{} localVarFormParams := url.Values{} - if r.source == nil { - return localVarReturnValue, nil, reportError("source is required and must be specified") - } - parameterAddToHeaderOrQuery(localVarQueryParams, "source", r.source, "") + if r.source != nil { + t := *r.source + if reflect.TypeOf(t).Kind() == reflect.Slice { + s := reflect.ValueOf(t) + for i := 0; i < s.Len(); i++ { + parameterAddToHeaderOrQuery(localVarQueryParams, "source", s.Index(i).Interface(), "multi") + } + } else { + parameterAddToHeaderOrQuery(localVarQueryParams, "source", t, "multi") + } + } if r.q != nil { parameterAddToHeaderOrQuery(localVarQueryParams, "q", r.q, "") } @@ -418,10 +426,38 @@ func (a *ModelCatalogServiceAPIService) FindSourcesExecute(r ApiFindSourcesReque } type ApiGetAllModelArtifactsRequest struct { - ctx context.Context - ApiService *ModelCatalogServiceAPIService - sourceId string - modelName string + ctx context.Context + ApiService *ModelCatalogServiceAPIService + sourceId string + modelName string + pageSize *string + orderBy *OrderByField + sortOrder *SortOrder + nextPageToken *string +} + +// Number of entities in each page. +func (r ApiGetAllModelArtifactsRequest) PageSize(pageSize string) ApiGetAllModelArtifactsRequest { + r.pageSize = &pageSize + return r +} + +// Specifies the order by criteria for listing entities. +func (r ApiGetAllModelArtifactsRequest) OrderBy(orderBy OrderByField) ApiGetAllModelArtifactsRequest { + r.orderBy = &orderBy + return r +} + +// Specifies the sort order for listing entities, defaults to ASC. +func (r ApiGetAllModelArtifactsRequest) SortOrder(sortOrder SortOrder) ApiGetAllModelArtifactsRequest { + r.sortOrder = &sortOrder + return r +} + +// Token to use to retrieve next page of results. +func (r ApiGetAllModelArtifactsRequest) NextPageToken(nextPageToken string) ApiGetAllModelArtifactsRequest { + r.nextPageToken = &nextPageToken + return r } func (r ApiGetAllModelArtifactsRequest) Execute() (*CatalogArtifactList, *http.Response, error) { @@ -469,6 +505,18 @@ func (a *ModelCatalogServiceAPIService) GetAllModelArtifactsExecute(r ApiGetAllM localVarQueryParams := url.Values{} localVarFormParams := url.Values{} + if r.pageSize != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "pageSize", r.pageSize, "") + } + if r.orderBy != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "orderBy", r.orderBy, "") + } + if r.sortOrder != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "sortOrder", r.sortOrder, "") + } + if r.nextPageToken != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "nextPageToken", r.nextPageToken, "") + } // to determine the Content-Type header localVarHTTPContentTypes := []string{} diff --git a/catalog/pkg/openapi/model_base_resource.go b/catalog/pkg/openapi/model_base_resource.go new file mode 100644 index 00000000..6ed5ad2e --- /dev/null +++ b/catalog/pkg/openapi/model_base_resource.go @@ -0,0 +1,347 @@ +/* +Model Catalog REST API + +REST API for Model Registry to create and manage ML model metadata + +API version: v1alpha1 +*/ + +// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. + +package openapi + +import ( + "encoding/json" +) + +// checks if the BaseResource type satisfies the MappedNullable interface at compile time +var _ MappedNullable = &BaseResource{} + +// BaseResource struct for BaseResource +type BaseResource struct { + // Output only. Create time of the resource in millisecond since epoch. + CreateTimeSinceEpoch *string `json:"createTimeSinceEpoch,omitempty"` + // Output only. Last update time of the resource since epoch in millisecond since epoch. + LastUpdateTimeSinceEpoch *string `json:"lastUpdateTimeSinceEpoch,omitempty"` + // User provided custom properties which are not defined by its type. + CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"` + // An optional description about the resource. + Description *string `json:"description,omitempty"` + // The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance. + ExternalId *string `json:"externalId,omitempty"` + // The client provided name of the artifact. This field is optional. If set, it must be unique among all the artifacts of the same artifact type within a database instance and cannot be changed once set. + Name *string `json:"name,omitempty"` + // The unique server generated id of the resource. + Id *string `json:"id,omitempty"` +} + +// NewBaseResource instantiates a new BaseResource object +// This constructor will assign default values to properties that have it defined, +// and makes sure properties required by API are set, but the set of arguments +// will change when the set of required properties is changed +func NewBaseResource() *BaseResource { + this := BaseResource{} + return &this +} + +// NewBaseResourceWithDefaults instantiates a new BaseResource object +// This constructor will only assign default values to properties that have it defined, +// but it doesn't guarantee that properties required by API are set +func NewBaseResourceWithDefaults() *BaseResource { + this := BaseResource{} + return &this +} + +// GetCreateTimeSinceEpoch returns the CreateTimeSinceEpoch field value if set, zero value otherwise. +func (o *BaseResource) GetCreateTimeSinceEpoch() string { + if o == nil || IsNil(o.CreateTimeSinceEpoch) { + var ret string + return ret + } + return *o.CreateTimeSinceEpoch +} + +// GetCreateTimeSinceEpochOk returns a tuple with the CreateTimeSinceEpoch field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *BaseResource) GetCreateTimeSinceEpochOk() (*string, bool) { + if o == nil || IsNil(o.CreateTimeSinceEpoch) { + return nil, false + } + return o.CreateTimeSinceEpoch, true +} + +// HasCreateTimeSinceEpoch returns a boolean if a field has been set. +func (o *BaseResource) HasCreateTimeSinceEpoch() bool { + if o != nil && !IsNil(o.CreateTimeSinceEpoch) { + return true + } + + return false +} + +// SetCreateTimeSinceEpoch gets a reference to the given string and assigns it to the CreateTimeSinceEpoch field. +func (o *BaseResource) SetCreateTimeSinceEpoch(v string) { + o.CreateTimeSinceEpoch = &v +} + +// GetLastUpdateTimeSinceEpoch returns the LastUpdateTimeSinceEpoch field value if set, zero value otherwise. +func (o *BaseResource) GetLastUpdateTimeSinceEpoch() string { + if o == nil || IsNil(o.LastUpdateTimeSinceEpoch) { + var ret string + return ret + } + return *o.LastUpdateTimeSinceEpoch +} + +// GetLastUpdateTimeSinceEpochOk returns a tuple with the LastUpdateTimeSinceEpoch field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *BaseResource) GetLastUpdateTimeSinceEpochOk() (*string, bool) { + if o == nil || IsNil(o.LastUpdateTimeSinceEpoch) { + return nil, false + } + return o.LastUpdateTimeSinceEpoch, true +} + +// HasLastUpdateTimeSinceEpoch returns a boolean if a field has been set. +func (o *BaseResource) HasLastUpdateTimeSinceEpoch() bool { + if o != nil && !IsNil(o.LastUpdateTimeSinceEpoch) { + return true + } + + return false +} + +// SetLastUpdateTimeSinceEpoch gets a reference to the given string and assigns it to the LastUpdateTimeSinceEpoch field. +func (o *BaseResource) SetLastUpdateTimeSinceEpoch(v string) { + o.LastUpdateTimeSinceEpoch = &v +} + +// GetCustomProperties returns the CustomProperties field value if set, zero value otherwise. +func (o *BaseResource) GetCustomProperties() map[string]MetadataValue { + if o == nil || IsNil(o.CustomProperties) { + var ret map[string]MetadataValue + return ret + } + return *o.CustomProperties +} + +// GetCustomPropertiesOk returns a tuple with the CustomProperties field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *BaseResource) GetCustomPropertiesOk() (*map[string]MetadataValue, bool) { + if o == nil || IsNil(o.CustomProperties) { + return nil, false + } + return o.CustomProperties, true +} + +// HasCustomProperties returns a boolean if a field has been set. +func (o *BaseResource) HasCustomProperties() bool { + if o != nil && !IsNil(o.CustomProperties) { + return true + } + + return false +} + +// SetCustomProperties gets a reference to the given map[string]MetadataValue and assigns it to the CustomProperties field. +func (o *BaseResource) SetCustomProperties(v map[string]MetadataValue) { + o.CustomProperties = &v +} + +// GetDescription returns the Description field value if set, zero value otherwise. +func (o *BaseResource) GetDescription() string { + if o == nil || IsNil(o.Description) { + var ret string + return ret + } + return *o.Description +} + +// GetDescriptionOk returns a tuple with the Description field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *BaseResource) GetDescriptionOk() (*string, bool) { + if o == nil || IsNil(o.Description) { + return nil, false + } + return o.Description, true +} + +// HasDescription returns a boolean if a field has been set. +func (o *BaseResource) HasDescription() bool { + if o != nil && !IsNil(o.Description) { + return true + } + + return false +} + +// SetDescription gets a reference to the given string and assigns it to the Description field. +func (o *BaseResource) SetDescription(v string) { + o.Description = &v +} + +// GetExternalId returns the ExternalId field value if set, zero value otherwise. +func (o *BaseResource) GetExternalId() string { + if o == nil || IsNil(o.ExternalId) { + var ret string + return ret + } + return *o.ExternalId +} + +// GetExternalIdOk returns a tuple with the ExternalId field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *BaseResource) GetExternalIdOk() (*string, bool) { + if o == nil || IsNil(o.ExternalId) { + return nil, false + } + return o.ExternalId, true +} + +// HasExternalId returns a boolean if a field has been set. +func (o *BaseResource) HasExternalId() bool { + if o != nil && !IsNil(o.ExternalId) { + return true + } + + return false +} + +// SetExternalId gets a reference to the given string and assigns it to the ExternalId field. +func (o *BaseResource) SetExternalId(v string) { + o.ExternalId = &v +} + +// GetName returns the Name field value if set, zero value otherwise. +func (o *BaseResource) GetName() string { + if o == nil || IsNil(o.Name) { + var ret string + return ret + } + return *o.Name +} + +// GetNameOk returns a tuple with the Name field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *BaseResource) GetNameOk() (*string, bool) { + if o == nil || IsNil(o.Name) { + return nil, false + } + return o.Name, true +} + +// HasName returns a boolean if a field has been set. +func (o *BaseResource) HasName() bool { + if o != nil && !IsNil(o.Name) { + return true + } + + return false +} + +// SetName gets a reference to the given string and assigns it to the Name field. +func (o *BaseResource) SetName(v string) { + o.Name = &v +} + +// GetId returns the Id field value if set, zero value otherwise. +func (o *BaseResource) GetId() string { + if o == nil || IsNil(o.Id) { + var ret string + return ret + } + return *o.Id +} + +// GetIdOk returns a tuple with the Id field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *BaseResource) GetIdOk() (*string, bool) { + if o == nil || IsNil(o.Id) { + return nil, false + } + return o.Id, true +} + +// HasId returns a boolean if a field has been set. +func (o *BaseResource) HasId() bool { + if o != nil && !IsNil(o.Id) { + return true + } + + return false +} + +// SetId gets a reference to the given string and assigns it to the Id field. +func (o *BaseResource) SetId(v string) { + o.Id = &v +} + +func (o BaseResource) MarshalJSON() ([]byte, error) { + toSerialize, err := o.ToMap() + if err != nil { + return []byte{}, err + } + return json.Marshal(toSerialize) +} + +func (o BaseResource) ToMap() (map[string]interface{}, error) { + toSerialize := map[string]interface{}{} + if !IsNil(o.CreateTimeSinceEpoch) { + toSerialize["createTimeSinceEpoch"] = o.CreateTimeSinceEpoch + } + if !IsNil(o.LastUpdateTimeSinceEpoch) { + toSerialize["lastUpdateTimeSinceEpoch"] = o.LastUpdateTimeSinceEpoch + } + if !IsNil(o.CustomProperties) { + toSerialize["customProperties"] = o.CustomProperties + } + if !IsNil(o.Description) { + toSerialize["description"] = o.Description + } + if !IsNil(o.ExternalId) { + toSerialize["externalId"] = o.ExternalId + } + if !IsNil(o.Name) { + toSerialize["name"] = o.Name + } + if !IsNil(o.Id) { + toSerialize["id"] = o.Id + } + return toSerialize, nil +} + +type NullableBaseResource struct { + value *BaseResource + isSet bool +} + +func (v NullableBaseResource) Get() *BaseResource { + return v.value +} + +func (v *NullableBaseResource) Set(val *BaseResource) { + v.value = val + v.isSet = true +} + +func (v NullableBaseResource) IsSet() bool { + return v.isSet +} + +func (v *NullableBaseResource) Unset() { + v.value = nil + v.isSet = false +} + +func NewNullableBaseResource(val *BaseResource) *NullableBaseResource { + return &NullableBaseResource{value: val, isSet: true} +} + +func (v NullableBaseResource) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) +} + +func (v *NullableBaseResource) UnmarshalJSON(src []byte) error { + v.isSet = true + return json.Unmarshal(src, &v.value) +} diff --git a/catalog/pkg/openapi/model_catalog_metrics_artifact.go b/catalog/pkg/openapi/model_catalog_metrics_artifact.go index c983d0fd..5c309b54 100644 --- a/catalog/pkg/openapi/model_catalog_metrics_artifact.go +++ b/catalog/pkg/openapi/model_catalog_metrics_artifact.go @@ -19,14 +19,22 @@ var _ MappedNullable = &CatalogMetricsArtifact{} // CatalogMetricsArtifact A metadata Artifact Entity. type CatalogMetricsArtifact struct { + // User provided custom properties which are not defined by its type. + CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"` + // An optional description about the resource. + Description *string `json:"description,omitempty"` + // The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance. + ExternalId *string `json:"externalId,omitempty"` + // The client provided name of the artifact. This field is optional. If set, it must be unique among all the artifacts of the same artifact type within a database instance and cannot be changed once set. + Name *string `json:"name,omitempty"` + // The unique server generated id of the resource. + Id *string `json:"id,omitempty"` // Output only. Create time of the resource in millisecond since epoch. CreateTimeSinceEpoch *string `json:"createTimeSinceEpoch,omitempty"` // Output only. Last update time of the resource since epoch in millisecond since epoch. LastUpdateTimeSinceEpoch *string `json:"lastUpdateTimeSinceEpoch,omitempty"` ArtifactType string `json:"artifactType"` MetricsType string `json:"metricsType"` - // User provided custom properties which are not defined by its type. - CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"` } // NewCatalogMetricsArtifact instantiates a new CatalogMetricsArtifact object @@ -50,6 +58,166 @@ func NewCatalogMetricsArtifactWithDefaults() *CatalogMetricsArtifact { return &this } +// GetCustomProperties returns the CustomProperties field value if set, zero value otherwise. +func (o *CatalogMetricsArtifact) GetCustomProperties() map[string]MetadataValue { + if o == nil || IsNil(o.CustomProperties) { + var ret map[string]MetadataValue + return ret + } + return *o.CustomProperties +} + +// GetCustomPropertiesOk returns a tuple with the CustomProperties field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *CatalogMetricsArtifact) GetCustomPropertiesOk() (*map[string]MetadataValue, bool) { + if o == nil || IsNil(o.CustomProperties) { + return nil, false + } + return o.CustomProperties, true +} + +// HasCustomProperties returns a boolean if a field has been set. +func (o *CatalogMetricsArtifact) HasCustomProperties() bool { + if o != nil && !IsNil(o.CustomProperties) { + return true + } + + return false +} + +// SetCustomProperties gets a reference to the given map[string]MetadataValue and assigns it to the CustomProperties field. +func (o *CatalogMetricsArtifact) SetCustomProperties(v map[string]MetadataValue) { + o.CustomProperties = &v +} + +// GetDescription returns the Description field value if set, zero value otherwise. +func (o *CatalogMetricsArtifact) GetDescription() string { + if o == nil || IsNil(o.Description) { + var ret string + return ret + } + return *o.Description +} + +// GetDescriptionOk returns a tuple with the Description field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *CatalogMetricsArtifact) GetDescriptionOk() (*string, bool) { + if o == nil || IsNil(o.Description) { + return nil, false + } + return o.Description, true +} + +// HasDescription returns a boolean if a field has been set. +func (o *CatalogMetricsArtifact) HasDescription() bool { + if o != nil && !IsNil(o.Description) { + return true + } + + return false +} + +// SetDescription gets a reference to the given string and assigns it to the Description field. +func (o *CatalogMetricsArtifact) SetDescription(v string) { + o.Description = &v +} + +// GetExternalId returns the ExternalId field value if set, zero value otherwise. +func (o *CatalogMetricsArtifact) GetExternalId() string { + if o == nil || IsNil(o.ExternalId) { + var ret string + return ret + } + return *o.ExternalId +} + +// GetExternalIdOk returns a tuple with the ExternalId field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *CatalogMetricsArtifact) GetExternalIdOk() (*string, bool) { + if o == nil || IsNil(o.ExternalId) { + return nil, false + } + return o.ExternalId, true +} + +// HasExternalId returns a boolean if a field has been set. +func (o *CatalogMetricsArtifact) HasExternalId() bool { + if o != nil && !IsNil(o.ExternalId) { + return true + } + + return false +} + +// SetExternalId gets a reference to the given string and assigns it to the ExternalId field. +func (o *CatalogMetricsArtifact) SetExternalId(v string) { + o.ExternalId = &v +} + +// GetName returns the Name field value if set, zero value otherwise. +func (o *CatalogMetricsArtifact) GetName() string { + if o == nil || IsNil(o.Name) { + var ret string + return ret + } + return *o.Name +} + +// GetNameOk returns a tuple with the Name field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *CatalogMetricsArtifact) GetNameOk() (*string, bool) { + if o == nil || IsNil(o.Name) { + return nil, false + } + return o.Name, true +} + +// HasName returns a boolean if a field has been set. +func (o *CatalogMetricsArtifact) HasName() bool { + if o != nil && !IsNil(o.Name) { + return true + } + + return false +} + +// SetName gets a reference to the given string and assigns it to the Name field. +func (o *CatalogMetricsArtifact) SetName(v string) { + o.Name = &v +} + +// GetId returns the Id field value if set, zero value otherwise. +func (o *CatalogMetricsArtifact) GetId() string { + if o == nil || IsNil(o.Id) { + var ret string + return ret + } + return *o.Id +} + +// GetIdOk returns a tuple with the Id field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *CatalogMetricsArtifact) GetIdOk() (*string, bool) { + if o == nil || IsNil(o.Id) { + return nil, false + } + return o.Id, true +} + +// HasId returns a boolean if a field has been set. +func (o *CatalogMetricsArtifact) HasId() bool { + if o != nil && !IsNil(o.Id) { + return true + } + + return false +} + +// SetId gets a reference to the given string and assigns it to the Id field. +func (o *CatalogMetricsArtifact) SetId(v string) { + o.Id = &v +} + // GetCreateTimeSinceEpoch returns the CreateTimeSinceEpoch field value if set, zero value otherwise. func (o *CatalogMetricsArtifact) GetCreateTimeSinceEpoch() string { if o == nil || IsNil(o.CreateTimeSinceEpoch) { @@ -162,38 +330,6 @@ func (o *CatalogMetricsArtifact) SetMetricsType(v string) { o.MetricsType = v } -// GetCustomProperties returns the CustomProperties field value if set, zero value otherwise. -func (o *CatalogMetricsArtifact) GetCustomProperties() map[string]MetadataValue { - if o == nil || IsNil(o.CustomProperties) { - var ret map[string]MetadataValue - return ret - } - return *o.CustomProperties -} - -// GetCustomPropertiesOk returns a tuple with the CustomProperties field value if set, nil otherwise -// and a boolean to check if the value has been set. -func (o *CatalogMetricsArtifact) GetCustomPropertiesOk() (*map[string]MetadataValue, bool) { - if o == nil || IsNil(o.CustomProperties) { - return nil, false - } - return o.CustomProperties, true -} - -// HasCustomProperties returns a boolean if a field has been set. -func (o *CatalogMetricsArtifact) HasCustomProperties() bool { - if o != nil && !IsNil(o.CustomProperties) { - return true - } - - return false -} - -// SetCustomProperties gets a reference to the given map[string]MetadataValue and assigns it to the CustomProperties field. -func (o *CatalogMetricsArtifact) SetCustomProperties(v map[string]MetadataValue) { - o.CustomProperties = &v -} - func (o CatalogMetricsArtifact) MarshalJSON() ([]byte, error) { toSerialize, err := o.ToMap() if err != nil { @@ -204,6 +340,21 @@ func (o CatalogMetricsArtifact) MarshalJSON() ([]byte, error) { func (o CatalogMetricsArtifact) ToMap() (map[string]interface{}, error) { toSerialize := map[string]interface{}{} + if !IsNil(o.CustomProperties) { + toSerialize["customProperties"] = o.CustomProperties + } + if !IsNil(o.Description) { + toSerialize["description"] = o.Description + } + if !IsNil(o.ExternalId) { + toSerialize["externalId"] = o.ExternalId + } + if !IsNil(o.Name) { + toSerialize["name"] = o.Name + } + if !IsNil(o.Id) { + toSerialize["id"] = o.Id + } if !IsNil(o.CreateTimeSinceEpoch) { toSerialize["createTimeSinceEpoch"] = o.CreateTimeSinceEpoch } @@ -212,9 +363,6 @@ func (o CatalogMetricsArtifact) ToMap() (map[string]interface{}, error) { } toSerialize["artifactType"] = o.ArtifactType toSerialize["metricsType"] = o.MetricsType - if !IsNil(o.CustomProperties) { - toSerialize["customProperties"] = o.CustomProperties - } return toSerialize, nil } diff --git a/catalog/pkg/openapi/model_catalog_model.go b/catalog/pkg/openapi/model_catalog_model.go index 0ece8c75..d1b5b114 100644 --- a/catalog/pkg/openapi/model_catalog_model.go +++ b/catalog/pkg/openapi/model_catalog_model.go @@ -19,11 +19,7 @@ var _ MappedNullable = &CatalogModel{} // CatalogModel A model in the model catalog. type CatalogModel struct { - // Output only. Create time of the resource in millisecond since epoch. - CreateTimeSinceEpoch *string `json:"createTimeSinceEpoch,omitempty"` - // Output only. Last update time of the resource since epoch in millisecond since epoch. - LastUpdateTimeSinceEpoch *string `json:"lastUpdateTimeSinceEpoch,omitempty"` - // Human-readable description of the model. + // An optional description about the resource. Description *string `json:"description,omitempty"` // Model documentation in Markdown. Readme *string `json:"readme,omitempty"` @@ -44,8 +40,16 @@ type CatalogModel struct { LibraryName *string `json:"libraryName,omitempty"` // User provided custom properties which are not defined by its type. CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"` + // The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance. + ExternalId *string `json:"externalId,omitempty"` // Name of the model. Must be unique within a source. Name string `json:"name"` + // The unique server generated id of the resource. + Id *string `json:"id,omitempty"` + // Output only. Create time of the resource in millisecond since epoch. + CreateTimeSinceEpoch *string `json:"createTimeSinceEpoch,omitempty"` + // Output only. Last update time of the resource since epoch in millisecond since epoch. + LastUpdateTimeSinceEpoch *string `json:"lastUpdateTimeSinceEpoch,omitempty"` // ID of the source this model belongs to. SourceId *string `json:"source_id,omitempty"` } @@ -68,70 +72,6 @@ func NewCatalogModelWithDefaults() *CatalogModel { return &this } -// GetCreateTimeSinceEpoch returns the CreateTimeSinceEpoch field value if set, zero value otherwise. -func (o *CatalogModel) GetCreateTimeSinceEpoch() string { - if o == nil || IsNil(o.CreateTimeSinceEpoch) { - var ret string - return ret - } - return *o.CreateTimeSinceEpoch -} - -// GetCreateTimeSinceEpochOk returns a tuple with the CreateTimeSinceEpoch field value if set, nil otherwise -// and a boolean to check if the value has been set. -func (o *CatalogModel) GetCreateTimeSinceEpochOk() (*string, bool) { - if o == nil || IsNil(o.CreateTimeSinceEpoch) { - return nil, false - } - return o.CreateTimeSinceEpoch, true -} - -// HasCreateTimeSinceEpoch returns a boolean if a field has been set. -func (o *CatalogModel) HasCreateTimeSinceEpoch() bool { - if o != nil && !IsNil(o.CreateTimeSinceEpoch) { - return true - } - - return false -} - -// SetCreateTimeSinceEpoch gets a reference to the given string and assigns it to the CreateTimeSinceEpoch field. -func (o *CatalogModel) SetCreateTimeSinceEpoch(v string) { - o.CreateTimeSinceEpoch = &v -} - -// GetLastUpdateTimeSinceEpoch returns the LastUpdateTimeSinceEpoch field value if set, zero value otherwise. -func (o *CatalogModel) GetLastUpdateTimeSinceEpoch() string { - if o == nil || IsNil(o.LastUpdateTimeSinceEpoch) { - var ret string - return ret - } - return *o.LastUpdateTimeSinceEpoch -} - -// GetLastUpdateTimeSinceEpochOk returns a tuple with the LastUpdateTimeSinceEpoch field value if set, nil otherwise -// and a boolean to check if the value has been set. -func (o *CatalogModel) GetLastUpdateTimeSinceEpochOk() (*string, bool) { - if o == nil || IsNil(o.LastUpdateTimeSinceEpoch) { - return nil, false - } - return o.LastUpdateTimeSinceEpoch, true -} - -// HasLastUpdateTimeSinceEpoch returns a boolean if a field has been set. -func (o *CatalogModel) HasLastUpdateTimeSinceEpoch() bool { - if o != nil && !IsNil(o.LastUpdateTimeSinceEpoch) { - return true - } - - return false -} - -// SetLastUpdateTimeSinceEpoch gets a reference to the given string and assigns it to the LastUpdateTimeSinceEpoch field. -func (o *CatalogModel) SetLastUpdateTimeSinceEpoch(v string) { - o.LastUpdateTimeSinceEpoch = &v -} - // GetDescription returns the Description field value if set, zero value otherwise. func (o *CatalogModel) GetDescription() string { if o == nil || IsNil(o.Description) { @@ -484,6 +424,38 @@ func (o *CatalogModel) SetCustomProperties(v map[string]MetadataValue) { o.CustomProperties = &v } +// GetExternalId returns the ExternalId field value if set, zero value otherwise. +func (o *CatalogModel) GetExternalId() string { + if o == nil || IsNil(o.ExternalId) { + var ret string + return ret + } + return *o.ExternalId +} + +// GetExternalIdOk returns a tuple with the ExternalId field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *CatalogModel) GetExternalIdOk() (*string, bool) { + if o == nil || IsNil(o.ExternalId) { + return nil, false + } + return o.ExternalId, true +} + +// HasExternalId returns a boolean if a field has been set. +func (o *CatalogModel) HasExternalId() bool { + if o != nil && !IsNil(o.ExternalId) { + return true + } + + return false +} + +// SetExternalId gets a reference to the given string and assigns it to the ExternalId field. +func (o *CatalogModel) SetExternalId(v string) { + o.ExternalId = &v +} + // GetName returns the Name field value func (o *CatalogModel) GetName() string { if o == nil { @@ -508,6 +480,102 @@ func (o *CatalogModel) SetName(v string) { o.Name = v } +// GetId returns the Id field value if set, zero value otherwise. +func (o *CatalogModel) GetId() string { + if o == nil || IsNil(o.Id) { + var ret string + return ret + } + return *o.Id +} + +// GetIdOk returns a tuple with the Id field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *CatalogModel) GetIdOk() (*string, bool) { + if o == nil || IsNil(o.Id) { + return nil, false + } + return o.Id, true +} + +// HasId returns a boolean if a field has been set. +func (o *CatalogModel) HasId() bool { + if o != nil && !IsNil(o.Id) { + return true + } + + return false +} + +// SetId gets a reference to the given string and assigns it to the Id field. +func (o *CatalogModel) SetId(v string) { + o.Id = &v +} + +// GetCreateTimeSinceEpoch returns the CreateTimeSinceEpoch field value if set, zero value otherwise. +func (o *CatalogModel) GetCreateTimeSinceEpoch() string { + if o == nil || IsNil(o.CreateTimeSinceEpoch) { + var ret string + return ret + } + return *o.CreateTimeSinceEpoch +} + +// GetCreateTimeSinceEpochOk returns a tuple with the CreateTimeSinceEpoch field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *CatalogModel) GetCreateTimeSinceEpochOk() (*string, bool) { + if o == nil || IsNil(o.CreateTimeSinceEpoch) { + return nil, false + } + return o.CreateTimeSinceEpoch, true +} + +// HasCreateTimeSinceEpoch returns a boolean if a field has been set. +func (o *CatalogModel) HasCreateTimeSinceEpoch() bool { + if o != nil && !IsNil(o.CreateTimeSinceEpoch) { + return true + } + + return false +} + +// SetCreateTimeSinceEpoch gets a reference to the given string and assigns it to the CreateTimeSinceEpoch field. +func (o *CatalogModel) SetCreateTimeSinceEpoch(v string) { + o.CreateTimeSinceEpoch = &v +} + +// GetLastUpdateTimeSinceEpoch returns the LastUpdateTimeSinceEpoch field value if set, zero value otherwise. +func (o *CatalogModel) GetLastUpdateTimeSinceEpoch() string { + if o == nil || IsNil(o.LastUpdateTimeSinceEpoch) { + var ret string + return ret + } + return *o.LastUpdateTimeSinceEpoch +} + +// GetLastUpdateTimeSinceEpochOk returns a tuple with the LastUpdateTimeSinceEpoch field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *CatalogModel) GetLastUpdateTimeSinceEpochOk() (*string, bool) { + if o == nil || IsNil(o.LastUpdateTimeSinceEpoch) { + return nil, false + } + return o.LastUpdateTimeSinceEpoch, true +} + +// HasLastUpdateTimeSinceEpoch returns a boolean if a field has been set. +func (o *CatalogModel) HasLastUpdateTimeSinceEpoch() bool { + if o != nil && !IsNil(o.LastUpdateTimeSinceEpoch) { + return true + } + + return false +} + +// SetLastUpdateTimeSinceEpoch gets a reference to the given string and assigns it to the LastUpdateTimeSinceEpoch field. +func (o *CatalogModel) SetLastUpdateTimeSinceEpoch(v string) { + o.LastUpdateTimeSinceEpoch = &v +} + // GetSourceId returns the SourceId field value if set, zero value otherwise. func (o *CatalogModel) GetSourceId() string { if o == nil || IsNil(o.SourceId) { @@ -550,12 +618,6 @@ func (o CatalogModel) MarshalJSON() ([]byte, error) { func (o CatalogModel) ToMap() (map[string]interface{}, error) { toSerialize := map[string]interface{}{} - if !IsNil(o.CreateTimeSinceEpoch) { - toSerialize["createTimeSinceEpoch"] = o.CreateTimeSinceEpoch - } - if !IsNil(o.LastUpdateTimeSinceEpoch) { - toSerialize["lastUpdateTimeSinceEpoch"] = o.LastUpdateTimeSinceEpoch - } if !IsNil(o.Description) { toSerialize["description"] = o.Description } @@ -589,7 +651,19 @@ func (o CatalogModel) ToMap() (map[string]interface{}, error) { if !IsNil(o.CustomProperties) { toSerialize["customProperties"] = o.CustomProperties } + if !IsNil(o.ExternalId) { + toSerialize["externalId"] = o.ExternalId + } toSerialize["name"] = o.Name + if !IsNil(o.Id) { + toSerialize["id"] = o.Id + } + if !IsNil(o.CreateTimeSinceEpoch) { + toSerialize["createTimeSinceEpoch"] = o.CreateTimeSinceEpoch + } + if !IsNil(o.LastUpdateTimeSinceEpoch) { + toSerialize["lastUpdateTimeSinceEpoch"] = o.LastUpdateTimeSinceEpoch + } if !IsNil(o.SourceId) { toSerialize["source_id"] = o.SourceId } diff --git a/catalog/pkg/openapi/model_catalog_model_artifact.go b/catalog/pkg/openapi/model_catalog_model_artifact.go index 97709aee..bfa5ff17 100644 --- a/catalog/pkg/openapi/model_catalog_model_artifact.go +++ b/catalog/pkg/openapi/model_catalog_model_artifact.go @@ -19,6 +19,16 @@ var _ MappedNullable = &CatalogModelArtifact{} // CatalogModelArtifact A Catalog Model Artifact Entity. type CatalogModelArtifact struct { + // User provided custom properties which are not defined by its type. + CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"` + // An optional description about the resource. + Description *string `json:"description,omitempty"` + // The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance. + ExternalId *string `json:"externalId,omitempty"` + // The client provided name of the artifact. This field is optional. If set, it must be unique among all the artifacts of the same artifact type within a database instance and cannot be changed once set. + Name *string `json:"name,omitempty"` + // The unique server generated id of the resource. + Id *string `json:"id,omitempty"` // Output only. Create time of the resource in millisecond since epoch. CreateTimeSinceEpoch *string `json:"createTimeSinceEpoch,omitempty"` // Output only. Last update time of the resource since epoch in millisecond since epoch. @@ -26,8 +36,6 @@ type CatalogModelArtifact struct { ArtifactType string `json:"artifactType"` // URI where the model can be retrieved. Uri string `json:"uri"` - // User provided custom properties which are not defined by its type. - CustomProperties *map[string]MetadataValue `json:"customProperties,omitempty"` } // NewCatalogModelArtifact instantiates a new CatalogModelArtifact object @@ -51,6 +59,166 @@ func NewCatalogModelArtifactWithDefaults() *CatalogModelArtifact { return &this } +// GetCustomProperties returns the CustomProperties field value if set, zero value otherwise. +func (o *CatalogModelArtifact) GetCustomProperties() map[string]MetadataValue { + if o == nil || IsNil(o.CustomProperties) { + var ret map[string]MetadataValue + return ret + } + return *o.CustomProperties +} + +// GetCustomPropertiesOk returns a tuple with the CustomProperties field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *CatalogModelArtifact) GetCustomPropertiesOk() (*map[string]MetadataValue, bool) { + if o == nil || IsNil(o.CustomProperties) { + return nil, false + } + return o.CustomProperties, true +} + +// HasCustomProperties returns a boolean if a field has been set. +func (o *CatalogModelArtifact) HasCustomProperties() bool { + if o != nil && !IsNil(o.CustomProperties) { + return true + } + + return false +} + +// SetCustomProperties gets a reference to the given map[string]MetadataValue and assigns it to the CustomProperties field. +func (o *CatalogModelArtifact) SetCustomProperties(v map[string]MetadataValue) { + o.CustomProperties = &v +} + +// GetDescription returns the Description field value if set, zero value otherwise. +func (o *CatalogModelArtifact) GetDescription() string { + if o == nil || IsNil(o.Description) { + var ret string + return ret + } + return *o.Description +} + +// GetDescriptionOk returns a tuple with the Description field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *CatalogModelArtifact) GetDescriptionOk() (*string, bool) { + if o == nil || IsNil(o.Description) { + return nil, false + } + return o.Description, true +} + +// HasDescription returns a boolean if a field has been set. +func (o *CatalogModelArtifact) HasDescription() bool { + if o != nil && !IsNil(o.Description) { + return true + } + + return false +} + +// SetDescription gets a reference to the given string and assigns it to the Description field. +func (o *CatalogModelArtifact) SetDescription(v string) { + o.Description = &v +} + +// GetExternalId returns the ExternalId field value if set, zero value otherwise. +func (o *CatalogModelArtifact) GetExternalId() string { + if o == nil || IsNil(o.ExternalId) { + var ret string + return ret + } + return *o.ExternalId +} + +// GetExternalIdOk returns a tuple with the ExternalId field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *CatalogModelArtifact) GetExternalIdOk() (*string, bool) { + if o == nil || IsNil(o.ExternalId) { + return nil, false + } + return o.ExternalId, true +} + +// HasExternalId returns a boolean if a field has been set. +func (o *CatalogModelArtifact) HasExternalId() bool { + if o != nil && !IsNil(o.ExternalId) { + return true + } + + return false +} + +// SetExternalId gets a reference to the given string and assigns it to the ExternalId field. +func (o *CatalogModelArtifact) SetExternalId(v string) { + o.ExternalId = &v +} + +// GetName returns the Name field value if set, zero value otherwise. +func (o *CatalogModelArtifact) GetName() string { + if o == nil || IsNil(o.Name) { + var ret string + return ret + } + return *o.Name +} + +// GetNameOk returns a tuple with the Name field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *CatalogModelArtifact) GetNameOk() (*string, bool) { + if o == nil || IsNil(o.Name) { + return nil, false + } + return o.Name, true +} + +// HasName returns a boolean if a field has been set. +func (o *CatalogModelArtifact) HasName() bool { + if o != nil && !IsNil(o.Name) { + return true + } + + return false +} + +// SetName gets a reference to the given string and assigns it to the Name field. +func (o *CatalogModelArtifact) SetName(v string) { + o.Name = &v +} + +// GetId returns the Id field value if set, zero value otherwise. +func (o *CatalogModelArtifact) GetId() string { + if o == nil || IsNil(o.Id) { + var ret string + return ret + } + return *o.Id +} + +// GetIdOk returns a tuple with the Id field value if set, nil otherwise +// and a boolean to check if the value has been set. +func (o *CatalogModelArtifact) GetIdOk() (*string, bool) { + if o == nil || IsNil(o.Id) { + return nil, false + } + return o.Id, true +} + +// HasId returns a boolean if a field has been set. +func (o *CatalogModelArtifact) HasId() bool { + if o != nil && !IsNil(o.Id) { + return true + } + + return false +} + +// SetId gets a reference to the given string and assigns it to the Id field. +func (o *CatalogModelArtifact) SetId(v string) { + o.Id = &v +} + // GetCreateTimeSinceEpoch returns the CreateTimeSinceEpoch field value if set, zero value otherwise. func (o *CatalogModelArtifact) GetCreateTimeSinceEpoch() string { if o == nil || IsNil(o.CreateTimeSinceEpoch) { @@ -163,38 +331,6 @@ func (o *CatalogModelArtifact) SetUri(v string) { o.Uri = v } -// GetCustomProperties returns the CustomProperties field value if set, zero value otherwise. -func (o *CatalogModelArtifact) GetCustomProperties() map[string]MetadataValue { - if o == nil || IsNil(o.CustomProperties) { - var ret map[string]MetadataValue - return ret - } - return *o.CustomProperties -} - -// GetCustomPropertiesOk returns a tuple with the CustomProperties field value if set, nil otherwise -// and a boolean to check if the value has been set. -func (o *CatalogModelArtifact) GetCustomPropertiesOk() (*map[string]MetadataValue, bool) { - if o == nil || IsNil(o.CustomProperties) { - return nil, false - } - return o.CustomProperties, true -} - -// HasCustomProperties returns a boolean if a field has been set. -func (o *CatalogModelArtifact) HasCustomProperties() bool { - if o != nil && !IsNil(o.CustomProperties) { - return true - } - - return false -} - -// SetCustomProperties gets a reference to the given map[string]MetadataValue and assigns it to the CustomProperties field. -func (o *CatalogModelArtifact) SetCustomProperties(v map[string]MetadataValue) { - o.CustomProperties = &v -} - func (o CatalogModelArtifact) MarshalJSON() ([]byte, error) { toSerialize, err := o.ToMap() if err != nil { @@ -205,6 +341,21 @@ func (o CatalogModelArtifact) MarshalJSON() ([]byte, error) { func (o CatalogModelArtifact) ToMap() (map[string]interface{}, error) { toSerialize := map[string]interface{}{} + if !IsNil(o.CustomProperties) { + toSerialize["customProperties"] = o.CustomProperties + } + if !IsNil(o.Description) { + toSerialize["description"] = o.Description + } + if !IsNil(o.ExternalId) { + toSerialize["externalId"] = o.ExternalId + } + if !IsNil(o.Name) { + toSerialize["name"] = o.Name + } + if !IsNil(o.Id) { + toSerialize["id"] = o.Id + } if !IsNil(o.CreateTimeSinceEpoch) { toSerialize["createTimeSinceEpoch"] = o.CreateTimeSinceEpoch } @@ -213,9 +364,6 @@ func (o CatalogModelArtifact) ToMap() (map[string]interface{}, error) { } toSerialize["artifactType"] = o.ArtifactType toSerialize["uri"] = o.Uri - if !IsNil(o.CustomProperties) { - toSerialize["customProperties"] = o.CustomProperties - } return toSerialize, nil } diff --git a/go.mod b/go.mod index 96ebbe78..8fcff20f 100644 --- a/go.mod +++ b/go.mod @@ -206,7 +206,7 @@ require ( github.com/tklauser/numcpus v0.6.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect + golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 golang.org/x/net v0.43.0 // indirect golang.org/x/sys v0.36.0 // indirect golang.org/x/text v0.28.0 // indirect diff --git a/templates/go-server/api.mustache b/templates/go-server/api.mustache index df847fda..e9fd7d3b 100644 --- a/templates/go-server/api.mustache +++ b/templates/go-server/api.mustache @@ -3,8 +3,7 @@ package {{packageName}} import ( "context" - "net/http"{{#apiInfo}}{{#apis}}{{#imports}} - "{{import}}"{{/imports}}{{/apis}}{{/apiInfo}} + "net/http" model "github.com/kubeflow/model-registry/pkg/openapi" ) @@ -30,5 +29,5 @@ type {{classname}}Servicer interface { {{#operations}}{{#operation}} {{#isDeprecated}} // Deprecated {{/isDeprecated}} - {{operationId}}(context.Context{{#allParams}}, {{^isPrimitiveType}}model.{{/isPrimitiveType}}{{dataType}}{{/allParams}}) (ImplResponse, error){{/operation}}{{/operations}} + {{operationId}}(context.Context{{#allParams}}, {{^isPrimitiveType}}{{^isContainer}}model.{{/isContainer}}{{#isContainer}}{{^items.isPrimitiveType}}model.{{/items.isPrimitiveType}}{{/isContainer}}{{/isPrimitiveType}}{{dataType}}{{/allParams}}) (ImplResponse, error){{/operation}}{{/operations}} }{{/apis}}{{/apiInfo}}