Merge pull request #11 from docker/tag-model
Add endpoint for tagging a model
This commit is contained in:
commit
d1d8e7a3c1
|
|
@ -92,6 +92,7 @@ func (m *Manager) routeHandlers() map[string]http.HandlerFunc {
|
|||
"GET " + inference.ModelsPrefix: m.handleGetModels,
|
||||
"GET " + inference.ModelsPrefix + "/{name...}": m.handleGetModel,
|
||||
"DELETE " + inference.ModelsPrefix + "/{name...}": m.handleDeleteModel,
|
||||
"POST " + inference.ModelsPrefix + "/{name}/tag": m.handleTagModel,
|
||||
"GET " + inference.InferencePrefix + "/{backend}/v1/models": m.handleOpenAIGetModels,
|
||||
"GET " + inference.InferencePrefix + "/{backend}/v1/models/{name...}": m.handleOpenAIGetModel,
|
||||
"GET " + inference.InferencePrefix + "/v1/models": m.handleOpenAIGetModels,
|
||||
|
|
@ -288,6 +289,53 @@ func (m *Manager) handleOpenAIGetModel(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// handleTagModel handles POST <inference-prefix>/models/{name}/tag requests.
|
||||
// The query parameters are:
|
||||
// - repo: the repository to tag the model with (required)
|
||||
// - tag: the tag to tag the model with (optional, defaults to "latest")
|
||||
func (m *Manager) handleTagModel(w http.ResponseWriter, r *http.Request) {
|
||||
if m.distributionClient == nil {
|
||||
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract the model name from the request path.
|
||||
model := r.PathValue("name")
|
||||
|
||||
// Extract query parameters.
|
||||
repo := r.URL.Query().Get("repo")
|
||||
tag := r.URL.Query().Get("tag")
|
||||
|
||||
// Validate query parameters.
|
||||
if repo == "" {
|
||||
http.Error(w, "missing repo or tag query parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if tag == "" {
|
||||
tag = "latest"
|
||||
}
|
||||
|
||||
// Construct the target string.
|
||||
target := fmt.Sprintf("%s:%s", repo, tag)
|
||||
|
||||
// Call the Tag method on the distribution client with source and modelName.
|
||||
if err := m.distributionClient.Tag(model, target); err != nil {
|
||||
m.log.Warnf("Failed to tag model %q: %v", model, err)
|
||||
|
||||
if errors.Is(err, distribution.ErrModelNotFound) {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Respond with success.
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.Write([]byte(fmt.Sprintf("Model %q tagged successfully with source %q", modelName, model)))
|
||||
}
|
||||
|
||||
// ServeHTTP implement net/http.Handler.ServeHTTP.
|
||||
func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
m.router.ServeHTTP(w, r)
|
||||
|
|
|
|||
Loading…
Reference in New Issue