Compare commits
184 Commits
Author | SHA1 | Date |
---|---|---|
|
3bfb7cdef7 | |
|
4949f66e95 | |
|
b293d53867 | |
|
a420f612fd | |
|
07d7ca9c19 | |
|
525fca9bfb | |
|
e4e7e03afc | |
|
dca71b4c82 | |
|
3b4d11b9bb | |
|
99a27c2377 | |
|
6eaeb71afa | |
|
1674840dc1 | |
|
9900fdc69a | |
|
6a255e8e35 | |
|
4248ee077a | |
|
772cb25693 | |
|
cb567bae15 | |
|
390d4835cd | |
|
b1cedab627 | |
|
7499971950 | |
|
228b62d77e | |
|
a30065bc2c | |
|
34550a5f07 | |
|
41b53ae73d | |
|
dfc1a224de | |
|
e4dd95ec87 | |
|
eb3bd81c4f | |
|
266023f9dc | |
|
dd26f5dc51 | |
|
29ab55cfef | |
|
9573e69c0a | |
|
7c50427c02 | |
|
655a9d5eee | |
|
60ab78af95 | |
|
36e658f17a | |
|
fb8d493f12 | |
|
9be856091f | |
|
b9085ccc56 | |
|
1da137c913 | |
|
1c9b6ffabc | |
|
58680c990d | |
|
16866854a7 | |
|
ca4b82ba11 | |
|
6d7b314a4b | |
|
0666a7c3dc | |
|
96b840330f | |
|
77db33d3e1 | |
|
e419bba715 | |
|
348ec21a63 | |
|
9c8852767f | |
|
4ec88777b9 | |
|
220c23c9be | |
|
585ea7c495 | |
|
3844d7c287 | |
|
30c1fc1452 | |
|
2b02a22307 | |
|
4a8756613d | |
|
b3d4b1ef5e | |
|
f1c84d6715 | |
|
bff6a47ad2 | |
|
3ddb84d3ae | |
|
c2f5a7c776 | |
|
3df6225763 | |
|
71db5939a5 | |
|
39e91a6980 | |
|
e5293f97f8 | |
|
77b7a58360 | |
|
875e349d5b | |
|
a56dc20249 | |
|
949f73e6b9 | |
|
32bf6ec955 | |
|
8fc73138c2 | |
|
a70a022d1d | |
|
91f18fb72e | |
|
81dfa00643 | |
|
ee86a5bc80 | |
|
b4c43a0dfc | |
|
f97d9c682c | |
|
223e8bd8e1 | |
|
2706506f21 | |
|
7a98228f9d | |
|
31188d675c | |
|
fa02ee6bb4 | |
|
8c9d1903a7 | |
|
a81ed96d2b | |
|
44b45989b2 | |
|
724345d688 | |
|
41084e7c3d | |
|
26ce0085da | |
|
34671eaaf4 | |
|
f333dc9a70 | |
|
de8a251c8a | |
|
f86d88022e | |
|
883f45696b | |
|
7a2cfeb384 | |
|
8356c82a6b | |
|
c8e680d13c | |
|
be68035427 | |
|
b59b25bdba | |
|
db98d7b190 | |
|
cdf362a42c | |
|
23e6e420d4 | |
|
095bbae787 | |
|
532525278a | |
|
f12ff4c936 | |
|
903eb23453 | |
|
009bd76d71 | |
|
d83a40b536 | |
|
b5858325f5 | |
|
0816f1484d | |
|
1aa78b5fc9 | |
|
1d850f1b1b | |
|
719ffd63b2 | |
|
4554d487e2 | |
|
88dfddfa5c | |
|
4a33154d9b | |
|
c843a03dfe | |
|
36ba2cc295 | |
|
929b423de9 | |
|
f96d87df7e | |
|
277b3c2a86 | |
|
0f0b58b42b | |
|
3014cb90f2 | |
|
54f2d8a503 | |
|
bde2a612c7 | |
|
8c030e257a | |
|
1f29feeae9 | |
|
19df266997 | |
|
367a87fb56 | |
|
a7494aabab | |
|
06961086fb | |
|
9e0e3f8d64 | |
|
dbebee9380 | |
|
412999891c | |
|
cce093bae3 | |
|
d094092f92 | |
|
e0169c08e6 | |
|
c7eb14c9cc | |
|
4ad9f87b5f | |
|
026add59fe | |
|
0fa7dbfbd6 | |
|
6b11b58e43 | |
|
47374731d5 | |
|
89b97016d8 | |
|
19e0bc8ecf | |
|
aca5e43960 | |
|
000d57cb42 | |
|
050b97c579 | |
|
8b4ffb2dd2 | |
|
26d7b2fa48 | |
|
2abf354763 | |
|
c7f480f3e0 | |
|
d6acd1d9c6 | |
|
fc5ab19f47 | |
|
b09af1c2c7 | |
|
7bdf9ba14b | |
|
2dbb1b0387 | |
|
6545c37be1 | |
|
126945bc4e | |
|
f8905e9295 | |
|
06fdaafd69 | |
|
1e91d33421 | |
|
d2c12d0447 | |
|
6faa8679d5 | |
|
21197e5747 | |
|
a3f4b6e707 | |
|
3bf656a05d | |
|
c7aed7f8f9 | |
|
d2c04b069b | |
|
ef98424c1f | |
|
185c389190 | |
|
89af055243 | |
|
5f69126991 | |
|
f22f82a20a | |
|
3efb314e27 | |
|
273cc63557 | |
|
4e4447ab14 | |
|
11af8840ec | |
|
fdbb22956c | |
|
f011eff313 | |
|
9e0beacd96 | |
|
edaacadbe2 | |
|
04ba57994b | |
|
b59a4ebd5c |
|
@ -0,0 +1,2 @@
|
|||
# Always check-out / check-in files with LF line endings.
|
||||
* text=auto eol=lf
|
|
@ -12,7 +12,9 @@ updates:
|
|||
schedule:
|
||||
interval: "weekly"
|
||||
- package-ecosystem: "pip"
|
||||
directory: "/clients/python/"
|
||||
directories:
|
||||
- "/clients/python/"
|
||||
- "/jobs/async-upload"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
- package-ecosystem: "docker"
|
||||
|
|
|
@ -20,6 +20,10 @@
|
|||
- changed-files:
|
||||
- any-glob-to-any-file: "csi/**"
|
||||
|
||||
"Area/Jobs/Async-upload":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: "jobs/async-upload/**"
|
||||
|
||||
"Area/Manifests":
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: "manifests/**"
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
- [ ] Automated tests are provided as part of the PR for major new functionalities; testing instructions have been added in the PR body (for PRs involving changes that are not immediately obvious).
|
||||
- [ ] The developer has manually tested the changes and verified that the changes work.
|
||||
- [ ] Code changes follow the [kubeflow contribution guidelines](https://www.kubeflow.org/docs/about/contributing/).
|
||||
- [ ] **For first time contributors**: Please reach out to the [Reviewers](../OWNERS) to ensure all tests are being run, ensuring the label `ok-to-test` has been added to the PR.
|
||||
- [ ] **For first time contributors**: Please reach out to the [Reviewers](https://github.com/kubeflow/model-registry/blob/main/OWNERS) to ensure all tests are being run, ensuring the label `ok-to-test` has been added to the PR.
|
||||
|
||||
If you have UI changes
|
||||
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
name: Test async-upload Job
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths-ignore:
|
||||
- "LICENSE*"
|
||||
- "**.gitignore"
|
||||
- "**.md"
|
||||
- "**.txt"
|
||||
- ".github/ISSUE_TEMPLATE/**"
|
||||
- ".github/dependabot.yml"
|
||||
- "docs/**"
|
||||
pull_request:
|
||||
paths:
|
||||
- "jobs/async-upload/**"
|
||||
- ".github/workflows/**"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# Async Job
|
||||
JOB_IMG_REGISTRY: ghcr.io
|
||||
JOB_IMG_ORG: kubeflow
|
||||
JOB_IMG_NAME: model-registry/job/async-upload
|
||||
JOB_IMG_VERSION: cicd
|
||||
# MR Server
|
||||
IMG_REGISTRY: ghcr.io
|
||||
IMG_ORG: kubeflow
|
||||
IMG_REPO: model-registry/server
|
||||
IMG_VERSION: cicd
|
||||
PUSH_IMAGE: false
|
||||
|
||||
jobs:
|
||||
py-test:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: jobs/async-upload
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10" # refers to the Container image
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
pipx install poetry
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
make install
|
||||
- name: Run tests
|
||||
run: |
|
||||
make test
|
||||
- name: Remove AppArmor profile for mysql in KinD on GHA # https://github.com/kubeflow/manifests/issues/2507
|
||||
run: |
|
||||
set -x
|
||||
sudo apparmor_parser -R /etc/apparmor.d/usr.sbin.mysqld
|
||||
- name: Run E2E tests
|
||||
run: |
|
||||
make test-e2e
|
||||
job-test:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: jobs/async-upload
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10" # refers to the Container image
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
pipx install poetry
|
||||
- name: Remove AppArmor profile for mysql in KinD on GHA # https://github.com/kubeflow/manifests/issues/2507
|
||||
run: |
|
||||
set -x
|
||||
sudo apparmor_parser -R /etc/apparmor.d/usr.sbin.mysqld
|
||||
- name: Execute Sample Job E2E test
|
||||
run: |
|
||||
make test-integration
|
|
@ -0,0 +1,63 @@
|
|||
name: Build and Push async-upload container image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
tags:
|
||||
- 'v*'
|
||||
paths:
|
||||
- 'jobs/async-upload/**'
|
||||
- '!LICENSE*'
|
||||
- '!DOCKERFILE*'
|
||||
- '!**.gitignore'
|
||||
- '!**.md'
|
||||
- '!**.txt'
|
||||
|
||||
env:
|
||||
IMG_REGISTRY: ghcr.io
|
||||
IMG_ORG: kubeflow
|
||||
IMG_NAME: model-registry/job/async-upload
|
||||
REGISTRY_USER: ${{ github.actor }}
|
||||
REGISTRY_PWD: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.IMG_REGISTRY }}
|
||||
username: ${{ env.REGISTRY_USER }}
|
||||
password: ${{ env.REGISTRY_PWD }}
|
||||
|
||||
- name: Set main-branch environment # this is for main-sha tag image build
|
||||
if: github.ref == 'refs/heads/main'
|
||||
run: |
|
||||
commit_sha=${{ github.sha }}
|
||||
tag=main-${commit_sha:0:7}
|
||||
echo "VERSION=${tag}" >> $GITHUB_ENV
|
||||
|
||||
- name: Set tag environment # this is for v* tag image build
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
run: |
|
||||
echo "VERSION=${{ github.ref_name }}" >> $GITHUB_ENV
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./jobs/async-upload
|
||||
push: true
|
||||
tags: ${{ env.IMG_REGISTRY }}/${{ env.IMG_ORG }}/${{ env.IMG_NAME }}:${{ env.VERSION }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
|
@ -23,6 +23,9 @@ env:
|
|||
jobs:
|
||||
build-csi-image:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
# Assign context variable for various action contexts (tag, main, CI)
|
||||
- name: Assigning tag context
|
||||
|
@ -32,7 +35,7 @@ jobs:
|
|||
if: github.head_ref == '' && github.ref == 'refs/heads/main'
|
||||
run: echo "BUILD_CONTEXT=main" >> $GITHUB_ENV
|
||||
# checkout branch
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
# set image version
|
||||
- name: Set main-branch environment
|
||||
if: env.BUILD_CONTEXT == 'main'
|
||||
|
|
|
@ -20,9 +20,19 @@ env:
|
|||
PUSH_IMAGE: true
|
||||
DOCKER_USER: ${{ github.actor }}
|
||||
DOCKER_PWD: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
permissions: # default workflow permission, overridden for specific job where required
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
prepare:
|
||||
uses: ./.github/workflows/prepare.yml
|
||||
build-image:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
runs-on: ubuntu-latest
|
||||
needs: prepare
|
||||
steps:
|
||||
# Assign context variable for various action contexts (tag, main, CI)
|
||||
- name: Assigning tag context
|
||||
|
@ -32,7 +42,7 @@ jobs:
|
|||
if: github.head_ref == '' && github.ref == 'refs/heads/main'
|
||||
run: echo "BUILD_CONTEXT=main" >> $GITHUB_ENV
|
||||
# checkout branch
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
# set image version
|
||||
- name: Set main-branch environment
|
||||
if: env.BUILD_CONTEXT == 'main'
|
||||
|
|
|
@ -27,7 +27,7 @@ jobs:
|
|||
packages: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
@ -74,5 +74,6 @@ jobs:
|
|||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
DEPLOYMENT_MODE=standalone
|
||||
STYLE_THEME=mui-theme
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
|
@ -1,4 +1,5 @@
|
|||
name: Build and Push UI and BFF Images
|
||||
name: Build and Push UI Image
|
||||
# this workflow builds an image to support local testing
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
|
@ -7,51 +8,72 @@ on:
|
|||
- 'v*'
|
||||
paths:
|
||||
- 'clients/ui/**'
|
||||
- '!LICENSE*'
|
||||
- '!DOCKERFILE*'
|
||||
- '!**.gitignore'
|
||||
- '!**.md'
|
||||
- '!**.txt'
|
||||
env:
|
||||
IMG_REGISTRY: ghcr.io
|
||||
IMG_ORG: kubeflow
|
||||
IMG_UI_REPO: model-registry/ui
|
||||
PUSH_IMAGE: true
|
||||
IMG_UI_REPO: model-registry/ui # this image is intended for local development, not production
|
||||
DOCKER_USER: ${{ github.actor }}
|
||||
DOCKER_PWD: ${{ secrets.GITHUB_TOKEN }}
|
||||
jobs:
|
||||
build-image:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
# Assign context variable for various action contexts (tag, main, CI)
|
||||
- name: Assigning tag context
|
||||
if: github.head_ref == '' && startsWith(github.ref, 'refs/tags/v')
|
||||
run: echo "BUILD_CONTEXT=tag" >> $GITHUB_ENV
|
||||
# Assign context variable for various action contexts (main, CI)
|
||||
- name: Assigning main context
|
||||
if: github.head_ref == '' && github.ref == 'refs/heads/main'
|
||||
run: echo "BUILD_CONTEXT=main" >> $GITHUB_ENV
|
||||
# checkout branch
|
||||
- uses: actions/checkout@v4
|
||||
# set image version
|
||||
- name: Set main-branch environment
|
||||
if: env.BUILD_CONTEXT == 'main'
|
||||
run: |
|
||||
commit_sha=${{ github.event.after }}
|
||||
tag=main-${commit_sha:0:7}
|
||||
echo "VERSION=${tag}" >> $GITHUB_ENV
|
||||
- name: Set tag environment
|
||||
if: env.BUILD_CONTEXT == 'tag'
|
||||
run: |
|
||||
echo "VERSION=${{ github.ref_name }}" >> $GITHUB_ENV
|
||||
- name: Build and Push UI Image
|
||||
shell: bash
|
||||
env:
|
||||
IMG_REPO: ${{ env.IMG_UI_REPO }}
|
||||
run: ./scripts/build_deploy.sh
|
||||
- name: Tag Latest UI Image
|
||||
if: env.BUILD_CONTEXT == 'main'
|
||||
shell: bash
|
||||
env:
|
||||
IMG_REPO: ${{ env.IMG_UI_REPO }}
|
||||
IMG: "${{ env.IMG_REGISTRY }}/${{ env.IMG_ORG }}/${{ env.IMG_UI_REPO }}"
|
||||
BUILD_IMAGE: false # image is already built in "Build and Push UI Image" step
|
||||
run: |
|
||||
docker tag ${{ env.IMG }}:$VERSION ${{ env.IMG }}:latest
|
||||
# BUILD_IMAGE=false skip the build, just push the tag made above
|
||||
VERSION=latest ./scripts/build_deploy.sh
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.IMG_REGISTRY }}
|
||||
username: ${{ env.DOCKER_USER }}
|
||||
password: ${{ env.DOCKER_PWD }}
|
||||
|
||||
- name: Set main-branch environment
|
||||
if: github.ref == 'refs/heads/main'
|
||||
run: |
|
||||
commit_sha=${{ github.sha }}
|
||||
tag=main-${commit_sha:0:7}
|
||||
echo "VERSION=${tag}" >> $GITHUB_ENV
|
||||
|
||||
- name: Set tag environment
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
run: |
|
||||
echo "VERSION=${{ github.ref_name }}" >> $GITHUB_ENV
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: "${{ env.IMG_REGISTRY }}/${{ env.IMG_ORG }}/${{ env.IMG_UI_REPO }}"
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=sha
|
||||
type=raw,value=${{ env.VERSION }},enable=${{ env.VERSION != '' }}
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./clients/ui
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
DEPLOYMENT_MODE=kubeflow
|
||||
STYLE_THEME=mui-theme
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
|
@ -11,6 +11,10 @@ on:
|
|||
- ".github/dependabot.yml"
|
||||
- "docs/**"
|
||||
- "clients/python/**"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
IMG_REGISTRY: ghcr.io
|
||||
IMG_ORG: kubeflow
|
||||
|
@ -21,7 +25,7 @@ jobs:
|
|||
build-and-test-image:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
- name: Generate Tag
|
||||
shell: bash
|
||||
id: tags
|
||||
|
|
|
@ -10,6 +10,11 @@ on:
|
|||
- "!**.gitignore"
|
||||
- "!**.md"
|
||||
- "!**.txt"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: read
|
||||
|
||||
env:
|
||||
IMG_ORG: kubeflow
|
||||
IMG_REPO: model-registry/ui
|
||||
|
@ -20,7 +25,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# checkout branch
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
- name: Build UI Image
|
||||
shell: bash
|
||||
run: ./scripts/build_deploy.sh
|
||||
|
|
|
@ -13,30 +13,24 @@ on:
|
|||
- ".github/ISSUE_TEMPLATE/**"
|
||||
- ".github/dependabot.yml"
|
||||
- "docs/**"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
prepare:
|
||||
uses: ./.github/workflows/prepare.yml
|
||||
build:
|
||||
needs: prepare
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23"
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.9
|
||||
go-version: "1.24"
|
||||
- name: Build
|
||||
run: make clean build
|
||||
- name: Check if there are uncommitted file changes
|
||||
run: |
|
||||
clean=$(git status --porcelain)
|
||||
if [[ -z "$clean" ]]; then
|
||||
echo "Empty git status --porcelain: $clean"
|
||||
else
|
||||
echo "Uncommitted file changes detected: $clean"
|
||||
git diff
|
||||
exit 1
|
||||
fi
|
||||
run: make build/compile
|
||||
- name: Unit tests
|
||||
run: make test-cover
|
||||
|
|
|
@ -1,28 +1,53 @@
|
|||
name: Check DB schema structs
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/**"
|
||||
- "internal/db/schema/**"
|
||||
- "internal/datastore/embedmd/mysql/migrations/**"
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
check-schema-structs:
|
||||
check-mysql-schema-structs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.6"
|
||||
- name: Generate DB schema structs
|
||||
run: make gen/gorm
|
||||
go-version: "1.24.4"
|
||||
- name: Generate MySQL DB schema structs
|
||||
run: make gen/gorm/mysql
|
||||
- name: Check if there are uncommitted file changes
|
||||
run: |
|
||||
clean=$(git status --porcelain)
|
||||
if [[ -z "$clean" ]]; then
|
||||
echo "Empty git status --porcelain: $clean"
|
||||
echo "MySQL schema is up to date."
|
||||
else
|
||||
echo "Uncommitted file changes detected: $clean"
|
||||
echo "Uncommitted file changes detected after generating MySQL schema: $clean"
|
||||
git diff
|
||||
exit 1
|
||||
fi
|
||||
check-postgres-schema-structs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.4"
|
||||
- name: Generate PostgreSQL DB schema structs
|
||||
run: make gen/gorm/postgres
|
||||
- name: Check if there are uncommitted file changes
|
||||
run: |
|
||||
clean=$(git status --porcelain)
|
||||
if [[ -z "$clean" ]]; then
|
||||
echo "PostgreSQL schema is up to date."
|
||||
else
|
||||
echo "Uncommitted file changes detected after generating PostgreSQL schema: $clean"
|
||||
git diff
|
||||
exit 1
|
||||
fi
|
||||
|
|
|
@ -4,11 +4,15 @@ on:
|
|||
paths:
|
||||
- ".github/workflows/**"
|
||||
- "api/openapi/model-registry.yaml"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
validate:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
- name: Validate OpenAPI spec
|
||||
run: |
|
||||
make openapi/validate
|
||||
|
|
|
@ -20,6 +20,9 @@ on:
|
|||
- "pkg/openapi/**"
|
||||
- "go.mod"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
BRANCH: ${{ github.base_ref }}
|
||||
jobs:
|
||||
|
@ -28,7 +31,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone the code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
|
|
|
@ -21,6 +21,9 @@ on:
|
|||
# csi build depends on base go.mod https://github.com/kubeflow/model-registry/issues/311
|
||||
- "go.mod"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
IMG_REGISTRY: ghcr.io
|
||||
IMG_ORG: kubeflow
|
||||
|
@ -32,7 +35,7 @@ jobs:
|
|||
build-and-test-csi-image:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- name: Generate tag
|
||||
shell: bash
|
||||
|
|
|
@ -6,6 +6,9 @@ on:
|
|||
- main
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
fossa-scan:
|
||||
if: github.repository_owner == 'kubeflow' # FOSSA is not intended to run on forks.
|
||||
|
@ -16,7 +19,7 @@ jobs:
|
|||
FOSSA_API_KEY: 80871bdd477c2c97f65e9822cae99d20 # This is a push-only token that is safe to be exposed.
|
||||
steps:
|
||||
- name: Checkout tree
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Run FOSSA scan and upload build data
|
||||
uses: fossas/fossa-action@v1.7.0
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
on:
|
||||
workflow_call
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
prepare:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24"
|
||||
- name: Prepare
|
||||
run: make clean build/prepare
|
||||
- name: Check if there are uncommitted file changes
|
||||
run: |
|
||||
clean=$(git status --porcelain)
|
||||
if [[ -z "$clean" ]]; then
|
||||
echo "Empty git status --porcelain: $clean"
|
||||
else
|
||||
echo "Uncommitted file changes detected: $clean"
|
||||
git diff
|
||||
exit 1
|
||||
fi
|
|
@ -14,7 +14,7 @@ jobs:
|
|||
FORCE_COLOR: "1"
|
||||
steps:
|
||||
- name: Check out the repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up Python
|
||||
|
|
|
@ -14,6 +14,9 @@ on:
|
|||
- ".github/dependabot.yml"
|
||||
- "docs/**"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
name: ${{ matrix.session }}
|
||||
|
@ -28,7 +31,7 @@ jobs:
|
|||
FORCE_COLOR: "1"
|
||||
steps:
|
||||
- name: Check out the repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
- name: Set up Python ${{ matrix.python }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
|
@ -75,7 +78,7 @@ jobs:
|
|||
nodejs: ["20"]
|
||||
steps:
|
||||
- name: Check out the repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up Python
|
||||
|
@ -140,7 +143,7 @@ jobs:
|
|||
IMG_REPO: model-registry
|
||||
steps:
|
||||
- name: Check out the repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
- name: Set up Python ${{ matrix.python }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
|
@ -221,6 +224,15 @@ jobs:
|
|||
kubectl port-forward service/distribution-registry-test-service 5001:5001 &
|
||||
sleep 2
|
||||
nox --python=${{ matrix.python }} --session=e2e -- --cov-report=xml
|
||||
- name: Nox test fuzz (main only)
|
||||
if: github.ref == 'refs/heads/main'
|
||||
working-directory: clients/python
|
||||
run: |
|
||||
kubectl port-forward -n kubeflow service/model-registry-service 8080:8080 &
|
||||
kubectl port-forward -n minio svc/minio 9000:9000 &
|
||||
kubectl port-forward service/distribution-registry-test-service 5001:5001 &
|
||||
sleep 2
|
||||
nox --python=${{ matrix.python }} --session=fuzz
|
||||
|
||||
docs-build:
|
||||
name: ${{ matrix.session }}
|
||||
|
@ -235,7 +247,7 @@ jobs:
|
|||
FORCE_COLOR: "1"
|
||||
steps:
|
||||
- name: Check out the repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
- name: Set up Python ${{ matrix.python }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
name: run-robot-tests
|
||||
run-name: Run Robot Framework tests
|
||||
# Run workflow
|
||||
on:
|
||||
# For every push to repository
|
||||
push:
|
||||
# To any branch
|
||||
branches:
|
||||
- "*"
|
||||
# For every pull request
|
||||
pull_request:
|
||||
# But ignore this paths
|
||||
paths-ignore:
|
||||
- "LICENSE*"
|
||||
- "DOCKERFILE*"
|
||||
- "**.gitignore"
|
||||
- "**.md"
|
||||
- "**.txt"
|
||||
- ".github/ISSUE_TEMPLATE/**"
|
||||
- ".github/dependabot.yml"
|
||||
- "docs/**"
|
||||
- "scripts/**"
|
||||
# Define workflow jobs
|
||||
jobs:
|
||||
# Job runs Robot Framework tests against locally build image from current code
|
||||
run-robot-tests:
|
||||
# Ubuntu latest is sufficient system for run
|
||||
runs-on: ubuntu-latest
|
||||
# Define steps of job
|
||||
steps:
|
||||
# Get checkout action to get this repository
|
||||
- uses: actions/checkout@v4
|
||||
# Install defined Python version to run Robot Framework tests
|
||||
- name: Install Python 3.9.x
|
||||
# Get setup-python action to install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
# Set Python version to install
|
||||
python-version: "3.9"
|
||||
# Set architecture of Python to install
|
||||
architecture: "x64"
|
||||
# Install required Python packages for running Robot Framework tests
|
||||
- name: Install required Python packages
|
||||
# Install required Python packages using pip
|
||||
run: pip install -r test/robot/requirements.txt
|
||||
# Install model_registry Python package from current code
|
||||
- name: Install model_registry Python package
|
||||
# Install model_registry package as editable using pip
|
||||
run: pip install -e clients/python
|
||||
# Start docker compose with locally build image from current code
|
||||
- name: Start docker compose with local image
|
||||
# Start docker compose in the background
|
||||
run: docker compose -f docker-compose-local.yaml up --detach
|
||||
# Run Robot Framework tests in REST mode against running docker compose
|
||||
- name: Run Robot Framework tests (REST mode)
|
||||
# Run Robot Framework tests in REST mode from test/robot directory
|
||||
run: robot test/robot
|
||||
# Shutdown docker compose with locally build image from current code
|
||||
- name: Shutdown docker compose with local image
|
||||
# Shutdown docker compose running in the background
|
||||
run: docker compose -f docker-compose-local.yaml down
|
|
@ -0,0 +1,69 @@
|
|||
name: Fuzz Test
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pr_number:
|
||||
description: 'The pull request number to run fuzz tests on'
|
||||
required: true
|
||||
type: number
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
IMG_REGISTRY: ghcr.io
|
||||
IMG_ORG: kubeflow
|
||||
IMG_REPO: model-registry/server
|
||||
IMG_VERSION: latest
|
||||
PUSH_IMAGE: false
|
||||
|
||||
jobs:
|
||||
test-fuzz:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: clients/python
|
||||
steps:
|
||||
- name: Get PR details
|
||||
id: pr
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
try {
|
||||
const pr = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: ${{ github.event.inputs.pr_number }}
|
||||
});
|
||||
return {
|
||||
sha: pr.data.head.sha,
|
||||
ref: pr.data.head.ref
|
||||
};
|
||||
} catch (error) {
|
||||
console.log(`Error fetching PR #${{ github.event.inputs.pr_number }}: ${error.message}`);
|
||||
throw error;
|
||||
}
|
||||
|
||||
- name: Checkout PR
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
ref: ${{ fromJson(steps.pr.outputs.result).sha }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
pipx install poetry
|
||||
|
||||
- name: Remove AppArmor profile for mysql in KinD on GHA # https://github.com/kubeflow/manifests/issues/2507
|
||||
run: |
|
||||
set -x
|
||||
sudo apparmor_parser -R /etc/apparmor.d/usr.sbin.mysqld
|
||||
|
||||
- name: Run Fuzz Tests
|
||||
run: |
|
||||
echo "Starting fuzz tests..."
|
||||
make test-fuzz
|
|
@ -18,11 +18,15 @@ on:
|
|||
- "!DOCKERFILE*"
|
||||
- "!**.gitignore"
|
||||
- "!**.md"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
|
|
|
@ -18,11 +18,15 @@ on:
|
|||
- "!DOCKERFILE*"
|
||||
- "!**.gitignore"
|
||||
- "!**.md"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
test-and-build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
|
|
|
@ -16,9 +16,6 @@ __debug*
|
|||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
# Go workspace file
|
||||
go.work
|
||||
|
||||
# Idea files
|
||||
.idea
|
||||
|
||||
|
@ -27,6 +24,9 @@ go.work
|
|||
model-registry
|
||||
metadata.sqlite.db
|
||||
|
||||
# Temporary files for running the project
|
||||
.port-forwards.pid
|
||||
|
||||
# Ignore go vendor and code coverage files
|
||||
vendor
|
||||
coverage.*
|
||||
|
@ -52,3 +52,7 @@ istio-*
|
|||
|
||||
# VSCode files
|
||||
.vscode/
|
||||
|
||||
# Python
|
||||
venv/
|
||||
.python-version
|
||||
|
|
43
Dockerfile
43
Dockerfile
|
@ -1,11 +1,12 @@
|
|||
# Build the model-registry binary
|
||||
FROM --platform=$BUILDPLATFORM registry.access.redhat.com/ubi8/go-toolset:1.23 AS common
|
||||
FROM --platform=$BUILDPLATFORM registry.access.redhat.com/ubi9/go-toolset:1.24 AS common
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
|
||||
WORKDIR /workspace
|
||||
# Copy the Go Modules manifests
|
||||
COPY ["go.mod", "go.sum", "./"]
|
||||
# Copy the Go Modules manifests and workspace file
|
||||
COPY ["go.mod", "go.sum", "go.work", "go.work.sum", "./"]
|
||||
COPY ["pkg/openapi/go.mod", "pkg/openapi/"]
|
||||
# cache deps before building and copying source so that we don't need to re-download as much
|
||||
# and so that source changes don't invalidate our downloaded layer
|
||||
RUN go mod download
|
||||
|
@ -22,49 +23,15 @@ COPY templates/ templates/
|
|||
COPY patches/ patches/
|
||||
COPY catalog/ catalog/
|
||||
|
||||
###### Dev stage - start ######
|
||||
# see: https://github.com/kubeflow/model-registry/pull/984#discussion_r2048732415
|
||||
|
||||
FROM common AS dev
|
||||
|
||||
USER root
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH} go build -a -o model-registry
|
||||
|
||||
FROM registry.access.redhat.com/ubi8/ubi-minimal:latest AS dev-build
|
||||
|
||||
WORKDIR /
|
||||
COPY --from=dev /workspace/model-registry .
|
||||
USER 65532:65532
|
||||
|
||||
ENTRYPOINT ["/model-registry"]
|
||||
|
||||
###### Dev stage - end ######
|
||||
|
||||
FROM common AS builder
|
||||
|
||||
USER root
|
||||
# default NodeJS 14 is not enough for openapi-generator-cli, switch to Node JS currently supported
|
||||
RUN yum remove -y nodejs npm
|
||||
RUN yum module -y reset nodejs
|
||||
RUN yum module -y enable nodejs:18
|
||||
# install npm and java for openapi-generator-cli
|
||||
RUN yum install -y nodejs npm java-11 python3
|
||||
|
||||
RUN make deps
|
||||
|
||||
# NOTE: The two instructions below are effectively equivalent to 'make clean build'
|
||||
# DO NOT REMOVE THE 'build/prepare' TARGET!!!
|
||||
# It ensures consitent repeatable Dockerfile builds
|
||||
|
||||
# prepare the build in a separate layer
|
||||
RUN make clean build/prepare
|
||||
# compile separately to optimize multi-platform builds
|
||||
RUN CGO_ENABLED=0 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH} make build/compile
|
||||
|
||||
# Use distroless as minimal base image to package the model-registry binary
|
||||
# Refer to https://github.com/GoogleContainerTools/distroless for more details
|
||||
FROM registry.access.redhat.com/ubi8/ubi-minimal:latest
|
||||
FROM registry.access.redhat.com/ubi9/ubi-minimal:latest
|
||||
WORKDIR /
|
||||
# copy the registry binary
|
||||
COPY --from=builder /workspace/model-registry .
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# Build the model-registry binary
|
||||
FROM registry.access.redhat.com/ubi8/go-toolset:1.23 AS builder
|
||||
FROM registry.access.redhat.com/ubi9/go-toolset:1.24 AS builder
|
||||
|
||||
WORKDIR /workspace
|
||||
# Copy the Go Modules manifests
|
||||
COPY ["go.mod", "go.sum", "./"]
|
||||
# Copy the Go Modules manifests and workspace file
|
||||
COPY ["go.mod", "go.sum", "go.work", "./"]
|
||||
COPY ["pkg/openapi/go.mod", "pkg/openapi/"]
|
||||
# cache deps before building and copying source so that we don't need to re-download as much
|
||||
# and so that source changes don't invalidate our downloaded layer
|
||||
RUN go mod download
|
||||
|
@ -25,7 +26,7 @@ RUN CGO_ENABLED=1 GOOS=linux GOARCH=amd64 make clean/odh build/odh
|
|||
|
||||
# Use distroless as minimal base image to package the model-registry binary
|
||||
# Refer to https://github.com/GoogleContainerTools/distroless for more details
|
||||
FROM registry.access.redhat.com/ubi8/ubi-minimal:latest
|
||||
FROM registry.access.redhat.com/ubi9/ubi-minimal:latest
|
||||
WORKDIR /
|
||||
# copy the registry binary
|
||||
COPY --from=builder /workspace/model-registry .
|
||||
|
|
114
Makefile
114
Makefile
|
@ -14,8 +14,6 @@ ENVTEST ?= $(PROJECT_BIN)/setup-envtest
|
|||
# add tools bin directory
|
||||
PATH := $(PROJECT_BIN):$(PATH)
|
||||
|
||||
MLMD_VERSION ?= 1.14.0
|
||||
|
||||
# docker executable
|
||||
DOCKER ?= docker
|
||||
# default Dockerfile
|
||||
|
@ -31,7 +29,9 @@ IMG_REPO ?= model-registry/server
|
|||
# container image build path
|
||||
BUILD_PATH ?= .
|
||||
# container image
|
||||
ifdef IMG_REGISTRY
|
||||
ifdef IMG
|
||||
IMG := ${IMG}
|
||||
else ifdef IMG_REGISTRY
|
||||
IMG := ${IMG_REGISTRY}/${IMG_ORG}/${IMG_REPO}
|
||||
else
|
||||
IMG := ${IMG_ORG}/${IMG_REPO}
|
||||
|
@ -55,44 +55,11 @@ endif
|
|||
|
||||
model-registry: build
|
||||
|
||||
# clean the ml-metadata protos and trigger a fresh new build which downloads
|
||||
# ml-metadata protos based on specified MLMD_VERSION
|
||||
.PHONY: update/ml_metadata
|
||||
update/ml_metadata: clean/ml_metadata clean build
|
||||
|
||||
clean/ml_metadata:
|
||||
rm -rf api/grpc/ml_metadata/proto/*.proto
|
||||
|
||||
api/grpc/ml_metadata/proto/metadata_source.proto:
|
||||
mkdir -p api/grpc/ml_metadata/proto/
|
||||
cd api/grpc/ml_metadata/proto/ && \
|
||||
curl -LO "https://raw.githubusercontent.com/google/ml-metadata/v${MLMD_VERSION}/ml_metadata/proto/metadata_source.proto" && \
|
||||
sed -i 's#syntax = "proto[23]";#&\noption go_package = "github.com/kubeflow/model-registry/internal/ml_metadata/proto";#' metadata_source.proto
|
||||
|
||||
api/grpc/ml_metadata/proto/metadata_store.proto:
|
||||
mkdir -p api/grpc/ml_metadata/proto/
|
||||
cd api/grpc/ml_metadata/proto/ && \
|
||||
curl -LO "https://raw.githubusercontent.com/google/ml-metadata/v${MLMD_VERSION}/ml_metadata/proto/metadata_store.proto" && \
|
||||
sed -i 's#syntax = "proto[23]";#&\noption go_package = "github.com/kubeflow/model-registry/internal/ml_metadata/proto";#' metadata_store.proto
|
||||
|
||||
api/grpc/ml_metadata/proto/metadata_store_service.proto:
|
||||
mkdir -p api/grpc/ml_metadata/proto/
|
||||
cd api/grpc/ml_metadata/proto/ && \
|
||||
curl -LO "https://raw.githubusercontent.com/google/ml-metadata/v${MLMD_VERSION}/ml_metadata/proto/metadata_store_service.proto" && \
|
||||
sed -i 's#syntax = "proto[23]";#&\noption go_package = "github.com/kubeflow/model-registry/internal/ml_metadata/proto";#' metadata_store_service.proto
|
||||
|
||||
internal/ml_metadata/proto/%.pb.go: api/grpc/ml_metadata/proto/%.proto
|
||||
bin/protoc -I./api/grpc --go_out=./internal --go_opt=paths=source_relative \
|
||||
--go-grpc_out=./internal --go-grpc_opt=paths=source_relative $<
|
||||
|
||||
.PHONY: gen/grpc
|
||||
gen/grpc: internal/ml_metadata/proto/metadata_store.pb.go internal/ml_metadata/proto/metadata_store_service.pb.go
|
||||
|
||||
internal/converter/generated/converter.go: internal/converter/*.go
|
||||
${GOVERTER} gen github.com/kubeflow/model-registry/internal/converter/
|
||||
|
||||
.PHONY: gen/converter
|
||||
gen/converter: gen/grpc internal/converter/generated/converter.go
|
||||
gen/converter: internal/converter/generated/converter.go
|
||||
|
||||
api/openapi/model-registry.yaml: api/openapi/src/model-registry.yaml api/openapi/src/lib/*.yaml bin/yq
|
||||
scripts/merge_openapi.sh model-registry.yaml
|
||||
|
@ -137,16 +104,47 @@ start/mysql:
|
|||
stop/mysql:
|
||||
./scripts/teardown_mysql_db.sh
|
||||
|
||||
# generate the gorm structs
|
||||
.PHONY: gen/gorm
|
||||
gen/gorm: bin/golang-migrate start/mysql
|
||||
# Start the PostgreSQL database
|
||||
.PHONY: start/postgres
|
||||
start/postgres:
|
||||
./scripts/start_postgres_db.sh
|
||||
|
||||
# Stop the PostgreSQL database
|
||||
.PHONY: stop/postgres
|
||||
stop/postgres:
|
||||
./scripts/teardown_postgres_db.sh
|
||||
|
||||
# generate the gorm structs for MySQL
|
||||
.PHONY: gen/gorm/mysql
|
||||
gen/gorm/mysql: bin/golang-migrate start/mysql
|
||||
@(trap 'cd $(CURDIR) && $(MAKE) stop/mysql' EXIT; \
|
||||
$(GOLANG_MIGRATE) -path './internal/datastore/embedmd/mysql/migrations' -database 'mysql://root:root@tcp(localhost:3306)/model-registry' up && \
|
||||
cd gorm-gen && go run main.go --db-type mysql --dsn 'root:root@tcp(localhost:3306)/model-registry?charset=utf8mb4&parseTime=True&loc=Local')
|
||||
cd gorm-gen && GOWORK=off go run main.go --db-type mysql --dsn 'root:root@tcp(localhost:3306)/model-registry?charset=utf8mb4&parseTime=True&loc=Local')
|
||||
|
||||
# generate the gorm structs for PostgreSQL
|
||||
.PHONY: gen/gorm/postgres
|
||||
gen/gorm/postgres: bin/golang-migrate start/postgres
|
||||
@(trap 'cd $(CURDIR) && $(MAKE) stop/postgres' EXIT; \
|
||||
$(GOLANG_MIGRATE) -path './internal/datastore/embedmd/postgres/migrations' -database 'postgres://postgres:postgres@localhost:5432/model-registry?sslmode=disable' up && \
|
||||
cd gorm-gen && GOWORK=off go run main.go --db-type postgres --dsn 'postgres://postgres:postgres@localhost:5432/model-registry?sslmode=disable' && \
|
||||
cd $(CURDIR) && ./scripts/remove_gorm_defaults.sh)
|
||||
|
||||
# generate the gorm structs (defaults to MySQL for backward compatibility)
|
||||
# Use GORM_DB_TYPE=postgres to generate for PostgreSQL instead
|
||||
.PHONY: gen/gorm
|
||||
gen/gorm: bin/golang-migrate
|
||||
ifeq ($(GORM_DB_TYPE),postgres)
|
||||
$(MAKE) gen/gorm/postgres
|
||||
else
|
||||
$(MAKE) gen/gorm/mysql
|
||||
endif
|
||||
|
||||
.PHONY: vet
|
||||
vet:
|
||||
${GO} vet ./...
|
||||
@echo "Running go vet on all packages..."
|
||||
@${GO} vet $$(${GO} list ./... | grep -vF github.com/kubeflow/model-registry/internal/db/filter) && \
|
||||
echo "Checking filter package (parser.go excluded due to participle struct tags)..." && \
|
||||
cd internal/db/filter && ${GO} build -o /dev/null . 2>&1 | grep -E "vet:|error:" || echo "✓ Filter package builds successfully"
|
||||
|
||||
.PHONY: clean/csi
|
||||
clean/csi:
|
||||
|
@ -164,24 +162,12 @@ clean-internal-server-openapi:
|
|||
|
||||
.PHONY: clean
|
||||
clean: clean-pkg-openapi clean-internal-server-openapi clean/csi
|
||||
rm -Rf ./model-registry internal/ml_metadata/proto/*.go internal/converter/generated/*.go
|
||||
rm -Rf ./model-registry internal/converter/generated/*.go
|
||||
|
||||
.PHONY: clean/odh
|
||||
clean/odh:
|
||||
rm -Rf ./model-registry
|
||||
|
||||
bin/protoc:
|
||||
./scripts/install_protoc.sh
|
||||
|
||||
bin/go-enum:
|
||||
GOBIN=$(PROJECT_BIN) ${GO} install github.com/searKing/golang/tools/go-enum@v1.2.97
|
||||
|
||||
bin/protoc-gen-go:
|
||||
GOBIN=$(PROJECT_BIN) ${GO} install google.golang.org/protobuf/cmd/protoc-gen-go@v1.31.0
|
||||
|
||||
bin/protoc-gen-go-grpc:
|
||||
GOBIN=$(PROJECT_BIN) ${GO} install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.3.0
|
||||
|
||||
bin/envtest:
|
||||
GOBIN=$(PROJECT_BIN) ${GO} install sigs.k8s.io/controller-runtime/tools/setup-envtest@v0.0.0-20240320141353-395cfc7486e6
|
||||
|
||||
|
@ -199,7 +185,11 @@ bin/yq:
|
|||
|
||||
GOLANG_MIGRATE ?= ${PROJECT_BIN}/migrate
|
||||
bin/golang-migrate:
|
||||
GOBIN=$(PROJECT_PATH)/bin ${GO} install -tags 'mysql' github.com/golang-migrate/migrate/v4/cmd/migrate@v4.18.3
|
||||
GOBIN=$(PROJECT_PATH)/bin ${GO} install -tags 'mysql,postgres' github.com/golang-migrate/migrate/v4/cmd/migrate@v4.18.3
|
||||
|
||||
GENQLIENT ?= ${PROJECT_BIN}/genqlient
|
||||
bin/genqlient:
|
||||
GOBIN=$(PROJECT_PATH)/bin ${GO} install github.com/Khan/genqlient@v0.7.0
|
||||
|
||||
OPENAPI_GENERATOR ?= ${PROJECT_BIN}/openapi-generator-cli
|
||||
NPM ?= "$(shell which npm)"
|
||||
|
@ -224,7 +214,7 @@ clean/deps:
|
|||
rm -Rf bin/*
|
||||
|
||||
.PHONY: deps
|
||||
deps: bin/protoc bin/go-enum bin/protoc-gen-go bin/protoc-gen-go-grpc bin/golangci-lint bin/goverter bin/openapi-generator-cli bin/envtest
|
||||
deps: bin/golangci-lint bin/goverter bin/openapi-generator-cli bin/envtest
|
||||
|
||||
.PHONY: vendor
|
||||
vendor:
|
||||
|
@ -261,7 +251,7 @@ build/compile/csi:
|
|||
build/csi: build/prepare/csi build/compile/csi
|
||||
|
||||
.PHONY: gen
|
||||
gen: deps gen/grpc gen/openapi gen/openapi-server gen/converter
|
||||
gen: deps gen/openapi gen/openapi-server gen/converter
|
||||
${GO} generate ./...
|
||||
|
||||
.PHONY: lint
|
||||
|
@ -275,15 +265,15 @@ lint/csi: bin/golangci-lint
|
|||
${GOLANGCI_LINT} run internal/csi/...
|
||||
|
||||
.PHONY: test
|
||||
test: gen bin/envtest
|
||||
test: bin/envtest
|
||||
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) -p path)" ${GO} test ./internal/... ./pkg/...
|
||||
|
||||
.PHONY: test-nocache
|
||||
test-nocache: gen bin/envtest
|
||||
test-nocache: bin/envtest
|
||||
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) -p path)" ${GO} test ./internal/... ./pkg/... -count=1
|
||||
|
||||
.PHONY: test-cover
|
||||
test-cover: gen bin/envtest
|
||||
test-cover: bin/envtest
|
||||
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) -p path)" ${GO} test ./internal/... ./pkg/... -coverprofile=coverage.txt
|
||||
${GO} tool cover -html=coverage.txt -o coverage.html
|
||||
|
||||
|
@ -366,7 +356,7 @@ controller/vet: ## Run go vet against code.
|
|||
|
||||
.PHONY: controller/test
|
||||
controller/test: controller/manifests controller/generate controller/fmt controller/vet bin/envtest ## Run tests.
|
||||
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(PROJECT_BIN) -p path)" go test $$(go list ./internal/controller/... | grep -v /e2e) -coverprofile cover.out
|
||||
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(PROJECT_BIN) -p path)" go test $$(go list ./internal/controller/... | grep -vF /e2e) -coverprofile cover.out
|
||||
|
||||
##@ Build
|
||||
|
||||
|
|
46
README.md
46
README.md
|
@ -5,7 +5,7 @@
|
|||
[](https://app.fossa.com/projects/custom%2B162%2Fgithub.com%2Fkubeflow%2Fmodel-registry?ref=badge_shield&issueType=license)
|
||||
[](https://www.bestpractices.dev/projects/9937)
|
||||
|
||||
Model registry provides a central repository for model developers to store and manage models, versions, and artifacts metadata. A Go-based application that leverages [ml_metadata](https://github.com/google/ml-metadata/) project under the hood.
|
||||
Model registry provides a central repository for model developers to store and manage models, versions, and artifacts metadata.
|
||||
|
||||
## Red Hat's Pledge
|
||||
- Red Hat drives the project's development through Open Source principles, ensuring transparency, sustainability, and community ownership.
|
||||
|
@ -23,7 +23,7 @@ Model registry provides a central repository for model developers to store and m
|
|||
- [Blog KF 1.10 introducing UI for Model Registry, CSI, and other features](https://blog.kubeflow.org/kubeflow-1.10-release/#model-registry)
|
||||
2. Installation
|
||||
- [installing Model Registry standalone](https://www.kubeflow.org/docs/components/model-registry/installation/#standalone-installation)
|
||||
- [installing Model Registry with Kubeflow manifests](https://github.com/kubeflow/manifests/tree/master/apps/model-registry/upstream#readme)
|
||||
- [installing Model Registry with Kubeflow manifests](https://github.com/kubeflow/manifests/tree/master/applications/model-registry/upstream#readme)
|
||||
- [installing Model Registry using ODH Operator](https://github.com/opendatahub-io/model-registry-operator/tree/main/docs#readme)
|
||||
3. Concepts
|
||||
- [Logical Model](./docs/logical_model.md)
|
||||
|
@ -45,7 +45,7 @@ Model registry provides a central repository for model developers to store and m
|
|||
8. [UI](clients/ui/README.md)
|
||||
|
||||
## Pre-requisites:
|
||||
- go >= 1.23
|
||||
- go >= 1.24
|
||||
- protoc v24.3 - [Protocol Buffers v24.3 Release](https://github.com/protocolbuffers/protobuf/releases/tag/v24.3)
|
||||
- npm >= 10.2.0 - [Installing Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm)
|
||||
- Java >= 11.0
|
||||
|
@ -62,9 +62,7 @@ Run the following command to start the OpenAPI proxy server from source:
|
|||
```shell
|
||||
make run/proxy
|
||||
```
|
||||
The proxy service implements the OpenAPI defined in [model-registry.yaml](api/openapi/model-registry.yaml) to create a Model Registry specific REST API on top of the existing ml-metadata server.
|
||||
|
||||
> **NOTE** The ml-metadata server must be running and accessible from the environment where model-registry starts up.
|
||||
The proxy service implements the OpenAPI defined in [model-registry.yaml](api/openapi/model-registry.yaml) to create a Model Registry specific REST API.
|
||||
|
||||
### Model registry logical model
|
||||
|
||||
|
@ -72,8 +70,8 @@ For a high-level documentation of the Model Registry _logical model_, please che
|
|||
|
||||
## Model Registry Core
|
||||
|
||||
The model registry core is the layer which implements the core/business logic by interacting with the underlying ml-metadata server.
|
||||
It provides a model registry domain-specific [api](pkg/api/api.go) that is in charge to proxy all, appropriately transformed, requests to ml-metadata using gRPC calls.
|
||||
The model registry core is the layer which implements the core/business logic by interacting with the underlying datastore internal service.
|
||||
It provides a model registry domain-specific [api](pkg/api/api.go) that is in charge to proxy all, appropriately transformed, requests to the datastore internal service.
|
||||
|
||||
### Model registry library
|
||||
|
||||
|
@ -143,8 +141,6 @@ Subsequent builds will re-use the cached tools layer.
|
|||
|
||||
#### Running the proxy server
|
||||
|
||||
> **NOTE:** ml-metadata server must be running and accessible, see more info on how to start the gRPC server in the official ml-metadata [documentation](https://github.com/google/ml-metadata).
|
||||
|
||||
The following command starts the proxy server:
|
||||
|
||||
```shell
|
||||
|
@ -155,11 +151,11 @@ Where, `<uid>`, `<gid>`, and `<host-path>` are the same as in the migrate comman
|
|||
And `<hostname>` and `<port>` are the local ip and port to use to expose the container's default `8080` listening port.
|
||||
The server listens on `localhost` by default, hence the `-n 0.0.0.0` option allows the server port to be exposed.
|
||||
|
||||
#### Running model registry & ml-metadata
|
||||
#### Running model registry
|
||||
|
||||
> **NOTE:** Docker compose must be installed in your environment.
|
||||
|
||||
There are two `docker-compose` files that make the startup of both model registry and ml-metadara easier, by simply running:
|
||||
There are two `docker-compose` files that make the startup of both model registry and a MySQL database easier, by simply running:
|
||||
|
||||
```shell
|
||||
docker compose -f docker-compose[-local].yaml up
|
||||
|
@ -167,20 +163,38 @@ docker compose -f docker-compose[-local].yaml up
|
|||
|
||||
The main difference between the two docker compose files is that `-local` one build the model registry from source, the other one, instead, download the `latest` pushed [quay.io](https://quay.io/repository/opendatahub/model-registry?tab=tags) image.
|
||||
|
||||
When shutting down the docker compose, you might want to clean-up the SQLite db file generated by ML Metadata, for example `./test/config/ml-metadata/metadata.sqlite.db`
|
||||
|
||||
### Testing architecture
|
||||
|
||||
The following diagram illustrates testing strategy for the several components in Model Registry project:
|
||||
|
||||

|
||||

|
||||
|
||||
Go layers components are tested with Unit Tests written in Go, as well as Integration Tests leveraging Testcontainers.
|
||||
This allows to verify the expected "Core layer" of logical data mapping developed and implemented in Go, matches technical expectations.
|
||||
|
||||
Python client is also tested with Unit Tests and Integration Tests written in Python.
|
||||
|
||||
End-to-end testing is developed with Pytest and Robot Framework; this higher-lever layer of testing is used to demonstrate *User Stories* from high level perspective.
|
||||
End-to-end testing is developed with KinD and Pytest; this higher-lever layer of testing is used to demonstrate *User Stories* from high level perspective.
|
||||
|
||||
## Related Components
|
||||
|
||||
### Model Catalog Service
|
||||
- [Model Catalog Service](catalog/README.md) - Federated model discovery across external catalogs
|
||||
|
||||
### Kubernetes Components
|
||||
- [Controller](cmd/controller/README.md) - Kubernetes controller for model registry CRDs
|
||||
- [CSI Driver](cmd/csi/README.md) - Container Storage Interface for model artifacts
|
||||
|
||||
### Client Components
|
||||
- [UI Backend for Frontend (BFF)](clients/ui/bff/README.md) - Go-based BFF service for the React UI
|
||||
- [UI Frontend](clients/ui/frontend/README.md) - React-based frontend application
|
||||
|
||||
### Job Components
|
||||
- [Async Upload Job](jobs/async-upload/README.md) - Background job for handling asynchronous model uploads
|
||||
|
||||
### Development & Deployment
|
||||
- [Development Environment](devenv/README.md) - Local development setup and tools
|
||||
- [Kubernetes Manifests](manifests/kustomize/README.md) - Kustomize-based Kubernetes deployment manifests
|
||||
|
||||
## FAQ
|
||||
|
||||
|
|
|
@ -72,6 +72,7 @@ git checkout -b mr_maintainer-$TDATE-upstreamSync
|
|||
pushd manifests/kustomize/base && kustomize edit set image ghcr.io/kubeflow/model-registry/server=ghcr.io/kubeflow/model-registry/server:$VVERSION && popd
|
||||
pushd manifests/kustomize/options/csi && kustomize edit set image ghcr.io/kubeflow/model-registry/storage-initializer=ghcr.io/kubeflow/model-registry/storage-initializer:$VVERSION && popd
|
||||
pushd manifests/kustomize/options/ui/base && kustomize edit set image model-registry-ui=ghcr.io/kubeflow/model-registry/ui:$VVERSION && popd
|
||||
pushd manifests/kustomize/options/catalog && kustomize edit set image ghcr.io/kubeflow/model-registry/server=ghcr.io/kubeflow/model-registry/server:$VVERSION && popd
|
||||
git add .
|
||||
git commit -s
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
# Security Policy
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Kubeflow Model Registry versions are expressed as `vX.Y.Z`, where X is the major version,
|
||||
Y is the minor version, and Z is the patch version, following the
|
||||
[Semantic Versioning](https://semver.org/) terminology.
|
||||
|
||||
The Kubeflow Model Registry project maintains release branches for the most recent two minor releases.
|
||||
Applicable fixes, including security fixes, may be backported to those two release branches,
|
||||
depending on severity and feasibility.
|
||||
|
||||
Users are encouraged to stay updated with the latest releases to benefit from security patches and
|
||||
improvements.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
We're extremely grateful for security researchers and users that report vulnerabilities to the
|
||||
Kubeflow Open Source Community. All reports are thoroughly investigated by Kubeflow projects owners.
|
||||
|
||||
You can use the following ways to report security vulnerabilities privately:
|
||||
|
||||
- Using the Kubeflow Model Registry repository [GitHub Security Advisory](https://github.com/kubeflow/model-registry/security/advisories/new).
|
||||
- Using our private Kubeflow Steering Committee mailing list: ksc@kubeflow.org.
|
||||
|
||||
Please provide detailed information to help us understand and address the issue promptly.
|
||||
|
||||
## Disclosure Process
|
||||
|
||||
**Acknowledgment**: We will acknowledge receipt of your report within 10 business days.
|
||||
|
||||
**Assessment**: The Kubeflow projects owners will investigate the reported issue to determine its
|
||||
validity and severity.
|
||||
|
||||
**Resolution**: If the issue is confirmed, we will work on a fix and prepare a release.
|
||||
|
||||
**Notification**: Once a fix is available, we will notify the reporter and coordinate a public
|
||||
disclosure.
|
||||
|
||||
**Public Disclosure**: Details of the vulnerability and the fix will be published in the project's
|
||||
release notes and communicated through appropriate channels.
|
||||
|
||||
## Prevention Mechanisms
|
||||
|
||||
Kubeflow Model Registry employs several measures to prevent security issues:
|
||||
|
||||
**Code Reviews**: All code changes are reviewed by maintainers to ensure code quality and security.
|
||||
|
||||
**Dependency Management**: Regular updates and monitoring of dependencies (e.g. Dependabot) to
|
||||
address known vulnerabilities.
|
||||
|
||||
**Continuous Integration**: Automated testing and security checks are integrated into the CI/CD pipeline.
|
||||
|
||||
**Image Scanning**: Container images are scanned for vulnerabilities.
|
||||
|
||||
## Communication Channels
|
||||
|
||||
For the general questions please join the following resources:
|
||||
|
||||
- Kubeflow [Slack channels](https://www.kubeflow.org/docs/about/community/#kubeflow-slack-channels).
|
||||
|
||||
- Kubeflow discuss [mailing list](https://www.kubeflow.org/docs/about/community/#kubeflow-mailing-list).
|
||||
|
||||
Please **do not report** security vulnerabilities through public channels.
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -20,13 +20,12 @@ paths:
|
|||
parameters:
|
||||
- name: source
|
||||
description: |-
|
||||
Filter models by source. If not provided, models from all sources
|
||||
are returned. If multiple sources are provided, models from any of
|
||||
the sources are returned.
|
||||
Filter models by source. This parameter is currently required and
|
||||
may only be specified once.
|
||||
schema:
|
||||
type: string
|
||||
in: query
|
||||
required: false
|
||||
required: true
|
||||
- name: q
|
||||
description: Free-form keyword search used to filter the response.
|
||||
schema:
|
||||
|
@ -138,6 +137,15 @@ paths:
|
|||
required: true
|
||||
components:
|
||||
schemas:
|
||||
ArtifactTypeQueryParam:
|
||||
description: Supported artifact types for querying.
|
||||
enum:
|
||||
- model-artifact
|
||||
- doc-artifact
|
||||
- dataset-artifact
|
||||
- metric
|
||||
- parameter
|
||||
type: string
|
||||
BaseModel:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -298,6 +306,10 @@ components:
|
|||
name:
|
||||
description: The name of the catalog source.
|
||||
type: string
|
||||
enabled:
|
||||
description: Whether the catalog source is enabled.
|
||||
type: boolean
|
||||
default: true
|
||||
CatalogSourceList:
|
||||
description: List of CatalogSource entities.
|
||||
allOf:
|
||||
|
@ -566,6 +578,43 @@ components:
|
|||
type: string
|
||||
in: query
|
||||
required: false
|
||||
filterQuery:
|
||||
examples:
|
||||
filterQuery:
|
||||
value: "name='my-model' AND state='LIVE'"
|
||||
name: filterQuery
|
||||
description: |
|
||||
A SQL-like query string to filter the list of entities. The query supports rich filtering capabilities with automatic type inference.
|
||||
|
||||
**Supported Operators:**
|
||||
- Comparison: `=`, `!=`, `<>`, `>`, `<`, `>=`, `<=`
|
||||
- Pattern matching: `LIKE`, `ILIKE` (case-insensitive)
|
||||
- Set membership: `IN`
|
||||
- Logical: `AND`, `OR`
|
||||
- Grouping: `()` for complex expressions
|
||||
|
||||
**Data Types:**
|
||||
- Strings: `"value"` or `'value'`
|
||||
- Numbers: `42`, `3.14`, `1e-5`
|
||||
- Booleans: `true`, `false` (case-insensitive)
|
||||
|
||||
**Property Access:**
|
||||
- Standard properties: `name`, `id`, `state`, `createTimeSinceEpoch`
|
||||
- Custom properties: Any user-defined property name
|
||||
- Escaped properties: Use backticks for special characters: `` `custom-property` ``
|
||||
- Type-specific access: `property.string_value`, `property.double_value`, `property.int_value`, `property.bool_value`
|
||||
|
||||
**Examples:**
|
||||
- Basic: `name = "my-model"`
|
||||
- Comparison: `accuracy > 0.95`
|
||||
- Pattern: `name LIKE "%tensorflow%"`
|
||||
- Complex: `(name = "model-a" OR name = "model-b") AND state = "LIVE"`
|
||||
- Custom property: `framework.string_value = "pytorch"`
|
||||
- Escaped property: `` `mlflow.source.type` = "notebook" ``
|
||||
schema:
|
||||
type: string
|
||||
in: query
|
||||
required: false
|
||||
pageSize:
|
||||
examples:
|
||||
pageSize:
|
||||
|
@ -595,6 +644,30 @@ components:
|
|||
$ref: "#/components/schemas/SortOrder"
|
||||
in: query
|
||||
required: false
|
||||
artifactType:
|
||||
style: form
|
||||
explode: true
|
||||
examples:
|
||||
artifactType:
|
||||
value: model-artifact
|
||||
name: artifactType
|
||||
description: "Specifies the artifact type for listing artifacts."
|
||||
schema:
|
||||
$ref: "#/components/schemas/ArtifactTypeQueryParam"
|
||||
in: query
|
||||
required: false
|
||||
stepIds:
|
||||
style: form
|
||||
explode: true
|
||||
examples:
|
||||
stepIds:
|
||||
value: "1,2,3"
|
||||
name: stepIds
|
||||
description: "Comma-separated list of step IDs to filter metrics by."
|
||||
schema:
|
||||
type: string
|
||||
in: query
|
||||
required: false
|
||||
securitySchemes:
|
||||
Bearer:
|
||||
scheme: bearer
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -20,13 +20,12 @@ paths:
|
|||
parameters:
|
||||
- name: source
|
||||
description: |-
|
||||
Filter models by source. If not provided, models from all sources
|
||||
are returned. If multiple sources are provided, models from any of
|
||||
the sources are returned.
|
||||
Filter models by source. This parameter is currently required and
|
||||
may only be specified once.
|
||||
schema:
|
||||
type: string
|
||||
in: query
|
||||
required: false
|
||||
required: true
|
||||
- name: q
|
||||
description: Free-form keyword search used to filter the response.
|
||||
schema:
|
||||
|
@ -210,6 +209,10 @@ components:
|
|||
name:
|
||||
description: The name of the catalog source.
|
||||
type: string
|
||||
enabled:
|
||||
description: Whether the catalog source is enabled.
|
||||
type: boolean
|
||||
default: true
|
||||
CatalogSourceList:
|
||||
description: List of CatalogSource entities.
|
||||
allOf:
|
||||
|
|
|
@ -1,5 +1,14 @@
|
|||
components:
|
||||
schemas:
|
||||
ArtifactTypeQueryParam:
|
||||
description: Supported artifact types for querying.
|
||||
enum:
|
||||
- model-artifact
|
||||
- doc-artifact
|
||||
- dataset-artifact
|
||||
- metric
|
||||
- parameter
|
||||
type: string
|
||||
BaseModel:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -295,6 +304,43 @@ components:
|
|||
type: string
|
||||
in: query
|
||||
required: false
|
||||
filterQuery:
|
||||
examples:
|
||||
filterQuery:
|
||||
value: "name='my-model' AND state='LIVE'"
|
||||
name: filterQuery
|
||||
description: |
|
||||
A SQL-like query string to filter the list of entities. The query supports rich filtering capabilities with automatic type inference.
|
||||
|
||||
**Supported Operators:**
|
||||
- Comparison: `=`, `!=`, `<>`, `>`, `<`, `>=`, `<=`
|
||||
- Pattern matching: `LIKE`, `ILIKE` (case-insensitive)
|
||||
- Set membership: `IN`
|
||||
- Logical: `AND`, `OR`
|
||||
- Grouping: `()` for complex expressions
|
||||
|
||||
**Data Types:**
|
||||
- Strings: `"value"` or `'value'`
|
||||
- Numbers: `42`, `3.14`, `1e-5`
|
||||
- Booleans: `true`, `false` (case-insensitive)
|
||||
|
||||
**Property Access:**
|
||||
- Standard properties: `name`, `id`, `state`, `createTimeSinceEpoch`
|
||||
- Custom properties: Any user-defined property name
|
||||
- Escaped properties: Use backticks for special characters: `` `custom-property` ``
|
||||
- Type-specific access: `property.string_value`, `property.double_value`, `property.int_value`, `property.bool_value`
|
||||
|
||||
**Examples:**
|
||||
- Basic: `name = "my-model"`
|
||||
- Comparison: `accuracy > 0.95`
|
||||
- Pattern: `name LIKE "%tensorflow%"`
|
||||
- Complex: `(name = "model-a" OR name = "model-b") AND state = "LIVE"`
|
||||
- Custom property: `framework.string_value = "pytorch"`
|
||||
- Escaped property: `` `mlflow.source.type` = "notebook" ``
|
||||
schema:
|
||||
type: string
|
||||
in: query
|
||||
required: false
|
||||
pageSize:
|
||||
examples:
|
||||
pageSize:
|
||||
|
@ -324,6 +370,30 @@ components:
|
|||
$ref: "#/components/schemas/SortOrder"
|
||||
in: query
|
||||
required: false
|
||||
artifactType:
|
||||
style: form
|
||||
explode: true
|
||||
examples:
|
||||
artifactType:
|
||||
value: model-artifact
|
||||
name: artifactType
|
||||
description: "Specifies the artifact type for listing artifacts."
|
||||
schema:
|
||||
$ref: "#/components/schemas/ArtifactTypeQueryParam"
|
||||
in: query
|
||||
required: false
|
||||
stepIds:
|
||||
style: form
|
||||
explode: true
|
||||
examples:
|
||||
stepIds:
|
||||
value: "1,2,3"
|
||||
name: stepIds
|
||||
description: "Comma-separated list of step IDs to filter metrics by."
|
||||
schema:
|
||||
type: string
|
||||
in: query
|
||||
required: false
|
||||
securitySchemes:
|
||||
Bearer:
|
||||
scheme: bearer
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,5 +1,10 @@
|
|||
PROJECT_BIN := $(CURDIR)/../bin
|
||||
OPENAPI_GENERATOR := $(PROJECT_BIN)/openapi-generator-cli
|
||||
GENQLIENT_BIN ?= $(PROJECT_BIN)/genqlient
|
||||
GENQLIENT_CONFIG := internal/catalog/genqlient/genqlient.yaml
|
||||
GENQLIENT_OUTPUT := internal/catalog/genqlient/generated.go
|
||||
GENQLIENT_SOURCES := $(wildcard internal/catalog/genqlient/queries/*.graphql)
|
||||
GRAPHQL_SCHEMA := internal/catalog/genqlient/queries/schema.graphql
|
||||
|
||||
.PHONY: gen/openapi-server
|
||||
gen/openapi-server: internal/server/openapi/api_model_catalog_service.go
|
||||
|
@ -16,10 +21,24 @@ pkg/openapi/client.go: ../api/openapi/catalog.yaml
|
|||
--ignore-file-override ./.openapi-generator-ignore --additional-properties=isGoSubmodule=true,enumClassPrefix=true,useOneOfDiscriminatorLookup=true
|
||||
gofmt -w pkg/openapi
|
||||
|
||||
.PHONY: gen/graphql
|
||||
gen/graphql: $(GENQLIENT_OUTPUT)
|
||||
|
||||
$(GENQLIENT_OUTPUT): $(GENQLIENT_CONFIG) $(GENQLIENT_SOURCES) $(PROJECT_BIN)/genqlient
|
||||
$(GENQLIENT_BIN) --config $(GENQLIENT_CONFIG)
|
||||
|
||||
.PHONY: download/graphql-schema
|
||||
download/graphql-schema:
|
||||
npx get-graphql-schema https://catalog.redhat.com/api/containers/graphql/ > $(GRAPHQL_SCHEMA)
|
||||
|
||||
.PHONY: clean-pkg-openapi
|
||||
clean-pkg-openapi:
|
||||
while IFS= read -r file; do rm -f "pkg/openapi/$$file"; done < pkg/openapi/.openapi-generator/FILES
|
||||
|
||||
.PHONY: clean-graphql
|
||||
clean-graphql:
|
||||
rm -f $(GENQLIENT_OUTPUT)
|
||||
|
||||
.PHONY: clean-internal-server-openapi
|
||||
clean-internal-server-openapi:
|
||||
while IFS= read -r file; do rm -f "internal/server/openapi/$$file"; done < internal/server/openapi/.openapi-generator/FILES
|
||||
|
|
|
@ -0,0 +1,159 @@
|
|||
# Model Catalog Service
|
||||
|
||||
The Model Catalog Service provides a **read-only discovery service** for ML models across multiple catalog sources. It acts as a federated metadata aggregation layer, allowing users to search and discover models from various external catalogs through a unified REST API.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
The catalog service operates as a **metadata aggregation layer** that:
|
||||
- Federates model discovery across different external catalogs
|
||||
- Provides a unified REST API for model search and discovery
|
||||
- Uses pluggable source providers for extensibility
|
||||
- Operates without traditional database storage (file-based configuration)
|
||||
|
||||
### Supported Catalog Sources
|
||||
|
||||
- **YAML Catalog** - Static YAML files containing model metadata
|
||||
- **Red Hat Ecosystem Catalog (RHEC)** - GraphQL API integration for container and model discovery. Can be used as a reference implementation of how one could extend with their own graphql providers.
|
||||
|
||||
## REST API
|
||||
|
||||
### Base URL
|
||||
`/api/model_catalog/v1alpha1`
|
||||
|
||||
### Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| `GET` | `/sources` | List all catalog sources with pagination |
|
||||
| `GET` | `/models` | Search models across sources (requires `source` parameter) |
|
||||
| `GET` | `/sources/{source_id}/models/{model_name+}` | Get specific model details |
|
||||
| `GET` | `/sources/{source_id}/models/{model_name}/artifacts` | List model artifacts |
|
||||
|
||||
### OpenAPI Specification
|
||||
|
||||
View the complete API specification:
|
||||
- [Swagger UI](https://www.kubeflow.org/docs/components/model-registry/reference/model-catalog-rest-api/#swagger-ui)
|
||||
- [Swagger Playground](https://petstore.swagger.io/?url=https://raw.githubusercontent.com/kubeflow/model-registry/main/api/openapi/catalog.yaml)
|
||||
|
||||
## Data Models
|
||||
|
||||
### CatalogSource
|
||||
Simple source metadata:
|
||||
```json
|
||||
{
|
||||
"id": "string",
|
||||
"name": "string"
|
||||
}
|
||||
```
|
||||
|
||||
### CatalogModel
|
||||
Rich model metadata including:
|
||||
- Basic info: `name`, `description`, `readme`, `maturity`
|
||||
- Technical: `language[]`, `tasks[]`, `libraryName`
|
||||
- Legal: `license`, `licenseLink`, `provider`
|
||||
- Extensible: `customProperties` (key-value metadata)
|
||||
|
||||
### CatalogModelArtifact
|
||||
Artifact references:
|
||||
```json
|
||||
{
|
||||
"uri": "string",
|
||||
"customProperties": {}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The catalog service uses **file-based configuration** instead of traditional databases:
|
||||
|
||||
```yaml
|
||||
# catalog-sources.yaml
|
||||
catalogs:
|
||||
- id: "yaml-catalog"
|
||||
name: "Local YAML Catalog"
|
||||
type: "yaml"
|
||||
properties:
|
||||
path: "./models"
|
||||
- id: "rhec-catalog"
|
||||
name: "Red Hat Ecosystem Catalog"
|
||||
type: "rhec"
|
||||
properties:
|
||||
# RHEC-specific configuration
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Prerequisites
|
||||
- Go >= 1.24
|
||||
- Java >= 11.0 (for OpenAPI generation)
|
||||
- Node.js >= 20.0.0 (for GraphQL schema downloads)
|
||||
|
||||
### Building
|
||||
|
||||
Generate OpenAPI server code:
|
||||
```bash
|
||||
make gen/openapi-server
|
||||
```
|
||||
|
||||
Generate OpenAPI client code:
|
||||
```bash
|
||||
make gen/openapi
|
||||
```
|
||||
|
||||
Generate GraphQL client (for RHEC integration):
|
||||
```bash
|
||||
make gen/graphql
|
||||
```
|
||||
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
catalog/
|
||||
├── cmd/ # Main application entry point
|
||||
├── internal/
|
||||
│ ├── catalog/ # Core catalog logic and providers
|
||||
│ │ ├── genqlient/ # GraphQL client generation
|
||||
│ │ └── testdata/ # Test fixtures
|
||||
│ └── server/openapi/ # REST API implementation
|
||||
├── pkg/openapi/ # Generated OpenAPI client
|
||||
├── scripts/ # Build and generation scripts
|
||||
└── Makefile # Build targets
|
||||
```
|
||||
|
||||
### Adding New Catalog Providers
|
||||
|
||||
1. Implement the `CatalogSourceProvider` interface:
|
||||
```go
|
||||
type CatalogSourceProvider interface {
|
||||
GetModel(ctx context.Context, name string) (*model.CatalogModel, error)
|
||||
ListModels(ctx context.Context, params ListModelsParams) (model.CatalogModelList, error)
|
||||
GetArtifacts(ctx context.Context, name string) (*model.CatalogModelArtifactList, error)
|
||||
}
|
||||
```
|
||||
|
||||
2. Register your provider:
|
||||
```go
|
||||
catalog.RegisterCatalogType("my-catalog", func(source *CatalogSourceConfig) (CatalogSourceProvider, error) {
|
||||
return NewMyCatalogProvider(source)
|
||||
})
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
The catalog service includes comprehensive testing:
|
||||
- Unit tests for core catalog logic
|
||||
- Integration tests for provider implementations
|
||||
- OpenAPI contract validation
|
||||
|
||||
### Configuration Hot Reloading
|
||||
|
||||
The service automatically reloads configuration when the catalog sources file changes, enabling dynamic catalog updates without service restarts.
|
||||
|
||||
## Integration
|
||||
|
||||
The catalog service is designed to complement the main Model Registry service by providing:
|
||||
- External model discovery capabilities
|
||||
- Unified metadata aggregation
|
||||
- Read-only access to distributed model catalogs
|
||||
|
||||
For complete Model Registry documentation, see the [main README](../README.md).
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/golang/glog"
|
||||
"k8s.io/apimachinery/pkg/util/yaml"
|
||||
|
@ -12,25 +13,10 @@ import (
|
|||
model "github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
)
|
||||
|
||||
type SortDirection int
|
||||
|
||||
const (
|
||||
SortDirectionAscending SortDirection = iota
|
||||
SortDirectionDescending
|
||||
)
|
||||
|
||||
type SortField int
|
||||
|
||||
const (
|
||||
SortByUnspecified SortField = iota
|
||||
SortByName
|
||||
SortByPublished
|
||||
)
|
||||
|
||||
type ListModelsParams struct {
|
||||
Query string
|
||||
SortBy SortField
|
||||
SortDirection SortDirection
|
||||
Query string
|
||||
OrderBy model.OrderByField
|
||||
SortOrder model.SortOrder
|
||||
}
|
||||
|
||||
// CatalogSourceProvider is implemented by catalog source types, e.g. YamlCatalog
|
||||
|
@ -39,7 +25,15 @@ type CatalogSourceProvider interface {
|
|||
// nothing is found with the name provided it returns nil, without an
|
||||
// error.
|
||||
GetModel(ctx context.Context, name string) (*model.CatalogModel, error)
|
||||
|
||||
// ListModels returns all models according to the parameters. If
|
||||
// nothing suitable is found, it returns an empty list.
|
||||
ListModels(ctx context.Context, params ListModelsParams) (model.CatalogModelList, error)
|
||||
|
||||
// GetArtifacts returns all artifacts for a particular model. If no
|
||||
// model is found with that name, it returns nil. If the model is
|
||||
// found, but has no artifacts, an empty list is returned.
|
||||
GetArtifacts(ctx context.Context, name string) (*model.CatalogModelArtifactList, error)
|
||||
}
|
||||
|
||||
// CatalogSourceConfig is a single entry from the catalog sources YAML file.
|
||||
|
@ -75,11 +69,35 @@ type CatalogSource struct {
|
|||
Metadata model.CatalogSource
|
||||
}
|
||||
|
||||
func LoadCatalogSources(catalogsPath string) (map[string]CatalogSource, error) {
|
||||
type SourceCollection struct {
|
||||
sourcesMu sync.RWMutex
|
||||
sources map[string]CatalogSource
|
||||
}
|
||||
|
||||
func NewSourceCollection(sources map[string]CatalogSource) *SourceCollection {
|
||||
return &SourceCollection{sources: sources}
|
||||
}
|
||||
|
||||
func (sc *SourceCollection) All() map[string]CatalogSource {
|
||||
sc.sourcesMu.RLock()
|
||||
defer sc.sourcesMu.RUnlock()
|
||||
|
||||
return sc.sources
|
||||
}
|
||||
|
||||
func (sc *SourceCollection) Get(name string) (src CatalogSource, ok bool) {
|
||||
sc.sourcesMu.RLock()
|
||||
defer sc.sourcesMu.RUnlock()
|
||||
|
||||
src, ok = sc.sources[name]
|
||||
return
|
||||
}
|
||||
|
||||
func (sc *SourceCollection) load(path string) error {
|
||||
// Get absolute path of the catalog config file
|
||||
absConfigPath, err := filepath.Abs(catalogsPath)
|
||||
absConfigPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get absolute path for %s: %v", catalogsPath, err)
|
||||
return fmt.Errorf("failed to get absolute path for %s: %v", path, err)
|
||||
}
|
||||
|
||||
// Get the directory of the config file to resolve relative paths
|
||||
|
@ -88,12 +106,12 @@ func LoadCatalogSources(catalogsPath string) (map[string]CatalogSource, error) {
|
|||
// Save current working directory
|
||||
originalWd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get current working directory: %v", err)
|
||||
return fmt.Errorf("failed to get current working directory: %v", err)
|
||||
}
|
||||
|
||||
// Change to the config directory to make relative paths work
|
||||
if err := os.Chdir(configDir); err != nil {
|
||||
return nil, fmt.Errorf("failed to change to config directory %s: %v", configDir, err)
|
||||
return fmt.Errorf("failed to change to config directory %s: %v", configDir, err)
|
||||
}
|
||||
|
||||
// Ensure we restore the original working directory when we're done
|
||||
|
@ -106,34 +124,45 @@ func LoadCatalogSources(catalogsPath string) (map[string]CatalogSource, error) {
|
|||
config := sourceConfig{}
|
||||
bytes, err := os.ReadFile(absConfigPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if err = yaml.UnmarshalStrict(bytes, &config); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
catalogs := make(map[string]CatalogSource, len(config.Catalogs))
|
||||
sources := make(map[string]CatalogSource, len(config.Catalogs))
|
||||
for _, catalogConfig := range config.Catalogs {
|
||||
// If enabled is explicitly set to false, skip
|
||||
hasEnabled := catalogConfig.HasEnabled()
|
||||
if hasEnabled && *catalogConfig.Enabled == false {
|
||||
continue
|
||||
}
|
||||
// If not explicitly set, default to enabled
|
||||
if !hasEnabled {
|
||||
t := true
|
||||
catalogConfig.CatalogSource.Enabled = &t
|
||||
}
|
||||
|
||||
catalogType := catalogConfig.Type
|
||||
glog.Infof("reading config type %s...", catalogType)
|
||||
registerFunc, ok := registeredCatalogTypes[catalogType]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("catalog type %s not registered", catalogType)
|
||||
return fmt.Errorf("catalog type %s not registered", catalogType)
|
||||
}
|
||||
id := catalogConfig.GetId()
|
||||
if len(id) == 0 {
|
||||
return nil, fmt.Errorf("invalid catalog id %s", id)
|
||||
return fmt.Errorf("invalid catalog id %s", id)
|
||||
}
|
||||
if _, exists := catalogs[id]; exists {
|
||||
return nil, fmt.Errorf("duplicate catalog id %s", id)
|
||||
if _, exists := sources[id]; exists {
|
||||
return fmt.Errorf("duplicate catalog id %s", id)
|
||||
}
|
||||
provider, err := registerFunc(&catalogConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading catalog type %s with id %s: %v", catalogType, id, err)
|
||||
return fmt.Errorf("error reading catalog type %s with id %s: %v", catalogType, id, err)
|
||||
}
|
||||
|
||||
catalogs[id] = CatalogSource{
|
||||
sources[id] = CatalogSource{
|
||||
Provider: provider,
|
||||
Metadata: catalogConfig.CatalogSource,
|
||||
}
|
||||
|
@ -141,5 +170,36 @@ func LoadCatalogSources(catalogsPath string) (map[string]CatalogSource, error) {
|
|||
glog.Infof("loaded config %s of type %s", id, catalogType)
|
||||
}
|
||||
|
||||
return catalogs, nil
|
||||
sc.sourcesMu.Lock()
|
||||
defer sc.sourcesMu.Unlock()
|
||||
sc.sources = sources
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func LoadCatalogSources(path string) (*SourceCollection, error) {
|
||||
sc := &SourceCollection{}
|
||||
err := sc.load(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
changes, err := getMonitor().Path(path)
|
||||
if err != nil {
|
||||
glog.Errorf("unable to watch sources file: %v", err)
|
||||
// Not fatal, we just won't get automatic updates.
|
||||
}
|
||||
|
||||
for range changes {
|
||||
glog.Infof("Reloading sources %s", path)
|
||||
|
||||
err = sc.load(path)
|
||||
if err != nil {
|
||||
glog.Errorf("unable to load sources: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return sc, nil
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
model "github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
)
|
||||
|
||||
func TestLoadCatalogSources(t *testing.T) {
|
||||
|
@ -19,7 +21,7 @@ func TestLoadCatalogSources(t *testing.T) {
|
|||
{
|
||||
name: "test-catalog-sources",
|
||||
args: args{catalogsPath: "testdata/test-catalog-sources.yaml"},
|
||||
want: []string{"catalog1", "catalog2"},
|
||||
want: []string{"catalog1", "catalog3", "catalog4"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
@ -30,8 +32,8 @@ func TestLoadCatalogSources(t *testing.T) {
|
|||
t.Errorf("LoadCatalogSources() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
gotKeys := make([]string, 0, len(got))
|
||||
for k := range got {
|
||||
gotKeys := make([]string, 0, len(got.All()))
|
||||
for k := range got.All() {
|
||||
gotKeys = append(gotKeys, k)
|
||||
}
|
||||
sort.Strings(gotKeys)
|
||||
|
@ -41,3 +43,60 @@ func TestLoadCatalogSources(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCatalogSourcesEnabledDisabled(t *testing.T) {
|
||||
trueValue := true
|
||||
type args struct {
|
||||
catalogsPath string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want map[string]model.CatalogSource
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "test-catalog-sources-enabled-and-disabled",
|
||||
args: args{catalogsPath: "testdata/test-catalog-sources.yaml"},
|
||||
want: map[string]model.CatalogSource{
|
||||
"catalog1": {
|
||||
Id: "catalog1",
|
||||
Name: "Catalog 1",
|
||||
Enabled: &trueValue,
|
||||
},
|
||||
"catalog3": {
|
||||
Id: "catalog3",
|
||||
Name: "Catalog 3",
|
||||
Enabled: &trueValue,
|
||||
},
|
||||
"catalog4": {
|
||||
Id: "catalog4",
|
||||
Name: "Catalog 4",
|
||||
Enabled: &trueValue,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := LoadCatalogSources(tt.args.catalogsPath)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("LoadCatalogSources() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
gotMetadata := make(map[string]model.CatalogSource)
|
||||
for id, source := range got.All() {
|
||||
gotMetadata[id] = source.Metadata
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotMetadata, tt.want) {
|
||||
t.Errorf("LoadCatalogSources() got metadata = %#v, want %#v", gotMetadata, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
## Using Genqlient with the Red Hat Ecosystem Catalog
|
||||
|
||||
The Genqlient is used to retrieve metadata from a CatalogSource, in this case the Red Hat Ecosystem Catalog (RHEC), for the Model Catalog based on GraphQL queries to the RHEC API.
|
||||
|
||||
This directory contains the necessary files to generate a type-safe Go GraphQL client for the RHEC using [genqlient](https://github.com/Khan/genqlient).
|
||||
|
||||
### File Structure
|
||||
|
||||
- `genqlient.yaml`: The configuration file for `genqlient`. It specifies the location of the GraphQL schema, the directory containing the GraphQL queries, and the output file for the generated code.
|
||||
- `queries/`: This directory contains the GraphQL schema and query files.
|
||||
- `schema.graphql`: The GraphQL schema for the RHEC API.
|
||||
- `*.graphql`: Files containing the GraphQL queries.
|
||||
|
||||
### Generating the Client
|
||||
|
||||
To regenerate the client, you will first need to ensure the required tools are installed by running `make deps` from the project root. Once the tools are installed, you can generate the client by running the following command from the `catalog` directory:
|
||||
|
||||
```bash
|
||||
make gen/graphql
|
||||
```
|
||||
|
||||
This will generate the `generated.go` file in the current directory.
|
||||
|
||||
### Downloading the Schema
|
||||
|
||||
The `schema.graphql` file can be updated by downloading the latest version from the RHEC API. You can do this by running the following command from the `catalog` directory:
|
||||
|
||||
```bash
|
||||
make download/graphql-schema
|
||||
```
|
||||
|
||||
This will download the schema and save it to the correct location. After updating the schema, you should regenerate the client to ensure it is up to date.
|
|
@ -0,0 +1,389 @@
|
|||
// Code generated by github.com/Khan/genqlient, DO NOT EDIT.
|
||||
|
||||
package genqlient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Khan/genqlient/graphql"
|
||||
)
|
||||
|
||||
// FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponse includes the requested fields of the GraphQL type ContainerImagePaginatedResponse.
|
||||
type FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponse struct {
|
||||
Error FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseError `json:"error"`
|
||||
Total int `json:"total"`
|
||||
Data []FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImage `json:"data"`
|
||||
}
|
||||
|
||||
// GetError returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponse.Error, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponse) GetError() FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseError {
|
||||
return v.Error
|
||||
}
|
||||
|
||||
// GetTotal returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponse.Total, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponse) GetTotal() int {
|
||||
return v.Total
|
||||
}
|
||||
|
||||
// GetData returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponse.Data, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponse) GetData() []FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImage {
|
||||
return v.Data
|
||||
}
|
||||
|
||||
// FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImage includes the requested fields of the GraphQL type ContainerImage.
|
||||
// The GraphQL type's documentation follows.
|
||||
//
|
||||
// Metadata about images contained in RedHat and ISV repositories
|
||||
type FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImage struct {
|
||||
// The date when the entry was created. Value is created automatically on creation.
|
||||
Creation_date time.Time `json:"creation_date"`
|
||||
// The date when the entry was last updated.
|
||||
Last_update_date time.Time `json:"last_update_date"`
|
||||
// Published repositories associated with the container image.
|
||||
Repositories []FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageRepositoriesContainerImageRepo `json:"repositories"`
|
||||
// Data parsed from image metadata.
|
||||
// These fields are not computed from any other source.
|
||||
Parsed_data FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageParsed_dataParsedData `json:"parsed_data"`
|
||||
}
|
||||
|
||||
// GetCreation_date returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImage.Creation_date, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImage) GetCreation_date() time.Time {
|
||||
return v.Creation_date
|
||||
}
|
||||
|
||||
// GetLast_update_date returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImage.Last_update_date, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImage) GetLast_update_date() time.Time {
|
||||
return v.Last_update_date
|
||||
}
|
||||
|
||||
// GetRepositories returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImage.Repositories, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImage) GetRepositories() []FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageRepositoriesContainerImageRepo {
|
||||
return v.Repositories
|
||||
}
|
||||
|
||||
// GetParsed_data returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImage.Parsed_data, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImage) GetParsed_data() FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageParsed_dataParsedData {
|
||||
return v.Parsed_data
|
||||
}
|
||||
|
||||
// FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageParsed_dataParsedData includes the requested fields of the GraphQL type ParsedData.
|
||||
type FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageParsed_dataParsedData struct {
|
||||
Labels []FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageParsed_dataParsedDataLabelsLabel `json:"labels"`
|
||||
}
|
||||
|
||||
// GetLabels returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageParsed_dataParsedData.Labels, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageParsed_dataParsedData) GetLabels() []FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageParsed_dataParsedDataLabelsLabel {
|
||||
return v.Labels
|
||||
}
|
||||
|
||||
// FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageParsed_dataParsedDataLabelsLabel includes the requested fields of the GraphQL type Label.
|
||||
// The GraphQL type's documentation follows.
|
||||
//
|
||||
// Image label.
|
||||
type FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageParsed_dataParsedDataLabelsLabel struct {
|
||||
// The name of the label
|
||||
Name string `json:"name"`
|
||||
// Value of the label.
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// GetName returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageParsed_dataParsedDataLabelsLabel.Name, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageParsed_dataParsedDataLabelsLabel) GetName() string {
|
||||
return v.Name
|
||||
}
|
||||
|
||||
// GetValue returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageParsed_dataParsedDataLabelsLabel.Value, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageParsed_dataParsedDataLabelsLabel) GetValue() string {
|
||||
return v.Value
|
||||
}
|
||||
|
||||
// FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageRepositoriesContainerImageRepo includes the requested fields of the GraphQL type ContainerImageRepo.
|
||||
type FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageRepositoriesContainerImageRepo struct {
|
||||
// Hostname of the registry where the repository can be accessed.
|
||||
Registry string `json:"registry"`
|
||||
// List of container tags assigned to this layer.
|
||||
Tags []FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageRepositoriesContainerImageRepoTagsContainerImageRepoTag `json:"tags"`
|
||||
}
|
||||
|
||||
// GetRegistry returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageRepositoriesContainerImageRepo.Registry, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageRepositoriesContainerImageRepo) GetRegistry() string {
|
||||
return v.Registry
|
||||
}
|
||||
|
||||
// GetTags returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageRepositoriesContainerImageRepo.Tags, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageRepositoriesContainerImageRepo) GetTags() []FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageRepositoriesContainerImageRepoTagsContainerImageRepoTag {
|
||||
return v.Tags
|
||||
}
|
||||
|
||||
// FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageRepositoriesContainerImageRepoTagsContainerImageRepoTag includes the requested fields of the GraphQL type ContainerImageRepoTag.
|
||||
type FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageRepositoriesContainerImageRepoTagsContainerImageRepoTag struct {
|
||||
// The name of the tag.
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// GetName returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageRepositoriesContainerImageRepoTagsContainerImageRepoTag.Name, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImageRepositoriesContainerImageRepoTagsContainerImageRepoTag) GetName() string {
|
||||
return v.Name
|
||||
}
|
||||
|
||||
// FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseError includes the requested fields of the GraphQL type ResponseError.
|
||||
type FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseError struct {
|
||||
Detail string `json:"detail"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
// GetDetail returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseError.Detail, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseError) GetDetail() string {
|
||||
return v.Detail
|
||||
}
|
||||
|
||||
// GetStatus returns FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseError.Status, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseError) GetStatus() int {
|
||||
return v.Status
|
||||
}
|
||||
|
||||
// FindRepositoryImagesResponse is returned by FindRepositoryImages on success.
|
||||
type FindRepositoryImagesResponse struct {
|
||||
// List images for a repository. Exclude total for improved performance.
|
||||
Find_repository_images_by_registry_path FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponse `json:"find_repository_images_by_registry_path"`
|
||||
}
|
||||
|
||||
// GetFind_repository_images_by_registry_path returns FindRepositoryImagesResponse.Find_repository_images_by_registry_path, and is useful for accessing the field via an interface.
|
||||
func (v *FindRepositoryImagesResponse) GetFind_repository_images_by_registry_path() FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponse {
|
||||
return v.Find_repository_images_by_registry_path
|
||||
}
|
||||
|
||||
// GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponse includes the requested fields of the GraphQL type ContainerRepositoryResponse.
|
||||
type GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponse struct {
|
||||
Error GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseError `json:"error"`
|
||||
Data GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepository `json:"data"`
|
||||
}
|
||||
|
||||
// GetError returns GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponse.Error, and is useful for accessing the field via an interface.
|
||||
func (v *GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponse) GetError() GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseError {
|
||||
return v.Error
|
||||
}
|
||||
|
||||
// GetData returns GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponse.Data, and is useful for accessing the field via an interface.
|
||||
func (v *GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponse) GetData() GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepository {
|
||||
return v.Data
|
||||
}
|
||||
|
||||
// GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepository includes the requested fields of the GraphQL type ContainerRepository.
|
||||
// The GraphQL type's documentation follows.
|
||||
//
|
||||
// Contains metadata associated with Red Hat and ISV repositories
|
||||
type GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepository struct {
|
||||
// The date when the entry was created. Value is created automatically on creation.
|
||||
Creation_date time.Time `json:"creation_date"`
|
||||
// The date when the entry was last updated.
|
||||
Last_update_date time.Time `json:"last_update_date"`
|
||||
// The release categories of a repository.
|
||||
Release_categories []string `json:"release_categories"`
|
||||
// Label of the vendor that owns this repository.
|
||||
Vendor_label string `json:"vendor_label"`
|
||||
Display_data GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepositoryDisplay_dataRepositoryDisplayData `json:"display_data"`
|
||||
}
|
||||
|
||||
// GetCreation_date returns GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepository.Creation_date, and is useful for accessing the field via an interface.
|
||||
func (v *GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepository) GetCreation_date() time.Time {
|
||||
return v.Creation_date
|
||||
}
|
||||
|
||||
// GetLast_update_date returns GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepository.Last_update_date, and is useful for accessing the field via an interface.
|
||||
func (v *GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepository) GetLast_update_date() time.Time {
|
||||
return v.Last_update_date
|
||||
}
|
||||
|
||||
// GetRelease_categories returns GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepository.Release_categories, and is useful for accessing the field via an interface.
|
||||
func (v *GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepository) GetRelease_categories() []string {
|
||||
return v.Release_categories
|
||||
}
|
||||
|
||||
// GetVendor_label returns GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepository.Vendor_label, and is useful for accessing the field via an interface.
|
||||
func (v *GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepository) GetVendor_label() string {
|
||||
return v.Vendor_label
|
||||
}
|
||||
|
||||
// GetDisplay_data returns GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepository.Display_data, and is useful for accessing the field via an interface.
|
||||
func (v *GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepository) GetDisplay_data() GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepositoryDisplay_dataRepositoryDisplayData {
|
||||
return v.Display_data
|
||||
}
|
||||
|
||||
// GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepositoryDisplay_dataRepositoryDisplayData includes the requested fields of the GraphQL type RepositoryDisplayData.
|
||||
// The GraphQL type's documentation follows.
|
||||
//
|
||||
// Display data for Catalog.
|
||||
type GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepositoryDisplay_dataRepositoryDisplayData struct {
|
||||
// The short description of the repository.
|
||||
Short_description string `json:"short_description"`
|
||||
// The long description of the repository.
|
||||
Long_description string `json:"long_description"`
|
||||
}
|
||||
|
||||
// GetShort_description returns GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepositoryDisplay_dataRepositoryDisplayData.Short_description, and is useful for accessing the field via an interface.
|
||||
func (v *GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepositoryDisplay_dataRepositoryDisplayData) GetShort_description() string {
|
||||
return v.Short_description
|
||||
}
|
||||
|
||||
// GetLong_description returns GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepositoryDisplay_dataRepositoryDisplayData.Long_description, and is useful for accessing the field via an interface.
|
||||
func (v *GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseDataContainerRepositoryDisplay_dataRepositoryDisplayData) GetLong_description() string {
|
||||
return v.Long_description
|
||||
}
|
||||
|
||||
// GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseError includes the requested fields of the GraphQL type ResponseError.
|
||||
type GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseError struct {
|
||||
Detail string `json:"detail"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
// GetDetail returns GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseError.Detail, and is useful for accessing the field via an interface.
|
||||
func (v *GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseError) GetDetail() string {
|
||||
return v.Detail
|
||||
}
|
||||
|
||||
// GetStatus returns GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseError.Status, and is useful for accessing the field via an interface.
|
||||
func (v *GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponseError) GetStatus() int {
|
||||
return v.Status
|
||||
}
|
||||
|
||||
// GetRepositoryResponse is returned by GetRepository on success.
|
||||
type GetRepositoryResponse struct {
|
||||
// Get a repository by registry and path (product line/image name).
|
||||
Get_repository_by_registry_path GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponse `json:"get_repository_by_registry_path"`
|
||||
}
|
||||
|
||||
// GetGet_repository_by_registry_path returns GetRepositoryResponse.Get_repository_by_registry_path, and is useful for accessing the field via an interface.
|
||||
func (v *GetRepositoryResponse) GetGet_repository_by_registry_path() GetRepositoryGet_repository_by_registry_pathContainerRepositoryResponse {
|
||||
return v.Get_repository_by_registry_path
|
||||
}
|
||||
|
||||
// __FindRepositoryImagesInput is used internally by genqlient
|
||||
type __FindRepositoryImagesInput struct {
|
||||
Registry string `json:"registry"`
|
||||
Repository string `json:"repository"`
|
||||
}
|
||||
|
||||
// GetRegistry returns __FindRepositoryImagesInput.Registry, and is useful for accessing the field via an interface.
|
||||
func (v *__FindRepositoryImagesInput) GetRegistry() string { return v.Registry }
|
||||
|
||||
// GetRepository returns __FindRepositoryImagesInput.Repository, and is useful for accessing the field via an interface.
|
||||
func (v *__FindRepositoryImagesInput) GetRepository() string { return v.Repository }
|
||||
|
||||
// __GetRepositoryInput is used internally by genqlient
|
||||
type __GetRepositoryInput struct {
|
||||
Registry string `json:"registry"`
|
||||
Repository string `json:"repository"`
|
||||
}
|
||||
|
||||
// GetRegistry returns __GetRepositoryInput.Registry, and is useful for accessing the field via an interface.
|
||||
func (v *__GetRepositoryInput) GetRegistry() string { return v.Registry }
|
||||
|
||||
// GetRepository returns __GetRepositoryInput.Repository, and is useful for accessing the field via an interface.
|
||||
func (v *__GetRepositoryInput) GetRepository() string { return v.Repository }
|
||||
|
||||
// The query executed by FindRepositoryImages.
|
||||
const FindRepositoryImages_Operation = `
|
||||
query FindRepositoryImages ($registry: String!, $repository: String!) {
|
||||
find_repository_images_by_registry_path(registry: $registry, repository: $repository, sort_by: [{field:"creation_date",order:DESC}]) {
|
||||
error {
|
||||
detail
|
||||
status
|
||||
}
|
||||
total
|
||||
data {
|
||||
creation_date
|
||||
last_update_date
|
||||
repositories {
|
||||
registry
|
||||
tags {
|
||||
name
|
||||
}
|
||||
}
|
||||
parsed_data {
|
||||
labels {
|
||||
name
|
||||
value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
func FindRepositoryImages(
|
||||
ctx_ context.Context,
|
||||
client_ graphql.Client,
|
||||
registry string,
|
||||
repository string,
|
||||
) (data_ *FindRepositoryImagesResponse, err_ error) {
|
||||
req_ := &graphql.Request{
|
||||
OpName: "FindRepositoryImages",
|
||||
Query: FindRepositoryImages_Operation,
|
||||
Variables: &__FindRepositoryImagesInput{
|
||||
Registry: registry,
|
||||
Repository: repository,
|
||||
},
|
||||
}
|
||||
|
||||
data_ = &FindRepositoryImagesResponse{}
|
||||
resp_ := &graphql.Response{Data: data_}
|
||||
|
||||
err_ = client_.MakeRequest(
|
||||
ctx_,
|
||||
req_,
|
||||
resp_,
|
||||
)
|
||||
|
||||
return data_, err_
|
||||
}
|
||||
|
||||
// The query executed by GetRepository.
|
||||
const GetRepository_Operation = `
|
||||
query GetRepository ($registry: String!, $repository: String!) {
|
||||
get_repository_by_registry_path(registry: $registry, repository: $repository) {
|
||||
error {
|
||||
detail
|
||||
status
|
||||
}
|
||||
data {
|
||||
creation_date
|
||||
last_update_date
|
||||
release_categories
|
||||
vendor_label
|
||||
display_data {
|
||||
short_description
|
||||
long_description
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
func GetRepository(
|
||||
ctx_ context.Context,
|
||||
client_ graphql.Client,
|
||||
registry string,
|
||||
repository string,
|
||||
) (data_ *GetRepositoryResponse, err_ error) {
|
||||
req_ := &graphql.Request{
|
||||
OpName: "GetRepository",
|
||||
Query: GetRepository_Operation,
|
||||
Variables: &__GetRepositoryInput{
|
||||
Registry: registry,
|
||||
Repository: repository,
|
||||
},
|
||||
}
|
||||
|
||||
data_ = &GetRepositoryResponse{}
|
||||
resp_ := &graphql.Response{Data: data_}
|
||||
|
||||
err_ = client_.MakeRequest(
|
||||
ctx_,
|
||||
req_,
|
||||
resp_,
|
||||
)
|
||||
|
||||
return data_, err_
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
# genqlient.yaml
|
||||
schema: queries/schema.graphql
|
||||
operations:
|
||||
- "queries/find_repository_images.graphql"
|
||||
- "queries/get_repository.graphql"
|
||||
generated: generated.go
|
||||
package: genqlient
|
||||
bindings:
|
||||
DateTime:
|
||||
type: time.Time
|
|
@ -0,0 +1,29 @@
|
|||
query FindRepositoryImages($registry: String!, $repository: String!) {
|
||||
find_repository_images_by_registry_path(
|
||||
registry: $registry
|
||||
repository: $repository
|
||||
sort_by: [{ field: "creation_date", order: DESC }]
|
||||
) {
|
||||
error {
|
||||
detail
|
||||
status
|
||||
}
|
||||
total
|
||||
data {
|
||||
creation_date
|
||||
last_update_date
|
||||
repositories {
|
||||
registry
|
||||
tags {
|
||||
name
|
||||
}
|
||||
}
|
||||
parsed_data {
|
||||
labels {
|
||||
name
|
||||
value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
query GetRepository($registry: String!, $repository: String!) {
|
||||
get_repository_by_registry_path(
|
||||
registry: $registry
|
||||
repository: $repository
|
||||
) {
|
||||
error {
|
||||
detail
|
||||
status
|
||||
}
|
||||
data {
|
||||
creation_date
|
||||
last_update_date
|
||||
release_categories
|
||||
vendor_label
|
||||
display_data {
|
||||
short_description
|
||||
long_description
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,125 @@
|
|||
package catalog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang/glog"
|
||||
"github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
model "github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
)
|
||||
|
||||
type hfCatalogImpl struct {
|
||||
client *http.Client
|
||||
apiKey string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
var _ CatalogSourceProvider = &hfCatalogImpl{}
|
||||
|
||||
const (
|
||||
defaultHuggingFaceURL = "https://huggingface.co"
|
||||
)
|
||||
|
||||
func (h *hfCatalogImpl) GetModel(ctx context.Context, name string) (*openapi.CatalogModel, error) {
|
||||
// TODO: Implement HuggingFace model retrieval
|
||||
return nil, fmt.Errorf("HuggingFace model retrieval not yet implemented")
|
||||
}
|
||||
|
||||
func (h *hfCatalogImpl) ListModels(ctx context.Context, params ListModelsParams) (model.CatalogModelList, error) {
|
||||
// TODO: Implement HuggingFace model listing
|
||||
// For now, return empty list to satisfy interface
|
||||
return model.CatalogModelList{
|
||||
Items: []model.CatalogModel{},
|
||||
PageSize: 0,
|
||||
Size: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *hfCatalogImpl) GetArtifacts(ctx context.Context, name string) (*openapi.CatalogModelArtifactList, error) {
|
||||
// TODO: Implement HuggingFace model artifacts retrieval
|
||||
// For now, return empty list to satisfy interface
|
||||
return &openapi.CatalogModelArtifactList{
|
||||
Items: []openapi.CatalogModelArtifact{},
|
||||
PageSize: 0,
|
||||
Size: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validateCredentials checks if the HuggingFace API credentials are valid
|
||||
func (h *hfCatalogImpl) validateCredentials(ctx context.Context) error {
|
||||
glog.Infof("Validating HuggingFace API credentials")
|
||||
|
||||
// Make a simple API call to validate credentials
|
||||
apiURL := h.baseURL + "/api/whoami-v2"
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create validation request: %w", err)
|
||||
}
|
||||
|
||||
if h.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+h.apiKey)
|
||||
}
|
||||
|
||||
resp, err := h.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to validate HuggingFace credentials: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
return fmt.Errorf("invalid HuggingFace API credentials")
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("HuggingFace API validation failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
glog.Infof("HuggingFace credentials validated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// newHfCatalog creates a new HuggingFace catalog source
|
||||
func newHfCatalog(source *CatalogSourceConfig) (CatalogSourceProvider, error) {
|
||||
apiKey, ok := source.Properties["apiKey"].(string)
|
||||
if !ok || apiKey == "" {
|
||||
return nil, fmt.Errorf("missing or invalid 'apiKey' property for HuggingFace catalog")
|
||||
}
|
||||
|
||||
baseURL := defaultHuggingFaceURL
|
||||
if url, ok := source.Properties["url"].(string); ok && url != "" {
|
||||
baseURL = strings.TrimSuffix(url, "/")
|
||||
}
|
||||
|
||||
// Optional model limit for future implementation
|
||||
modelLimit := 100
|
||||
if limit, ok := source.Properties["modelLimit"].(int); ok && limit > 0 {
|
||||
modelLimit = limit
|
||||
}
|
||||
|
||||
glog.Infof("Configuring HuggingFace catalog with URL: %s, modelLimit: %d", baseURL, modelLimit)
|
||||
|
||||
h := &hfCatalogImpl{
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
apiKey: apiKey,
|
||||
baseURL: baseURL,
|
||||
}
|
||||
|
||||
// Validate credentials during initialization (as required by Jira ticket)
|
||||
ctx := context.Background()
|
||||
if err := h.validateCredentials(ctx); err != nil {
|
||||
glog.Errorf("HuggingFace catalog credential validation failed: %v", err)
|
||||
return nil, fmt.Errorf("failed to validate HuggingFace catalog credentials: %w", err)
|
||||
}
|
||||
|
||||
glog.Infof("HuggingFace catalog source configured successfully")
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
if err := RegisterCatalogType("hf", newHfCatalog); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,174 @@
|
|||
package catalog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
)
|
||||
|
||||
func TestNewHfCatalog_MissingAPIKey(t *testing.T) {
|
||||
source := &CatalogSourceConfig{
|
||||
CatalogSource: openapi.CatalogSource{
|
||||
Id: "test_hf",
|
||||
Name: "Test HF",
|
||||
},
|
||||
Type: "hf",
|
||||
Properties: map[string]any{
|
||||
"url": "https://huggingface.co",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := newHfCatalog(source)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing API key, got nil")
|
||||
}
|
||||
if err.Error() != "missing or invalid 'apiKey' property for HuggingFace catalog" {
|
||||
t.Fatalf("Expected specific error message, got: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHfCatalog_WithValidCredentials(t *testing.T) {
|
||||
// Create mock server that returns valid response for credential validation
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check for authorization header
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth != "Bearer test-api-key" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.URL.Path {
|
||||
case "/api/whoami-v2":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"name": "test-user", "type": "user"}`))
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
source := &CatalogSourceConfig{
|
||||
CatalogSource: openapi.CatalogSource{
|
||||
Id: "test_hf",
|
||||
Name: "Test HF",
|
||||
},
|
||||
Type: "hf",
|
||||
Properties: map[string]any{
|
||||
"apiKey": "test-api-key",
|
||||
"url": server.URL,
|
||||
"modelLimit": 10,
|
||||
},
|
||||
}
|
||||
|
||||
catalog, err := newHfCatalog(source)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create HF catalog: %v", err)
|
||||
}
|
||||
|
||||
hfCatalog := catalog.(*hfCatalogImpl)
|
||||
|
||||
// Test that methods return appropriate responses for stub implementation
|
||||
ctx := context.Background()
|
||||
|
||||
// Test GetModel - should return not implemented error
|
||||
model, err := hfCatalog.GetModel(ctx, "test-model")
|
||||
if err == nil {
|
||||
t.Fatal("Expected not implemented error, got nil")
|
||||
}
|
||||
if model != nil {
|
||||
t.Fatal("Expected nil model, got non-nil")
|
||||
}
|
||||
|
||||
// Test ListModels - should return empty list
|
||||
listParams := ListModelsParams{
|
||||
Query: "",
|
||||
OrderBy: openapi.ORDERBYFIELD_NAME,
|
||||
SortOrder: openapi.SORTORDER_ASC,
|
||||
}
|
||||
modelList, err := hfCatalog.ListModels(ctx, listParams)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list models: %v", err)
|
||||
}
|
||||
if len(modelList.Items) != 0 {
|
||||
t.Fatalf("Expected 0 models, got %d", len(modelList.Items))
|
||||
}
|
||||
|
||||
// Test GetArtifacts - should return empty list
|
||||
artifacts, err := hfCatalog.GetArtifacts(ctx, "test-model")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get artifacts: %v", err)
|
||||
}
|
||||
if artifacts == nil {
|
||||
t.Fatal("Expected artifacts list, got nil")
|
||||
}
|
||||
if len(artifacts.Items) != 0 {
|
||||
t.Fatalf("Expected 0 artifacts, got %d", len(artifacts.Items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHfCatalog_InvalidCredentials(t *testing.T) {
|
||||
// Create mock server that returns 401
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
source := &CatalogSourceConfig{
|
||||
CatalogSource: openapi.CatalogSource{
|
||||
Id: "test_hf",
|
||||
Name: "Test HF",
|
||||
},
|
||||
Type: "hf",
|
||||
Properties: map[string]any{
|
||||
"apiKey": "invalid-key",
|
||||
"url": server.URL,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := newHfCatalog(source)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid credentials, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid HuggingFace API credentials") {
|
||||
t.Fatalf("Expected credential validation error, got: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHfCatalog_DefaultConfiguration(t *testing.T) {
|
||||
// Create mock server for default HuggingFace URL
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"name": "test-user"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
source := &CatalogSourceConfig{
|
||||
CatalogSource: openapi.CatalogSource{
|
||||
Id: "test_hf",
|
||||
Name: "Test HF",
|
||||
},
|
||||
Type: "hf",
|
||||
Properties: map[string]any{
|
||||
"apiKey": "test-key",
|
||||
"url": server.URL, // Override default for testing
|
||||
},
|
||||
}
|
||||
|
||||
catalog, err := newHfCatalog(source)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create HF catalog with defaults: %v", err)
|
||||
}
|
||||
|
||||
hfCatalog := catalog.(*hfCatalogImpl)
|
||||
if hfCatalog.apiKey != "test-key" {
|
||||
t.Fatalf("Expected apiKey 'test-key', got '%s'", hfCatalog.apiKey)
|
||||
}
|
||||
if hfCatalog.baseURL != server.URL {
|
||||
t.Fatalf("Expected baseURL '%s', got '%s'", server.URL, hfCatalog.baseURL)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,238 @@
|
|||
package catalog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/golang/glog"
|
||||
)
|
||||
|
||||
// monitor sends events when the contents of a file have changed.
|
||||
//
|
||||
// Unfortunately, simply watching the file misses events for our primary case
|
||||
// of k8s mounted configmaps because the files we're watching are actually
|
||||
// symlinks which aren't modified:
|
||||
//
|
||||
// drwxrwxrwx 1 root root 138 Jul 2 15:45 .
|
||||
// drwxr-xr-x 1 root root 116 Jul 2 15:52 ..
|
||||
// drwxr-xr-x 1 root root 62 Jul 2 15:45 ..2025_07_02_15_45_09.2837733502
|
||||
// lrwxrwxrwx 1 root root 32 Jul 2 15:45 ..data -> ..2025_07_02_15_45_09.2837733502
|
||||
// lrwxrwxrwx 1 root root 26 Jul 2 13:18 sample-catalog.yaml -> ..data/sample-catalog.yaml
|
||||
// lrwxrwxrwx 1 root root 19 Jul 2 13:18 sources.yaml -> ..data/sources.yaml
|
||||
//
|
||||
// Updates are written to a new directory and the ..data symlink is updated. No
|
||||
// fsnotify events will ever be triggered for the YAML files.
|
||||
//
|
||||
// The approach taken here is to watch the directory containing the file for
|
||||
// any change and then hash the contents of the file to avoid false-positives.
|
||||
type monitor struct {
|
||||
watcher *fsnotify.Watcher
|
||||
closed <-chan struct{}
|
||||
|
||||
recordsMu sync.RWMutex
|
||||
records map[string]map[string]*monitorRecord
|
||||
}
|
||||
|
||||
var _monitor *monitor
|
||||
var initMonitor sync.Once
|
||||
|
||||
// getMonitor returns a singleton monitor instance. Panics on failure.
|
||||
func getMonitor() *monitor {
|
||||
initMonitor.Do(func() {
|
||||
var err error
|
||||
_monitor, err = newMonitor()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Unable to create file monitor: %v", err))
|
||||
}
|
||||
})
|
||||
if _monitor == nil {
|
||||
// Panic in case someone traps the panic that occurred during
|
||||
// initialization and tries to call this again.
|
||||
panic("Unable to get file monitor")
|
||||
}
|
||||
|
||||
return _monitor
|
||||
}
|
||||
|
||||
func newMonitor() (*monitor, error) {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &monitor{
|
||||
watcher: watcher,
|
||||
records: map[string]map[string]*monitorRecord{},
|
||||
}
|
||||
|
||||
go m.monitor()
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Close stops the monitor and waits for the background goroutine to exit.
|
||||
//
|
||||
// All channels returned by Path() will be closed.
|
||||
func (m *monitor) Close() {
|
||||
select {
|
||||
case <-m.closed:
|
||||
// Already closed, nothing to do.
|
||||
return
|
||||
default:
|
||||
// Fallthrough
|
||||
}
|
||||
|
||||
m.watcher.Close()
|
||||
<-m.closed
|
||||
|
||||
m.recordsMu.Lock()
|
||||
defer m.recordsMu.Unlock()
|
||||
|
||||
uniqCh := make(map[chan<- struct{}]struct{})
|
||||
for dir := range m.records {
|
||||
for file := range m.records[dir] {
|
||||
record, ok := m.records[dir][file]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, ch := range record.channels {
|
||||
uniqCh[ch] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
for ch := range uniqCh {
|
||||
close(ch)
|
||||
}
|
||||
m.records = nil
|
||||
}
|
||||
|
||||
// Path returns a channel that receives an event when the contents of a file
|
||||
// change. The file does not need to exist before calling this method, however
|
||||
// the provided path should only be a file or a symlink (not a directory,
|
||||
// device, etc.). The returned channel will be closed when the monitor is
|
||||
// closed.
|
||||
func (m *monitor) Path(p string) (<-chan struct{}, error) {
|
||||
absPath, err := filepath.Abs(p)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("abs: %w", err)
|
||||
}
|
||||
|
||||
m.recordsMu.Lock()
|
||||
defer m.recordsMu.Unlock()
|
||||
|
||||
dir, base := filepath.Split(absPath)
|
||||
dir = filepath.Clean(dir)
|
||||
|
||||
err = m.watcher.Add(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to watch directory %q: %w", dir, err)
|
||||
}
|
||||
|
||||
if _, exists := m.records[dir]; !exists {
|
||||
m.records[dir] = make(map[string]*monitorRecord, 1)
|
||||
}
|
||||
|
||||
ch := make(chan struct{}, 1)
|
||||
|
||||
if _, exists := m.records[dir][base]; !exists {
|
||||
m.records[dir][base] = &monitorRecord{
|
||||
channels: []chan<- struct{}{ch},
|
||||
}
|
||||
} else {
|
||||
r := m.records[dir][base]
|
||||
r.channels = append(r.channels, ch)
|
||||
}
|
||||
m.records[dir][base].updateHash(filepath.Join(dir, base))
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (m *monitor) monitor() {
|
||||
closed := make(chan struct{})
|
||||
m.closed = closed
|
||||
defer close(closed)
|
||||
|
||||
for {
|
||||
select {
|
||||
case err, ok := <-m.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
glog.Errorf("fsnotify error: %v", err)
|
||||
case e, ok := <-m.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
glog.V(2).Infof("fsnotify.Event: %v", e)
|
||||
|
||||
switch e.Op {
|
||||
case fsnotify.Create, fsnotify.Write:
|
||||
// Fallthrough
|
||||
default:
|
||||
// Ignore fsnotify.Remove, fsnotify.Rename and fsnotify.Chmod
|
||||
continue
|
||||
}
|
||||
|
||||
func() {
|
||||
m.recordsMu.RLock()
|
||||
defer m.recordsMu.RUnlock()
|
||||
|
||||
dir := filepath.Dir(e.Name)
|
||||
|
||||
dc := m.records[dir]
|
||||
if dc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for base, record := range dc {
|
||||
path := filepath.Join(dir, base)
|
||||
if !record.updateHash(path) {
|
||||
continue
|
||||
}
|
||||
for _, ch := range record.channels {
|
||||
// Send the event, ignore any that would block.
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
glog.Errorf("monitor: missed event for path %s", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type monitorRecord struct {
|
||||
channels []chan<- struct{}
|
||||
hash uint32
|
||||
}
|
||||
|
||||
// updateHash recalculates the hash and returns true if it has changed.
|
||||
func (mr *monitorRecord) updateHash(path string) bool {
|
||||
newHash := mr.calculateHash(path)
|
||||
oldHash := atomic.SwapUint32(&mr.hash, newHash)
|
||||
return oldHash != newHash
|
||||
}
|
||||
|
||||
func (monitorRecord) calculateHash(path string) uint32 {
|
||||
fh, err := os.Open(path)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
defer fh.Close()
|
||||
|
||||
h := crc32.NewIEEE()
|
||||
_, err = io.Copy(h, fh)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return h.Sum32()
|
||||
}
|
|
@ -0,0 +1,179 @@
|
|||
package catalog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMonitor(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
mon, err := newMonitor()
|
||||
if !assert.NoError(err) {
|
||||
return
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
fileA := filepath.Join(tmpDir, "a")
|
||||
fileB := filepath.Join(tmpDir, "b")
|
||||
fileC := filepath.Join(tmpDir, "c")
|
||||
|
||||
_watchMonitor := func(ch <-chan struct{}, err error) *monitorWatcher {
|
||||
if err != nil {
|
||||
t.Fatalf("watchMonitor passed error %v", err)
|
||||
}
|
||||
return watchMonitor(ch)
|
||||
}
|
||||
|
||||
a := _watchMonitor(mon.Path(fileA))
|
||||
b := _watchMonitor(mon.Path(fileB))
|
||||
|
||||
updateFile(t, fileA)
|
||||
a.AssertCount(t, 1)
|
||||
b.AssertCount(t, 0, "unchanged file should not have any events")
|
||||
|
||||
a.Reset()
|
||||
updateFile(t, fileB)
|
||||
b.AssertCount(t, 1)
|
||||
updateFile(t, fileB)
|
||||
b.AssertCount(t, 2)
|
||||
a.AssertCount(t, 0, "unchanged file should not have any events")
|
||||
|
||||
b.Reset()
|
||||
updateFile(t, fileC)
|
||||
a.AssertCount(t, 0, "unchanged file should not have an event")
|
||||
b.AssertCount(t, 0, "unchanged file should not have an event")
|
||||
|
||||
// Ensure that Close doesn't hang.
|
||||
finished := make(chan struct{})
|
||||
go func() {
|
||||
defer close(finished)
|
||||
mon.Close()
|
||||
}()
|
||||
assert.Eventually(func() bool {
|
||||
select {
|
||||
case <-finished:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, time.Second, 50*time.Millisecond)
|
||||
|
||||
// Verify that the monitor channels closed.
|
||||
assert.True(a.Done())
|
||||
assert.True(b.Done())
|
||||
}
|
||||
|
||||
func TestMonitorSymlinks(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
mon, err := newMonitor()
|
||||
if !assert.NoError(err) {
|
||||
return
|
||||
}
|
||||
defer mon.Close()
|
||||
|
||||
// Watch the files on the published path.
|
||||
_watchMonitor := func(ch <-chan struct{}, err error) *monitorWatcher {
|
||||
if err != nil {
|
||||
t.Fatalf("watchMonitor passed error %v", err)
|
||||
}
|
||||
return watchMonitor(ch)
|
||||
}
|
||||
|
||||
a := _watchMonitor(mon.Path(filepath.Join(tmpDir, "a")))
|
||||
b := _watchMonitor(mon.Path(filepath.Join(tmpDir, "b")))
|
||||
|
||||
// Set up a directory structure with symlinks like k8s does for mounted
|
||||
// configmaps.
|
||||
// a -> latest/a, b -> latest/b, latest -> v1
|
||||
assert.NoError(os.Mkdir(filepath.Join(tmpDir, "v1"), 0777))
|
||||
updateFile(t, filepath.Join(tmpDir, "v1", "a"), "foo")
|
||||
updateFile(t, filepath.Join(tmpDir, "v1", "b"), "bar")
|
||||
assert.NoError(os.Symlink("v1", filepath.Join(tmpDir, "latest")))
|
||||
assert.NoError(os.Symlink(filepath.Join("latest", "a"), filepath.Join(tmpDir, "a")))
|
||||
assert.NoError(os.Symlink(filepath.Join("latest", "b"), filepath.Join(tmpDir, "b")))
|
||||
|
||||
a.AssertCount(t, 1)
|
||||
b.AssertCount(t, 1)
|
||||
a.Reset()
|
||||
b.Reset()
|
||||
|
||||
// Make a new version directory
|
||||
os.Mkdir(filepath.Join(tmpDir, "v2"), 0777)
|
||||
updateFile(t, filepath.Join(tmpDir, "v2", "a"), "UPDATED")
|
||||
updateFile(t, filepath.Join(tmpDir, "v2", "b"), "bar")
|
||||
|
||||
a.AssertCount(t, 0)
|
||||
b.AssertCount(t, 0)
|
||||
a.Reset()
|
||||
b.Reset()
|
||||
|
||||
// Update the symlink to point to the new version:
|
||||
assert.NoError(os.Rename(filepath.Join(tmpDir, "latest"), filepath.Join(tmpDir, "latest_tmp")))
|
||||
assert.NoError(os.Symlink(filepath.Join("v2"), filepath.Join(tmpDir, "latest")))
|
||||
assert.NoError(os.Remove(filepath.Join(tmpDir, "latest_tmp")))
|
||||
assert.NoError(os.RemoveAll(filepath.Join(tmpDir, "v1")))
|
||||
|
||||
a.AssertCount(t, 1)
|
||||
b.AssertCount(t, 0)
|
||||
}
|
||||
|
||||
type monitorWatcher struct {
|
||||
count int32
|
||||
done int32
|
||||
}
|
||||
|
||||
func (mw *monitorWatcher) Reset() {
|
||||
atomic.StoreInt32(&mw.count, 0)
|
||||
}
|
||||
|
||||
func (mw *monitorWatcher) AssertCount(t *testing.T, expected int, args ...any) bool {
|
||||
t.Helper()
|
||||
return assert.Eventually(t, func() bool {
|
||||
return int(atomic.LoadInt32(&mw.count)) == expected
|
||||
}, time.Second, 10*time.Millisecond, args...)
|
||||
}
|
||||
|
||||
func (mw *monitorWatcher) Count() int {
|
||||
return int(atomic.LoadInt32(&mw.count))
|
||||
}
|
||||
|
||||
func (mw *monitorWatcher) Done() bool {
|
||||
return atomic.LoadInt32(&mw.done) != 0
|
||||
}
|
||||
|
||||
func watchMonitor(ch <-chan struct{}) *monitorWatcher {
|
||||
mw := &monitorWatcher{}
|
||||
|
||||
go func() {
|
||||
defer atomic.StoreInt32(&mw.done, 1)
|
||||
for range ch {
|
||||
atomic.AddInt32(&mw.count, 1)
|
||||
}
|
||||
}()
|
||||
|
||||
return mw
|
||||
}
|
||||
|
||||
func updateFile(t *testing.T, path string, contents ...string) {
|
||||
fh, err := os.Create(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open %q: %v", path, err)
|
||||
}
|
||||
if len(contents) == 0 {
|
||||
fmt.Fprintf(fh, "%s\n", time.Now())
|
||||
} else {
|
||||
for _, line := range contents {
|
||||
fmt.Fprintf(fh, "%s\n", line)
|
||||
}
|
||||
}
|
||||
fh.Close()
|
||||
}
|
|
@ -0,0 +1,334 @@
|
|||
package catalog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Khan/genqlient/graphql"
|
||||
"github.com/kubeflow/model-registry/catalog/internal/catalog/genqlient"
|
||||
"github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
model "github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
models "github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
)
|
||||
|
||||
type rhecModel struct {
|
||||
models.CatalogModel `yaml:",inline"`
|
||||
Artifacts []*openapi.CatalogModelArtifact `yaml:"artifacts"`
|
||||
}
|
||||
|
||||
// rhecCatalogConfig defines the structure of the RHEC catalog configuration.
|
||||
type rhecCatalogConfig struct {
|
||||
Models []string `yaml:"models"`
|
||||
ExcludedModels []string `yaml:"excludedModels"`
|
||||
}
|
||||
|
||||
type rhecCatalogImpl struct {
|
||||
modelsLock sync.RWMutex
|
||||
models map[string]*rhecModel
|
||||
}
|
||||
|
||||
var _ CatalogSourceProvider = &rhecCatalogImpl{}
|
||||
|
||||
func (r *rhecCatalogImpl) GetModel(ctx context.Context, name string) (*openapi.CatalogModel, error) {
|
||||
r.modelsLock.RLock()
|
||||
defer r.modelsLock.RUnlock()
|
||||
|
||||
rm := r.models[name]
|
||||
if rm == nil {
|
||||
return nil, nil
|
||||
}
|
||||
cp := rm.CatalogModel
|
||||
return &cp, nil
|
||||
}
|
||||
|
||||
func (r *rhecCatalogImpl) ListModels(ctx context.Context, params ListModelsParams) (openapi.CatalogModelList, error) {
|
||||
r.modelsLock.RLock()
|
||||
defer r.modelsLock.RUnlock()
|
||||
|
||||
var filteredModels []*model.CatalogModel
|
||||
for _, rm := range r.models {
|
||||
cm := rm.CatalogModel
|
||||
if params.Query != "" {
|
||||
query := strings.ToLower(params.Query)
|
||||
// Check if query matches name, description, tasks, provider, or libraryName
|
||||
if !strings.Contains(strings.ToLower(cm.Name), query) &&
|
||||
!strings.Contains(strings.ToLower(cm.GetDescription()), query) &&
|
||||
!strings.Contains(strings.ToLower(cm.GetProvider()), query) &&
|
||||
!strings.Contains(strings.ToLower(cm.GetLibraryName()), query) {
|
||||
|
||||
// Check tasks
|
||||
foundInTasks := false
|
||||
for _, task := range cm.GetTasks() { // Use GetTasks() for nil safety
|
||||
if strings.Contains(strings.ToLower(task), query) {
|
||||
foundInTasks = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundInTasks {
|
||||
continue // Skip if no match in any searchable field
|
||||
}
|
||||
}
|
||||
}
|
||||
filteredModels = append(filteredModels, &cm)
|
||||
}
|
||||
|
||||
// Sort the filtered models
|
||||
sort.Slice(filteredModels, func(i, j int) bool {
|
||||
a := filteredModels[i]
|
||||
b := filteredModels[j]
|
||||
|
||||
var less bool
|
||||
switch params.OrderBy {
|
||||
case model.ORDERBYFIELD_CREATE_TIME:
|
||||
// Convert CreateTimeSinceEpoch (string) to int64 for comparison
|
||||
// Handle potential nil or conversion errors by treating as 0
|
||||
aTime, _ := strconv.ParseInt(a.GetCreateTimeSinceEpoch(), 10, 64)
|
||||
bTime, _ := strconv.ParseInt(b.GetCreateTimeSinceEpoch(), 10, 64)
|
||||
less = aTime < bTime
|
||||
case model.ORDERBYFIELD_LAST_UPDATE_TIME:
|
||||
// Convert LastUpdateTimeSinceEpoch (string) to int64 for comparison
|
||||
// Handle potential nil or conversion errors by treating as 0
|
||||
aTime, _ := strconv.ParseInt(a.GetLastUpdateTimeSinceEpoch(), 10, 64)
|
||||
bTime, _ := strconv.ParseInt(b.GetLastUpdateTimeSinceEpoch(), 10, 64)
|
||||
less = aTime < bTime
|
||||
case model.ORDERBYFIELD_NAME:
|
||||
fallthrough
|
||||
default:
|
||||
// Fallback to name sort if an unknown sort field is provided
|
||||
less = strings.Compare(a.Name, b.Name) < 0
|
||||
}
|
||||
|
||||
if params.SortOrder == model.SORTORDER_DESC {
|
||||
return !less
|
||||
}
|
||||
return less
|
||||
})
|
||||
|
||||
count := len(filteredModels)
|
||||
if count > math.MaxInt32 {
|
||||
count = math.MaxInt32
|
||||
}
|
||||
|
||||
list := model.CatalogModelList{
|
||||
Items: make([]model.CatalogModel, count),
|
||||
PageSize: int32(count),
|
||||
Size: int32(count),
|
||||
}
|
||||
for i := range list.Items {
|
||||
list.Items[i] = *filteredModels[i]
|
||||
}
|
||||
return list, nil // Return the struct value directly
|
||||
}
|
||||
|
||||
func (r *rhecCatalogImpl) GetArtifacts(ctx context.Context, name string) (*openapi.CatalogModelArtifactList, error) {
|
||||
r.modelsLock.RLock()
|
||||
defer r.modelsLock.RUnlock()
|
||||
|
||||
rm := r.models[name]
|
||||
if rm == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
count := len(rm.Artifacts)
|
||||
if count > math.MaxInt32 {
|
||||
count = math.MaxInt32
|
||||
}
|
||||
|
||||
list := openapi.CatalogModelArtifactList{
|
||||
Items: make([]openapi.CatalogModelArtifact, count),
|
||||
PageSize: int32(count),
|
||||
Size: int32(count),
|
||||
}
|
||||
for i := range list.Items {
|
||||
list.Items[i] = *rm.Artifacts[i]
|
||||
}
|
||||
return &list, nil
|
||||
}
|
||||
|
||||
func fetchRepository(ctx context.Context, client graphql.Client, repository string) (*genqlient.GetRepositoryResponse, error) {
|
||||
resp, err := genqlient.GetRepository(ctx, client, "registry.access.redhat.com", repository)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query rhec repository: %w", err)
|
||||
}
|
||||
|
||||
if err := resp.Get_repository_by_registry_path.Error; err.Detail != "" || err.Status != 0 {
|
||||
return nil, fmt.Errorf("rhec repository query error: detail: %s, status: %d", err.Detail, err.Status)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func fetchRepositoryImages(ctx context.Context, client graphql.Client, repository string) ([]genqlient.FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImage, error) {
|
||||
resp, err := genqlient.FindRepositoryImages(ctx, client, "registry.access.redhat.com", repository)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query rhec images: %w", err)
|
||||
}
|
||||
|
||||
if err := resp.Find_repository_images_by_registry_path.Error; err.Detail != "" || err.Status != 0 {
|
||||
return nil, fmt.Errorf("rhec images query error: detail: %s, status: %d", err.Detail, err.Status)
|
||||
}
|
||||
return resp.Find_repository_images_by_registry_path.Data, nil
|
||||
}
|
||||
|
||||
func newRhecModel(repoData *genqlient.GetRepositoryResponse, imageData genqlient.FindRepositoryImagesFind_repository_images_by_registry_pathContainerImagePaginatedResponseDataContainerImage, imageTagName, repositoryName string) *rhecModel {
|
||||
|
||||
sourceId := "rhec"
|
||||
createTime := repoData.Get_repository_by_registry_path.Data.Creation_date.Format(time.RFC3339)
|
||||
lastUpdateTime := repoData.Get_repository_by_registry_path.Data.Last_update_date.Format(time.RFC3339)
|
||||
description := repoData.Get_repository_by_registry_path.Data.Display_data.Short_description
|
||||
readme := repoData.Get_repository_by_registry_path.Data.Display_data.Long_description
|
||||
provider := repoData.Get_repository_by_registry_path.Data.Vendor_label
|
||||
|
||||
var maturity *string
|
||||
if len(repoData.Get_repository_by_registry_path.Data.Release_categories) > 0 {
|
||||
maturityStr := repoData.Get_repository_by_registry_path.Data.Release_categories[0]
|
||||
maturity = &maturityStr
|
||||
}
|
||||
|
||||
var tasks []string
|
||||
for _, label := range imageData.Parsed_data.Labels {
|
||||
tasks = append(tasks, label.Value)
|
||||
}
|
||||
imageCreationDate := imageData.Creation_date.Format(time.RFC3339)
|
||||
imageLastUpdateDate := imageData.Last_update_date.Format(time.RFC3339)
|
||||
|
||||
modelName := repositoryName + ":" + imageTagName
|
||||
|
||||
return &rhecModel{
|
||||
CatalogModel: openapi.CatalogModel{
|
||||
Name: modelName,
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Description: &description,
|
||||
Readme: &readme,
|
||||
Maturity: maturity,
|
||||
Language: []string{},
|
||||
Tasks: tasks,
|
||||
Provider: &provider,
|
||||
Logo: nil,
|
||||
License: nil,
|
||||
LicenseLink: nil,
|
||||
LibraryName: nil,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
Artifacts: []*openapi.CatalogModelArtifact{
|
||||
{
|
||||
Uri: "oci://registry.redhat.io/" + repositoryName + ":" + imageTagName,
|
||||
CreateTimeSinceEpoch: &imageCreationDate,
|
||||
LastUpdateTimeSinceEpoch: &imageLastUpdateDate,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rhecCatalogImpl) load(modelsList []string, excludedModelsList []string) error {
|
||||
graphqlClient := graphql.NewClient("https://catalog.redhat.com/api/containers/graphql/", http.DefaultClient)
|
||||
ctx := context.Background()
|
||||
|
||||
models := make(map[string]*rhecModel)
|
||||
for _, repo := range modelsList {
|
||||
repoData, err := fetchRepository(ctx, graphqlClient, repo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
imagesData, err := fetchRepositoryImages(ctx, graphqlClient, repo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, image := range imagesData {
|
||||
for _, imageRepository := range image.Repositories {
|
||||
for _, imageTag := range imageRepository.Tags {
|
||||
tagName := imageTag.Name
|
||||
fullModelName := repo + ":" + tagName
|
||||
|
||||
if isModelExcluded(fullModelName, excludedModelsList) {
|
||||
continue
|
||||
}
|
||||
|
||||
model := newRhecModel(repoData, image, tagName, repo)
|
||||
models[fullModelName] = model
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.modelsLock.Lock()
|
||||
defer r.modelsLock.Unlock()
|
||||
r.models = models
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isModelExcluded(modelName string, patterns []string) bool {
|
||||
for _, pattern := range patterns {
|
||||
if strings.HasSuffix(pattern, "*") {
|
||||
if strings.HasPrefix(modelName, strings.TrimSuffix(pattern, "*")) {
|
||||
return true
|
||||
}
|
||||
} else if modelName == pattern {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func newRhecCatalog(source *CatalogSourceConfig) (CatalogSourceProvider, error) {
|
||||
modelsData, ok := source.Properties["models"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'models' property for rhec catalog")
|
||||
}
|
||||
|
||||
modelsList, ok := modelsData.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("'models' property should be a list")
|
||||
}
|
||||
|
||||
models := make([]string, len(modelsList))
|
||||
for i, v := range modelsList {
|
||||
models[i], ok = v.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid entry in 'models' list, expected a string")
|
||||
}
|
||||
}
|
||||
|
||||
// Excluded models is an optional source property.
|
||||
var excludedModels []string
|
||||
if excludedModelsData, ok := source.Properties["excludedModels"]; ok {
|
||||
excludedModelsList, ok := excludedModelsData.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("'excludedModels' property should be a list")
|
||||
}
|
||||
excludedModels = make([]string, len(excludedModelsList))
|
||||
for i, v := range excludedModelsList {
|
||||
excludedModels[i], ok = v.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid entry in 'excludedModels' list, expected a string")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r := &rhecCatalogImpl{
|
||||
models: make(map[string]*rhecModel),
|
||||
}
|
||||
|
||||
err := r.load(models, excludedModels)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error loading rhec catalog: %w", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
if err := RegisterCatalogType("rhec", newRhecCatalog); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,393 @@
|
|||
package catalog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
model "github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
)
|
||||
|
||||
func TestRhecCatalogImpl_GetModel(t *testing.T) {
|
||||
modelTime := time.Now()
|
||||
createTime := modelTime.Format(time.RFC3339)
|
||||
lastUpdateTime := modelTime.Add(5 * time.Minute).Format(time.RFC3339)
|
||||
sourceId := "rhec"
|
||||
provider := "redhat"
|
||||
|
||||
testModels := map[string]*rhecModel{
|
||||
"model1": {
|
||||
CatalogModel: openapi.CatalogModel{
|
||||
Name: "model1",
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Provider: &provider,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
Artifacts: []*openapi.CatalogModelArtifact{},
|
||||
},
|
||||
}
|
||||
|
||||
r := &rhecCatalogImpl{
|
||||
models: testModels,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
want *openapi.CatalogModel
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "get existing model",
|
||||
modelName: "model1",
|
||||
want: &openapi.CatalogModel{
|
||||
Name: "model1",
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Provider: &provider,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "get non-existent model",
|
||||
modelName: "not-exist",
|
||||
want: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := r.GetModel(context.Background(), tt.modelName)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GetModel() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("GetModel() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRhecCatalogImpl_GetArtifacts(t *testing.T) {
|
||||
modelTime := time.Now()
|
||||
createTime := modelTime.Format(time.RFC3339)
|
||||
lastUpdateTime := modelTime.Add(5 * time.Minute).Format(time.RFC3339)
|
||||
sourceId := "rhec"
|
||||
provider := "redhat"
|
||||
artifactCreateTime := modelTime.Add(10 * time.Minute).Format(time.RFC3339)
|
||||
artifactLastUpdateTime := modelTime.Add(15 * time.Minute).Format(time.RFC3339)
|
||||
|
||||
testModels := map[string]*rhecModel{
|
||||
"model1": {
|
||||
CatalogModel: openapi.CatalogModel{
|
||||
Name: "model1",
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Provider: &provider,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
Artifacts: []*openapi.CatalogModelArtifact{
|
||||
{
|
||||
Uri: "test-uri",
|
||||
CreateTimeSinceEpoch: &artifactCreateTime,
|
||||
LastUpdateTimeSinceEpoch: &artifactLastUpdateTime,
|
||||
},
|
||||
},
|
||||
},
|
||||
"model2-no-artifacts": {
|
||||
CatalogModel: openapi.CatalogModel{
|
||||
Name: "model2-no-artifacts",
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Provider: &provider,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
Artifacts: []*openapi.CatalogModelArtifact{},
|
||||
},
|
||||
}
|
||||
|
||||
r := &rhecCatalogImpl{
|
||||
models: testModels,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
want *openapi.CatalogModelArtifactList
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "get artifacts for existing model",
|
||||
modelName: "model1",
|
||||
want: &openapi.CatalogModelArtifactList{
|
||||
Items: []openapi.CatalogModelArtifact{
|
||||
{
|
||||
Uri: "test-uri",
|
||||
CreateTimeSinceEpoch: &artifactCreateTime,
|
||||
LastUpdateTimeSinceEpoch: &artifactLastUpdateTime,
|
||||
},
|
||||
},
|
||||
PageSize: 1,
|
||||
Size: 1,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "get artifacts for model with no artifacts",
|
||||
modelName: "model2-no-artifacts",
|
||||
want: &openapi.CatalogModelArtifactList{
|
||||
Items: []openapi.CatalogModelArtifact{},
|
||||
PageSize: 0,
|
||||
Size: 0,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "get artifacts for non-existent model",
|
||||
modelName: "not-exist",
|
||||
want: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := r.GetArtifacts(context.Background(), tt.modelName)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GetArtifacts() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("GetArtifacts() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRhecCatalogListModels(t *testing.T) {
|
||||
modelTime := time.Now()
|
||||
createTime := modelTime.Format(time.RFC3339)
|
||||
lastUpdateTime := modelTime.Add(5 * time.Minute).Format(time.RFC3339)
|
||||
sourceId := "rhec"
|
||||
provider := "redhat"
|
||||
artifactCreateTime := modelTime.Add(10 * time.Minute).Format(time.RFC3339)
|
||||
artifactLastUpdateTime := modelTime.Add(15 * time.Minute).Format(time.RFC3339)
|
||||
|
||||
testModels := map[string]*rhecModel{
|
||||
"model3": {
|
||||
CatalogModel: openapi.CatalogModel{
|
||||
Name: "model3",
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Provider: &provider,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
Artifacts: []*openapi.CatalogModelArtifact{
|
||||
{
|
||||
Uri: "test-uri",
|
||||
CreateTimeSinceEpoch: &artifactCreateTime,
|
||||
LastUpdateTimeSinceEpoch: &artifactLastUpdateTime,
|
||||
},
|
||||
},
|
||||
},
|
||||
"model1": {
|
||||
CatalogModel: openapi.CatalogModel{
|
||||
Name: "model1",
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Provider: &provider,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
Artifacts: []*openapi.CatalogModelArtifact{
|
||||
{
|
||||
Uri: "test-uri",
|
||||
CreateTimeSinceEpoch: &artifactCreateTime,
|
||||
LastUpdateTimeSinceEpoch: &artifactLastUpdateTime,
|
||||
},
|
||||
},
|
||||
},
|
||||
"model1:v2": {
|
||||
CatalogModel: openapi.CatalogModel{
|
||||
Name: "model1:v2",
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Provider: &provider,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
Artifacts: []*openapi.CatalogModelArtifact{
|
||||
{
|
||||
Uri: "test-uri",
|
||||
CreateTimeSinceEpoch: &artifactCreateTime,
|
||||
LastUpdateTimeSinceEpoch: &artifactLastUpdateTime,
|
||||
},
|
||||
},
|
||||
},
|
||||
"model2": {
|
||||
CatalogModel: openapi.CatalogModel{
|
||||
Name: "model2",
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Provider: &provider,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
Artifacts: []*openapi.CatalogModelArtifact{},
|
||||
},
|
||||
}
|
||||
|
||||
r := &rhecCatalogImpl{
|
||||
models: testModels,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
params ListModelsParams
|
||||
want openapi.CatalogModelList
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "list models and sort order",
|
||||
want: openapi.CatalogModelList{
|
||||
Items: []openapi.CatalogModel{
|
||||
{
|
||||
Name: "model1",
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Provider: &provider,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
{
|
||||
Name: "model1:v2",
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Provider: &provider,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
{
|
||||
Name: "model2",
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Provider: &provider,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
{
|
||||
Name: "model3",
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Provider: &provider,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
},
|
||||
PageSize: 4,
|
||||
Size: 4,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "list models with query and sort order",
|
||||
modelName: "model1",
|
||||
params: ListModelsParams{
|
||||
Query: "model1",
|
||||
OrderBy: model.ORDERBYFIELD_NAME,
|
||||
SortOrder: model.SORTORDER_ASC,
|
||||
},
|
||||
want: openapi.CatalogModelList{
|
||||
Items: []openapi.CatalogModel{
|
||||
{
|
||||
Name: "model1",
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Provider: &provider,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
{
|
||||
Name: "model1:v2",
|
||||
CreateTimeSinceEpoch: &createTime,
|
||||
LastUpdateTimeSinceEpoch: &lastUpdateTime,
|
||||
Provider: &provider,
|
||||
SourceId: &sourceId,
|
||||
},
|
||||
},
|
||||
PageSize: 2,
|
||||
Size: 2,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "get non-existent model",
|
||||
modelName: "not-exist",
|
||||
params: ListModelsParams{
|
||||
Query: "not-exist",
|
||||
OrderBy: model.ORDERBYFIELD_NAME,
|
||||
SortOrder: model.SORTORDER_ASC,
|
||||
},
|
||||
want: openapi.CatalogModelList{Items: []openapi.CatalogModel{}},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := r.ListModels(context.Background(), tt.params)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ListModels() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("ListModels() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsModelExcluded(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
patterns []string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
modelName: "model1:v1",
|
||||
patterns: []string{"model1:v1"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard match",
|
||||
modelName: "model1:v2",
|
||||
patterns: []string{"model1:*"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
modelName: "model2:v1",
|
||||
patterns: []string{"model1:*"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "multiple patterns with match",
|
||||
modelName: "model3:v1",
|
||||
patterns: []string{"model2:*", "model3:v1"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty patterns",
|
||||
modelName: "model1:v1",
|
||||
patterns: []string{},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isModelExcluded(tt.modelName, tt.patterns)
|
||||
if got != tt.want {
|
||||
t.Errorf("isModelExcluded() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
|
||||
source: empty-catalog
|
||||
models: []
|
||||
|
|
@ -2,6 +2,7 @@ catalogs:
|
|||
- name: "Catalog 1"
|
||||
id: catalog1
|
||||
type: yaml
|
||||
enabled: true
|
||||
properties:
|
||||
privateProp11: 54321
|
||||
privateProp12: privateStringValue
|
||||
|
@ -9,7 +10,24 @@ catalogs:
|
|||
- name: "Catalog 2"
|
||||
id: catalog2
|
||||
type: yaml
|
||||
enabled: false
|
||||
properties:
|
||||
privateProp21: 12345
|
||||
privateProp22: privateStringValue2
|
||||
yamlCatalogPath: test-yaml-catalog.yaml
|
||||
- name: "Catalog 3"
|
||||
id: catalog3
|
||||
type: rhec
|
||||
enabled: true
|
||||
properties:
|
||||
models:
|
||||
- rhelai1/modelcar-granite-7b-starter
|
||||
- name: "Catalog 4"
|
||||
id: catalog4
|
||||
type: rhec
|
||||
enabled: true
|
||||
properties:
|
||||
models:
|
||||
- rhelai1/modelcar-granite-7b-starter
|
||||
excludedModels:
|
||||
- rhelai1/modelcar-granite-7b-starter:latest
|
|
@ -0,0 +1,17 @@
|
|||
catalogs:
|
||||
- name: "HuggingFace Test Catalog"
|
||||
id: hf_test
|
||||
type: hf
|
||||
enabled: true
|
||||
properties:
|
||||
apiKey: "hf_test_api_key_here"
|
||||
url: "https://huggingface.co"
|
||||
modelLimit: 50
|
||||
- name: "HuggingFace Invalid Credentials"
|
||||
id: hf_invalid
|
||||
type: hf
|
||||
enabled: false # disabled so it doesn't cause startup failures in tests
|
||||
properties:
|
||||
apiKey: "invalid_key"
|
||||
url: "https://huggingface.co"
|
||||
modelLimit: 10
|
|
@ -0,0 +1,40 @@
|
|||
|
||||
source: test-list-models
|
||||
models:
|
||||
- name: model-alpha
|
||||
description: A model for text generation.
|
||||
tasks: ["text-generation", "nlp"]
|
||||
provider: IBM
|
||||
libraryName: transformers
|
||||
createTimeSinceEpoch: "1678886400000" # March 15, 2023 00:00:00 GMT
|
||||
- name: model-beta
|
||||
description: Another model for image recognition.
|
||||
tasks: ["image-recognition"]
|
||||
provider: Google
|
||||
libraryName: tensorflow
|
||||
createTimeSinceEpoch: "1681564800000" # April 15, 2023 00:00:00 GMT
|
||||
- name: model-gamma
|
||||
description: A specialized model for natural language processing.
|
||||
tasks: ["nlp"]
|
||||
provider: IBM
|
||||
libraryName: pytorch
|
||||
createTimeSinceEpoch: "1675209600000" # February 1, 2023 00:00:00 GMT
|
||||
- name: another-model-alpha
|
||||
description: A different model for text summarization.
|
||||
tasks: ["text-summarization", "nlp"]
|
||||
provider: Microsoft
|
||||
libraryName: huggingface
|
||||
createTimeSinceEpoch: "1684243200000" # May 16, 2023 00:00:00 GMT
|
||||
- name: model-with-no-tasks
|
||||
description: This model has no specific tasks.
|
||||
tasks: []
|
||||
provider: None
|
||||
libraryName: custom
|
||||
createTimeSinceEpoch: "1672531200000" # January 1, 2023 00:00:00 GMT
|
||||
- name: Z-model
|
||||
description: The last model in alphabetical order.
|
||||
tasks: ["optimization"]
|
||||
provider: Oracle
|
||||
libraryName: scikit-learn
|
||||
createTimeSinceEpoch: "1690934400000" # August 2, 2023 00:00:00 GMT
|
||||
|
|
@ -324,8 +324,7 @@ models:
|
|||
createTimeSinceEpoch: "1733514949000"
|
||||
lastUpdateTimeSinceEpoch: "1734637721000"
|
||||
artifacts:
|
||||
- protocol: oci
|
||||
uri: oci://registry.redhat.io/rhelai1/granite-8b-code-base:1.3-1732870892
|
||||
- uri: oci://registry.redhat.io/rhelai1/granite-8b-code-base:1.3-1732870892
|
||||
- name: rhelai1/granite-8b-code-instruct
|
||||
provider: IBM
|
||||
description: |-
|
||||
|
@ -668,5 +667,27 @@ models:
|
|||
createTimeSinceEpoch: "1733514949000"
|
||||
lastUpdateTimeSinceEpoch: "1734637721000"
|
||||
artifacts:
|
||||
- protocol: oci
|
||||
uri: oci://registry.redhat.io/rhelai1/granite-8b-code-instruct:1.3-1732870892
|
||||
- uri: oci://registry.redhat.io/rhelai1/granite-8b-code-instruct:1.3-1732870892
|
||||
createTimeSinceEpoch: "1733514949000"
|
||||
lastUpdateTimeSinceEpoch: "1734637721000"
|
||||
customProperties:
|
||||
foo:
|
||||
string_value: bar
|
||||
baz:
|
||||
string_value: qux
|
||||
- name: model-with-no-artifacts
|
||||
provider: Test
|
||||
description: A model used for testing the GetArtifacts method when no artifacts are present.
|
||||
readme: |
|
||||
# Model with No Artifacts
|
||||
This is a test model.
|
||||
language: ["en"]
|
||||
license: apache-2.0
|
||||
licenseLink: https://www.apache.org/licenses/LICENSE-2.0.txt
|
||||
maturity: Development
|
||||
libraryName: testlib
|
||||
tasks:
|
||||
- test-task
|
||||
createTimeSinceEpoch: "1700000000000"
|
||||
lastUpdateTimeSinceEpoch: "1700000000000"
|
||||
artifacts: []
|
||||
|
|
|
@ -3,21 +3,23 @@ package catalog
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"k8s.io/apimachinery/pkg/util/yaml"
|
||||
|
||||
"github.com/golang/glog"
|
||||
model "github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
)
|
||||
|
||||
type yamlArtifacts struct {
|
||||
Protocol string `yaml:"protocol"`
|
||||
URI string `yaml:"uri"`
|
||||
}
|
||||
|
||||
type yamlModel struct {
|
||||
model.CatalogModel `yaml:",inline"`
|
||||
Artifacts []yamlArtifacts `yaml:"artifacts"`
|
||||
Artifacts []*model.CatalogModelArtifact `yaml:"artifacts"`
|
||||
}
|
||||
|
||||
type yamlCatalog struct {
|
||||
|
@ -26,13 +28,16 @@ type yamlCatalog struct {
|
|||
}
|
||||
|
||||
type yamlCatalogImpl struct {
|
||||
models map[string]*yamlModel
|
||||
source *CatalogSourceConfig
|
||||
modelsLock sync.RWMutex
|
||||
models map[string]*yamlModel
|
||||
}
|
||||
|
||||
var _ CatalogSourceProvider = &yamlCatalogImpl{}
|
||||
|
||||
func (y *yamlCatalogImpl) GetModel(ctx context.Context, name string) (*model.CatalogModel, error) {
|
||||
y.modelsLock.RLock()
|
||||
defer y.modelsLock.RUnlock()
|
||||
|
||||
ym := y.models[name]
|
||||
if ym == nil {
|
||||
return nil, nil
|
||||
|
@ -42,11 +47,135 @@ func (y *yamlCatalogImpl) GetModel(ctx context.Context, name string) (*model.Cat
|
|||
}
|
||||
|
||||
func (y *yamlCatalogImpl) ListModels(ctx context.Context, params ListModelsParams) (model.CatalogModelList, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
y.modelsLock.RLock()
|
||||
defer y.modelsLock.RUnlock()
|
||||
|
||||
var filteredModels []*model.CatalogModel
|
||||
for _, ym := range y.models {
|
||||
cm := ym.CatalogModel
|
||||
if params.Query != "" {
|
||||
query := strings.ToLower(params.Query)
|
||||
// Check if query matches name, description, tasks, provider, or libraryName
|
||||
if !strings.Contains(strings.ToLower(cm.Name), query) &&
|
||||
!strings.Contains(strings.ToLower(cm.GetDescription()), query) &&
|
||||
!strings.Contains(strings.ToLower(cm.GetProvider()), query) &&
|
||||
!strings.Contains(strings.ToLower(cm.GetLibraryName()), query) {
|
||||
|
||||
// Check tasks
|
||||
foundInTasks := false
|
||||
for _, task := range cm.GetTasks() { // Use GetTasks() for nil safety
|
||||
if strings.Contains(strings.ToLower(task), query) {
|
||||
foundInTasks = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundInTasks {
|
||||
continue // Skip if no match in any searchable field
|
||||
}
|
||||
}
|
||||
}
|
||||
filteredModels = append(filteredModels, &cm)
|
||||
}
|
||||
|
||||
// Sort the filtered models
|
||||
sort.Slice(filteredModels, func(i, j int) bool {
|
||||
a := filteredModels[i]
|
||||
b := filteredModels[j]
|
||||
|
||||
var less bool
|
||||
switch params.OrderBy {
|
||||
case model.ORDERBYFIELD_CREATE_TIME:
|
||||
// Convert CreateTimeSinceEpoch (string) to int64 for comparison
|
||||
// Handle potential nil or conversion errors by treating as 0
|
||||
aTime, _ := strconv.ParseInt(a.GetCreateTimeSinceEpoch(), 10, 64)
|
||||
bTime, _ := strconv.ParseInt(b.GetCreateTimeSinceEpoch(), 10, 64)
|
||||
less = aTime < bTime
|
||||
case model.ORDERBYFIELD_LAST_UPDATE_TIME:
|
||||
// Convert LastUpdateTimeSinceEpoch (string) to int64 for comparison
|
||||
// Handle potential nil or conversion errors by treating as 0
|
||||
aTime, _ := strconv.ParseInt(a.GetLastUpdateTimeSinceEpoch(), 10, 64)
|
||||
bTime, _ := strconv.ParseInt(b.GetLastUpdateTimeSinceEpoch(), 10, 64)
|
||||
less = aTime < bTime
|
||||
case model.ORDERBYFIELD_NAME:
|
||||
fallthrough
|
||||
default:
|
||||
// Fallback to name sort if an unknown sort field is provided
|
||||
less = strings.Compare(a.Name, b.Name) < 0
|
||||
}
|
||||
|
||||
if params.SortOrder == model.SORTORDER_DESC {
|
||||
return !less
|
||||
}
|
||||
return less
|
||||
})
|
||||
|
||||
count := len(filteredModels)
|
||||
if count > math.MaxInt32 {
|
||||
count = math.MaxInt32
|
||||
}
|
||||
|
||||
list := model.CatalogModelList{
|
||||
Items: make([]model.CatalogModel, count),
|
||||
PageSize: int32(count),
|
||||
Size: int32(count),
|
||||
}
|
||||
for i := range list.Items {
|
||||
list.Items[i] = *filteredModels[i]
|
||||
}
|
||||
return list, nil // Return the struct value directly
|
||||
}
|
||||
|
||||
// TODO start background thread to watch file
|
||||
func (y *yamlCatalogImpl) GetArtifacts(ctx context.Context, name string) (*model.CatalogModelArtifactList, error) {
|
||||
y.modelsLock.RLock()
|
||||
defer y.modelsLock.RUnlock()
|
||||
|
||||
ym := y.models[name]
|
||||
if ym == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
count := len(ym.Artifacts)
|
||||
if count > math.MaxInt32 {
|
||||
count = math.MaxInt32
|
||||
}
|
||||
|
||||
list := model.CatalogModelArtifactList{
|
||||
Items: make([]model.CatalogModelArtifact, count),
|
||||
PageSize: int32(count),
|
||||
Size: int32(count),
|
||||
}
|
||||
for i := range list.Items {
|
||||
list.Items[i] = *ym.Artifacts[i]
|
||||
}
|
||||
return &list, nil
|
||||
}
|
||||
|
||||
func (y *yamlCatalogImpl) load(path string, excludedModelsList []string) error {
|
||||
bytes, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read %s file: %v", yamlCatalogPath, err)
|
||||
}
|
||||
|
||||
var contents yamlCatalog
|
||||
if err = yaml.UnmarshalStrict(bytes, &contents); err != nil {
|
||||
return fmt.Errorf("failed to parse %s file: %v", yamlCatalogPath, err)
|
||||
}
|
||||
|
||||
models := make(map[string]*yamlModel)
|
||||
for i := range contents.Models {
|
||||
modelName := contents.Models[i].Name
|
||||
if isModelExcluded(modelName, excludedModelsList) {
|
||||
continue
|
||||
}
|
||||
models[modelName] = &contents.Models[i]
|
||||
}
|
||||
|
||||
y.modelsLock.Lock()
|
||||
defer y.modelsLock.Unlock()
|
||||
y.models = models
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const yamlCatalogPath = "yamlCatalogPath"
|
||||
|
||||
|
@ -55,30 +184,54 @@ func newYamlCatalog(source *CatalogSourceConfig) (CatalogSourceProvider, error)
|
|||
if !exists || yamlModelFile == "" {
|
||||
return nil, fmt.Errorf("missing %s string property", yamlCatalogPath)
|
||||
}
|
||||
bytes, err := os.ReadFile(yamlModelFile)
|
||||
|
||||
yamlModelFile, err := filepath.Abs(yamlModelFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read %s file: %v", yamlCatalogPath, err)
|
||||
return nil, fmt.Errorf("abs: %w", err)
|
||||
}
|
||||
|
||||
var contents yamlCatalog
|
||||
if err = yaml.UnmarshalStrict(bytes, &contents); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %s file: %v", yamlCatalogPath, err)
|
||||
// Excluded models is an optional source property.
|
||||
var excludedModels []string
|
||||
if excludedModelsData, ok := source.Properties["excludedModels"]; ok {
|
||||
excludedModelsList, ok := excludedModelsData.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("'excludedModels' property should be a list")
|
||||
}
|
||||
excludedModels = make([]string, len(excludedModelsList))
|
||||
for i, v := range excludedModelsList {
|
||||
excludedModels[i], ok = v.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid entry in 'excludedModels' list, expected a string")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// override catalog name from Yaml Catalog File if set
|
||||
if source.Name != "" {
|
||||
source.Name = contents.Source
|
||||
p := &yamlCatalogImpl{
|
||||
models: make(map[string]*yamlModel),
|
||||
}
|
||||
err = p.load(yamlModelFile, excludedModels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
models := make(map[string]*yamlModel, len(contents.Models))
|
||||
for i := range contents.Models {
|
||||
models[contents.Models[i].Name] = &contents.Models[i]
|
||||
}
|
||||
go func() {
|
||||
changes, err := getMonitor().Path(yamlModelFile)
|
||||
if err != nil {
|
||||
glog.Errorf("unable to watch YAML catalog file: %v", err)
|
||||
// Not fatal, we just won't get automatic updates.
|
||||
}
|
||||
|
||||
return &yamlCatalogImpl{
|
||||
models: models,
|
||||
source: source,
|
||||
}, nil
|
||||
for range changes {
|
||||
glog.Infof("Reloading YAML catalog %s", yamlModelFile)
|
||||
|
||||
err = p.load(yamlModelFile, excludedModels)
|
||||
if err != nil {
|
||||
glog.Errorf("unable to load YAML catalog: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
|
||||
model "github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -29,14 +30,178 @@ func TestYAMLCatalogGetModel(t *testing.T) {
|
|||
assert.Nil(notFound)
|
||||
}
|
||||
|
||||
func TestYAMLCatalogGetArtifacts(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
provider := testYAMLProvider(t, "testdata/test-yaml-catalog.yaml")
|
||||
|
||||
// Test case 1: Model with artifacts
|
||||
artifacts, err := provider.GetArtifacts(context.Background(), "rhelai1/granite-8b-code-base")
|
||||
if assert.NoError(err) {
|
||||
assert.NotNil(artifacts)
|
||||
assert.Equal(int32(1), artifacts.Size)
|
||||
assert.Equal(int32(1), artifacts.PageSize)
|
||||
assert.Len(artifacts.Items, 1)
|
||||
assert.Equal("oci://registry.redhat.io/rhelai1/granite-8b-code-base:1.3-1732870892", artifacts.Items[0].Uri)
|
||||
}
|
||||
|
||||
// Test case 2: Model with no artifacts
|
||||
noArtifactsModel, err := provider.GetArtifacts(context.Background(), "model-with-no-artifacts")
|
||||
if assert.NoError(err) {
|
||||
assert.NotNil(noArtifactsModel)
|
||||
assert.Equal(int32(0), noArtifactsModel.Size)
|
||||
assert.Equal(int32(0), noArtifactsModel.PageSize)
|
||||
assert.Len(noArtifactsModel.Items, 0)
|
||||
}
|
||||
|
||||
// Test case 3: Model not found
|
||||
notFoundArtifacts, err := provider.GetArtifacts(context.Background(), "non-existent-model")
|
||||
assert.NoError(err)
|
||||
assert.Nil(notFoundArtifacts)
|
||||
}
|
||||
|
||||
func TestYAMLCatalogListModels(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
provider := testYAMLProvider(t, "testdata/test-list-models-catalog.yaml")
|
||||
ctx := context.Background()
|
||||
|
||||
// Test case 1: List all models, default sort (by name ascending)
|
||||
models, err := provider.ListModels(ctx, ListModelsParams{})
|
||||
if assert.NoError(err) {
|
||||
assert.NotNil(models)
|
||||
assert.Equal(int32(6), models.Size)
|
||||
assert.Equal(int32(6), models.PageSize)
|
||||
assert.Len(models.Items, 6)
|
||||
assert.Equal("Z-model", models.Items[0].Name) // Z-model should be first due to string comparison for alphabetical sort
|
||||
assert.Equal("another-model-alpha", models.Items[1].Name)
|
||||
assert.Equal("model-alpha", models.Items[2].Name)
|
||||
assert.Equal("model-beta", models.Items[3].Name)
|
||||
assert.Equal("model-gamma", models.Items[4].Name)
|
||||
assert.Equal("model-with-no-tasks", models.Items[5].Name)
|
||||
}
|
||||
|
||||
// Test case 2: List all models, sort by name ascending
|
||||
models, err = provider.ListModels(ctx, ListModelsParams{OrderBy: model.ORDERBYFIELD_NAME, SortOrder: model.SORTORDER_ASC})
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(int32(6), models.Size)
|
||||
assert.Equal("Z-model", models.Items[0].Name)
|
||||
assert.Equal("another-model-alpha", models.Items[1].Name)
|
||||
}
|
||||
|
||||
// Test case 3: List all models, sort by name descending
|
||||
models, err = provider.ListModels(ctx, ListModelsParams{OrderBy: model.ORDERBYFIELD_NAME, SortOrder: model.SORTORDER_DESC})
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(int32(6), models.Size)
|
||||
assert.Equal("model-with-no-tasks", models.Items[0].Name)
|
||||
assert.Equal("model-gamma", models.Items[1].Name)
|
||||
}
|
||||
|
||||
// Test case 4: List all models, sort by created (CreateTimeSinceEpoch) ascending
|
||||
models, err = provider.ListModels(ctx, ListModelsParams{OrderBy: model.ORDERBYFIELD_CREATE_TIME, SortOrder: model.SORTORDER_ASC})
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(int32(6), models.Size)
|
||||
assert.Equal("model-with-no-tasks", models.Items[0].Name) // Jan 1, 2023
|
||||
assert.Equal("model-gamma", models.Items[1].Name) // Feb 1, 2023
|
||||
}
|
||||
|
||||
// Test case 5: List all models, sort by published (CreateTimeSinceEpoch) descending
|
||||
models, err = provider.ListModels(ctx, ListModelsParams{OrderBy: model.ORDERBYFIELD_CREATE_TIME, SortOrder: model.SORTORDER_DESC})
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(int32(6), models.Size)
|
||||
assert.Equal("Z-model", models.Items[0].Name) // Aug 2, 2023
|
||||
assert.Equal("another-model-alpha", models.Items[1].Name) // May 16, 2023
|
||||
}
|
||||
|
||||
// Test case 6: Filter by query "model" (should match all 6 models)
|
||||
models, err = provider.ListModels(ctx, ListModelsParams{Query: "model"})
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(int32(6), models.Size)
|
||||
assert.Equal("Z-model", models.Items[0].Name)
|
||||
assert.Equal("another-model-alpha", models.Items[1].Name)
|
||||
assert.Equal("model-alpha", models.Items[2].Name)
|
||||
assert.Equal("model-beta", models.Items[3].Name)
|
||||
assert.Equal("model-gamma", models.Items[4].Name)
|
||||
assert.Equal("model-with-no-tasks", models.Items[5].Name)
|
||||
}
|
||||
|
||||
// Test case 7: Filter by query "text" (should match model-alpha, another-model-alpha)
|
||||
models, err = provider.ListModels(ctx, ListModelsParams{Query: "text"})
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(int32(2), models.Size)
|
||||
assert.Equal("another-model-alpha", models.Items[0].Name) // Alphabetical order
|
||||
assert.Equal("model-alpha", models.Items[1].Name)
|
||||
}
|
||||
|
||||
// Test case 8: Filter by query "nlp" (should match model-alpha, model-gamma, another-model-alpha)
|
||||
models, err = provider.ListModels(ctx, ListModelsParams{Query: "nlp"})
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(int32(3), models.Size)
|
||||
assert.Equal("another-model-alpha", models.Items[0].Name)
|
||||
assert.Equal("model-alpha", models.Items[1].Name)
|
||||
assert.Equal("model-gamma", models.Items[2].Name)
|
||||
}
|
||||
|
||||
// Test case 9: Filter by query "IBM" (should match model-alpha, model-gamma)
|
||||
models, err = provider.ListModels(ctx, ListModelsParams{Query: "IBM"})
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(int32(2), models.Size)
|
||||
assert.Equal("model-alpha", models.Items[0].Name)
|
||||
assert.Equal("model-gamma", models.Items[1].Name)
|
||||
}
|
||||
|
||||
// Test case 10: Filter by query "transformers" (should match model-alpha)
|
||||
models, err = provider.ListModels(ctx, ListModelsParams{Query: "transformers"})
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(int32(1), models.Size)
|
||||
assert.Equal("model-alpha", models.Items[0].Name)
|
||||
}
|
||||
|
||||
// Test case 11: Filter by query "nonexistent" (should return empty list)
|
||||
models, err = provider.ListModels(ctx, ListModelsParams{Query: "nonexistent"})
|
||||
assert.NoError(err)
|
||||
assert.NotNil(models)
|
||||
assert.Equal(int32(0), models.Size)
|
||||
assert.Equal(int32(0), models.PageSize)
|
||||
assert.Len(models.Items, 0)
|
||||
|
||||
// Test case 12: Empty catalog
|
||||
emptyProvider := testYAMLProvider(t, "testdata/empty-catalog.yaml") // Assuming an empty-catalog.yaml exists or will be created
|
||||
emptyModels, err := emptyProvider.ListModels(ctx, ListModelsParams{})
|
||||
assert.NoError(err)
|
||||
assert.NotNil(emptyModels)
|
||||
assert.Equal(int32(0), emptyModels.Size)
|
||||
assert.Equal(int32(0), emptyModels.PageSize)
|
||||
assert.Len(emptyModels.Items, 0)
|
||||
|
||||
// Test case 13: Test with excluded models
|
||||
excludedProvider := testYAMLProviderWithExclusions(t, "testdata/test-list-models-catalog.yaml", []any{
|
||||
"model-alpha",
|
||||
})
|
||||
excludedModels, err := excludedProvider.ListModels(ctx, ListModelsParams{})
|
||||
if assert.NoError(err) {
|
||||
assert.NotNil(excludedModels)
|
||||
assert.Equal(int32(5), excludedModels.Size)
|
||||
for _, m := range excludedModels.Items {
|
||||
assert.NotEqual("model-alpha", m.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testYAMLProvider(t *testing.T, path string) CatalogSourceProvider {
|
||||
return testYAMLProviderWithExclusions(t, path, nil)
|
||||
}
|
||||
|
||||
func testYAMLProviderWithExclusions(t *testing.T, path string, excludedModels []any) CatalogSourceProvider {
|
||||
properties := map[string]any{
|
||||
yamlCatalogPath: path,
|
||||
}
|
||||
if excludedModels != nil {
|
||||
properties["excludedModels"] = excludedModels
|
||||
}
|
||||
provider, err := newYamlCatalog(&CatalogSourceConfig{
|
||||
Properties: map[string]any{
|
||||
yamlCatalogPath: path,
|
||||
},
|
||||
Properties: properties,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("newYamlCatalog(%s) failed: %v", path, err)
|
||||
t.Fatalf("newYamlCatalog(%s) with exclusions failed: %v", path, err)
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ error.go
|
|||
helpers.go
|
||||
impl.go
|
||||
logger.go
|
||||
model_artifact_type_query_param.go
|
||||
model_base_model.go
|
||||
model_base_resource_dates.go
|
||||
model_base_resource_list.go
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"math"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/kubeflow/model-registry/catalog/internal/catalog"
|
||||
|
@ -18,15 +17,53 @@ import (
|
|||
// This service should implement the business logic for every endpoint for the ModelCatalogServiceAPI s.coreApi.
|
||||
// Include any external packages or services that will be required by this service.
|
||||
type ModelCatalogServiceAPIService struct {
|
||||
sources map[string]catalog.CatalogSource
|
||||
sources *catalog.SourceCollection
|
||||
}
|
||||
|
||||
func (m *ModelCatalogServiceAPIService) GetAllModelArtifacts(context.Context, string, string) (ImplResponse, error) {
|
||||
return Response(http.StatusNotImplemented, "Not implemented"), nil
|
||||
// GetAllModelArtifacts retrieves all model artifacts for a given model from the specified source.
|
||||
func (m *ModelCatalogServiceAPIService) GetAllModelArtifacts(ctx context.Context, sourceID string, name string) (ImplResponse, error) {
|
||||
source, ok := m.sources.Get(sourceID)
|
||||
if !ok {
|
||||
return notFound("Unknown source"), nil
|
||||
}
|
||||
|
||||
artifacts, err := source.Provider.GetArtifacts(ctx, name)
|
||||
if err != nil {
|
||||
return Response(http.StatusInternalServerError, err), err
|
||||
}
|
||||
|
||||
return Response(http.StatusOK, artifacts), nil
|
||||
}
|
||||
|
||||
func (m *ModelCatalogServiceAPIService) FindModels(ctx context.Context, source string, q string, pageSize string, orderBy model.OrderByField, sortOder model.SortOrder, nextPageToken string) (ImplResponse, error) {
|
||||
return Response(http.StatusNotImplemented, "Not implemented"), nil
|
||||
func (m *ModelCatalogServiceAPIService) FindModels(ctx context.Context, sourceID string, q string, pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (ImplResponse, error) {
|
||||
source, ok := m.sources.Get(sourceID)
|
||||
if !ok {
|
||||
return notFound("Unknown source"), errors.New("Unknown source")
|
||||
}
|
||||
|
||||
p, err := newPaginator[model.CatalogModel](pageSize, orderBy, sortOrder, nextPageToken)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusBadRequest, err), err
|
||||
}
|
||||
|
||||
listModelsParams := catalog.ListModelsParams{
|
||||
Query: q,
|
||||
OrderBy: p.OrderBy,
|
||||
SortOrder: p.SortOrder,
|
||||
}
|
||||
|
||||
models, err := source.Provider.ListModels(ctx, listModelsParams)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusInternalServerError, err), err
|
||||
}
|
||||
|
||||
page, next := p.Paginate(models.Items)
|
||||
|
||||
models.Items = page
|
||||
models.PageSize = p.PageSize
|
||||
models.NextPageToken = next.Token()
|
||||
|
||||
return Response(http.StatusOK, models), nil
|
||||
}
|
||||
|
||||
func (m *ModelCatalogServiceAPIService) GetModel(ctx context.Context, sourceID string, name string) (ImplResponse, error) {
|
||||
|
@ -34,7 +71,7 @@ func (m *ModelCatalogServiceAPIService) GetModel(ctx context.Context, sourceID s
|
|||
return m.GetAllModelArtifacts(ctx, sourceID, name)
|
||||
}
|
||||
|
||||
source, ok := m.sources[sourceID]
|
||||
source, ok := m.sources.Get(sourceID)
|
||||
if !ok {
|
||||
return notFound("Unknown source"), nil
|
||||
}
|
||||
|
@ -51,28 +88,22 @@ func (m *ModelCatalogServiceAPIService) GetModel(ctx context.Context, sourceID s
|
|||
}
|
||||
|
||||
func (m *ModelCatalogServiceAPIService) FindSources(ctx context.Context, name string, strPageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (ImplResponse, error) {
|
||||
// TODO: Implement real pagination in here by reusing the nextPageToken
|
||||
// code from https://github.com/kubeflow/model-registry/pull/1205.
|
||||
|
||||
if len(m.sources) > math.MaxInt32 {
|
||||
sources := m.sources.All()
|
||||
if len(sources) > math.MaxInt32 {
|
||||
err := errors.New("too many registered models")
|
||||
return ErrorResponse(http.StatusInternalServerError, err), err
|
||||
}
|
||||
|
||||
var pageSize int32 = 10
|
||||
if strPageSize != "" {
|
||||
pageSize64, err := strconv.ParseInt(strPageSize, 10, 32)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusBadRequest, err), err
|
||||
}
|
||||
pageSize = int32(pageSize64)
|
||||
paginator, err := newPaginator[model.CatalogSource](strPageSize, orderBy, sortOrder, nextPageToken)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusBadRequest, err), err
|
||||
}
|
||||
|
||||
items := make([]model.CatalogSource, 0, len(m.sources))
|
||||
items := make([]model.CatalogSource, 0, len(sources))
|
||||
|
||||
name = strings.ToLower(name)
|
||||
|
||||
for _, v := range m.sources {
|
||||
for _, v := range sources {
|
||||
if !strings.Contains(strings.ToLower(v.Metadata.Name), name) {
|
||||
continue
|
||||
}
|
||||
|
@ -87,15 +118,14 @@ func (m *ModelCatalogServiceAPIService) FindSources(ctx context.Context, name st
|
|||
slices.SortStableFunc(items, cmpFunc)
|
||||
|
||||
total := int32(len(items))
|
||||
if total > pageSize {
|
||||
items = items[:pageSize]
|
||||
}
|
||||
|
||||
pagedItems, next := paginator.Paginate(items)
|
||||
|
||||
res := model.CatalogSourceList{
|
||||
PageSize: pageSize,
|
||||
Items: items,
|
||||
PageSize: paginator.PageSize,
|
||||
Items: pagedItems,
|
||||
Size: total,
|
||||
NextPageToken: "",
|
||||
NextPageToken: next.Token(),
|
||||
}
|
||||
return Response(http.StatusOK, res), nil
|
||||
}
|
||||
|
@ -128,7 +158,7 @@ func genCatalogCmpFunc(orderBy model.OrderByField, sortOrder model.SortOrder) (f
|
|||
var _ ModelCatalogServiceAPIServicer = &ModelCatalogServiceAPIService{}
|
||||
|
||||
// NewModelCatalogServiceAPIService creates a default api service
|
||||
func NewModelCatalogServiceAPIService(sources map[string]catalog.CatalogSource) ModelCatalogServiceAPIServicer {
|
||||
func NewModelCatalogServiceAPIService(sources *catalog.SourceCollection) ModelCatalogServiceAPIServicer {
|
||||
return &ModelCatalogServiceAPIService{
|
||||
sources: sources,
|
||||
}
|
||||
|
|
|
@ -3,8 +3,11 @@ package openapi
|
|||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kubeflow/model-registry/catalog/internal/catalog"
|
||||
model "github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
|
@ -12,8 +15,306 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// timeToMillisStringPointer converts time.Time to *string representing milliseconds since epoch.
|
||||
func timeToMillisStringPointer(t time.Time) *string {
|
||||
s := strconv.FormatInt(t.UnixMilli(), 10)
|
||||
return &s
|
||||
}
|
||||
|
||||
// pointerOrDefault returns the value pointed to by p, or def if p is nil.
|
||||
func pointerOrDefault(p *string, def string) string {
|
||||
if p == nil {
|
||||
return def
|
||||
}
|
||||
return *p
|
||||
}
|
||||
|
||||
func TestFindModels(t *testing.T) {
|
||||
// Define common models for testing
|
||||
time1 := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
time2 := time.Date(2023, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
time3 := time.Date(2023, 1, 3, 0, 0, 0, 0, time.UTC)
|
||||
time4 := time.Date(2023, 1, 4, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
// Updated model definitions to match OpenAPI schema (no direct Id or Published, use Name, CreateTime, LastUpdateTime)
|
||||
modelA := &model.CatalogModel{Name: "Model A", CreateTimeSinceEpoch: timeToMillisStringPointer(time1), LastUpdateTimeSinceEpoch: timeToMillisStringPointer(time4)}
|
||||
modelB := &model.CatalogModel{Name: "Model B", CreateTimeSinceEpoch: timeToMillisStringPointer(time2), LastUpdateTimeSinceEpoch: timeToMillisStringPointer(time3)}
|
||||
modelC := &model.CatalogModel{Name: "Another Model C", CreateTimeSinceEpoch: timeToMillisStringPointer(time3), LastUpdateTimeSinceEpoch: timeToMillisStringPointer(time2)}
|
||||
modelD := &model.CatalogModel{Name: "My Model D", CreateTimeSinceEpoch: timeToMillisStringPointer(time4), LastUpdateTimeSinceEpoch: timeToMillisStringPointer(time1)}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
sourceID string
|
||||
mockModels map[string]*model.CatalogModel
|
||||
q string
|
||||
pageSize string
|
||||
orderBy model.OrderByField
|
||||
sortOrder model.SortOrder
|
||||
nextPageToken string
|
||||
expectedStatus int
|
||||
expectedModelList *model.CatalogModelList
|
||||
}{
|
||||
{
|
||||
name: "Successful query with no filters",
|
||||
sourceID: "source1",
|
||||
mockModels: map[string]*model.CatalogModel{
|
||||
"modelA": modelA, "modelB": modelB, "modelC": modelC, "modelD": modelD,
|
||||
},
|
||||
q: "",
|
||||
pageSize: "10",
|
||||
orderBy: model.ORDERBYFIELD_NAME,
|
||||
sortOrder: model.SORTORDER_ASC,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedModelList: &model.CatalogModelList{
|
||||
Items: []model.CatalogModel{*modelC, *modelA, *modelB, *modelD}, // Sorted by Name ASC: Another Model C, Model A, Model B, My Model D
|
||||
Size: 4,
|
||||
PageSize: 10, // Default page size
|
||||
NextPageToken: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Filter by query 'Model'",
|
||||
sourceID: "source1",
|
||||
mockModels: map[string]*model.CatalogModel{
|
||||
"modelA": modelA, "modelB": modelB, "modelC": modelC, "modelD": modelD,
|
||||
},
|
||||
q: "Model",
|
||||
pageSize: "10",
|
||||
orderBy: model.ORDERBYFIELD_NAME,
|
||||
sortOrder: model.SORTORDER_ASC,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedModelList: &model.CatalogModelList{
|
||||
Items: []model.CatalogModel{*modelC, *modelA, *modelB, *modelD}, // Corrected to include modelC and sorted by name ASC
|
||||
Size: 4, // Corrected from 3 to 4
|
||||
PageSize: 10,
|
||||
NextPageToken: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Filter by query 'model' (case insensitive)",
|
||||
sourceID: "source1",
|
||||
mockModels: map[string]*model.CatalogModel{
|
||||
"modelA": modelA, "modelB": modelB, "modelC": modelC, "modelD": modelD,
|
||||
},
|
||||
q: "model",
|
||||
pageSize: "10",
|
||||
orderBy: model.ORDERBYFIELD_NAME,
|
||||
sortOrder: model.SORTORDER_ASC,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedModelList: &model.CatalogModelList{
|
||||
Items: []model.CatalogModel{*modelC, *modelA, *modelB, *modelD}, // Corrected to include modelC and sorted by name ASC
|
||||
Size: 4, // Corrected from 3 to 4
|
||||
PageSize: 10,
|
||||
NextPageToken: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Page size limit",
|
||||
sourceID: "source1",
|
||||
mockModels: map[string]*model.CatalogModel{
|
||||
"modelA": modelA, "modelB": modelB, "modelC": modelC, "modelD": modelD,
|
||||
},
|
||||
q: "",
|
||||
pageSize: "2",
|
||||
orderBy: model.ORDERBYFIELD_NAME,
|
||||
sortOrder: model.SORTORDER_ASC,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedModelList: &model.CatalogModelList{
|
||||
Items: []model.CatalogModel{*modelC, *modelA}, // First 2 after sorting by Name ASC
|
||||
Size: 4, // Total size remains 4
|
||||
PageSize: 2,
|
||||
NextPageToken: (&stringCursor{Value: "Model A", ID: "Model A"}).String(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Sort by ID Descending (mocked as Name Descending)",
|
||||
sourceID: "source1",
|
||||
mockModels: map[string]*model.CatalogModel{
|
||||
"modelA": modelA, "modelB": modelB, "modelC": modelC, "modelD": modelD,
|
||||
},
|
||||
q: "",
|
||||
pageSize: "10",
|
||||
orderBy: model.ORDERBYFIELD_ID,
|
||||
sortOrder: model.SORTORDER_DESC,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedModelList: &model.CatalogModelList{
|
||||
Items: []model.CatalogModel{*modelD, *modelB, *modelA, *modelC}, // Sorted by Name DESC
|
||||
Size: 4,
|
||||
PageSize: 10,
|
||||
NextPageToken: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Sort by CreateTime Ascending",
|
||||
sourceID: "source1",
|
||||
mockModels: map[string]*model.CatalogModel{
|
||||
"modelA": modelA, "modelB": modelB, "modelC": modelC, "modelD": modelD,
|
||||
},
|
||||
q: "",
|
||||
pageSize: "10",
|
||||
orderBy: model.ORDERBYFIELD_CREATE_TIME,
|
||||
sortOrder: model.SORTORDER_ASC,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedModelList: &model.CatalogModelList{
|
||||
Items: []model.CatalogModel{*modelA, *modelB, *modelC, *modelD}, // Sorted by CreateTime ASC
|
||||
Size: 4,
|
||||
PageSize: 10,
|
||||
NextPageToken: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Sort by LastUpdateTime Descending",
|
||||
sourceID: "source1",
|
||||
mockModels: map[string]*model.CatalogModel{
|
||||
"modelA": modelA, "modelB": modelB, "modelC": modelC, "modelD": modelD,
|
||||
},
|
||||
q: "",
|
||||
pageSize: "10",
|
||||
orderBy: model.ORDERBYFIELD_LAST_UPDATE_TIME,
|
||||
sortOrder: model.SORTORDER_DESC,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedModelList: &model.CatalogModelList{
|
||||
Items: []model.CatalogModel{*modelA, *modelB, *modelC, *modelD}, // Corrected to be sorted by LastUpdateTime DESC (modelA has latest time4, modelD has earliest time1)
|
||||
Size: 4,
|
||||
PageSize: 10,
|
||||
NextPageToken: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid source ID",
|
||||
sourceID: "unknown-source",
|
||||
mockModels: map[string]*model.CatalogModel{},
|
||||
q: "",
|
||||
pageSize: "10",
|
||||
orderBy: model.ORDERBYFIELD_ID,
|
||||
sortOrder: model.SORTORDER_ASC,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectedModelList: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid pageSize string",
|
||||
sourceID: "source1",
|
||||
mockModels: map[string]*model.CatalogModel{
|
||||
"modelA": modelA,
|
||||
},
|
||||
q: "",
|
||||
pageSize: "abc",
|
||||
orderBy: model.ORDERBYFIELD_ID,
|
||||
sortOrder: model.SORTORDER_ASC,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedModelList: nil,
|
||||
},
|
||||
{
|
||||
name: "Unsupported orderBy field",
|
||||
sourceID: "source1",
|
||||
mockModels: map[string]*model.CatalogModel{
|
||||
"modelA": modelA,
|
||||
},
|
||||
q: "",
|
||||
pageSize: "10",
|
||||
orderBy: "UNSUPPORTED_FIELD",
|
||||
sortOrder: model.SORTORDER_ASC,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedModelList: nil,
|
||||
},
|
||||
{
|
||||
name: "Unsupported sortOrder field",
|
||||
sourceID: "source1",
|
||||
mockModels: map[string]*model.CatalogModel{
|
||||
"modelA": modelA,
|
||||
},
|
||||
q: "",
|
||||
pageSize: "10",
|
||||
orderBy: model.ORDERBYFIELD_ID,
|
||||
sortOrder: "UNSUPPORTED_ORDER",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedModelList: nil,
|
||||
},
|
||||
{
|
||||
name: "Empty models in source",
|
||||
sourceID: "source1",
|
||||
mockModels: map[string]*model.CatalogModel{},
|
||||
q: "",
|
||||
pageSize: "10",
|
||||
orderBy: model.ORDERBYFIELD_ID,
|
||||
sortOrder: model.SORTORDER_ASC,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedModelList: &model.CatalogModelList{
|
||||
Items: []model.CatalogModel{},
|
||||
Size: 0,
|
||||
PageSize: 10,
|
||||
NextPageToken: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Default sort (ID ascending) and default page size",
|
||||
sourceID: "source1",
|
||||
mockModels: map[string]*model.CatalogModel{
|
||||
"modelB": modelB, "modelA": modelA, "modelD": modelD, "modelC": modelC,
|
||||
},
|
||||
q: "",
|
||||
pageSize: "", // Default page size
|
||||
orderBy: "", // Default order by ID
|
||||
sortOrder: "", // Default sort order ASC
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedModelList: &model.CatalogModelList{
|
||||
Items: []model.CatalogModel{*modelC, *modelA, *modelB, *modelD}, // Sorted by Name ASC (as ID is mocked to use Name)
|
||||
Size: 4,
|
||||
PageSize: 10, // Default page size
|
||||
NextPageToken: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create mock source collection
|
||||
sources := catalog.NewSourceCollection(map[string]catalog.CatalogSource{
|
||||
"source1": {
|
||||
Metadata: model.CatalogSource{Id: "source1", Name: "Test Source 1"},
|
||||
Provider: &mockModelProvider{
|
||||
models: tc.mockModels,
|
||||
},
|
||||
},
|
||||
})
|
||||
service := NewModelCatalogServiceAPIService(sources)
|
||||
|
||||
resp, err := service.FindModels(
|
||||
context.Background(),
|
||||
tc.sourceID,
|
||||
tc.q,
|
||||
tc.pageSize,
|
||||
tc.orderBy,
|
||||
tc.sortOrder,
|
||||
tc.nextPageToken,
|
||||
)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, resp.Code)
|
||||
|
||||
if tc.expectedStatus != http.StatusOK {
|
||||
assert.NotNil(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, resp.Body)
|
||||
modelList, ok := resp.Body.(model.CatalogModelList)
|
||||
require.True(t, ok, "Response body should be a CatalogModelList")
|
||||
|
||||
assert.Equal(t, tc.expectedModelList.Size, modelList.Size)
|
||||
assert.Equal(t, tc.expectedModelList.PageSize, modelList.PageSize)
|
||||
if !assert.Equal(t, tc.expectedModelList.NextPageToken, modelList.NextPageToken) && tc.expectedModelList.NextPageToken != "" {
|
||||
assert.Equal(t, decodeStringCursor(tc.expectedModelList.NextPageToken), decodeStringCursor(modelList.NextPageToken))
|
||||
}
|
||||
|
||||
// Deep equality check for items
|
||||
assert.Equal(t, tc.expectedModelList.Items, modelList.Items)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSources(t *testing.T) {
|
||||
// Setup test cases
|
||||
trueValue := true
|
||||
testCases := []struct {
|
||||
name string
|
||||
catalogs map[string]catalog.CatalogSource
|
||||
|
@ -42,7 +343,7 @@ func TestFindSources(t *testing.T) {
|
|||
name: "Single catalog",
|
||||
catalogs: map[string]catalog.CatalogSource{
|
||||
"catalog1": {
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "Test Catalog 1"},
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "Test Catalog 1", Enabled: &trueValue},
|
||||
},
|
||||
},
|
||||
nameFilter: "",
|
||||
|
@ -57,13 +358,13 @@ func TestFindSources(t *testing.T) {
|
|||
name: "Multiple catalogs with no filter",
|
||||
catalogs: map[string]catalog.CatalogSource{
|
||||
"catalog1": {
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "Test Catalog 1"},
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "Test Catalog 1", Enabled: &trueValue},
|
||||
},
|
||||
"catalog2": {
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "Test Catalog 2"},
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "Test Catalog 2", Enabled: &trueValue},
|
||||
},
|
||||
"catalog3": {
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "Another Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "Another Catalog", Enabled: &trueValue},
|
||||
},
|
||||
},
|
||||
nameFilter: "",
|
||||
|
@ -78,13 +379,13 @@ func TestFindSources(t *testing.T) {
|
|||
name: "Filter by name",
|
||||
catalogs: map[string]catalog.CatalogSource{
|
||||
"catalog1": {
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "Test Catalog 1"},
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "Test Catalog 1", Enabled: &trueValue},
|
||||
},
|
||||
"catalog2": {
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "Test Catalog 2"},
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "Test Catalog 2", Enabled: &trueValue},
|
||||
},
|
||||
"catalog3": {
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "Another Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "Another Catalog", Enabled: &trueValue},
|
||||
},
|
||||
},
|
||||
nameFilter: "Test",
|
||||
|
@ -99,13 +400,13 @@ func TestFindSources(t *testing.T) {
|
|||
name: "Filter by name case insensitive",
|
||||
catalogs: map[string]catalog.CatalogSource{
|
||||
"catalog1": {
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "Test Catalog 1"},
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "Test Catalog 1", Enabled: &trueValue},
|
||||
},
|
||||
"catalog2": {
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "Test Catalog 2"},
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "Test Catalog 2", Enabled: &trueValue},
|
||||
},
|
||||
"catalog3": {
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "Another Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "Another Catalog", Enabled: &trueValue},
|
||||
},
|
||||
},
|
||||
nameFilter: "test",
|
||||
|
@ -120,13 +421,13 @@ func TestFindSources(t *testing.T) {
|
|||
name: "Pagination - limit results",
|
||||
catalogs: map[string]catalog.CatalogSource{
|
||||
"catalog1": {
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "Test Catalog 1"},
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "Test Catalog 1", Enabled: &trueValue},
|
||||
},
|
||||
"catalog2": {
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "Test Catalog 2"},
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "Test Catalog 2", Enabled: &trueValue},
|
||||
},
|
||||
"catalog3": {
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "Another Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "Another Catalog", Enabled: &trueValue},
|
||||
},
|
||||
},
|
||||
nameFilter: "",
|
||||
|
@ -141,10 +442,10 @@ func TestFindSources(t *testing.T) {
|
|||
name: "Default page size",
|
||||
catalogs: map[string]catalog.CatalogSource{
|
||||
"catalog1": {
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "Test Catalog 1"},
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "Test Catalog 1", Enabled: &trueValue},
|
||||
},
|
||||
"catalog2": {
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "Test Catalog 2"},
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "Test Catalog 2", Enabled: &trueValue},
|
||||
},
|
||||
},
|
||||
nameFilter: "",
|
||||
|
@ -159,7 +460,7 @@ func TestFindSources(t *testing.T) {
|
|||
name: "Invalid page size",
|
||||
catalogs: map[string]catalog.CatalogSource{
|
||||
"catalog1": {
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "Test Catalog 1"},
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "Test Catalog 1", Enabled: &trueValue},
|
||||
},
|
||||
},
|
||||
nameFilter: "",
|
||||
|
@ -172,13 +473,13 @@ func TestFindSources(t *testing.T) {
|
|||
name: "Sort by ID ascending",
|
||||
catalogs: map[string]catalog.CatalogSource{
|
||||
"catalog2": {
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "B Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "B Catalog", Enabled: &trueValue},
|
||||
},
|
||||
"catalog1": {
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "A Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "A Catalog", Enabled: &trueValue},
|
||||
},
|
||||
"catalog3": {
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "C Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "C Catalog", Enabled: &trueValue},
|
||||
},
|
||||
},
|
||||
nameFilter: "",
|
||||
|
@ -194,13 +495,13 @@ func TestFindSources(t *testing.T) {
|
|||
name: "Sort by ID descending",
|
||||
catalogs: map[string]catalog.CatalogSource{
|
||||
"catalog2": {
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "B Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "B Catalog", Enabled: &trueValue},
|
||||
},
|
||||
"catalog1": {
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "A Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "A Catalog", Enabled: &trueValue},
|
||||
},
|
||||
"catalog3": {
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "C Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "C Catalog", Enabled: &trueValue},
|
||||
},
|
||||
},
|
||||
nameFilter: "",
|
||||
|
@ -216,13 +517,13 @@ func TestFindSources(t *testing.T) {
|
|||
name: "Sort by name ascending",
|
||||
catalogs: map[string]catalog.CatalogSource{
|
||||
"catalog2": {
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "B Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "B Catalog", Enabled: &trueValue},
|
||||
},
|
||||
"catalog1": {
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "A Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "A Catalog", Enabled: &trueValue},
|
||||
},
|
||||
"catalog3": {
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "C Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "C Catalog", Enabled: &trueValue},
|
||||
},
|
||||
},
|
||||
nameFilter: "",
|
||||
|
@ -238,10 +539,10 @@ func TestFindSources(t *testing.T) {
|
|||
name: "Sort by name descending",
|
||||
catalogs: map[string]catalog.CatalogSource{
|
||||
"catalog2": {
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "B Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog2", Name: "B Catalog", Enabled: &trueValue},
|
||||
},
|
||||
"catalog1": {
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "A Catalog"},
|
||||
Metadata: model.CatalogSource{Id: "catalog1", Name: "A Catalog", Enabled: &trueValue},
|
||||
},
|
||||
"catalog3": {
|
||||
Metadata: model.CatalogSource{Id: "catalog3", Name: "C Catalog"},
|
||||
|
@ -310,7 +611,7 @@ func TestFindSources(t *testing.T) {
|
|||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create service with test catalogs
|
||||
service := NewModelCatalogServiceAPIService(tc.catalogs)
|
||||
service := NewModelCatalogServiceAPIService(catalog.NewSourceCollection(tc.catalogs))
|
||||
|
||||
// Call FindSources
|
||||
resp, err := service.FindSources(
|
||||
|
@ -393,7 +694,8 @@ func TestFindSources(t *testing.T) {
|
|||
|
||||
// Define a mock model provider
|
||||
type mockModelProvider struct {
|
||||
models map[string]*model.CatalogModel
|
||||
models map[string]*model.CatalogModel
|
||||
artifacts map[string][]model.CatalogModelArtifact
|
||||
}
|
||||
|
||||
// Implement GetModel method for the mock provider
|
||||
|
@ -406,11 +708,71 @@ func (m *mockModelProvider) GetModel(ctx context.Context, name string) (*model.C
|
|||
}
|
||||
|
||||
func (m *mockModelProvider) ListModels(ctx context.Context, params catalog.ListModelsParams) (model.CatalogModelList, error) {
|
||||
return model.CatalogModelList{}, nil
|
||||
var filteredModels []*model.CatalogModel
|
||||
for _, mdl := range m.models {
|
||||
if params.Query == "" || strings.Contains(strings.ToLower(mdl.Name), strings.ToLower(params.Query)) {
|
||||
filteredModels = append(filteredModels, mdl)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the filtered models
|
||||
sort.SliceStable(filteredModels, func(i, j int) bool {
|
||||
cmp := 0
|
||||
switch params.OrderBy {
|
||||
case model.ORDERBYFIELD_CREATE_TIME:
|
||||
// Parse CreateTimeSinceEpoch strings to int64 for comparison
|
||||
t1, _ := strconv.ParseInt(pointerOrDefault(filteredModels[i].CreateTimeSinceEpoch, "0"), 10, 64)
|
||||
t2, _ := strconv.ParseInt(pointerOrDefault(filteredModels[j].CreateTimeSinceEpoch, "0"), 10, 64)
|
||||
cmp = int(t1 - t2)
|
||||
case model.ORDERBYFIELD_LAST_UPDATE_TIME:
|
||||
// Parse LastUpdateTimeSinceEpoch strings to int64 for comparison
|
||||
t1, _ := strconv.ParseInt(pointerOrDefault(filteredModels[i].LastUpdateTimeSinceEpoch, "0"), 10, 64)
|
||||
t2, _ := strconv.ParseInt(pointerOrDefault(filteredModels[j].LastUpdateTimeSinceEpoch, "0"), 10, 64)
|
||||
cmp = int(t1 - t2)
|
||||
case model.ORDERBYFIELD_NAME:
|
||||
fallthrough
|
||||
default:
|
||||
cmp = strings.Compare(filteredModels[i].Name, filteredModels[j].Name)
|
||||
}
|
||||
|
||||
if params.SortOrder == model.SORTORDER_DESC {
|
||||
return cmp > 0
|
||||
}
|
||||
return cmp < 0
|
||||
})
|
||||
|
||||
items := make([]model.CatalogModel, len(filteredModels))
|
||||
for i, mdl := range filteredModels {
|
||||
items[i] = *mdl
|
||||
}
|
||||
|
||||
return model.CatalogModelList{
|
||||
Items: items,
|
||||
Size: int32(len(items)),
|
||||
PageSize: int32(len(items)), // Mock returns all filtered items as one "page"
|
||||
NextPageToken: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockModelProvider) GetArtifacts(ctx context.Context, name string) (*model.CatalogModelArtifactList, error) {
|
||||
artifacts, exists := m.artifacts[name]
|
||||
if !exists {
|
||||
return &model.CatalogModelArtifactList{
|
||||
Items: []model.CatalogModelArtifact{},
|
||||
Size: 0,
|
||||
PageSize: 0, // Or a default page size if applicable
|
||||
NextPageToken: "",
|
||||
}, nil
|
||||
}
|
||||
return &model.CatalogModelArtifactList{
|
||||
Items: artifacts,
|
||||
Size: int32(len(artifacts)),
|
||||
PageSize: int32(len(artifacts)),
|
||||
NextPageToken: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestGetModel(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
sources map[string]catalog.CatalogSource
|
||||
|
@ -472,7 +834,7 @@ func TestGetModel(t *testing.T) {
|
|||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create service with test sources
|
||||
service := NewModelCatalogServiceAPIService(tc.sources)
|
||||
service := NewModelCatalogServiceAPIService(catalog.NewSourceCollection(tc.sources))
|
||||
|
||||
// Call GetModel
|
||||
resp, _ := service.GetModel(
|
||||
|
@ -501,3 +863,106 @@ func TestGetModel(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllModelArtifacts(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
sources map[string]catalog.CatalogSource
|
||||
sourceID string
|
||||
modelName string
|
||||
expectedStatus int
|
||||
expectedArtifacts []model.CatalogModelArtifact
|
||||
}{
|
||||
{
|
||||
name: "Existing artifacts for model in source",
|
||||
sources: map[string]catalog.CatalogSource{
|
||||
"source1": {
|
||||
Metadata: model.CatalogSource{Id: "source1", Name: "Test Source"},
|
||||
Provider: &mockModelProvider{
|
||||
artifacts: map[string][]model.CatalogModelArtifact{
|
||||
"test-model": {
|
||||
{
|
||||
Uri: "s3://bucket/artifact1",
|
||||
},
|
||||
{
|
||||
Uri: "s3://bucket/artifact2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
sourceID: "source1",
|
||||
modelName: "test-model",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedArtifacts: []model.CatalogModelArtifact{
|
||||
{
|
||||
Uri: "s3://bucket/artifact1",
|
||||
},
|
||||
{
|
||||
Uri: "s3://bucket/artifact2",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Non-existing source",
|
||||
sources: map[string]catalog.CatalogSource{
|
||||
"source1": {
|
||||
Metadata: model.CatalogSource{Id: "source1", Name: "Test Source"},
|
||||
},
|
||||
},
|
||||
sourceID: "source2",
|
||||
modelName: "test-model",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectedArtifacts: nil,
|
||||
},
|
||||
{
|
||||
name: "Existing source, no artifacts for model",
|
||||
sources: map[string]catalog.CatalogSource{
|
||||
"source1": {
|
||||
Metadata: model.CatalogSource{Id: "source1", Name: "Test Source"},
|
||||
Provider: &mockModelProvider{
|
||||
artifacts: map[string][]model.CatalogModelArtifact{},
|
||||
},
|
||||
},
|
||||
},
|
||||
sourceID: "source1",
|
||||
modelName: "test-model",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedArtifacts: []model.CatalogModelArtifact{}, // Should be an empty slice, not nil
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create service with test sources
|
||||
service := NewModelCatalogServiceAPIService(catalog.NewSourceCollection(tc.sources))
|
||||
|
||||
// Call GetAllModelArtifacts
|
||||
resp, _ := service.GetAllModelArtifacts(
|
||||
context.Background(),
|
||||
tc.sourceID,
|
||||
tc.modelName,
|
||||
)
|
||||
|
||||
// Check response status
|
||||
assert.Equal(t, tc.expectedStatus, resp.Code)
|
||||
|
||||
// If we expect an error or not found, we don't need to check the response body
|
||||
if tc.expectedStatus != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
// For successful responses, check the response body
|
||||
require.NotNil(t, resp.Body)
|
||||
|
||||
// Type assertion to access the list of artifacts
|
||||
artifactList, ok := resp.Body.(*model.CatalogModelArtifactList)
|
||||
require.True(t, ok, "Response body should be a CatalogModelArtifactList")
|
||||
|
||||
// Check the artifacts
|
||||
assert.Equal(t, tc.expectedArtifacts, artifactList.Items)
|
||||
assert.Equal(t, int32(len(tc.expectedArtifacts)), artifactList.Size)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
package openapi
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
model "github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
)
|
||||
|
||||
type paginator[T model.Sortable] struct {
|
||||
PageSize int32
|
||||
OrderBy model.OrderByField
|
||||
SortOrder model.SortOrder
|
||||
cursor *stringCursor
|
||||
}
|
||||
|
||||
func newPaginator[T model.Sortable](pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (*paginator[T], error) {
|
||||
if orderBy != "" && !orderBy.IsValid() {
|
||||
return nil, fmt.Errorf("unsupported order by field: %s", orderBy)
|
||||
}
|
||||
if sortOrder != "" && !sortOrder.IsValid() {
|
||||
return nil, fmt.Errorf("unsupported sort order field: %s", sortOrder)
|
||||
}
|
||||
|
||||
p := &paginator[T]{
|
||||
PageSize: 10, // Default page size
|
||||
OrderBy: orderBy,
|
||||
SortOrder: sortOrder,
|
||||
}
|
||||
|
||||
if pageSize != "" {
|
||||
pageSize64, err := strconv.ParseInt(pageSize, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error converting page size to int32: %w", err)
|
||||
}
|
||||
p.PageSize = int32(pageSize64)
|
||||
}
|
||||
|
||||
if nextPageToken != "" {
|
||||
p.cursor = decodeStringCursor(nextPageToken)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *paginator[T]) Token() string {
|
||||
if p == nil || p.cursor == nil {
|
||||
return ""
|
||||
}
|
||||
return p.cursor.String()
|
||||
}
|
||||
|
||||
func (p *paginator[T]) Paginate(items []T) ([]T, *paginator[T]) {
|
||||
startIndex := 0
|
||||
if p.cursor != nil {
|
||||
for i, item := range items {
|
||||
itemValue := item.SortValue(p.OrderBy)
|
||||
id := item.SortValue(model.ORDERBYFIELD_ID)
|
||||
if id != "" && id == p.cursor.ID && itemValue == p.cursor.Value {
|
||||
startIndex = i + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if startIndex >= len(items) {
|
||||
return []T{}, nil
|
||||
}
|
||||
|
||||
var pagedItems []T
|
||||
var next *paginator[T]
|
||||
|
||||
endIndex := startIndex + int(p.PageSize)
|
||||
if endIndex > len(items) {
|
||||
endIndex = len(items)
|
||||
}
|
||||
pagedItems = items[startIndex:endIndex]
|
||||
|
||||
if endIndex < len(items) {
|
||||
lastItem := pagedItems[len(pagedItems)-1]
|
||||
lastItemID := lastItem.SortValue(model.ORDERBYFIELD_ID)
|
||||
if lastItemID != "" {
|
||||
next = &paginator[T]{
|
||||
PageSize: p.PageSize,
|
||||
OrderBy: p.OrderBy,
|
||||
SortOrder: p.SortOrder,
|
||||
cursor: &stringCursor{
|
||||
Value: lastItem.SortValue(p.OrderBy),
|
||||
ID: lastItemID,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return pagedItems, next
|
||||
}
|
||||
|
||||
type stringCursor struct {
|
||||
Value string
|
||||
ID string
|
||||
}
|
||||
|
||||
func (c *stringCursor) String() string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", c.Value, c.ID)))
|
||||
}
|
||||
|
||||
func decodeStringCursor(encoded string) *stringCursor {
|
||||
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
// Show the first page on a bad token.
|
||||
return nil
|
||||
}
|
||||
parts := strings.SplitN(string(decoded), ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil
|
||||
}
|
||||
return &stringCursor{
|
||||
Value: parts[0],
|
||||
ID: parts[1],
|
||||
}
|
||||
}
|
|
@ -0,0 +1,180 @@
|
|||
package openapi
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
model "github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func createCatalogSource(id int) model.CatalogSource {
|
||||
return model.CatalogSource{
|
||||
Id: "source" + strconv.Itoa(id),
|
||||
Name: "Source " + strconv.Itoa(id),
|
||||
}
|
||||
}
|
||||
|
||||
func createCatalogSources(count int) []model.CatalogSource {
|
||||
sources := make([]model.CatalogSource, count)
|
||||
for i := 0; i < count; i++ {
|
||||
sources[i] = createCatalogSource(i)
|
||||
}
|
||||
return sources
|
||||
}
|
||||
|
||||
func TestPaginateSources(t *testing.T) {
|
||||
allSources := createCatalogSources(25)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
items []model.CatalogSource
|
||||
pageSize string
|
||||
orderBy model.OrderByField
|
||||
nextPageToken string
|
||||
expectedItemsCount int
|
||||
expectedNextToken bool
|
||||
expectedFirstID string
|
||||
expectedLastID string
|
||||
}{
|
||||
{
|
||||
name: "First page, full page",
|
||||
items: allSources,
|
||||
pageSize: "10",
|
||||
orderBy: "ID",
|
||||
nextPageToken: "",
|
||||
expectedItemsCount: 10,
|
||||
expectedNextToken: true,
|
||||
expectedFirstID: "source0",
|
||||
expectedLastID: "source9",
|
||||
},
|
||||
{
|
||||
name: "Second page, full page",
|
||||
items: allSources,
|
||||
pageSize: "10",
|
||||
orderBy: "ID",
|
||||
nextPageToken: (&stringCursor{Value: "source9", ID: "source9"}).String(),
|
||||
expectedItemsCount: 10,
|
||||
expectedNextToken: true,
|
||||
expectedFirstID: "source10",
|
||||
expectedLastID: "source19",
|
||||
},
|
||||
{
|
||||
name: "Last page, partial page",
|
||||
items: allSources,
|
||||
pageSize: "10",
|
||||
orderBy: "ID",
|
||||
nextPageToken: (&stringCursor{Value: "source19", ID: "source19"}).String(),
|
||||
expectedItemsCount: 5,
|
||||
expectedNextToken: false,
|
||||
expectedFirstID: "source20",
|
||||
expectedLastID: "source24",
|
||||
},
|
||||
{
|
||||
name: "Page size larger than items",
|
||||
items: allSources,
|
||||
pageSize: "30",
|
||||
orderBy: "ID",
|
||||
nextPageToken: "",
|
||||
expectedItemsCount: 25,
|
||||
expectedNextToken: false,
|
||||
expectedFirstID: "source0",
|
||||
expectedLastID: "source24",
|
||||
},
|
||||
{
|
||||
name: "Empty items",
|
||||
items: []model.CatalogSource{},
|
||||
pageSize: "10",
|
||||
orderBy: "ID",
|
||||
nextPageToken: "",
|
||||
expectedItemsCount: 0,
|
||||
expectedNextToken: false,
|
||||
},
|
||||
{
|
||||
name: "Order by Name, first page",
|
||||
items: allSources,
|
||||
pageSize: "5",
|
||||
orderBy: "NAME",
|
||||
nextPageToken: "",
|
||||
expectedItemsCount: 5,
|
||||
expectedNextToken: true,
|
||||
expectedFirstID: "source0",
|
||||
expectedLastID: "source4",
|
||||
},
|
||||
{
|
||||
name: "Order by Name, second page",
|
||||
items: allSources,
|
||||
pageSize: "5",
|
||||
orderBy: "NAME",
|
||||
nextPageToken: (&stringCursor{Value: "Source 4", ID: "source4"}).String(),
|
||||
expectedItemsCount: 5,
|
||||
expectedNextToken: true,
|
||||
expectedFirstID: "source5",
|
||||
expectedLastID: "source9",
|
||||
},
|
||||
{
|
||||
name: "Invalid token",
|
||||
items: allSources,
|
||||
pageSize: "10",
|
||||
orderBy: "ID",
|
||||
nextPageToken: "invalid-token",
|
||||
expectedItemsCount: 10,
|
||||
expectedNextToken: true,
|
||||
expectedFirstID: "source0",
|
||||
expectedLastID: "source9",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
paginator, err := newPaginator[model.CatalogSource](tc.pageSize, tc.orderBy, "", tc.nextPageToken)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
pagedItems, newNextPageToken := paginator.Paginate(tc.items)
|
||||
|
||||
assert.Equal(t, tc.expectedItemsCount, len(pagedItems))
|
||||
if tc.expectedNextToken {
|
||||
assert.NotEmpty(t, newNextPageToken)
|
||||
} else {
|
||||
assert.Empty(t, newNextPageToken)
|
||||
}
|
||||
|
||||
if tc.expectedItemsCount > 0 {
|
||||
assert.Equal(t, tc.expectedFirstID, pagedItems[0].Id)
|
||||
assert.Equal(t, tc.expectedLastID, pagedItems[len(pagedItems)-1].Id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginateSources_NoDuplicates(t *testing.T) {
|
||||
allSources := createCatalogSources(100)
|
||||
pageSize := "10"
|
||||
orderBy := "ID"
|
||||
|
||||
seenItems := make(map[string]struct{}, len(allSources))
|
||||
totalSeen := 0
|
||||
|
||||
paginator, err := newPaginator[model.CatalogSource](pageSize, model.OrderByField(orderBy), "", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
for paginator != nil {
|
||||
var pagedItems []model.CatalogSource
|
||||
pagedItems, paginator = paginator.Paginate(allSources)
|
||||
|
||||
for _, item := range pagedItems {
|
||||
if _, ok := seenItems[item.Id]; ok {
|
||||
t.Errorf("Duplicate item found: %s", item.Id)
|
||||
}
|
||||
seenItems[item.Id] = struct{}{}
|
||||
}
|
||||
|
||||
totalSeen += len(pagedItems)
|
||||
}
|
||||
|
||||
assert.Equal(t, len(allSources), totalSeen, "Total number of items seen should match the original slice")
|
||||
}
|
|
@ -16,6 +16,16 @@ import (
|
|||
model "github.com/kubeflow/model-registry/catalog/pkg/openapi"
|
||||
)
|
||||
|
||||
// AssertArtifactTypeQueryParamConstraints checks if the values respects the defined constraints
|
||||
func AssertArtifactTypeQueryParamConstraints(obj model.ArtifactTypeQueryParam) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssertArtifactTypeQueryParamRequired checks if the required fields are not zero-ed
|
||||
func AssertArtifactTypeQueryParamRequired(obj model.ArtifactTypeQueryParam) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssertBaseModelConstraints checks if the values respects the defined constraints
|
||||
func AssertBaseModelConstraints(obj model.BaseModel) error {
|
||||
return nil
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
api_model_catalog_service.go
|
||||
client.go
|
||||
configuration.go
|
||||
model_artifact_type_query_param.go
|
||||
model_base_model.go
|
||||
model_base_resource_dates.go
|
||||
model_base_resource_list.go
|
||||
|
|
|
@ -33,7 +33,7 @@ type ApiFindModelsRequest struct {
|
|||
nextPageToken *string
|
||||
}
|
||||
|
||||
// Filter models by source. If not provided, models from all sources are returned. If multiple sources are provided, models from any of the sources are returned.
|
||||
// Filter models by source. This parameter is currently required and may only be specified once.
|
||||
func (r ApiFindModelsRequest) Source(source string) ApiFindModelsRequest {
|
||||
r.source = &source
|
||||
return r
|
||||
|
@ -107,10 +107,11 @@ func (a *ModelCatalogServiceAPIService) FindModelsExecute(r ApiFindModelsRequest
|
|||
localVarHeaderParams := make(map[string]string)
|
||||
localVarQueryParams := url.Values{}
|
||||
localVarFormParams := url.Values{}
|
||||
|
||||
if r.source != nil {
|
||||
parameterAddToHeaderOrQuery(localVarQueryParams, "source", r.source, "")
|
||||
if r.source == nil {
|
||||
return localVarReturnValue, nil, reportError("source is required and must be specified")
|
||||
}
|
||||
|
||||
parameterAddToHeaderOrQuery(localVarQueryParams, "source", r.source, "")
|
||||
if r.q != nil {
|
||||
parameterAddToHeaderOrQuery(localVarQueryParams, "q", r.q, "")
|
||||
}
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
/*
|
||||
Model Catalog REST API
|
||||
|
||||
REST API for Model Registry to create and manage ML model metadata
|
||||
|
||||
API version: v1alpha1
|
||||
*/
|
||||
|
||||
// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
|
||||
|
||||
package openapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ArtifactTypeQueryParam Supported artifact types for querying.
|
||||
type ArtifactTypeQueryParam string
|
||||
|
||||
// List of ArtifactTypeQueryParam
|
||||
const (
|
||||
ARTIFACTTYPEQUERYPARAM_MODEL_ARTIFACT ArtifactTypeQueryParam = "model-artifact"
|
||||
ARTIFACTTYPEQUERYPARAM_DOC_ARTIFACT ArtifactTypeQueryParam = "doc-artifact"
|
||||
ARTIFACTTYPEQUERYPARAM_DATASET_ARTIFACT ArtifactTypeQueryParam = "dataset-artifact"
|
||||
ARTIFACTTYPEQUERYPARAM_METRIC ArtifactTypeQueryParam = "metric"
|
||||
ARTIFACTTYPEQUERYPARAM_PARAMETER ArtifactTypeQueryParam = "parameter"
|
||||
)
|
||||
|
||||
// All allowed values of ArtifactTypeQueryParam enum
|
||||
var AllowedArtifactTypeQueryParamEnumValues = []ArtifactTypeQueryParam{
|
||||
"model-artifact",
|
||||
"doc-artifact",
|
||||
"dataset-artifact",
|
||||
"metric",
|
||||
"parameter",
|
||||
}
|
||||
|
||||
func (v *ArtifactTypeQueryParam) UnmarshalJSON(src []byte) error {
|
||||
var value string
|
||||
err := json.Unmarshal(src, &value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
enumTypeValue := ArtifactTypeQueryParam(value)
|
||||
for _, existing := range AllowedArtifactTypeQueryParamEnumValues {
|
||||
if existing == enumTypeValue {
|
||||
*v = enumTypeValue
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%+v is not a valid ArtifactTypeQueryParam", value)
|
||||
}
|
||||
|
||||
// NewArtifactTypeQueryParamFromValue returns a pointer to a valid ArtifactTypeQueryParam
|
||||
// for the value passed as argument, or an error if the value passed is not allowed by the enum
|
||||
func NewArtifactTypeQueryParamFromValue(v string) (*ArtifactTypeQueryParam, error) {
|
||||
ev := ArtifactTypeQueryParam(v)
|
||||
if ev.IsValid() {
|
||||
return &ev, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid value '%v' for ArtifactTypeQueryParam: valid values are %v", v, AllowedArtifactTypeQueryParamEnumValues)
|
||||
}
|
||||
}
|
||||
|
||||
// IsValid return true if the value is valid for the enum, false otherwise
|
||||
func (v ArtifactTypeQueryParam) IsValid() bool {
|
||||
for _, existing := range AllowedArtifactTypeQueryParamEnumValues {
|
||||
if existing == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Ptr returns reference to ArtifactTypeQueryParam value
|
||||
func (v ArtifactTypeQueryParam) Ptr() *ArtifactTypeQueryParam {
|
||||
return &v
|
||||
}
|
||||
|
||||
type NullableArtifactTypeQueryParam struct {
|
||||
value *ArtifactTypeQueryParam
|
||||
isSet bool
|
||||
}
|
||||
|
||||
func (v NullableArtifactTypeQueryParam) Get() *ArtifactTypeQueryParam {
|
||||
return v.value
|
||||
}
|
||||
|
||||
func (v *NullableArtifactTypeQueryParam) Set(val *ArtifactTypeQueryParam) {
|
||||
v.value = val
|
||||
v.isSet = true
|
||||
}
|
||||
|
||||
func (v NullableArtifactTypeQueryParam) IsSet() bool {
|
||||
return v.isSet
|
||||
}
|
||||
|
||||
func (v *NullableArtifactTypeQueryParam) Unset() {
|
||||
v.value = nil
|
||||
v.isSet = false
|
||||
}
|
||||
|
||||
func NewNullableArtifactTypeQueryParam(val *ArtifactTypeQueryParam) *NullableArtifactTypeQueryParam {
|
||||
return &NullableArtifactTypeQueryParam{value: val, isSet: true}
|
||||
}
|
||||
|
||||
func (v NullableArtifactTypeQueryParam) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(v.value)
|
||||
}
|
||||
|
||||
func (v *NullableArtifactTypeQueryParam) UnmarshalJSON(src []byte) error {
|
||||
v.isSet = true
|
||||
return json.Unmarshal(src, &v.value)
|
||||
}
|
|
@ -23,6 +23,8 @@ type CatalogSource struct {
|
|||
Id string `json:"id"`
|
||||
// The name of the catalog source.
|
||||
Name string `json:"name"`
|
||||
// Whether the catalog source is enabled.
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
}
|
||||
|
||||
// NewCatalogSource instantiates a new CatalogSource object
|
||||
|
@ -33,6 +35,8 @@ func NewCatalogSource(id string, name string) *CatalogSource {
|
|||
this := CatalogSource{}
|
||||
this.Id = id
|
||||
this.Name = name
|
||||
var enabled bool = true
|
||||
this.Enabled = &enabled
|
||||
return &this
|
||||
}
|
||||
|
||||
|
@ -41,6 +45,8 @@ func NewCatalogSource(id string, name string) *CatalogSource {
|
|||
// but it doesn't guarantee that properties required by API are set
|
||||
func NewCatalogSourceWithDefaults() *CatalogSource {
|
||||
this := CatalogSource{}
|
||||
var enabled bool = true
|
||||
this.Enabled = &enabled
|
||||
return &this
|
||||
}
|
||||
|
||||
|
@ -92,6 +98,38 @@ func (o *CatalogSource) SetName(v string) {
|
|||
o.Name = v
|
||||
}
|
||||
|
||||
// GetEnabled returns the Enabled field value if set, zero value otherwise.
|
||||
func (o *CatalogSource) GetEnabled() bool {
|
||||
if o == nil || IsNil(o.Enabled) {
|
||||
var ret bool
|
||||
return ret
|
||||
}
|
||||
return *o.Enabled
|
||||
}
|
||||
|
||||
// GetEnabledOk returns a tuple with the Enabled field value if set, nil otherwise
|
||||
// and a boolean to check if the value has been set.
|
||||
func (o *CatalogSource) GetEnabledOk() (*bool, bool) {
|
||||
if o == nil || IsNil(o.Enabled) {
|
||||
return nil, false
|
||||
}
|
||||
return o.Enabled, true
|
||||
}
|
||||
|
||||
// HasEnabled returns a boolean if a field has been set.
|
||||
func (o *CatalogSource) HasEnabled() bool {
|
||||
if o != nil && !IsNil(o.Enabled) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetEnabled gets a reference to the given bool and assigns it to the Enabled field.
|
||||
func (o *CatalogSource) SetEnabled(v bool) {
|
||||
o.Enabled = &v
|
||||
}
|
||||
|
||||
func (o CatalogSource) MarshalJSON() ([]byte, error) {
|
||||
toSerialize, err := o.ToMap()
|
||||
if err != nil {
|
||||
|
@ -104,6 +142,9 @@ func (o CatalogSource) ToMap() (map[string]interface{}, error) {
|
|||
toSerialize := map[string]interface{}{}
|
||||
toSerialize["id"] = o.Id
|
||||
toSerialize["name"] = o.Name
|
||||
if !IsNil(o.Enabled) {
|
||||
toSerialize["enabled"] = o.Enabled
|
||||
}
|
||||
return toSerialize, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
package openapi
|
||||
|
||||
type Sortable interface {
|
||||
// SortValue returns the value of a requested field converted to a string.
|
||||
SortValue(field OrderByField) string
|
||||
}
|
||||
|
||||
func (s CatalogSource) SortValue(field OrderByField) string {
|
||||
switch field {
|
||||
case ORDERBYFIELD_ID:
|
||||
return s.Id
|
||||
case ORDERBYFIELD_NAME:
|
||||
return s.Name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m CatalogModel) SortValue(field OrderByField) string {
|
||||
switch field {
|
||||
case ORDERBYFIELD_ID:
|
||||
return m.Name // Name is ID for models
|
||||
case ORDERBYFIELD_NAME:
|
||||
return m.Name
|
||||
case ORDERBYFIELD_LAST_UPDATE_TIME:
|
||||
return unrefString(m.LastUpdateTimeSinceEpoch)
|
||||
case ORDERBYFIELD_CREATE_TIME:
|
||||
return unrefString(m.CreateTimeSinceEpoch)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func unrefString(v *string) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return *v
|
||||
}
|
|
@ -6,4 +6,4 @@
|
|||
/.python-version
|
||||
__pycache__/
|
||||
venv/
|
||||
.port-forwards.pid
|
||||
.hypothesis/
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
all: install tidy
|
||||
|
||||
IMG_VERSION ?= latest
|
||||
IMG ?= ghcr.io/kubeflow/model-registry/server:${IMG_VERSION}
|
||||
IMG ?= ghcr.io/kubeflow/model-registry/server
|
||||
BUILD_IMAGE ?= true # whether to build the MR server image
|
||||
|
||||
.PHONY: install
|
||||
install:
|
||||
|
@ -16,7 +17,11 @@ clean:
|
|||
|
||||
.PHONY: deploy-latest-mr
|
||||
deploy-latest-mr:
|
||||
cd ../../ && IMG_VERSION=${IMG_VERSION} IMG=${IMG} make image/build ARGS="--load$(if ${DEV_BUILD}, --target dev-build)" && LOCAL=1 ./scripts/deploy_on_kind.sh
|
||||
cd ../../ && \
|
||||
$(if $(filter true,$(BUILD_IMAGE)),\
|
||||
IMG_VERSION=${IMG_VERSION} IMG=${IMG} make image/build ARGS="--load$(if ${DEV_BUILD}, --target dev-build)" && \
|
||||
) \
|
||||
LOCAL=1 ./scripts/deploy_on_kind.sh
|
||||
kubectl port-forward -n kubeflow services/model-registry-service 8080:8080 & echo $$! >> .port-forwards.pid
|
||||
|
||||
.PHONY: deploy-test-minio
|
||||
|
@ -36,6 +41,16 @@ test-e2e: deploy-latest-mr deploy-local-registry deploy-test-minio
|
|||
$(MAKE) test-e2e-cleanup
|
||||
@exit $$STATUS
|
||||
|
||||
.PHONY: test-fuzz
|
||||
test-fuzz: deploy-latest-mr deploy-local-registry deploy-test-minio
|
||||
@echo "Starting test-fuzz"
|
||||
poetry install --all-extras
|
||||
@set -a; . ../../scripts/manifests/minio/.env; set +a; \
|
||||
poetry run pytest --fuzz -v -s --hypothesis-show-statistics
|
||||
@rm -f ../../scripts/manifests/minio/.env
|
||||
$(MAKE) test-e2e-cleanup
|
||||
@exit $$STATUS
|
||||
|
||||
.PHONY: test-e2e-run
|
||||
test-e2e-run:
|
||||
@echo "Ensuring all extras are installed..."
|
||||
|
@ -47,6 +62,8 @@ test-e2e-run:
|
|||
|
||||
.PHONY: test-e2e-cleanup
|
||||
test-e2e-cleanup:
|
||||
@echo "Cleaning up database..."
|
||||
cd ../../ && ./scripts/cleanup.sh
|
||||
@echo "Cleaning up port-forward processes..."
|
||||
@if [ -f .port-forwards.pid ]; then \
|
||||
kill $$(cat .port-forwards.pid) || true; \
|
||||
|
|
|
@ -345,6 +345,12 @@ Then you can run tests:
|
|||
make test test-e2e
|
||||
```
|
||||
|
||||
Then you can run fuzz tests:
|
||||
|
||||
```bash
|
||||
make test-fuzz
|
||||
```
|
||||
|
||||
### Using Nox
|
||||
|
||||
Common tasks, such as building documentation and running tests, can be executed using [`nox`](https://github.com/wntrblm/nox) sessions.
|
||||
|
|
|
@ -60,6 +60,7 @@ def tests(session: Session) -> None:
|
|||
"pytest-asyncio",
|
||||
"uvloop",
|
||||
"olot",
|
||||
"schemathesis",
|
||||
)
|
||||
session.run(
|
||||
"pytest",
|
||||
|
@ -83,6 +84,7 @@ def e2e_tests(session: Session) -> None:
|
|||
"boto3",
|
||||
"olot",
|
||||
"uvloop",
|
||||
"schemathesis",
|
||||
)
|
||||
try:
|
||||
session.run(
|
||||
|
@ -99,6 +101,22 @@ def e2e_tests(session: Session) -> None:
|
|||
session.notify("coverage", posargs=[])
|
||||
|
||||
|
||||
@session(name="fuzz", python=python_versions)
|
||||
def fuzz_tests(session: Session) -> None:
|
||||
"""Run the fuzzing tests."""
|
||||
session.install(
|
||||
".",
|
||||
"requests",
|
||||
"pytest",
|
||||
"uvloop",
|
||||
"olot",
|
||||
"schemathesis",
|
||||
)
|
||||
session.run(
|
||||
"pytest",
|
||||
"--fuzz",
|
||||
"-rA",
|
||||
)
|
||||
@session(python=python_versions[0])
|
||||
def coverage(session: Session) -> None:
|
||||
"""Produce the coverage report."""
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "model-registry"
|
||||
version = "0.2.20"
|
||||
version = "0.3.0"
|
||||
description = "Client for Kubeflow Model Registry"
|
||||
authors = ["Isabella Basso do Amaral <idoamara@redhat.com>"]
|
||||
license = "Apache-2.0"
|
||||
|
@ -26,7 +26,7 @@ nest-asyncio = "^1.6.0"
|
|||
# necessary for modern type annotations using pydantic on 3.9
|
||||
eval-type-backport = "^0.2.0"
|
||||
|
||||
huggingface-hub = { version = ">=0.20.1,<0.34.0", optional = true }
|
||||
huggingface-hub = { version = ">=0.20.1,<0.35.0", optional = true }
|
||||
olot = { version = "^0.1.6", optional = true }
|
||||
boto3 = { version = "^1.37.34", optional = true }
|
||||
|
||||
|
@ -40,7 +40,7 @@ optional = true
|
|||
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
sphinx = "^7.2.6"
|
||||
furo = ">=2023.9.10,<2025.0.0"
|
||||
furo = ">=2023.9.10,<2026.0.0"
|
||||
myst-parser = { extras = ["linkify"], version = ">=2,<4" }
|
||||
sphinx-autobuild = ">=2021.3.14,<2025.0.0"
|
||||
|
||||
|
@ -55,11 +55,12 @@ ray = [
|
|||
{version = "^2.43.0", python = ">=3.9, <3.13"}
|
||||
]
|
||||
uvloop = "^0.21.0"
|
||||
pytest-asyncio = ">=0.23.7,<0.27.0"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
requests = "^2.32.2"
|
||||
black = ">=24.4.2,<26.0.0"
|
||||
types-python-dateutil = "^2.9.0.20240906"
|
||||
pytest-html = "^4.1.1"
|
||||
schemathesis = ">=4.0.3"
|
||||
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
|
@ -81,7 +82,10 @@ line-length = 119
|
|||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
markers = ["e2e: end-to-end testing"]
|
||||
markers = [
|
||||
"e2e: end-to-end testing",
|
||||
"fuzz: mark a test as a fuzzing (property-based or randomized) test"
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
base-url = "${API_HOST}"
|
||||
|
||||
[generation]
|
||||
# Don't shrink failing examples to save time
|
||||
no-shrink = true
|
|
@ -10,6 +10,7 @@ mr_openapi/models/artifact.py
|
|||
mr_openapi/models/artifact_create.py
|
||||
mr_openapi/models/artifact_list.py
|
||||
mr_openapi/models/artifact_state.py
|
||||
mr_openapi/models/artifact_type_query_param.py
|
||||
mr_openapi/models/artifact_update.py
|
||||
mr_openapi/models/base_model.py
|
||||
mr_openapi/models/base_resource.py
|
||||
|
@ -17,11 +18,25 @@ mr_openapi/models/base_resource_create.py
|
|||
mr_openapi/models/base_resource_dates.py
|
||||
mr_openapi/models/base_resource_list.py
|
||||
mr_openapi/models/base_resource_update.py
|
||||
mr_openapi/models/data_set.py
|
||||
mr_openapi/models/data_set_create.py
|
||||
mr_openapi/models/data_set_update.py
|
||||
mr_openapi/models/doc_artifact.py
|
||||
mr_openapi/models/doc_artifact_create.py
|
||||
mr_openapi/models/doc_artifact_update.py
|
||||
mr_openapi/models/error.py
|
||||
mr_openapi/models/execution_state.py
|
||||
mr_openapi/models/experiment.py
|
||||
mr_openapi/models/experiment_create.py
|
||||
mr_openapi/models/experiment_list.py
|
||||
mr_openapi/models/experiment_run.py
|
||||
mr_openapi/models/experiment_run_create.py
|
||||
mr_openapi/models/experiment_run_list.py
|
||||
mr_openapi/models/experiment_run_state.py
|
||||
mr_openapi/models/experiment_run_status.py
|
||||
mr_openapi/models/experiment_run_update.py
|
||||
mr_openapi/models/experiment_state.py
|
||||
mr_openapi/models/experiment_update.py
|
||||
mr_openapi/models/inference_service.py
|
||||
mr_openapi/models/inference_service_create.py
|
||||
mr_openapi/models/inference_service_list.py
|
||||
|
@ -34,6 +49,10 @@ mr_openapi/models/metadata_proto_value.py
|
|||
mr_openapi/models/metadata_string_value.py
|
||||
mr_openapi/models/metadata_struct_value.py
|
||||
mr_openapi/models/metadata_value.py
|
||||
mr_openapi/models/metric.py
|
||||
mr_openapi/models/metric_create.py
|
||||
mr_openapi/models/metric_list.py
|
||||
mr_openapi/models/metric_update.py
|
||||
mr_openapi/models/model_artifact.py
|
||||
mr_openapi/models/model_artifact_create.py
|
||||
mr_openapi/models/model_artifact_list.py
|
||||
|
@ -44,6 +63,10 @@ mr_openapi/models/model_version_list.py
|
|||
mr_openapi/models/model_version_state.py
|
||||
mr_openapi/models/model_version_update.py
|
||||
mr_openapi/models/order_by_field.py
|
||||
mr_openapi/models/parameter.py
|
||||
mr_openapi/models/parameter_create.py
|
||||
mr_openapi/models/parameter_type.py
|
||||
mr_openapi/models/parameter_update.py
|
||||
mr_openapi/models/registered_model.py
|
||||
mr_openapi/models/registered_model_create.py
|
||||
mr_openapi/models/registered_model_list.py
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""Main package for the Kubeflow model registry."""
|
||||
|
||||
__version__ = "0.2.20"
|
||||
__version__ = "0.3.0"
|
||||
|
||||
from ._client import ModelRegistry
|
||||
|
||||
|
|
|
@ -22,6 +22,8 @@ from mr_openapi import (
|
|||
)
|
||||
from mr_openapi import (
|
||||
ArtifactState,
|
||||
DocArtifactCreate,
|
||||
DocArtifactUpdate,
|
||||
ModelArtifactCreate,
|
||||
ModelArtifactUpdate,
|
||||
)
|
||||
|
@ -47,7 +49,7 @@ class Artifact(BaseResourceModel, ABC):
|
|||
"""
|
||||
|
||||
name: str | None = None
|
||||
uri: str
|
||||
uri: str | None = None
|
||||
state: ArtifactState = ArtifactState.UNKNOWN
|
||||
|
||||
@classmethod
|
||||
|
@ -87,11 +89,23 @@ class DocArtifact(Artifact):
|
|||
|
||||
@override
|
||||
def create(self, **kwargs) -> Any:
|
||||
raise NotImplementedError
|
||||
"""Create a new DocArtifactCreate object."""
|
||||
return DocArtifactCreate(
|
||||
customProperties=self._map_custom_properties(),
|
||||
**self._props_as_dict(exclude=("id", "custom_properties")),
|
||||
artifactType="doc-artifact",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@override
|
||||
def update(self, **kwargs) -> Any:
|
||||
raise NotImplementedError
|
||||
"""Create a new DocArtifactUpdate object."""
|
||||
return DocArtifactUpdate(
|
||||
customProperties=self._map_custom_properties(),
|
||||
**self._props_as_dict(exclude=("id", "name", "custom_properties")),
|
||||
artifactType="doc-artifact",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@override
|
||||
def as_basemodel(self) -> DocArtifactBaseModel:
|
||||
|
@ -105,7 +119,6 @@ class DocArtifact(Artifact):
|
|||
@override
|
||||
def from_basemodel(cls, source: DocArtifactBaseModel) -> DocArtifact:
|
||||
assert source.name
|
||||
assert source.uri
|
||||
assert source.state
|
||||
return cls(
|
||||
id=source.id,
|
||||
|
@ -189,7 +202,6 @@ class ModelArtifact(Artifact):
|
|||
def from_basemodel(cls, source: ModelArtifactBaseModel) -> ModelArtifact:
|
||||
"""Create a new ModelArtifact object from a BaseModel object."""
|
||||
assert source.name
|
||||
assert source.uri
|
||||
assert source.state
|
||||
return cls(
|
||||
id=source.id,
|
||||
|
|
|
@ -36,8 +36,9 @@ class ModelVersion(BaseResourceModel):
|
|||
|
||||
Attributes:
|
||||
name: Name of this version.
|
||||
author: Author of the model version.
|
||||
description: Description of the object.
|
||||
author: Author of this model version.
|
||||
state: Status of this model version.
|
||||
description: Description of this object.
|
||||
external_id: Customizable ID. Has to be unique among instances of the same type.
|
||||
artifacts: Artifacts associated with this version.
|
||||
"""
|
||||
|
@ -45,6 +46,7 @@ class ModelVersion(BaseResourceModel):
|
|||
name: str
|
||||
author: str | None = None
|
||||
state: ModelVersionState = ModelVersionState.LIVE
|
||||
registered_model_id: str | None = None
|
||||
|
||||
@override
|
||||
def create(self, *, registered_model_id: str, **kwargs) -> ModelVersionCreate: # type: ignore[override]
|
||||
|
@ -75,6 +77,7 @@ class ModelVersion(BaseResourceModel):
|
|||
author=source.author,
|
||||
description=source.description,
|
||||
external_id=source.external_id,
|
||||
registered_model_id=source.registered_model_id,
|
||||
create_time_since_epoch=source.create_time_since_epoch,
|
||||
last_update_time_since_epoch=source.last_update_time_since_epoch,
|
||||
custom_properties=cls._unmap_custom_properties(source.custom_properties)
|
||||
|
|
|
@ -75,6 +75,9 @@ Class | Method | HTTP request | Description
|
|||
------------ | ------------- | ------------- | -------------
|
||||
*ModelRegistryServiceApi* | [**create_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#create_artifact) | **POST** /api/model_registry/v1alpha3/artifacts | Create an Artifact
|
||||
*ModelRegistryServiceApi* | [**create_environment_inference_service**](mr_openapi/docs/ModelRegistryServiceApi.md#create_environment_inference_service) | **POST** /api/model_registry/v1alpha3/serving_environments/{servingenvironmentId}/inference_services | Create a InferenceService in ServingEnvironment
|
||||
*ModelRegistryServiceApi* | [**create_experiment**](mr_openapi/docs/ModelRegistryServiceApi.md#create_experiment) | **POST** /api/model_registry/v1alpha3/experiments | Create an Experiment
|
||||
*ModelRegistryServiceApi* | [**create_experiment_experiment_run**](mr_openapi/docs/ModelRegistryServiceApi.md#create_experiment_experiment_run) | **POST** /api/model_registry/v1alpha3/experiments/{experimentId}/experiment_runs | Create an ExperimentRun in Experiment
|
||||
*ModelRegistryServiceApi* | [**create_experiment_run**](mr_openapi/docs/ModelRegistryServiceApi.md#create_experiment_run) | **POST** /api/model_registry/v1alpha3/experiment_runs | Create an ExperimentRun
|
||||
*ModelRegistryServiceApi* | [**create_inference_service**](mr_openapi/docs/ModelRegistryServiceApi.md#create_inference_service) | **POST** /api/model_registry/v1alpha3/inference_services | Create a InferenceService
|
||||
*ModelRegistryServiceApi* | [**create_inference_service_serve**](mr_openapi/docs/ModelRegistryServiceApi.md#create_inference_service_serve) | **POST** /api/model_registry/v1alpha3/inference_services/{inferenceserviceId}/serves | Create a ServeModel action in a InferenceService
|
||||
*ModelRegistryServiceApi* | [**create_model_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#create_model_artifact) | **POST** /api/model_registry/v1alpha3/model_artifacts | Create a ModelArtifact
|
||||
|
@ -83,6 +86,8 @@ Class | Method | HTTP request | Description
|
|||
*ModelRegistryServiceApi* | [**create_registered_model_version**](mr_openapi/docs/ModelRegistryServiceApi.md#create_registered_model_version) | **POST** /api/model_registry/v1alpha3/registered_models/{registeredmodelId}/versions | Create a ModelVersion in RegisteredModel
|
||||
*ModelRegistryServiceApi* | [**create_serving_environment**](mr_openapi/docs/ModelRegistryServiceApi.md#create_serving_environment) | **POST** /api/model_registry/v1alpha3/serving_environments | Create a ServingEnvironment
|
||||
*ModelRegistryServiceApi* | [**find_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#find_artifact) | **GET** /api/model_registry/v1alpha3/artifact | Get an Artifact that matches search parameters.
|
||||
*ModelRegistryServiceApi* | [**find_experiment**](mr_openapi/docs/ModelRegistryServiceApi.md#find_experiment) | **GET** /api/model_registry/v1alpha3/experiment | Get an Experiment that matches search parameters.
|
||||
*ModelRegistryServiceApi* | [**find_experiment_run**](mr_openapi/docs/ModelRegistryServiceApi.md#find_experiment_run) | **GET** /api/model_registry/v1alpha3/experiment_run | Get an ExperimentRun that matches search parameters.
|
||||
*ModelRegistryServiceApi* | [**find_inference_service**](mr_openapi/docs/ModelRegistryServiceApi.md#find_inference_service) | **GET** /api/model_registry/v1alpha3/inference_service | Get an InferenceServices that matches search parameters.
|
||||
*ModelRegistryServiceApi* | [**find_model_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#find_model_artifact) | **GET** /api/model_registry/v1alpha3/model_artifact | Get a ModelArtifact that matches search parameters.
|
||||
*ModelRegistryServiceApi* | [**find_model_version**](mr_openapi/docs/ModelRegistryServiceApi.md#find_model_version) | **GET** /api/model_registry/v1alpha3/model_version | Get a ModelVersion that matches search parameters.
|
||||
|
@ -91,6 +96,13 @@ Class | Method | HTTP request | Description
|
|||
*ModelRegistryServiceApi* | [**get_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#get_artifact) | **GET** /api/model_registry/v1alpha3/artifacts/{id} | Get an Artifact
|
||||
*ModelRegistryServiceApi* | [**get_artifacts**](mr_openapi/docs/ModelRegistryServiceApi.md#get_artifacts) | **GET** /api/model_registry/v1alpha3/artifacts | List All Artifacts
|
||||
*ModelRegistryServiceApi* | [**get_environment_inference_services**](mr_openapi/docs/ModelRegistryServiceApi.md#get_environment_inference_services) | **GET** /api/model_registry/v1alpha3/serving_environments/{servingenvironmentId}/inference_services | List All ServingEnvironment's InferenceServices
|
||||
*ModelRegistryServiceApi* | [**get_experiment**](mr_openapi/docs/ModelRegistryServiceApi.md#get_experiment) | **GET** /api/model_registry/v1alpha3/experiments/{experimentId} | Get an Experiment
|
||||
*ModelRegistryServiceApi* | [**get_experiment_experiment_runs**](mr_openapi/docs/ModelRegistryServiceApi.md#get_experiment_experiment_runs) | **GET** /api/model_registry/v1alpha3/experiments/{experimentId}/experiment_runs | List All Experiment's ExperimentRuns
|
||||
*ModelRegistryServiceApi* | [**get_experiment_run**](mr_openapi/docs/ModelRegistryServiceApi.md#get_experiment_run) | **GET** /api/model_registry/v1alpha3/experiment_runs/{experimentrunId} | Get an ExperimentRun
|
||||
*ModelRegistryServiceApi* | [**get_experiment_run_artifacts**](mr_openapi/docs/ModelRegistryServiceApi.md#get_experiment_run_artifacts) | **GET** /api/model_registry/v1alpha3/experiment_runs/{experimentrunId}/artifacts | List all artifacts associated with the `ExperimentRun`
|
||||
*ModelRegistryServiceApi* | [**get_experiment_run_metric_history**](mr_openapi/docs/ModelRegistryServiceApi.md#get_experiment_run_metric_history) | **GET** /api/model_registry/v1alpha3/experiment_runs/{experimentrunId}/metric_history | Get metric history for an ExperimentRun
|
||||
*ModelRegistryServiceApi* | [**get_experiment_runs**](mr_openapi/docs/ModelRegistryServiceApi.md#get_experiment_runs) | **GET** /api/model_registry/v1alpha3/experiment_runs | List All ExperimentRuns
|
||||
*ModelRegistryServiceApi* | [**get_experiments**](mr_openapi/docs/ModelRegistryServiceApi.md#get_experiments) | **GET** /api/model_registry/v1alpha3/experiments | List All Experiments
|
||||
*ModelRegistryServiceApi* | [**get_inference_service**](mr_openapi/docs/ModelRegistryServiceApi.md#get_inference_service) | **GET** /api/model_registry/v1alpha3/inference_services/{inferenceserviceId} | Get a InferenceService
|
||||
*ModelRegistryServiceApi* | [**get_inference_service_model**](mr_openapi/docs/ModelRegistryServiceApi.md#get_inference_service_model) | **GET** /api/model_registry/v1alpha3/inference_services/{inferenceserviceId}/model | Get InferenceService's RegisteredModel
|
||||
*ModelRegistryServiceApi* | [**get_inference_service_serves**](mr_openapi/docs/ModelRegistryServiceApi.md#get_inference_service_serves) | **GET** /api/model_registry/v1alpha3/inference_services/{inferenceserviceId}/serves | List All InferenceService's ServeModel actions
|
||||
|
@ -107,11 +119,14 @@ Class | Method | HTTP request | Description
|
|||
*ModelRegistryServiceApi* | [**get_serving_environment**](mr_openapi/docs/ModelRegistryServiceApi.md#get_serving_environment) | **GET** /api/model_registry/v1alpha3/serving_environments/{servingenvironmentId} | Get a ServingEnvironment
|
||||
*ModelRegistryServiceApi* | [**get_serving_environments**](mr_openapi/docs/ModelRegistryServiceApi.md#get_serving_environments) | **GET** /api/model_registry/v1alpha3/serving_environments | List All ServingEnvironments
|
||||
*ModelRegistryServiceApi* | [**update_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#update_artifact) | **PATCH** /api/model_registry/v1alpha3/artifacts/{id} | Update an Artifact
|
||||
*ModelRegistryServiceApi* | [**update_experiment**](mr_openapi/docs/ModelRegistryServiceApi.md#update_experiment) | **PATCH** /api/model_registry/v1alpha3/experiments/{experimentId} | Update an Experiment
|
||||
*ModelRegistryServiceApi* | [**update_experiment_run**](mr_openapi/docs/ModelRegistryServiceApi.md#update_experiment_run) | **PATCH** /api/model_registry/v1alpha3/experiment_runs/{experimentrunId} | Update an ExperimentRun
|
||||
*ModelRegistryServiceApi* | [**update_inference_service**](mr_openapi/docs/ModelRegistryServiceApi.md#update_inference_service) | **PATCH** /api/model_registry/v1alpha3/inference_services/{inferenceserviceId} | Update a InferenceService
|
||||
*ModelRegistryServiceApi* | [**update_model_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#update_model_artifact) | **PATCH** /api/model_registry/v1alpha3/model_artifacts/{modelartifactId} | Update a ModelArtifact
|
||||
*ModelRegistryServiceApi* | [**update_model_version**](mr_openapi/docs/ModelRegistryServiceApi.md#update_model_version) | **PATCH** /api/model_registry/v1alpha3/model_versions/{modelversionId} | Update a ModelVersion
|
||||
*ModelRegistryServiceApi* | [**update_registered_model**](mr_openapi/docs/ModelRegistryServiceApi.md#update_registered_model) | **PATCH** /api/model_registry/v1alpha3/registered_models/{registeredmodelId} | Update a RegisteredModel
|
||||
*ModelRegistryServiceApi* | [**update_serving_environment**](mr_openapi/docs/ModelRegistryServiceApi.md#update_serving_environment) | **PATCH** /api/model_registry/v1alpha3/serving_environments/{servingenvironmentId} | Update a ServingEnvironment
|
||||
*ModelRegistryServiceApi* | [**upsert_experiment_run_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#upsert_experiment_run_artifact) | **POST** /api/model_registry/v1alpha3/experiment_runs/{experimentrunId}/artifacts | Upsert an Artifact in an ExperimentRun
|
||||
*ModelRegistryServiceApi* | [**upsert_model_version_artifact**](mr_openapi/docs/ModelRegistryServiceApi.md#upsert_model_version_artifact) | **POST** /api/model_registry/v1alpha3/model_versions/{modelversionId}/artifacts | Upsert an Artifact in a ModelVersion
|
||||
|
||||
|
||||
|
@ -121,6 +136,7 @@ Class | Method | HTTP request | Description
|
|||
- [ArtifactCreate](mr_openapi/docs/ArtifactCreate.md)
|
||||
- [ArtifactList](mr_openapi/docs/ArtifactList.md)
|
||||
- [ArtifactState](mr_openapi/docs/ArtifactState.md)
|
||||
- [ArtifactTypeQueryParam](mr_openapi/docs/ArtifactTypeQueryParam.md)
|
||||
- [ArtifactUpdate](mr_openapi/docs/ArtifactUpdate.md)
|
||||
- [BaseModel](mr_openapi/docs/BaseModel.md)
|
||||
- [BaseResource](mr_openapi/docs/BaseResource.md)
|
||||
|
@ -128,11 +144,25 @@ Class | Method | HTTP request | Description
|
|||
- [BaseResourceDates](mr_openapi/docs/BaseResourceDates.md)
|
||||
- [BaseResourceList](mr_openapi/docs/BaseResourceList.md)
|
||||
- [BaseResourceUpdate](mr_openapi/docs/BaseResourceUpdate.md)
|
||||
- [DataSet](mr_openapi/docs/DataSet.md)
|
||||
- [DataSetCreate](mr_openapi/docs/DataSetCreate.md)
|
||||
- [DataSetUpdate](mr_openapi/docs/DataSetUpdate.md)
|
||||
- [DocArtifact](mr_openapi/docs/DocArtifact.md)
|
||||
- [DocArtifactCreate](mr_openapi/docs/DocArtifactCreate.md)
|
||||
- [DocArtifactUpdate](mr_openapi/docs/DocArtifactUpdate.md)
|
||||
- [Error](mr_openapi/docs/Error.md)
|
||||
- [ExecutionState](mr_openapi/docs/ExecutionState.md)
|
||||
- [Experiment](mr_openapi/docs/Experiment.md)
|
||||
- [ExperimentCreate](mr_openapi/docs/ExperimentCreate.md)
|
||||
- [ExperimentList](mr_openapi/docs/ExperimentList.md)
|
||||
- [ExperimentRun](mr_openapi/docs/ExperimentRun.md)
|
||||
- [ExperimentRunCreate](mr_openapi/docs/ExperimentRunCreate.md)
|
||||
- [ExperimentRunList](mr_openapi/docs/ExperimentRunList.md)
|
||||
- [ExperimentRunState](mr_openapi/docs/ExperimentRunState.md)
|
||||
- [ExperimentRunStatus](mr_openapi/docs/ExperimentRunStatus.md)
|
||||
- [ExperimentRunUpdate](mr_openapi/docs/ExperimentRunUpdate.md)
|
||||
- [ExperimentState](mr_openapi/docs/ExperimentState.md)
|
||||
- [ExperimentUpdate](mr_openapi/docs/ExperimentUpdate.md)
|
||||
- [InferenceService](mr_openapi/docs/InferenceService.md)
|
||||
- [InferenceServiceCreate](mr_openapi/docs/InferenceServiceCreate.md)
|
||||
- [InferenceServiceList](mr_openapi/docs/InferenceServiceList.md)
|
||||
|
@ -145,6 +175,10 @@ Class | Method | HTTP request | Description
|
|||
- [MetadataStringValue](mr_openapi/docs/MetadataStringValue.md)
|
||||
- [MetadataStructValue](mr_openapi/docs/MetadataStructValue.md)
|
||||
- [MetadataValue](mr_openapi/docs/MetadataValue.md)
|
||||
- [Metric](mr_openapi/docs/Metric.md)
|
||||
- [MetricCreate](mr_openapi/docs/MetricCreate.md)
|
||||
- [MetricList](mr_openapi/docs/MetricList.md)
|
||||
- [MetricUpdate](mr_openapi/docs/MetricUpdate.md)
|
||||
- [ModelArtifact](mr_openapi/docs/ModelArtifact.md)
|
||||
- [ModelArtifactCreate](mr_openapi/docs/ModelArtifactCreate.md)
|
||||
- [ModelArtifactList](mr_openapi/docs/ModelArtifactList.md)
|
||||
|
@ -155,6 +189,10 @@ Class | Method | HTTP request | Description
|
|||
- [ModelVersionState](mr_openapi/docs/ModelVersionState.md)
|
||||
- [ModelVersionUpdate](mr_openapi/docs/ModelVersionUpdate.md)
|
||||
- [OrderByField](mr_openapi/docs/OrderByField.md)
|
||||
- [Parameter](mr_openapi/docs/Parameter.md)
|
||||
- [ParameterCreate](mr_openapi/docs/ParameterCreate.md)
|
||||
- [ParameterType](mr_openapi/docs/ParameterType.md)
|
||||
- [ParameterUpdate](mr_openapi/docs/ParameterUpdate.md)
|
||||
- [RegisteredModel](mr_openapi/docs/RegisteredModel.md)
|
||||
- [RegisteredModelCreate](mr_openapi/docs/RegisteredModelCreate.md)
|
||||
- [RegisteredModelList](mr_openapi/docs/RegisteredModelList.md)
|
||||
|
|
|
@ -35,6 +35,7 @@ from mr_openapi.models.artifact import Artifact
|
|||
from mr_openapi.models.artifact_create import ArtifactCreate
|
||||
from mr_openapi.models.artifact_list import ArtifactList
|
||||
from mr_openapi.models.artifact_state import ArtifactState
|
||||
from mr_openapi.models.artifact_type_query_param import ArtifactTypeQueryParam
|
||||
from mr_openapi.models.artifact_update import ArtifactUpdate
|
||||
from mr_openapi.models.base_model import BaseModel
|
||||
from mr_openapi.models.base_resource import BaseResource
|
||||
|
@ -42,11 +43,25 @@ from mr_openapi.models.base_resource_create import BaseResourceCreate
|
|||
from mr_openapi.models.base_resource_dates import BaseResourceDates
|
||||
from mr_openapi.models.base_resource_list import BaseResourceList
|
||||
from mr_openapi.models.base_resource_update import BaseResourceUpdate
|
||||
from mr_openapi.models.data_set import DataSet
|
||||
from mr_openapi.models.data_set_create import DataSetCreate
|
||||
from mr_openapi.models.data_set_update import DataSetUpdate
|
||||
from mr_openapi.models.doc_artifact import DocArtifact
|
||||
from mr_openapi.models.doc_artifact_create import DocArtifactCreate
|
||||
from mr_openapi.models.doc_artifact_update import DocArtifactUpdate
|
||||
from mr_openapi.models.error import Error
|
||||
from mr_openapi.models.execution_state import ExecutionState
|
||||
from mr_openapi.models.experiment import Experiment
|
||||
from mr_openapi.models.experiment_create import ExperimentCreate
|
||||
from mr_openapi.models.experiment_list import ExperimentList
|
||||
from mr_openapi.models.experiment_run import ExperimentRun
|
||||
from mr_openapi.models.experiment_run_create import ExperimentRunCreate
|
||||
from mr_openapi.models.experiment_run_list import ExperimentRunList
|
||||
from mr_openapi.models.experiment_run_state import ExperimentRunState
|
||||
from mr_openapi.models.experiment_run_status import ExperimentRunStatus
|
||||
from mr_openapi.models.experiment_run_update import ExperimentRunUpdate
|
||||
from mr_openapi.models.experiment_state import ExperimentState
|
||||
from mr_openapi.models.experiment_update import ExperimentUpdate
|
||||
from mr_openapi.models.inference_service import InferenceService
|
||||
from mr_openapi.models.inference_service_create import InferenceServiceCreate
|
||||
from mr_openapi.models.inference_service_list import InferenceServiceList
|
||||
|
@ -59,6 +74,10 @@ from mr_openapi.models.metadata_proto_value import MetadataProtoValue
|
|||
from mr_openapi.models.metadata_string_value import MetadataStringValue
|
||||
from mr_openapi.models.metadata_struct_value import MetadataStructValue
|
||||
from mr_openapi.models.metadata_value import MetadataValue
|
||||
from mr_openapi.models.metric import Metric
|
||||
from mr_openapi.models.metric_create import MetricCreate
|
||||
from mr_openapi.models.metric_list import MetricList
|
||||
from mr_openapi.models.metric_update import MetricUpdate
|
||||
from mr_openapi.models.model_artifact import ModelArtifact
|
||||
from mr_openapi.models.model_artifact_create import ModelArtifactCreate
|
||||
from mr_openapi.models.model_artifact_list import ModelArtifactList
|
||||
|
@ -69,6 +88,10 @@ from mr_openapi.models.model_version_list import ModelVersionList
|
|||
from mr_openapi.models.model_version_state import ModelVersionState
|
||||
from mr_openapi.models.model_version_update import ModelVersionUpdate
|
||||
from mr_openapi.models.order_by_field import OrderByField
|
||||
from mr_openapi.models.parameter import Parameter
|
||||
from mr_openapi.models.parameter_create import ParameterCreate
|
||||
from mr_openapi.models.parameter_type import ParameterType
|
||||
from mr_openapi.models.parameter_update import ParameterUpdate
|
||||
from mr_openapi.models.registered_model import RegisteredModel
|
||||
from mr_openapi.models.registered_model_create import RegisteredModelCreate
|
||||
from mr_openapi.models.registered_model_list import RegisteredModelList
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -18,6 +18,7 @@ from mr_openapi.models.artifact import Artifact
|
|||
from mr_openapi.models.artifact_create import ArtifactCreate
|
||||
from mr_openapi.models.artifact_list import ArtifactList
|
||||
from mr_openapi.models.artifact_state import ArtifactState
|
||||
from mr_openapi.models.artifact_type_query_param import ArtifactTypeQueryParam
|
||||
from mr_openapi.models.artifact_update import ArtifactUpdate
|
||||
from mr_openapi.models.base_model import BaseModel
|
||||
from mr_openapi.models.base_resource import BaseResource
|
||||
|
@ -25,11 +26,25 @@ from mr_openapi.models.base_resource_create import BaseResourceCreate
|
|||
from mr_openapi.models.base_resource_dates import BaseResourceDates
|
||||
from mr_openapi.models.base_resource_list import BaseResourceList
|
||||
from mr_openapi.models.base_resource_update import BaseResourceUpdate
|
||||
from mr_openapi.models.data_set import DataSet
|
||||
from mr_openapi.models.data_set_create import DataSetCreate
|
||||
from mr_openapi.models.data_set_update import DataSetUpdate
|
||||
from mr_openapi.models.doc_artifact import DocArtifact
|
||||
from mr_openapi.models.doc_artifact_create import DocArtifactCreate
|
||||
from mr_openapi.models.doc_artifact_update import DocArtifactUpdate
|
||||
from mr_openapi.models.error import Error
|
||||
from mr_openapi.models.execution_state import ExecutionState
|
||||
from mr_openapi.models.experiment import Experiment
|
||||
from mr_openapi.models.experiment_create import ExperimentCreate
|
||||
from mr_openapi.models.experiment_list import ExperimentList
|
||||
from mr_openapi.models.experiment_run import ExperimentRun
|
||||
from mr_openapi.models.experiment_run_create import ExperimentRunCreate
|
||||
from mr_openapi.models.experiment_run_list import ExperimentRunList
|
||||
from mr_openapi.models.experiment_run_state import ExperimentRunState
|
||||
from mr_openapi.models.experiment_run_status import ExperimentRunStatus
|
||||
from mr_openapi.models.experiment_run_update import ExperimentRunUpdate
|
||||
from mr_openapi.models.experiment_state import ExperimentState
|
||||
from mr_openapi.models.experiment_update import ExperimentUpdate
|
||||
from mr_openapi.models.inference_service import InferenceService
|
||||
from mr_openapi.models.inference_service_create import InferenceServiceCreate
|
||||
from mr_openapi.models.inference_service_list import InferenceServiceList
|
||||
|
@ -42,6 +57,10 @@ from mr_openapi.models.metadata_proto_value import MetadataProtoValue
|
|||
from mr_openapi.models.metadata_string_value import MetadataStringValue
|
||||
from mr_openapi.models.metadata_struct_value import MetadataStructValue
|
||||
from mr_openapi.models.metadata_value import MetadataValue
|
||||
from mr_openapi.models.metric import Metric
|
||||
from mr_openapi.models.metric_create import MetricCreate
|
||||
from mr_openapi.models.metric_list import MetricList
|
||||
from mr_openapi.models.metric_update import MetricUpdate
|
||||
from mr_openapi.models.model_artifact import ModelArtifact
|
||||
from mr_openapi.models.model_artifact_create import ModelArtifactCreate
|
||||
from mr_openapi.models.model_artifact_list import ModelArtifactList
|
||||
|
@ -52,6 +71,10 @@ from mr_openapi.models.model_version_list import ModelVersionList
|
|||
from mr_openapi.models.model_version_state import ModelVersionState
|
||||
from mr_openapi.models.model_version_update import ModelVersionUpdate
|
||||
from mr_openapi.models.order_by_field import OrderByField
|
||||
from mr_openapi.models.parameter import Parameter
|
||||
from mr_openapi.models.parameter_create import ParameterCreate
|
||||
from mr_openapi.models.parameter_type import ParameterType
|
||||
from mr_openapi.models.parameter_update import ParameterUpdate
|
||||
from mr_openapi.models.registered_model import RegisteredModel
|
||||
from mr_openapi.models.registered_model_create import RegisteredModelCreate
|
||||
from mr_openapi.models.registered_model_list import RegisteredModelList
|
||||
|
|
|
@ -22,10 +22,13 @@ from pydantic import (
|
|||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from mr_openapi.models.data_set import DataSet
|
||||
from mr_openapi.models.doc_artifact import DocArtifact
|
||||
from mr_openapi.models.metric import Metric
|
||||
from mr_openapi.models.model_artifact import ModelArtifact
|
||||
from mr_openapi.models.parameter import Parameter
|
||||
|
||||
ARTIFACT_ONE_OF_SCHEMAS = ["DocArtifact", "ModelArtifact"]
|
||||
ARTIFACT_ONE_OF_SCHEMAS = ["DataSet", "DocArtifact", "Metric", "ModelArtifact", "Parameter"]
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
|
@ -35,8 +38,14 @@ class Artifact(BaseModel):
|
|||
oneof_schema_1_validator: ModelArtifact | None = None
|
||||
# data type: DocArtifact
|
||||
oneof_schema_2_validator: DocArtifact | None = None
|
||||
actual_instance: DocArtifact | ModelArtifact | None = None
|
||||
one_of_schemas: set[str] = {"DocArtifact", "ModelArtifact"}
|
||||
# data type: DataSet
|
||||
oneof_schema_3_validator: DataSet | None = None
|
||||
# data type: Metric
|
||||
oneof_schema_4_validator: Metric | None = None
|
||||
# data type: Parameter
|
||||
oneof_schema_5_validator: Parameter | None = None
|
||||
actual_instance: DataSet | DocArtifact | Metric | ModelArtifact | Parameter | None = None
|
||||
one_of_schemas: set[str] = {"DataSet", "DocArtifact", "Metric", "ModelArtifact", "Parameter"}
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
|
@ -72,16 +81,31 @@ class Artifact(BaseModel):
|
|||
error_messages.append(f"Error! Input type `{type(v)}` is not `DocArtifact`")
|
||||
else:
|
||||
match += 1
|
||||
# validate data type: DataSet
|
||||
if not isinstance(v, DataSet):
|
||||
error_messages.append(f"Error! Input type `{type(v)}` is not `DataSet`")
|
||||
else:
|
||||
match += 1
|
||||
# validate data type: Metric
|
||||
if not isinstance(v, Metric):
|
||||
error_messages.append(f"Error! Input type `{type(v)}` is not `Metric`")
|
||||
else:
|
||||
match += 1
|
||||
# validate data type: Parameter
|
||||
if not isinstance(v, Parameter):
|
||||
error_messages.append(f"Error! Input type `{type(v)}` is not `Parameter`")
|
||||
else:
|
||||
match += 1
|
||||
if match > 1:
|
||||
# more than 1 match
|
||||
raise ValueError(
|
||||
"Multiple matches found when setting `actual_instance` in Artifact with oneOf schemas: DocArtifact, ModelArtifact. Details: "
|
||||
"Multiple matches found when setting `actual_instance` in Artifact with oneOf schemas: DataSet, DocArtifact, Metric, ModelArtifact, Parameter. Details: "
|
||||
+ ", ".join(error_messages)
|
||||
)
|
||||
if match == 0:
|
||||
# no match
|
||||
raise ValueError(
|
||||
"No match found when setting `actual_instance` in Artifact with oneOf schemas: DocArtifact, ModelArtifact. Details: "
|
||||
"No match found when setting `actual_instance` in Artifact with oneOf schemas: DataSet, DocArtifact, Metric, ModelArtifact, Parameter. Details: "
|
||||
+ ", ".join(error_messages)
|
||||
)
|
||||
return v
|
||||
|
@ -103,26 +127,56 @@ class Artifact(BaseModel):
|
|||
msg = "Failed to lookup data type from the field `artifactType` in the input."
|
||||
raise ValueError(msg)
|
||||
|
||||
# check if data type is `DataSet`
|
||||
if _data_type == "dataset-artifact":
|
||||
instance.actual_instance = DataSet.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `DocArtifact`
|
||||
if _data_type == "doc-artifact":
|
||||
instance.actual_instance = DocArtifact.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `Metric`
|
||||
if _data_type == "metric":
|
||||
instance.actual_instance = Metric.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `ModelArtifact`
|
||||
if _data_type == "model-artifact":
|
||||
instance.actual_instance = ModelArtifact.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `Parameter`
|
||||
if _data_type == "parameter":
|
||||
instance.actual_instance = Parameter.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `DataSet`
|
||||
if _data_type == "DataSet":
|
||||
instance.actual_instance = DataSet.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `DocArtifact`
|
||||
if _data_type == "DocArtifact":
|
||||
instance.actual_instance = DocArtifact.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `Metric`
|
||||
if _data_type == "Metric":
|
||||
instance.actual_instance = Metric.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `ModelArtifact`
|
||||
if _data_type == "ModelArtifact":
|
||||
instance.actual_instance = ModelArtifact.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `Parameter`
|
||||
if _data_type == "Parameter":
|
||||
instance.actual_instance = Parameter.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# deserialize data into ModelArtifact
|
||||
try:
|
||||
instance.actual_instance = ModelArtifact.from_json(json_str)
|
||||
|
@ -135,17 +189,35 @@ class Artifact(BaseModel):
|
|||
match += 1
|
||||
except (ValidationError, ValueError) as e:
|
||||
error_messages.append(str(e))
|
||||
# deserialize data into DataSet
|
||||
try:
|
||||
instance.actual_instance = DataSet.from_json(json_str)
|
||||
match += 1
|
||||
except (ValidationError, ValueError) as e:
|
||||
error_messages.append(str(e))
|
||||
# deserialize data into Metric
|
||||
try:
|
||||
instance.actual_instance = Metric.from_json(json_str)
|
||||
match += 1
|
||||
except (ValidationError, ValueError) as e:
|
||||
error_messages.append(str(e))
|
||||
# deserialize data into Parameter
|
||||
try:
|
||||
instance.actual_instance = Parameter.from_json(json_str)
|
||||
match += 1
|
||||
except (ValidationError, ValueError) as e:
|
||||
error_messages.append(str(e))
|
||||
|
||||
if match > 1:
|
||||
# more than 1 match
|
||||
raise ValueError(
|
||||
"Multiple matches found when deserializing the JSON string into Artifact with oneOf schemas: DocArtifact, ModelArtifact. Details: "
|
||||
"Multiple matches found when deserializing the JSON string into Artifact with oneOf schemas: DataSet, DocArtifact, Metric, ModelArtifact, Parameter. Details: "
|
||||
+ ", ".join(error_messages)
|
||||
)
|
||||
if match == 0:
|
||||
# no match
|
||||
raise ValueError(
|
||||
"No match found when deserializing the JSON string into Artifact with oneOf schemas: DocArtifact, ModelArtifact. Details: "
|
||||
"No match found when deserializing the JSON string into Artifact with oneOf schemas: DataSet, DocArtifact, Metric, ModelArtifact, Parameter. Details: "
|
||||
+ ", ".join(error_messages)
|
||||
)
|
||||
return instance
|
||||
|
@ -159,7 +231,7 @@ class Artifact(BaseModel):
|
|||
return self.actual_instance.to_json()
|
||||
return json.dumps(self.actual_instance)
|
||||
|
||||
def to_dict(self) -> dict[str, Any] | DocArtifact | ModelArtifact | None:
|
||||
def to_dict(self) -> dict[str, Any] | DataSet | DocArtifact | Metric | ModelArtifact | Parameter | None:
|
||||
"""Returns the dict representation of the actual instance."""
|
||||
if self.actual_instance is None:
|
||||
return None
|
||||
|
|
|
@ -22,10 +22,19 @@ from pydantic import (
|
|||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from mr_openapi.models.data_set_create import DataSetCreate
|
||||
from mr_openapi.models.doc_artifact_create import DocArtifactCreate
|
||||
from mr_openapi.models.metric_create import MetricCreate
|
||||
from mr_openapi.models.model_artifact_create import ModelArtifactCreate
|
||||
from mr_openapi.models.parameter_create import ParameterCreate
|
||||
|
||||
ARTIFACTCREATE_ONE_OF_SCHEMAS = ["DocArtifactCreate", "ModelArtifactCreate"]
|
||||
ARTIFACTCREATE_ONE_OF_SCHEMAS = [
|
||||
"DataSetCreate",
|
||||
"DocArtifactCreate",
|
||||
"MetricCreate",
|
||||
"ModelArtifactCreate",
|
||||
"ParameterCreate",
|
||||
]
|
||||
|
||||
|
||||
class ArtifactCreate(BaseModel):
|
||||
|
@ -35,8 +44,22 @@ class ArtifactCreate(BaseModel):
|
|||
oneof_schema_1_validator: ModelArtifactCreate | None = None
|
||||
# data type: DocArtifactCreate
|
||||
oneof_schema_2_validator: DocArtifactCreate | None = None
|
||||
actual_instance: DocArtifactCreate | ModelArtifactCreate | None = None
|
||||
one_of_schemas: set[str] = {"DocArtifactCreate", "ModelArtifactCreate"}
|
||||
# data type: DataSetCreate
|
||||
oneof_schema_3_validator: DataSetCreate | None = None
|
||||
# data type: MetricCreate
|
||||
oneof_schema_4_validator: MetricCreate | None = None
|
||||
# data type: ParameterCreate
|
||||
oneof_schema_5_validator: ParameterCreate | None = None
|
||||
actual_instance: (
|
||||
DataSetCreate | DocArtifactCreate | MetricCreate | ModelArtifactCreate | ParameterCreate | None
|
||||
) = None
|
||||
one_of_schemas: set[str] = {
|
||||
"DataSetCreate",
|
||||
"DocArtifactCreate",
|
||||
"MetricCreate",
|
||||
"ModelArtifactCreate",
|
||||
"ParameterCreate",
|
||||
}
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
|
@ -72,16 +95,31 @@ class ArtifactCreate(BaseModel):
|
|||
error_messages.append(f"Error! Input type `{type(v)}` is not `DocArtifactCreate`")
|
||||
else:
|
||||
match += 1
|
||||
# validate data type: DataSetCreate
|
||||
if not isinstance(v, DataSetCreate):
|
||||
error_messages.append(f"Error! Input type `{type(v)}` is not `DataSetCreate`")
|
||||
else:
|
||||
match += 1
|
||||
# validate data type: MetricCreate
|
||||
if not isinstance(v, MetricCreate):
|
||||
error_messages.append(f"Error! Input type `{type(v)}` is not `MetricCreate`")
|
||||
else:
|
||||
match += 1
|
||||
# validate data type: ParameterCreate
|
||||
if not isinstance(v, ParameterCreate):
|
||||
error_messages.append(f"Error! Input type `{type(v)}` is not `ParameterCreate`")
|
||||
else:
|
||||
match += 1
|
||||
if match > 1:
|
||||
# more than 1 match
|
||||
raise ValueError(
|
||||
"Multiple matches found when setting `actual_instance` in ArtifactCreate with oneOf schemas: DocArtifactCreate, ModelArtifactCreate. Details: "
|
||||
"Multiple matches found when setting `actual_instance` in ArtifactCreate with oneOf schemas: DataSetCreate, DocArtifactCreate, MetricCreate, ModelArtifactCreate, ParameterCreate. Details: "
|
||||
+ ", ".join(error_messages)
|
||||
)
|
||||
if match == 0:
|
||||
# no match
|
||||
raise ValueError(
|
||||
"No match found when setting `actual_instance` in ArtifactCreate with oneOf schemas: DocArtifactCreate, ModelArtifactCreate. Details: "
|
||||
"No match found when setting `actual_instance` in ArtifactCreate with oneOf schemas: DataSetCreate, DocArtifactCreate, MetricCreate, ModelArtifactCreate, ParameterCreate. Details: "
|
||||
+ ", ".join(error_messages)
|
||||
)
|
||||
return v
|
||||
|
@ -103,26 +141,56 @@ class ArtifactCreate(BaseModel):
|
|||
msg = "Failed to lookup data type from the field `artifactType` in the input."
|
||||
raise ValueError(msg)
|
||||
|
||||
# check if data type is `DataSetCreate`
|
||||
if _data_type == "dataset-artifact":
|
||||
instance.actual_instance = DataSetCreate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `DocArtifactCreate`
|
||||
if _data_type == "doc-artifact":
|
||||
instance.actual_instance = DocArtifactCreate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `MetricCreate`
|
||||
if _data_type == "metric":
|
||||
instance.actual_instance = MetricCreate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `ModelArtifactCreate`
|
||||
if _data_type == "model-artifact":
|
||||
instance.actual_instance = ModelArtifactCreate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `ParameterCreate`
|
||||
if _data_type == "parameter":
|
||||
instance.actual_instance = ParameterCreate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `DataSetCreate`
|
||||
if _data_type == "DataSetCreate":
|
||||
instance.actual_instance = DataSetCreate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `DocArtifactCreate`
|
||||
if _data_type == "DocArtifactCreate":
|
||||
instance.actual_instance = DocArtifactCreate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `MetricCreate`
|
||||
if _data_type == "MetricCreate":
|
||||
instance.actual_instance = MetricCreate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `ModelArtifactCreate`
|
||||
if _data_type == "ModelArtifactCreate":
|
||||
instance.actual_instance = ModelArtifactCreate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `ParameterCreate`
|
||||
if _data_type == "ParameterCreate":
|
||||
instance.actual_instance = ParameterCreate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# deserialize data into ModelArtifactCreate
|
||||
try:
|
||||
instance.actual_instance = ModelArtifactCreate.from_json(json_str)
|
||||
|
@ -135,17 +203,35 @@ class ArtifactCreate(BaseModel):
|
|||
match += 1
|
||||
except (ValidationError, ValueError) as e:
|
||||
error_messages.append(str(e))
|
||||
# deserialize data into DataSetCreate
|
||||
try:
|
||||
instance.actual_instance = DataSetCreate.from_json(json_str)
|
||||
match += 1
|
||||
except (ValidationError, ValueError) as e:
|
||||
error_messages.append(str(e))
|
||||
# deserialize data into MetricCreate
|
||||
try:
|
||||
instance.actual_instance = MetricCreate.from_json(json_str)
|
||||
match += 1
|
||||
except (ValidationError, ValueError) as e:
|
||||
error_messages.append(str(e))
|
||||
# deserialize data into ParameterCreate
|
||||
try:
|
||||
instance.actual_instance = ParameterCreate.from_json(json_str)
|
||||
match += 1
|
||||
except (ValidationError, ValueError) as e:
|
||||
error_messages.append(str(e))
|
||||
|
||||
if match > 1:
|
||||
# more than 1 match
|
||||
raise ValueError(
|
||||
"Multiple matches found when deserializing the JSON string into ArtifactCreate with oneOf schemas: DocArtifactCreate, ModelArtifactCreate. Details: "
|
||||
"Multiple matches found when deserializing the JSON string into ArtifactCreate with oneOf schemas: DataSetCreate, DocArtifactCreate, MetricCreate, ModelArtifactCreate, ParameterCreate. Details: "
|
||||
+ ", ".join(error_messages)
|
||||
)
|
||||
if match == 0:
|
||||
# no match
|
||||
raise ValueError(
|
||||
"No match found when deserializing the JSON string into ArtifactCreate with oneOf schemas: DocArtifactCreate, ModelArtifactCreate. Details: "
|
||||
"No match found when deserializing the JSON string into ArtifactCreate with oneOf schemas: DataSetCreate, DocArtifactCreate, MetricCreate, ModelArtifactCreate, ParameterCreate. Details: "
|
||||
+ ", ".join(error_messages)
|
||||
)
|
||||
return instance
|
||||
|
@ -159,7 +245,17 @@ class ArtifactCreate(BaseModel):
|
|||
return self.actual_instance.to_json()
|
||||
return json.dumps(self.actual_instance)
|
||||
|
||||
def to_dict(self) -> dict[str, Any] | DocArtifactCreate | ModelArtifactCreate | None:
|
||||
def to_dict(
|
||||
self,
|
||||
) -> (
|
||||
dict[str, Any]
|
||||
| DataSetCreate
|
||||
| DocArtifactCreate
|
||||
| MetricCreate
|
||||
| ModelArtifactCreate
|
||||
| ParameterCreate
|
||||
| None
|
||||
):
|
||||
"""Returns the dict representation of the actual instance."""
|
||||
if self.actual_instance is None:
|
||||
return None
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
"""Model Registry REST API.
|
||||
|
||||
REST API for Model Registry to create and manage ML model metadata
|
||||
|
||||
The version of the OpenAPI document: v1alpha3
|
||||
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
||||
|
||||
Do not edit the class manually.
|
||||
""" # noqa: E501
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ArtifactTypeQueryParam(str, Enum):
|
||||
"""Supported artifact types for querying."""
|
||||
|
||||
"""
|
||||
allowed enum values
|
||||
"""
|
||||
MODEL_MINUS_ARTIFACT = "model-artifact"
|
||||
DOC_MINUS_ARTIFACT = "doc-artifact"
|
||||
DATASET_MINUS_ARTIFACT = "dataset-artifact"
|
||||
METRIC = "metric"
|
||||
PARAMETER = "parameter"
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> Self:
|
||||
"""Create an instance of ArtifactTypeQueryParam from a JSON string."""
|
||||
return cls(json.loads(json_str))
|
|
@ -22,10 +22,19 @@ from pydantic import (
|
|||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from mr_openapi.models.data_set_update import DataSetUpdate
|
||||
from mr_openapi.models.doc_artifact_update import DocArtifactUpdate
|
||||
from mr_openapi.models.metric_update import MetricUpdate
|
||||
from mr_openapi.models.model_artifact_update import ModelArtifactUpdate
|
||||
from mr_openapi.models.parameter_update import ParameterUpdate
|
||||
|
||||
ARTIFACTUPDATE_ONE_OF_SCHEMAS = ["DocArtifactUpdate", "ModelArtifactUpdate"]
|
||||
ARTIFACTUPDATE_ONE_OF_SCHEMAS = [
|
||||
"DataSetUpdate",
|
||||
"DocArtifactUpdate",
|
||||
"MetricUpdate",
|
||||
"ModelArtifactUpdate",
|
||||
"ParameterUpdate",
|
||||
]
|
||||
|
||||
|
||||
class ArtifactUpdate(BaseModel):
|
||||
|
@ -35,8 +44,22 @@ class ArtifactUpdate(BaseModel):
|
|||
oneof_schema_1_validator: ModelArtifactUpdate | None = None
|
||||
# data type: DocArtifactUpdate
|
||||
oneof_schema_2_validator: DocArtifactUpdate | None = None
|
||||
actual_instance: DocArtifactUpdate | ModelArtifactUpdate | None = None
|
||||
one_of_schemas: set[str] = {"DocArtifactUpdate", "ModelArtifactUpdate"}
|
||||
# data type: DataSetUpdate
|
||||
oneof_schema_3_validator: DataSetUpdate | None = None
|
||||
# data type: MetricUpdate
|
||||
oneof_schema_4_validator: MetricUpdate | None = None
|
||||
# data type: ParameterUpdate
|
||||
oneof_schema_5_validator: ParameterUpdate | None = None
|
||||
actual_instance: (
|
||||
DataSetUpdate | DocArtifactUpdate | MetricUpdate | ModelArtifactUpdate | ParameterUpdate | None
|
||||
) = None
|
||||
one_of_schemas: set[str] = {
|
||||
"DataSetUpdate",
|
||||
"DocArtifactUpdate",
|
||||
"MetricUpdate",
|
||||
"ModelArtifactUpdate",
|
||||
"ParameterUpdate",
|
||||
}
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
|
@ -72,16 +95,31 @@ class ArtifactUpdate(BaseModel):
|
|||
error_messages.append(f"Error! Input type `{type(v)}` is not `DocArtifactUpdate`")
|
||||
else:
|
||||
match += 1
|
||||
# validate data type: DataSetUpdate
|
||||
if not isinstance(v, DataSetUpdate):
|
||||
error_messages.append(f"Error! Input type `{type(v)}` is not `DataSetUpdate`")
|
||||
else:
|
||||
match += 1
|
||||
# validate data type: MetricUpdate
|
||||
if not isinstance(v, MetricUpdate):
|
||||
error_messages.append(f"Error! Input type `{type(v)}` is not `MetricUpdate`")
|
||||
else:
|
||||
match += 1
|
||||
# validate data type: ParameterUpdate
|
||||
if not isinstance(v, ParameterUpdate):
|
||||
error_messages.append(f"Error! Input type `{type(v)}` is not `ParameterUpdate`")
|
||||
else:
|
||||
match += 1
|
||||
if match > 1:
|
||||
# more than 1 match
|
||||
raise ValueError(
|
||||
"Multiple matches found when setting `actual_instance` in ArtifactUpdate with oneOf schemas: DocArtifactUpdate, ModelArtifactUpdate. Details: "
|
||||
"Multiple matches found when setting `actual_instance` in ArtifactUpdate with oneOf schemas: DataSetUpdate, DocArtifactUpdate, MetricUpdate, ModelArtifactUpdate, ParameterUpdate. Details: "
|
||||
+ ", ".join(error_messages)
|
||||
)
|
||||
if match == 0:
|
||||
# no match
|
||||
raise ValueError(
|
||||
"No match found when setting `actual_instance` in ArtifactUpdate with oneOf schemas: DocArtifactUpdate, ModelArtifactUpdate. Details: "
|
||||
"No match found when setting `actual_instance` in ArtifactUpdate with oneOf schemas: DataSetUpdate, DocArtifactUpdate, MetricUpdate, ModelArtifactUpdate, ParameterUpdate. Details: "
|
||||
+ ", ".join(error_messages)
|
||||
)
|
||||
return v
|
||||
|
@ -103,26 +141,56 @@ class ArtifactUpdate(BaseModel):
|
|||
msg = "Failed to lookup data type from the field `artifactType` in the input."
|
||||
raise ValueError(msg)
|
||||
|
||||
# check if data type is `DataSetUpdate`
|
||||
if _data_type == "dataset-artifact":
|
||||
instance.actual_instance = DataSetUpdate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `DocArtifactUpdate`
|
||||
if _data_type == "doc-artifact":
|
||||
instance.actual_instance = DocArtifactUpdate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `MetricUpdate`
|
||||
if _data_type == "metric":
|
||||
instance.actual_instance = MetricUpdate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `ModelArtifactUpdate`
|
||||
if _data_type == "model-artifact":
|
||||
instance.actual_instance = ModelArtifactUpdate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `ParameterUpdate`
|
||||
if _data_type == "parameter":
|
||||
instance.actual_instance = ParameterUpdate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `DataSetUpdate`
|
||||
if _data_type == "DataSetUpdate":
|
||||
instance.actual_instance = DataSetUpdate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `DocArtifactUpdate`
|
||||
if _data_type == "DocArtifactUpdate":
|
||||
instance.actual_instance = DocArtifactUpdate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `MetricUpdate`
|
||||
if _data_type == "MetricUpdate":
|
||||
instance.actual_instance = MetricUpdate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `ModelArtifactUpdate`
|
||||
if _data_type == "ModelArtifactUpdate":
|
||||
instance.actual_instance = ModelArtifactUpdate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# check if data type is `ParameterUpdate`
|
||||
if _data_type == "ParameterUpdate":
|
||||
instance.actual_instance = ParameterUpdate.from_json(json_str)
|
||||
return instance
|
||||
|
||||
# deserialize data into ModelArtifactUpdate
|
||||
try:
|
||||
instance.actual_instance = ModelArtifactUpdate.from_json(json_str)
|
||||
|
@ -135,17 +203,35 @@ class ArtifactUpdate(BaseModel):
|
|||
match += 1
|
||||
except (ValidationError, ValueError) as e:
|
||||
error_messages.append(str(e))
|
||||
# deserialize data into DataSetUpdate
|
||||
try:
|
||||
instance.actual_instance = DataSetUpdate.from_json(json_str)
|
||||
match += 1
|
||||
except (ValidationError, ValueError) as e:
|
||||
error_messages.append(str(e))
|
||||
# deserialize data into MetricUpdate
|
||||
try:
|
||||
instance.actual_instance = MetricUpdate.from_json(json_str)
|
||||
match += 1
|
||||
except (ValidationError, ValueError) as e:
|
||||
error_messages.append(str(e))
|
||||
# deserialize data into ParameterUpdate
|
||||
try:
|
||||
instance.actual_instance = ParameterUpdate.from_json(json_str)
|
||||
match += 1
|
||||
except (ValidationError, ValueError) as e:
|
||||
error_messages.append(str(e))
|
||||
|
||||
if match > 1:
|
||||
# more than 1 match
|
||||
raise ValueError(
|
||||
"Multiple matches found when deserializing the JSON string into ArtifactUpdate with oneOf schemas: DocArtifactUpdate, ModelArtifactUpdate. Details: "
|
||||
"Multiple matches found when deserializing the JSON string into ArtifactUpdate with oneOf schemas: DataSetUpdate, DocArtifactUpdate, MetricUpdate, ModelArtifactUpdate, ParameterUpdate. Details: "
|
||||
+ ", ".join(error_messages)
|
||||
)
|
||||
if match == 0:
|
||||
# no match
|
||||
raise ValueError(
|
||||
"No match found when deserializing the JSON string into ArtifactUpdate with oneOf schemas: DocArtifactUpdate, ModelArtifactUpdate. Details: "
|
||||
"No match found when deserializing the JSON string into ArtifactUpdate with oneOf schemas: DataSetUpdate, DocArtifactUpdate, MetricUpdate, ModelArtifactUpdate, ParameterUpdate. Details: "
|
||||
+ ", ".join(error_messages)
|
||||
)
|
||||
return instance
|
||||
|
@ -159,7 +245,17 @@ class ArtifactUpdate(BaseModel):
|
|||
return self.actual_instance.to_json()
|
||||
return json.dumps(self.actual_instance)
|
||||
|
||||
def to_dict(self) -> dict[str, Any] | DocArtifactUpdate | ModelArtifactUpdate | None:
|
||||
def to_dict(
|
||||
self,
|
||||
) -> (
|
||||
dict[str, Any]
|
||||
| DataSetUpdate
|
||||
| DocArtifactUpdate
|
||||
| MetricUpdate
|
||||
| ModelArtifactUpdate
|
||||
| ParameterUpdate
|
||||
| None
|
||||
):
|
||||
"""Returns the dict representation of the actual instance."""
|
||||
if self.actual_instance is None:
|
||||
return None
|
||||
|
|
|
@ -42,7 +42,7 @@ class BaseResource(BaseModel):
|
|||
description: StrictStr | None = Field(default=None, description="An optional description about the resource.")
|
||||
external_id: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="The external id that come from the clients' system. This field is optional. If set, it must be unique among all resources within a database instance.",
|
||||
description="The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance.",
|
||||
alias="externalId",
|
||||
)
|
||||
name: StrictStr | None = Field(
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
"""Model Registry REST API.
|
||||
|
||||
REST API for Model Registry to create and manage ML model metadata
|
||||
|
||||
The version of the OpenAPI document: v1alpha3
|
||||
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
||||
|
||||
Do not edit the class manually.
|
||||
""" # noqa: E501
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pprint
|
||||
import re # noqa: F401
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, StrictStr
|
||||
from typing_extensions import Self
|
||||
|
||||
from mr_openapi.models.artifact_state import ArtifactState
|
||||
from mr_openapi.models.metadata_value import MetadataValue
|
||||
|
||||
|
||||
class DataSet(BaseModel):
|
||||
"""A dataset artifact representing training or test data.""" # noqa: E501
|
||||
|
||||
custom_properties: dict[str, MetadataValue] | None = Field(
|
||||
default=None,
|
||||
description="User provided custom properties which are not defined by its type.",
|
||||
alias="customProperties",
|
||||
)
|
||||
description: StrictStr | None = Field(default=None, description="An optional description about the resource.")
|
||||
external_id: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance.",
|
||||
alias="externalId",
|
||||
)
|
||||
name: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="The client provided name of the artifact. This field is optional. If set, it must be unique among all the artifacts of the same artifact type within a database instance and cannot be changed once set.",
|
||||
)
|
||||
id: StrictStr | None = Field(default=None, description="The unique server generated id of the resource.")
|
||||
create_time_since_epoch: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="Output only. Create time of the resource in millisecond since epoch.",
|
||||
alias="createTimeSinceEpoch",
|
||||
)
|
||||
last_update_time_since_epoch: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="Output only. Last update time of the resource since epoch in millisecond since epoch.",
|
||||
alias="lastUpdateTimeSinceEpoch",
|
||||
)
|
||||
artifact_type: StrictStr | None = Field(default="dataset-artifact", alias="artifactType")
|
||||
digest: StrictStr | None = Field(default=None, description="A unique hash or identifier for the dataset content.")
|
||||
source_type: StrictStr | None = Field(
|
||||
default=None,
|
||||
description='The type of data source (e.g., "s3", "hdfs", "local", "database").',
|
||||
alias="sourceType",
|
||||
)
|
||||
source: StrictStr | None = Field(
|
||||
default=None, description="The location or connection string for the dataset source."
|
||||
)
|
||||
var_schema: StrictStr | None = Field(
|
||||
default=None, description="JSON schema or description of the dataset structure.", alias="schema"
|
||||
)
|
||||
profile: StrictStr | None = Field(default=None, description="Statistical profile or summary of the dataset.")
|
||||
uri: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="The uniform resource identifier of the physical dataset. May be empty if there is no physical dataset.",
|
||||
)
|
||||
state: ArtifactState | None = None
|
||||
__properties: ClassVar[list[str]] = [
|
||||
"customProperties",
|
||||
"description",
|
||||
"externalId",
|
||||
"name",
|
||||
"id",
|
||||
"createTimeSinceEpoch",
|
||||
"lastUpdateTimeSinceEpoch",
|
||||
"artifactType",
|
||||
"digest",
|
||||
"sourceType",
|
||||
"source",
|
||||
"schema",
|
||||
"profile",
|
||||
"uri",
|
||||
"state",
|
||||
]
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
validate_assignment=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
def to_str(self) -> str:
|
||||
"""Returns the string representation of the model using alias."""
|
||||
return pprint.pformat(self.model_dump(by_alias=True))
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Returns the JSON representation of the model using alias."""
|
||||
# TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> Self | None:
|
||||
"""Create an instance of DataSet from a JSON string."""
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return the dictionary representation of the model using alias.
|
||||
|
||||
This has the following differences from calling pydantic's
|
||||
`self.model_dump(by_alias=True)`:
|
||||
|
||||
* `None` is only added to the output dict for nullable fields that
|
||||
were set at model initialization. Other fields with value `None`
|
||||
are ignored.
|
||||
* OpenAPI `readOnly` fields are excluded.
|
||||
* OpenAPI `readOnly` fields are excluded.
|
||||
"""
|
||||
excluded_fields: set[str] = {
|
||||
"create_time_since_epoch",
|
||||
"last_update_time_since_epoch",
|
||||
}
|
||||
|
||||
_dict = self.model_dump(
|
||||
by_alias=True,
|
||||
exclude=excluded_fields,
|
||||
exclude_none=True,
|
||||
)
|
||||
# override the default output from pydantic by calling `to_dict()` of each value in custom_properties (dict)
|
||||
_field_dict = {}
|
||||
if self.custom_properties:
|
||||
for _key in self.custom_properties:
|
||||
if self.custom_properties[_key]:
|
||||
_field_dict[_key] = self.custom_properties[_key].to_dict()
|
||||
_dict["customProperties"] = _field_dict
|
||||
return _dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, obj: dict[str, Any] | None) -> Self | None:
|
||||
"""Create an instance of DataSet from a dict."""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
if not isinstance(obj, dict):
|
||||
return cls.model_validate(obj)
|
||||
|
||||
return cls.model_validate(
|
||||
{
|
||||
"customProperties": (
|
||||
{_k: MetadataValue.from_dict(_v) for _k, _v in obj["customProperties"].items()}
|
||||
if obj.get("customProperties") is not None
|
||||
else None
|
||||
),
|
||||
"description": obj.get("description"),
|
||||
"externalId": obj.get("externalId"),
|
||||
"name": obj.get("name"),
|
||||
"id": obj.get("id"),
|
||||
"createTimeSinceEpoch": obj.get("createTimeSinceEpoch"),
|
||||
"lastUpdateTimeSinceEpoch": obj.get("lastUpdateTimeSinceEpoch"),
|
||||
"artifactType": obj.get("artifactType") if obj.get("artifactType") is not None else "dataset-artifact",
|
||||
"digest": obj.get("digest"),
|
||||
"sourceType": obj.get("sourceType"),
|
||||
"source": obj.get("source"),
|
||||
"schema": obj.get("schema"),
|
||||
"profile": obj.get("profile"),
|
||||
"uri": obj.get("uri"),
|
||||
"state": obj.get("state"),
|
||||
}
|
||||
)
|
|
@ -0,0 +1,151 @@
|
|||
"""Model Registry REST API.
|
||||
|
||||
REST API for Model Registry to create and manage ML model metadata
|
||||
|
||||
The version of the OpenAPI document: v1alpha3
|
||||
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
||||
|
||||
Do not edit the class manually.
|
||||
""" # noqa: E501
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pprint
|
||||
import re # noqa: F401
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, StrictStr
|
||||
from typing_extensions import Self
|
||||
|
||||
from mr_openapi.models.artifact_state import ArtifactState
|
||||
from mr_openapi.models.metadata_value import MetadataValue
|
||||
|
||||
|
||||
class DataSetCreate(BaseModel):
|
||||
"""A dataset artifact to be created.""" # noqa: E501
|
||||
|
||||
custom_properties: dict[str, MetadataValue] | None = Field(
|
||||
default=None,
|
||||
description="User provided custom properties which are not defined by its type.",
|
||||
alias="customProperties",
|
||||
)
|
||||
description: StrictStr | None = Field(default=None, description="An optional description about the resource.")
|
||||
external_id: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance.",
|
||||
alias="externalId",
|
||||
)
|
||||
name: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="The client provided name of the artifact. This field is optional. If set, it must be unique among all the artifacts of the same artifact type within a database instance and cannot be changed once set.",
|
||||
)
|
||||
artifact_type: StrictStr | None = Field(default="dataset-artifact", alias="artifactType")
|
||||
digest: StrictStr | None = Field(default=None, description="A unique hash or identifier for the dataset content.")
|
||||
source_type: StrictStr | None = Field(
|
||||
default=None,
|
||||
description='The type of data source (e.g., "s3", "hdfs", "local", "database").',
|
||||
alias="sourceType",
|
||||
)
|
||||
source: StrictStr | None = Field(
|
||||
default=None, description="The location or connection string for the dataset source."
|
||||
)
|
||||
var_schema: StrictStr | None = Field(
|
||||
default=None, description="JSON schema or description of the dataset structure.", alias="schema"
|
||||
)
|
||||
profile: StrictStr | None = Field(default=None, description="Statistical profile or summary of the dataset.")
|
||||
uri: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="The uniform resource identifier of the physical dataset. May be empty if there is no physical dataset.",
|
||||
)
|
||||
state: ArtifactState | None = None
|
||||
__properties: ClassVar[list[str]] = [
|
||||
"customProperties",
|
||||
"description",
|
||||
"externalId",
|
||||
"name",
|
||||
"artifactType",
|
||||
"digest",
|
||||
"sourceType",
|
||||
"source",
|
||||
"schema",
|
||||
"profile",
|
||||
"uri",
|
||||
"state",
|
||||
]
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
validate_assignment=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
def to_str(self) -> str:
|
||||
"""Returns the string representation of the model using alias."""
|
||||
return pprint.pformat(self.model_dump(by_alias=True))
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Returns the JSON representation of the model using alias."""
|
||||
# TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> Self | None:
|
||||
"""Create an instance of DataSetCreate from a JSON string."""
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return the dictionary representation of the model using alias.
|
||||
|
||||
This has the following differences from calling pydantic's
|
||||
`self.model_dump(by_alias=True)`:
|
||||
|
||||
* `None` is only added to the output dict for nullable fields that
|
||||
were set at model initialization. Other fields with value `None`
|
||||
are ignored.
|
||||
"""
|
||||
excluded_fields: set[str] = set()
|
||||
|
||||
_dict = self.model_dump(
|
||||
by_alias=True,
|
||||
exclude=excluded_fields,
|
||||
exclude_none=True,
|
||||
)
|
||||
# override the default output from pydantic by calling `to_dict()` of each value in custom_properties (dict)
|
||||
_field_dict = {}
|
||||
if self.custom_properties:
|
||||
for _key in self.custom_properties:
|
||||
if self.custom_properties[_key]:
|
||||
_field_dict[_key] = self.custom_properties[_key].to_dict()
|
||||
_dict["customProperties"] = _field_dict
|
||||
return _dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, obj: dict[str, Any] | None) -> Self | None:
|
||||
"""Create an instance of DataSetCreate from a dict."""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
if not isinstance(obj, dict):
|
||||
return cls.model_validate(obj)
|
||||
|
||||
return cls.model_validate(
|
||||
{
|
||||
"customProperties": (
|
||||
{_k: MetadataValue.from_dict(_v) for _k, _v in obj["customProperties"].items()}
|
||||
if obj.get("customProperties") is not None
|
||||
else None
|
||||
),
|
||||
"description": obj.get("description"),
|
||||
"externalId": obj.get("externalId"),
|
||||
"name": obj.get("name"),
|
||||
"artifactType": obj.get("artifactType") if obj.get("artifactType") is not None else "dataset-artifact",
|
||||
"digest": obj.get("digest"),
|
||||
"sourceType": obj.get("sourceType"),
|
||||
"source": obj.get("source"),
|
||||
"schema": obj.get("schema"),
|
||||
"profile": obj.get("profile"),
|
||||
"uri": obj.get("uri"),
|
||||
"state": obj.get("state"),
|
||||
}
|
||||
)
|
|
@ -0,0 +1,145 @@
|
|||
"""Model Registry REST API.
|
||||
|
||||
REST API for Model Registry to create and manage ML model metadata
|
||||
|
||||
The version of the OpenAPI document: v1alpha3
|
||||
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
||||
|
||||
Do not edit the class manually.
|
||||
""" # noqa: E501
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pprint
|
||||
import re # noqa: F401
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, StrictStr
|
||||
from typing_extensions import Self
|
||||
|
||||
from mr_openapi.models.artifact_state import ArtifactState
|
||||
from mr_openapi.models.metadata_value import MetadataValue
|
||||
|
||||
|
||||
class DataSetUpdate(BaseModel):
|
||||
"""A dataset artifact to be updated.""" # noqa: E501
|
||||
|
||||
custom_properties: dict[str, MetadataValue] | None = Field(
|
||||
default=None,
|
||||
description="User provided custom properties which are not defined by its type.",
|
||||
alias="customProperties",
|
||||
)
|
||||
description: StrictStr | None = Field(default=None, description="An optional description about the resource.")
|
||||
external_id: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance.",
|
||||
alias="externalId",
|
||||
)
|
||||
artifact_type: StrictStr | None = Field(default="dataset-artifact", alias="artifactType")
|
||||
digest: StrictStr | None = Field(default=None, description="A unique hash or identifier for the dataset content.")
|
||||
source_type: StrictStr | None = Field(
|
||||
default=None,
|
||||
description='The type of data source (e.g., "s3", "hdfs", "local", "database").',
|
||||
alias="sourceType",
|
||||
)
|
||||
source: StrictStr | None = Field(
|
||||
default=None, description="The location or connection string for the dataset source."
|
||||
)
|
||||
var_schema: StrictStr | None = Field(
|
||||
default=None, description="JSON schema or description of the dataset structure.", alias="schema"
|
||||
)
|
||||
profile: StrictStr | None = Field(default=None, description="Statistical profile or summary of the dataset.")
|
||||
uri: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="The uniform resource identifier of the physical dataset. May be empty if there is no physical dataset.",
|
||||
)
|
||||
state: ArtifactState | None = None
|
||||
__properties: ClassVar[list[str]] = [
|
||||
"customProperties",
|
||||
"description",
|
||||
"externalId",
|
||||
"artifactType",
|
||||
"digest",
|
||||
"sourceType",
|
||||
"source",
|
||||
"schema",
|
||||
"profile",
|
||||
"uri",
|
||||
"state",
|
||||
]
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
validate_assignment=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
def to_str(self) -> str:
|
||||
"""Returns the string representation of the model using alias."""
|
||||
return pprint.pformat(self.model_dump(by_alias=True))
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Returns the JSON representation of the model using alias."""
|
||||
# TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> Self | None:
|
||||
"""Create an instance of DataSetUpdate from a JSON string."""
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return the dictionary representation of the model using alias.
|
||||
|
||||
This has the following differences from calling pydantic's
|
||||
`self.model_dump(by_alias=True)`:
|
||||
|
||||
* `None` is only added to the output dict for nullable fields that
|
||||
were set at model initialization. Other fields with value `None`
|
||||
are ignored.
|
||||
"""
|
||||
excluded_fields: set[str] = set()
|
||||
|
||||
_dict = self.model_dump(
|
||||
by_alias=True,
|
||||
exclude=excluded_fields,
|
||||
exclude_none=True,
|
||||
)
|
||||
# override the default output from pydantic by calling `to_dict()` of each value in custom_properties (dict)
|
||||
_field_dict = {}
|
||||
if self.custom_properties:
|
||||
for _key in self.custom_properties:
|
||||
if self.custom_properties[_key]:
|
||||
_field_dict[_key] = self.custom_properties[_key].to_dict()
|
||||
_dict["customProperties"] = _field_dict
|
||||
return _dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, obj: dict[str, Any] | None) -> Self | None:
|
||||
"""Create an instance of DataSetUpdate from a dict."""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
if not isinstance(obj, dict):
|
||||
return cls.model_validate(obj)
|
||||
|
||||
return cls.model_validate(
|
||||
{
|
||||
"customProperties": (
|
||||
{_k: MetadataValue.from_dict(_v) for _k, _v in obj["customProperties"].items()}
|
||||
if obj.get("customProperties") is not None
|
||||
else None
|
||||
),
|
||||
"description": obj.get("description"),
|
||||
"externalId": obj.get("externalId"),
|
||||
"artifactType": obj.get("artifactType") if obj.get("artifactType") is not None else "dataset-artifact",
|
||||
"digest": obj.get("digest"),
|
||||
"sourceType": obj.get("sourceType"),
|
||||
"source": obj.get("source"),
|
||||
"schema": obj.get("schema"),
|
||||
"profile": obj.get("profile"),
|
||||
"uri": obj.get("uri"),
|
||||
"state": obj.get("state"),
|
||||
}
|
||||
)
|
|
@ -0,0 +1,143 @@
|
|||
"""Model Registry REST API.
|
||||
|
||||
REST API for Model Registry to create and manage ML model metadata
|
||||
|
||||
The version of the OpenAPI document: v1alpha3
|
||||
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
||||
|
||||
Do not edit the class manually.
|
||||
""" # noqa: E501
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pprint
|
||||
import re # noqa: F401
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, StrictStr
|
||||
from typing_extensions import Self
|
||||
|
||||
from mr_openapi.models.experiment_state import ExperimentState
|
||||
from mr_openapi.models.metadata_value import MetadataValue
|
||||
|
||||
|
||||
class Experiment(BaseModel):
|
||||
"""An experiment in model registry. An experiment has ExperimentRun children.""" # noqa: E501
|
||||
|
||||
custom_properties: dict[str, MetadataValue] | None = Field(
|
||||
default=None,
|
||||
description="User provided custom properties which are not defined by its type.",
|
||||
alias="customProperties",
|
||||
)
|
||||
description: StrictStr | None = Field(default=None, description="An optional description about the resource.")
|
||||
external_id: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance.",
|
||||
alias="externalId",
|
||||
)
|
||||
name: StrictStr = Field(
|
||||
description="The client provided name of the experiment. It must be unique among all the Experiments of the same type within a Model Registry instance and cannot be changed once set."
|
||||
)
|
||||
id: StrictStr | None = Field(default=None, description="The unique server generated id of the resource.")
|
||||
create_time_since_epoch: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="Output only. Create time of the resource in millisecond since epoch.",
|
||||
alias="createTimeSinceEpoch",
|
||||
)
|
||||
last_update_time_since_epoch: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="Output only. Last update time of the resource since epoch in millisecond since epoch.",
|
||||
alias="lastUpdateTimeSinceEpoch",
|
||||
)
|
||||
owner: StrictStr | None = None
|
||||
state: ExperimentState | None = None
|
||||
__properties: ClassVar[list[str]] = [
|
||||
"customProperties",
|
||||
"description",
|
||||
"externalId",
|
||||
"name",
|
||||
"id",
|
||||
"createTimeSinceEpoch",
|
||||
"lastUpdateTimeSinceEpoch",
|
||||
"owner",
|
||||
"state",
|
||||
]
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
validate_assignment=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
def to_str(self) -> str:
|
||||
"""Returns the string representation of the model using alias."""
|
||||
return pprint.pformat(self.model_dump(by_alias=True))
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Returns the JSON representation of the model using alias."""
|
||||
# TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> Self | None:
|
||||
"""Create an instance of Experiment from a JSON string."""
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return the dictionary representation of the model using alias.
|
||||
|
||||
This has the following differences from calling pydantic's
|
||||
`self.model_dump(by_alias=True)`:
|
||||
|
||||
* `None` is only added to the output dict for nullable fields that
|
||||
were set at model initialization. Other fields with value `None`
|
||||
are ignored.
|
||||
* OpenAPI `readOnly` fields are excluded.
|
||||
* OpenAPI `readOnly` fields are excluded.
|
||||
"""
|
||||
excluded_fields: set[str] = {
|
||||
"create_time_since_epoch",
|
||||
"last_update_time_since_epoch",
|
||||
}
|
||||
|
||||
_dict = self.model_dump(
|
||||
by_alias=True,
|
||||
exclude=excluded_fields,
|
||||
exclude_none=True,
|
||||
)
|
||||
# override the default output from pydantic by calling `to_dict()` of each value in custom_properties (dict)
|
||||
_field_dict = {}
|
||||
if self.custom_properties:
|
||||
for _key in self.custom_properties:
|
||||
if self.custom_properties[_key]:
|
||||
_field_dict[_key] = self.custom_properties[_key].to_dict()
|
||||
_dict["customProperties"] = _field_dict
|
||||
return _dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, obj: dict[str, Any] | None) -> Self | None:
|
||||
"""Create an instance of Experiment from a dict."""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
if not isinstance(obj, dict):
|
||||
return cls.model_validate(obj)
|
||||
|
||||
return cls.model_validate(
|
||||
{
|
||||
"customProperties": (
|
||||
{_k: MetadataValue.from_dict(_v) for _k, _v in obj["customProperties"].items()}
|
||||
if obj.get("customProperties") is not None
|
||||
else None
|
||||
),
|
||||
"description": obj.get("description"),
|
||||
"externalId": obj.get("externalId"),
|
||||
"name": obj.get("name"),
|
||||
"id": obj.get("id"),
|
||||
"createTimeSinceEpoch": obj.get("createTimeSinceEpoch"),
|
||||
"lastUpdateTimeSinceEpoch": obj.get("lastUpdateTimeSinceEpoch"),
|
||||
"owner": obj.get("owner"),
|
||||
"state": obj.get("state"),
|
||||
}
|
||||
)
|
|
@ -0,0 +1,114 @@
|
|||
"""Model Registry REST API.
|
||||
|
||||
REST API for Model Registry to create and manage ML model metadata
|
||||
|
||||
The version of the OpenAPI document: v1alpha3
|
||||
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
||||
|
||||
Do not edit the class manually.
|
||||
""" # noqa: E501
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pprint
|
||||
import re # noqa: F401
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, StrictStr
|
||||
from typing_extensions import Self
|
||||
|
||||
from mr_openapi.models.experiment_state import ExperimentState
|
||||
from mr_openapi.models.metadata_value import MetadataValue
|
||||
|
||||
|
||||
class ExperimentCreate(BaseModel):
|
||||
"""An experiment in model registry. An experiment has ExperimentRun children.""" # noqa: E501
|
||||
|
||||
custom_properties: dict[str, MetadataValue] | None = Field(
|
||||
default=None,
|
||||
description="User provided custom properties which are not defined by its type.",
|
||||
alias="customProperties",
|
||||
)
|
||||
description: StrictStr | None = Field(default=None, description="An optional description about the resource.")
|
||||
external_id: StrictStr | None = Field(
|
||||
default=None,
|
||||
description="The external id that come from the clients’ system. This field is optional. If set, it must be unique among all resources within a database instance.",
|
||||
alias="externalId",
|
||||
)
|
||||
name: StrictStr = Field(
|
||||
description="The client provided name of the experiment. It must be unique among all the Experiments of the same type within a Model Registry instance and cannot be changed once set."
|
||||
)
|
||||
owner: StrictStr | None = None
|
||||
state: ExperimentState | None = None
|
||||
__properties: ClassVar[list[str]] = ["customProperties", "description", "externalId", "name", "owner", "state"]
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
validate_assignment=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
def to_str(self) -> str:
|
||||
"""Returns the string representation of the model using alias."""
|
||||
return pprint.pformat(self.model_dump(by_alias=True))
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Returns the JSON representation of the model using alias."""
|
||||
# TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> Self | None:
|
||||
"""Create an instance of ExperimentCreate from a JSON string."""
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return the dictionary representation of the model using alias.
|
||||
|
||||
This has the following differences from calling pydantic's
|
||||
`self.model_dump(by_alias=True)`:
|
||||
|
||||
* `None` is only added to the output dict for nullable fields that
|
||||
were set at model initialization. Other fields with value `None`
|
||||
are ignored.
|
||||
"""
|
||||
excluded_fields: set[str] = set()
|
||||
|
||||
_dict = self.model_dump(
|
||||
by_alias=True,
|
||||
exclude=excluded_fields,
|
||||
exclude_none=True,
|
||||
)
|
||||
# override the default output from pydantic by calling `to_dict()` of each value in custom_properties (dict)
|
||||
_field_dict = {}
|
||||
if self.custom_properties:
|
||||
for _key in self.custom_properties:
|
||||
if self.custom_properties[_key]:
|
||||
_field_dict[_key] = self.custom_properties[_key].to_dict()
|
||||
_dict["customProperties"] = _field_dict
|
||||
return _dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, obj: dict[str, Any] | None) -> Self | None:
|
||||
"""Create an instance of ExperimentCreate from a dict."""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
if not isinstance(obj, dict):
|
||||
return cls.model_validate(obj)
|
||||
|
||||
return cls.model_validate(
|
||||
{
|
||||
"customProperties": (
|
||||
{_k: MetadataValue.from_dict(_v) for _k, _v in obj["customProperties"].items()}
|
||||
if obj.get("customProperties") is not None
|
||||
else None
|
||||
),
|
||||
"description": obj.get("description"),
|
||||
"externalId": obj.get("externalId"),
|
||||
"name": obj.get("name"),
|
||||
"owner": obj.get("owner"),
|
||||
"state": obj.get("state"),
|
||||
}
|
||||
)
|
|
@ -0,0 +1,99 @@
|
|||
"""Model Registry REST API.
|
||||
|
||||
REST API for Model Registry to create and manage ML model metadata
|
||||
|
||||
The version of the OpenAPI document: v1alpha3
|
||||
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
||||
|
||||
Do not edit the class manually.
|
||||
""" # noqa: E501
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pprint
|
||||
import re # noqa: F401
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr
|
||||
from typing_extensions import Self
|
||||
|
||||
from mr_openapi.models.experiment import Experiment
|
||||
|
||||
|
||||
class ExperimentList(BaseModel):
|
||||
"""List of Experiments.""" # noqa: E501
|
||||
|
||||
next_page_token: StrictStr = Field(
|
||||
description="Token to use to retrieve next page of results.", alias="nextPageToken"
|
||||
)
|
||||
page_size: StrictInt = Field(description="Maximum number of resources to return in the result.", alias="pageSize")
|
||||
size: StrictInt = Field(description="Number of items in result list.")
|
||||
items: list[Experiment]
|
||||
__properties: ClassVar[list[str]] = ["nextPageToken", "pageSize", "size", "items"]
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
validate_assignment=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
def to_str(self) -> str:
|
||||
"""Returns the string representation of the model using alias."""
|
||||
return pprint.pformat(self.model_dump(by_alias=True))
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Returns the JSON representation of the model using alias."""
|
||||
# TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> Self | None:
|
||||
"""Create an instance of ExperimentList from a JSON string."""
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return the dictionary representation of the model using alias.
|
||||
|
||||
This has the following differences from calling pydantic's
|
||||
`self.model_dump(by_alias=True)`:
|
||||
|
||||
* `None` is only added to the output dict for nullable fields that
|
||||
were set at model initialization. Other fields with value `None`
|
||||
are ignored.
|
||||
"""
|
||||
excluded_fields: set[str] = set()
|
||||
|
||||
_dict = self.model_dump(
|
||||
by_alias=True,
|
||||
exclude=excluded_fields,
|
||||
exclude_none=True,
|
||||
)
|
||||
# override the default output from pydantic by calling `to_dict()` of each item in items (list)
|
||||
_items = []
|
||||
if self.items:
|
||||
for _item in self.items:
|
||||
if _item:
|
||||
_items.append(_item.to_dict())
|
||||
_dict["items"] = _items
|
||||
return _dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, obj: dict[str, Any] | None) -> Self | None:
|
||||
"""Create an instance of ExperimentList from a dict."""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
if not isinstance(obj, dict):
|
||||
return cls.model_validate(obj)
|
||||
|
||||
return cls.model_validate(
|
||||
{
|
||||
"nextPageToken": obj.get("nextPageToken"),
|
||||
"pageSize": obj.get("pageSize"),
|
||||
"size": obj.get("size"),
|
||||
"items": (
|
||||
[Experiment.from_dict(_item) for _item in obj["items"]] if obj.get("items") is not None else None
|
||||
),
|
||||
}
|
||||
)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue