Skip to content

Commit 82ea264

Browse files
committed
use server side backend selection instead of duplicating logic in CLI
Remove the CLI side GetRequiredBackendFromModelInfo mapping and delegate backend resolution to the server via a new GET /models/backend?model=<ref> endpoint. The endpoint reuses Scheduler.selectBackendForModel so the platform aware fallback chain (vLLM, MLX, SGLang, diffusers) stays in one place. Also address earlier review feedback: sentinel error for cancelled install, cmd.InOrStdin for prompt input, and unit tests for the new helpers.
1 parent e5b6afe commit 82ea264

File tree

8 files changed

+120
-48
lines changed

8 files changed

+120
-48
lines changed

cmd/cli/commands/run.go

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -828,22 +828,15 @@ func newRunCmd() *cobra.Command {
828828
return nil
829829
}
830830

831-
modelInfo, err := desktopClient.Inspect(model, false)
831+
_, err := desktopClient.Inspect(model, false)
832832
modelFoundLocally := err == nil
833833
if err != nil && !errors.Is(err, desktop.ErrNotFound) {
834834
return handleClientError(err, "Failed to inspect model")
835835
}
836836

837-
if !modelFoundLocally {
838-
remoteInfo, remoteErr := desktopClient.Inspect(model, true)
839-
if remoteErr == nil {
840-
modelInfo = remoteInfo
841-
}
842-
}
843-
844837
backend := ""
845-
if modelInfo.ID != "" {
846-
backend, _ = GetRequiredBackendFromModelInfo(&modelInfo)
838+
if resolvedBackend, resolveErr := ResolveRequiredBackend(model); resolveErr == nil {
839+
backend = resolvedBackend
847840
}
848841

849842
if backend != "" {

cmd/cli/commands/utils.go

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@ import (
1414
"github.com/docker/model-runner/cmd/cli/pkg/standalone"
1515
"github.com/docker/model-runner/pkg/distribution/distribution"
1616
"github.com/docker/model-runner/pkg/distribution/oci/reference"
17-
"github.com/docker/model-runner/pkg/distribution/types"
18-
"github.com/docker/model-runner/pkg/inference/backends/diffusers"
19-
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
2017
"github.com/docker/model-runner/pkg/inference/backends/vllm"
21-
dmrm "github.com/docker/model-runner/pkg/inference/models"
2218
"github.com/moby/term"
2319
"github.com/olekukonko/tablewriter"
2420
"github.com/olekukonko/tablewriter/renderer"
@@ -367,22 +363,13 @@ func EnsureBackendAvailable(backend string, cmd *cobra.Command) error {
367363
return nil
368364
}
369365

370-
func GetRequiredBackendFromModelInfo(modelInfo *dmrm.Model) (string, error) {
371-
config, ok := modelInfo.Config.(*types.Config)
372-
if !ok {
373-
return llamacpp.Name, nil
366+
func ResolveRequiredBackend(model string) (string, error) {
367+
selection, err := desktopClient.ResolveModelBackend(model)
368+
if err != nil {
369+
return "", err
374370
}
375371

376-
switch config.Format {
377-
case types.FormatSafetensors:
378-
return vllm.Name, nil
379-
case types.FormatGGUF:
380-
return llamacpp.Name, nil
381-
case types.FormatDiffusers:
382-
return diffusers.Name, nil
383-
default:
384-
return llamacpp.Name, nil
385-
}
372+
return selection.Backend, nil
386373
}
387374

388375
func printNextSteps(out io.Writer, messages []string) {

cmd/cli/commands/utils_test.go

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,15 @@ import (
77
"fmt"
88
"io"
99
"net/http"
10+
"net/url"
1011
"strings"
1112
"testing"
1213

1314
"github.com/docker/model-runner/cmd/cli/desktop"
1415
mockdesktop "github.com/docker/model-runner/cmd/cli/mocks"
15-
"github.com/docker/model-runner/pkg/distribution/types"
1616
"github.com/docker/model-runner/pkg/inference"
17-
"github.com/docker/model-runner/pkg/inference/backends/diffusers"
18-
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
1917
"github.com/docker/model-runner/pkg/inference/backends/vllm"
20-
dmrm "github.com/docker/model-runner/pkg/inference/models"
18+
"github.com/docker/model-runner/pkg/inference/scheduling"
2119
"github.com/spf13/cobra"
2220
"github.com/stretchr/testify/require"
2321
"go.uber.org/mock/gomock"
@@ -219,22 +217,28 @@ func TestEnsureBackendAvailableCancelled(t *testing.T) {
219217
require.Contains(t, out.String(), "docker model install-runner --backend vllm")
220218
}
221219

222-
func TestGetRequiredBackendFromModelInfo(t *testing.T) {
223-
t.Run("safetensors chooses vllm", func(t *testing.T) {
224-
backend, err := GetRequiredBackendFromModelInfo(&dmrm.Model{Config: &types.Config{Format: types.FormatSafetensors}})
225-
require.NoError(t, err)
226-
require.Equal(t, vllm.Name, backend)
227-
})
220+
func TestResolveRequiredBackend(t *testing.T) {
221+
ctrl := gomock.NewController(t)
222+
defer ctrl.Finish()
228223

229-
t.Run("gguf chooses llamacpp", func(t *testing.T) {
230-
backend, err := GetRequiredBackendFromModelInfo(&dmrm.Model{Config: &types.Config{Format: types.FormatGGUF}})
231-
require.NoError(t, err)
232-
require.Equal(t, llamacpp.Name, backend)
233-
})
224+
client := mockdesktop.NewMockDockerHttpClient(ctrl)
225+
modelRunner = desktop.NewContextForMock(client)
226+
desktopClient = desktop.New(modelRunner)
234227

235-
t.Run("diffusers chooses diffusers backend", func(t *testing.T) {
236-
backend, err := GetRequiredBackendFromModelInfo(&dmrm.Model{Config: &types.Config{Format: types.FormatDiffusers}})
237-
require.NoError(t, err)
238-
require.Equal(t, diffusers.Name, backend)
239-
})
228+
model := "ai/functiongemma-vllm:270M"
229+
selection := scheduling.ModelBackendSelection{Backend: vllm.Name, Installed: false}
230+
body, err := json.Marshal(selection)
231+
require.NoError(t, err)
232+
233+
expectedResolveURL := modelRunner.URL(inference.ModelsPrefix + "/backend?model=" + url.QueryEscape(model))
234+
expectedUserAgent := "docker-model-cli/" + desktop.Version
235+
236+
client.EXPECT().Do(gomock.Cond(func(req any) bool {
237+
r, ok := req.(*http.Request)
238+
return ok && r.URL.String() == expectedResolveURL && r.Header.Get("User-Agent") == expectedUserAgent
239+
})).Return(&http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(body))}, nil)
240+
241+
backend, err := ResolveRequiredBackend(model)
242+
require.NoError(t, err)
243+
require.Equal(t, vllm.Name, backend)
240244
}

cmd/cli/desktop/desktop.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,36 @@ func (c *Client) Inspect(model string, remote bool) (dmrm.Model, error) {
362362
return modelInspect, nil
363363
}
364364

365+
func (c *Client) ResolveModelBackend(model string) (scheduling.ModelBackendSelection, error) {
366+
resolvePath := fmt.Sprintf("%s/backend?model=%s", inference.ModelsPrefix, url.QueryEscape(model))
367+
368+
resp, err := c.doRequest(http.MethodGet, resolvePath, nil)
369+
if err != nil {
370+
return scheduling.ModelBackendSelection{}, c.handleQueryError(err, resolvePath)
371+
}
372+
defer resp.Body.Close()
373+
374+
if resp.StatusCode != http.StatusOK {
375+
if resp.StatusCode == http.StatusNotFound {
376+
return scheduling.ModelBackendSelection{}, errors.Wrap(ErrNotFound, model)
377+
}
378+
body, _ := io.ReadAll(resp.Body)
379+
return scheduling.ModelBackendSelection{}, fmt.Errorf("failed to resolve model backend: %s: %s", resp.Status, string(body))
380+
}
381+
382+
body, err := io.ReadAll(resp.Body)
383+
if err != nil {
384+
return scheduling.ModelBackendSelection{}, fmt.Errorf("failed to read response body: %w", err)
385+
}
386+
387+
var selection scheduling.ModelBackendSelection
388+
if err := json.Unmarshal(body, &selection); err != nil {
389+
return scheduling.ModelBackendSelection{}, fmt.Errorf("failed to unmarshal response body: %w", err)
390+
}
391+
392+
return selection, nil
393+
}
394+
365395
func (c *Client) InspectOpenAI(model string) (dmrm.OpenAIModel, error) {
366396
modelsRoute := c.modelRunner.OpenAIPathPrefix() + "/models"
367397
rawResponse, err := c.listRaw(fmt.Sprintf("%s/%s", modelsRoute, model), model)

pkg/inference/scheduling/api.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,9 @@ type ModelConfigEntry struct {
119119
Mode inference.BackendMode
120120
Config inference.BackendConfiguration
121121
}
122+
123+
// ModelBackendSelection describes the backend selected by the scheduler for a model.
124+
type ModelBackendSelection struct {
125+
Backend string `json:"backend"`
126+
Installed bool `json:"installed"`
127+
}

pkg/inference/scheduling/http_handler.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc {
119119
m["GET "+inference.InferencePrefix+"/status"] = h.GetBackendStatus
120120
m["GET "+inference.InferencePrefix+"/ps"] = h.GetRunningBackends
121121
m["GET "+inference.InferencePrefix+"/df"] = h.GetDiskUsage
122+
m["GET "+inference.ModelsPrefix+"/backend"] = h.GetModelBackend
122123
m["POST "+inference.InferencePrefix+"/unload"] = h.Unload
123124
m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = h.Configure
124125
m["POST "+inference.InferencePrefix+"/_configure"] = h.Configure
@@ -540,6 +541,30 @@ func (h *HTTPHandler) GetModelConfigs(w http.ResponseWriter, r *http.Request) {
540541
}
541542
}
542543

544+
// GetModelBackend resolves the backend selected by the scheduler for the provided model.
545+
func (h *HTTPHandler) GetModelBackend(w http.ResponseWriter, r *http.Request) {
546+
modelRef := r.URL.Query().Get("model")
547+
if modelRef == "" {
548+
http.Error(w, "model is required", http.StatusBadRequest)
549+
return
550+
}
551+
552+
backend, err := h.scheduler.ResolveBackendForModel(r.Context(), modelRef)
553+
if err != nil {
554+
http.Error(w, fmt.Sprintf("failed to resolve backend: %v", err), http.StatusNotFound)
555+
return
556+
}
557+
558+
w.Header().Set("Content-Type", "application/json")
559+
if err := json.NewEncoder(w).Encode(ModelBackendSelection{
560+
Backend: backend.Name(),
561+
Installed: h.scheduler.installer.isInstalled(backend.Name()),
562+
}); err != nil {
563+
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
564+
return
565+
}
566+
}
567+
543568
// ServeHTTP implements net/http.Handler.ServeHTTP.
544569
func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
545570
h.lock.RLock()

pkg/inference/scheduling/scheduler.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ func (s *Scheduler) selectBackendForModel(model types.Model, backend inference.B
142142
format = inferFormatFromModel(model)
143143
}
144144

145+
return s.selectBackendForFormat(format, backend, modelRef)
146+
}
147+
148+
func (s *Scheduler) selectBackendForFormat(format types.Format, backend inference.Backend, modelRef string) inference.Backend {
145149
switch format {
146150
case types.FormatSafetensors:
147151
// Prefer vLLM for safetensors models (handles platform dispatch internally)
@@ -211,6 +215,28 @@ func inferFormatFromModel(model types.Model) types.Format {
211215
return ""
212216
}
213217

218+
// ResolveBackendForModel resolves the backend that should be used for a model reference.
219+
// It prefers local model metadata and falls back to remote metadata when needed.
220+
func (s *Scheduler) ResolveBackendForModel(ctx context.Context, modelRef string) (inference.Backend, error) {
221+
backend := s.defaultBackend
222+
223+
if model, err := s.modelManager.GetLocal(modelRef); err == nil {
224+
return s.selectBackendForModel(model, backend, modelRef), nil
225+
}
226+
227+
artifact, err := s.modelManager.GetRemote(ctx, modelRef)
228+
if err != nil {
229+
return nil, err
230+
}
231+
232+
config, err := artifact.Config()
233+
if err != nil {
234+
return nil, err
235+
}
236+
237+
return s.selectBackendForFormat(config.GetFormat(), backend, modelRef), nil
238+
}
239+
214240
// ResetInstaller resets the backend installer with a new HTTP client.
215241
func (s *Scheduler) ResetInstaller(httpClient *http.Client) {
216242
s.installer = newInstaller(s.log, s.backends, httpClient, s.deferredBackends)

pkg/routing/router.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ func NewRouter(cfg RouterConfig) *NormalizedServeMux {
4949
if cfg.ModelHandlerMiddleware != nil {
5050
modelEndpoint = cfg.ModelHandlerMiddleware(cfg.ModelHandler)
5151
}
52+
router.Handle(inference.ModelsPrefix+"/backend", cfg.SchedulerHTTP)
5253
router.Handle(inference.ModelsPrefix, modelEndpoint)
5354
router.Handle(inference.ModelsPrefix+"/", modelEndpoint)
5455

0 commit comments

Comments
 (0)