Merge pull request #11 from docker/tag-model

Add endpoint for tagging a model
This commit is contained in:
Emily Casey 2025-04-17 15:54:42 -04:00 committed by GitHub
commit d1d8e7a3c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 48 additions and 0 deletions

View File

@ -92,6 +92,7 @@ func (m *Manager) routeHandlers() map[string]http.HandlerFunc {
"GET " + inference.ModelsPrefix: m.handleGetModels,
"GET " + inference.ModelsPrefix + "/{name...}": m.handleGetModel,
"DELETE " + inference.ModelsPrefix + "/{name...}": m.handleDeleteModel,
"POST " + inference.ModelsPrefix + "/{name}/tag": m.handleTagModel,
"GET " + inference.InferencePrefix + "/{backend}/v1/models": m.handleOpenAIGetModels,
"GET " + inference.InferencePrefix + "/{backend}/v1/models/{name...}": m.handleOpenAIGetModel,
"GET " + inference.InferencePrefix + "/v1/models": m.handleOpenAIGetModels,
@ -288,6 +289,53 @@ func (m *Manager) handleOpenAIGetModel(w http.ResponseWriter, r *http.Request) {
}
}
// handleTagModel handles POST <inference-prefix>/models/{name}/tag requests.
// The query parameters are:
// - repo: the repository to tag the model with (required)
// - tag: the tag to tag the model with (optional, defaults to "latest")
func (m *Manager) handleTagModel(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// Extract the model name from the request path.
model := r.PathValue("name")
// Extract query parameters.
repo := r.URL.Query().Get("repo")
tag := r.URL.Query().Get("tag")
// Validate query parameters.
if repo == "" {
http.Error(w, "missing repo or tag query parameter", http.StatusBadRequest)
return
}
if tag == "" {
tag = "latest"
}
// Construct the target string.
target := fmt.Sprintf("%s:%s", repo, tag)
// Call the Tag method on the distribution client with source and modelName.
if err := m.distributionClient.Tag(model, target); err != nil {
m.log.Warnf("Failed to tag model %q: %v", model, err)
if errors.Is(err, distribution.ErrModelNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Respond with success.
w.WriteHeader(http.StatusCreated)
w.Write([]byte(fmt.Sprintf("Model %q tagged successfully with source %q", modelName, model)))
}
// ServeHTTP implement net/http.Handler.ServeHTTP.
func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.router.ServeHTTP(w, r)