package models import ( "context" "encoding/json" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "strings" "testing" "github.com/google/go-containerregistry/pkg/registry" "github.com/docker/model-distribution/builder" reg "github.com/docker/model-distribution/registry" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/memory" "github.com/sirupsen/logrus" ) type mockMemoryEstimator struct{} func (me *mockMemoryEstimator) SetDefaultBackend(_ memory.MemoryEstimatorBackend) {} func (me *mockMemoryEstimator) GetRequiredMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (*inference.RequiredMemory, error) { return &inference.RequiredMemory{RAM: 0, VRAM: 0}, nil } func (me *mockMemoryEstimator) HaveSufficientMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (bool, error) { return true, nil } // getProjectRoot returns the absolute path to the project root directory func getProjectRoot(t *testing.T) string { // Start from the current test file's directory dir, err := os.Getwd() if err != nil { t.Fatalf("Failed to get current directory: %v", err) } // Walk up the directory tree until we find the go.mod file for { if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { return dir } parent := filepath.Dir(dir) if parent == dir { t.Fatal("Could not find project root (go.mod)") } dir = parent } } func TestPullModel(t *testing.T) { // Create temp directory for store tempDir, err := os.MkdirTemp("", "model-distribution-test-*") if err != nil { t.Fatalf("Failed to create temp directory: %v", err) } defer os.RemoveAll(tempDir) // Create a test registry server := httptest.NewServer(registry.New()) defer server.Close() // Create a tag for the model uri, err := url.Parse(server.URL) if err != nil { t.Fatalf("Failed to parse registry URL: %v", err) } tag := uri.Host + "/ai/model:v1.0.0" // Prepare the OCI model artifact projectRoot := getProjectRoot(t) model, err := builder.FromGGUF(filepath.Join(projectRoot, "assets", "dummy.gguf")) if err != nil { t.Fatalf("Failed to create model builder: %v", err) } license, err := model.WithLicense(filepath.Join(projectRoot, "assets", "license.txt")) if err != nil { t.Fatalf("Failed to add license to model: %v", err) } // Build the OCI model artifact + push it client := reg.NewClient() target, err := client.NewTarget(tag) if err != nil { t.Fatalf("Failed to create model target: %v", err) } err = license.Build(context.Background(), target, os.Stdout) if err != nil { t.Fatalf("Failed to build model: %v", err) } tests := []struct { name string acceptHeader string expectedCT string }{ { name: "default content type", acceptHeader: "", expectedCT: "text/plain", }, { name: "plain text content type", acceptHeader: "text/plain", expectedCT: "text/plain", }, { name: "json content type", acceptHeader: "application/json", expectedCT: "application/json", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { log := logrus.NewEntry(logrus.StandardLogger()) memEstimator := &mockMemoryEstimator{} m := NewManager(log, ClientConfig{ StoreRootPath: tempDir, Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), }, nil, memEstimator) r := httptest.NewRequest("POST", "/models/create", strings.NewReader(`{"from": "`+tag+`"}`)) if tt.acceptHeader != "" { r.Header.Set("Accept", tt.acceptHeader) } w := httptest.NewRecorder() err = m.PullModel(tag, r, w) if err != nil { t.Fatalf("Failed to pull model: %v", err) } if tt.expectedCT != w.Header().Get("Content-Type") { t.Fatalf("Expected content type %s, got %s", tt.expectedCT, w.Header().Get("Content-Type")) } // Clean tempDir after each test if err := os.RemoveAll(tempDir); err != nil { t.Fatalf("Failed to clean temp directory: %v", err) } if err := os.MkdirAll(tempDir, 0755); err != nil { t.Fatalf("Failed to recreate temp directory: %v", err) } }) } } func TestHandleGetModel(t *testing.T) { // Create temp directory for store tempDir, err := os.MkdirTemp("", "model-distribution-test-*") if err != nil { t.Fatalf("Failed to create temp directory: %v", err) } defer os.RemoveAll(tempDir) // Create a test registry server := httptest.NewServer(registry.New()) defer server.Close() uri, err := url.Parse(server.URL) if err != nil { t.Fatalf("Failed to parse registry URL: %v", err) } // Prepare the OCI model artifact projectRoot := getProjectRoot(t) model, err := builder.FromGGUF(filepath.Join(projectRoot, "assets", "dummy.gguf")) if err != nil { t.Fatalf("Failed to create model builder: %v", err) } license, err := model.WithLicense(filepath.Join(projectRoot, "assets", "license.txt")) if err != nil { t.Fatalf("Failed to add license to model: %v", err) } // Build the OCI model artifact + push it tag := uri.Host + "/ai/model:v1.0.0" client := reg.NewClient() target, err := client.NewTarget(tag) if err != nil { t.Fatalf("Failed to create model target: %v", err) } err = license.Build(context.Background(), target, os.Stdout) if err != nil { t.Fatalf("Failed to build model: %v", err) } tests := []struct { name string remote bool modelName string expectedCode int expectedError string }{ { name: "get local model - success", remote: false, modelName: tag, expectedCode: http.StatusOK, }, { name: "get local model - not found", remote: false, modelName: "nonexistent:v1", expectedCode: http.StatusNotFound, expectedError: "error while getting model", }, { name: "get remote model - success", remote: true, modelName: tag, expectedCode: http.StatusOK, }, { name: "get remote model - not found", remote: true, modelName: uri.Host + "/ai/nonexistent:v1", expectedCode: http.StatusNotFound, expectedError: "not found", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { log := logrus.NewEntry(logrus.StandardLogger()) memEstimator := &mockMemoryEstimator{} m := NewManager(log, ClientConfig{ StoreRootPath: tempDir, Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), Transport: http.DefaultTransport, UserAgent: "test-agent", }, nil, memEstimator) // First pull the model if we're testing local access if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") { r := httptest.NewRequest("POST", "/models/create", strings.NewReader(`{"from": "`+tt.modelName+`"}`)) w := httptest.NewRecorder() err = m.PullModel(tt.modelName, r, w) if err != nil { t.Fatalf("Failed to pull model: %v", err) } } // Create request with remote query param path := inference.ModelsPrefix + "/" + tt.modelName if tt.remote { path += "?remote=true" } r := httptest.NewRequest("GET", path, nil) w := httptest.NewRecorder() // Set the path value for {name} so r.PathValue("name") works r.SetPathValue("name", tt.modelName) // Call the handler directly m.handleGetModel(w, r) // Check response if w.Code != tt.expectedCode { t.Errorf("Expected status code %d, got %d", tt.expectedCode, w.Code) } if tt.expectedError != "" { if !strings.Contains(w.Body.String(), tt.expectedError) { t.Errorf("Expected error containing %q, got %q", tt.expectedError, w.Body.String()) } } else { // For successful responses, verify we got a valid JSON response var response Model if err := json.NewDecoder(w.Body).Decode(&response); err != nil { t.Errorf("Failed to decode response body: %v", err) } } // Clean tempDir after each test if err := os.RemoveAll(tempDir); err != nil { t.Fatalf("Failed to clean temp directory: %v", err) } if err := os.MkdirAll(tempDir, 0755); err != nil { t.Fatalf("Failed to recreate temp directory: %v", err) } }) } }