-
Notifications
You must be signed in to change notification settings - Fork 116
Expand file tree
/
Copy pathutils.go
More file actions
388 lines (338 loc) · 12.3 KB
/
utils.go
File metadata and controls
388 lines (338 loc) · 12.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
package commands
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"strings"
"github.com/docker/cli/cli-plugins/hooks"
"github.com/docker/model-runner/cmd/cli/desktop"
"github.com/docker/model-runner/cmd/cli/pkg/standalone"
"github.com/docker/model-runner/pkg/distribution/oci/reference"
"github.com/docker/model-runner/pkg/distribution/types"
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
"github.com/docker/model-runner/pkg/inference/backends/vllm"
dmrm "github.com/docker/model-runner/pkg/inference/models"
"github.com/moby/term"
"github.com/olekukonko/tablewriter"
"github.com/olekukonko/tablewriter/renderer"
"github.com/olekukonko/tablewriter/tw"
"github.com/spf13/cobra"
)
const (
defaultOrg = "ai"
defaultTag = "latest"
)
const (
enableViaCLI = "Enable Docker Model Runner via the CLI → docker desktop enable model-runner"
enableViaGUI = "Enable Docker Model Runner via the GUI → Go to Settings->AI->Enable Docker Model Runner"
enableVLLM = "It looks like you're trying to use a model for vLLM → docker model reinstall-runner --backend vllm --gpu cuda"
)
// getDefaultRegistry returns the default registry, checking for environment override
// If DEFAULT_REGISTRY environment variable is set, it returns that value
// Otherwise, it returns reference.DefaultRegistry ("index.docker.io")
func getDefaultRegistry() string {
if defaultReg := os.Getenv("DEFAULT_REGISTRY"); defaultReg != "" {
return defaultReg
}
return reference.DefaultRegistry
}
var errNotRunning = fmt.Errorf("Docker Model Runner is not running. Please start it and try again.\n")
func handleClientError(err error, message string) error {
if errors.Is(err, desktop.ErrServiceUnavailable) {
err = errNotRunning
var buf bytes.Buffer
hooks.PrintNextSteps(&buf, []string{enableViaCLI, enableViaGUI})
return fmt.Errorf("%w\n%s", err, strings.TrimRight(buf.String(), "\n"))
} else if strings.Contains(err.Error(), vllm.ErrorNotFound.Error()) {
// Handle `run` error.
var buf bytes.Buffer
hooks.PrintNextSteps(&buf, []string{enableVLLM})
return fmt.Errorf("%w\n%s", err, strings.TrimRight(buf.String(), "\n"))
}
return fmt.Errorf("%s: %w", message, err)
}
// commandPrinter wraps a cobra.Command to implement standalone.StatusPrinter
type commandPrinter struct {
cmd *cobra.Command
}
// Printf implements StatusPrinter.Printf by delegating to cobra.Command.Printf
func (cp *commandPrinter) Printf(format string, args ...any) {
cp.cmd.Printf(format, args...)
}
// Println implements StatusPrinter.Println by delegating to cobra.Command.Println
func (cp *commandPrinter) Println(args ...any) {
cp.cmd.Println(args...)
}
// PrintErrf implements StatusPrinter.PrintErrf by delegating to cobra.Command.PrintErrf
func (cp *commandPrinter) PrintErrf(format string, args ...any) {
cp.cmd.PrintErrf(format, args...)
}
// Write implements StatusPrinter.Write by delegating to cobra.Command's output writer
func (cp *commandPrinter) Write(p []byte) (n int, err error) {
return cp.cmd.OutOrStdout().Write(p)
}
// GetFdInfo returns the file descriptor and terminal status of the command's output
func (cp *commandPrinter) GetFdInfo() (fd uintptr, isTerminal bool) {
out := cp.cmd.OutOrStdout()
if file, ok := out.(*os.File); ok {
return term.GetFdInfo(file)
}
// For progress display, we care about whether stdout is a terminal
// Even if cobra wraps the output, checking os.Stdout directly is appropriate
// because that's where the visual progress bars should be displayed
return term.GetFdInfo(os.Stdout)
}
// asPrinter wraps a cobra.Command to implement standalone.StatusPrinter
func asPrinter(cmd *cobra.Command) standalone.StatusPrinter {
return &commandPrinter{cmd: cmd}
}
// stripDefaultsFromModelName removes the default "ai/" prefix, default registry, and ":latest" tag for display.
// Examples:
// - "ai/gemma3:latest" -> "gemma3"
// - "ai/gemma3:v1" -> "gemma3:v1"
// - "myorg/gemma3:latest" -> "myorg/gemma3"
// - "gemma3:latest" -> "gemma3"
// - "index.docker.io/ai/gemma3:latest" -> "gemma3"
// - "docker.io/ai/gemma3:latest" -> "gemma3"
// - "docker.io/myorg/gemma3:latest" -> "myorg/gemma3"
// - "hf.co/bartowski/model:latest" -> "hf.co/bartowski/model"
func stripDefaultsFromModelName(model string) string {
// Get the current default registry (checking for environment override)
defaultRegistry := getDefaultRegistry()
// Handle the common default registries that are aliases for each other
// Always handle "index.docker.io" and "docker.io" as defaults regardless of DEFAULT_REGISTRY env var
// since they are equivalent and commonly used interchangeably
defaultRegistries := []string{"index.docker.io/", "docker.io/"}
if defaultRegistry != "" &&
defaultRegistry != "index.docker.io" &&
defaultRegistry != "docker.io" {
// Ensure it has a trailing slash for correct prefix trimming
if !strings.HasSuffix(defaultRegistry, "/") {
defaultRegistry += "/"
}
// Overwrite the list to contain only the custom registry
defaultRegistries = []string{defaultRegistry}
}
// Check for the common default registries first
for _, reg := range defaultRegistries {
if strings.HasPrefix(model, reg) {
// Remove the registry prefix
model = strings.TrimPrefix(model, reg)
break
}
}
// If model has default org prefix (without tag, or with :latest tag), strip the org
// but preserve other tags
if strings.HasPrefix(model, defaultOrg+"/") {
model = strings.TrimPrefix(model, defaultOrg+"/")
}
// Check if model has :latest but no slash (no org specified) - strip :latest
if strings.HasSuffix(model, ":"+defaultTag) {
model = strings.TrimSuffix(model, ":"+defaultTag)
}
// For other cases (ai/ with custom tag, custom org with :latest, etc.), keep as-is
return model
}
// requireExactArgs returns a cobra.PositionalArgs validator that ensures exactly n arguments are provided
func requireExactArgs(n int, cmdName string, usageArgs string) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
if len(args) != n {
return fmt.Errorf(
"'docker model %s' requires %d argument(s).\n\n"+
"Usage: docker model %s %s\n\n"+
"See 'docker model %s --help' for more information",
cmdName, n, cmdName, usageArgs, cmdName,
)
}
return nil
}
}
// requireMinArgs returns a cobra.PositionalArgs validator that ensures at least n arguments are provided
func requireMinArgs(n int, cmdName string, usageArgs string) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
if len(args) < n {
return fmt.Errorf(
"'docker model %s' requires at least %d argument(s).\n\n"+
"Usage: docker model %s %s\n\n"+
"See 'docker model %s --help' for more information",
cmdName, n, cmdName, usageArgs, cmdName,
)
}
return nil
}
}
// runnerFlagOptions holds common runner configuration options
type runnerFlagOptions struct {
Port *uint16
Host *string
GpuMode *string
Backend *string
DoNotTrack *bool
Debug *bool
ProxyCert *string
TLS *bool
TLSPort *uint16
TLSCert *string
TLSKey *string
}
// addRunnerFlags adds common runner flags to a command
func addRunnerFlags(cmd *cobra.Command, opts runnerFlagOptions) {
if opts.Port != nil {
cmd.Flags().Uint16Var(opts.Port, "port", 0,
"Docker container port for Docker Model Runner (default: 12434 for Docker Engine, 12435 for Cloud mode)")
}
if opts.Host != nil {
cmd.Flags().StringVar(opts.Host, "host", "127.0.0.1", "Host address to bind Docker Model Runner")
}
if opts.GpuMode != nil {
cmd.Flags().StringVar(opts.GpuMode, "gpu", "auto", "Specify GPU support (none|auto|cuda|rocm|musa|cann)")
}
if opts.Backend != nil {
cmd.Flags().StringVar(opts.Backend, "backend", "", backendUsage)
}
if opts.DoNotTrack != nil {
cmd.Flags().BoolVar(opts.DoNotTrack, "do-not-track", false, "Do not track models usage in Docker Model Runner")
}
if opts.Debug != nil {
cmd.Flags().BoolVar(opts.Debug, "debug", false, "Enable debug logging")
}
if opts.ProxyCert != nil {
cmd.Flags().StringVar(opts.ProxyCert, "proxy-cert", "", "Path to a CA certificate file for proxy SSL inspection")
}
if opts.TLS != nil {
cmd.Flags().BoolVar(opts.TLS, "tls", false, "Enable TLS/HTTPS for Docker Model Runner API")
}
if opts.TLSPort != nil {
cmd.Flags().Uint16Var(opts.TLSPort, "tls-port", 0,
"TLS port for Docker Model Runner (default: 12444 for Docker Engine, 12445 for Cloud mode)")
}
if opts.TLSCert != nil {
cmd.Flags().StringVar(opts.TLSCert, "tls-cert", "", "Path to TLS certificate file (auto-generated if not provided)")
}
if opts.TLSKey != nil {
cmd.Flags().StringVar(opts.TLSKey, "tls-key", "", "Path to TLS private key file (auto-generated if not provided)")
}
}
// newTable creates a new table with Docker CLI-style formatting:
// no borders, no column separators, no header line, left-aligned, and 2-space padding.
func newTable(w io.Writer) *tablewriter.Table {
return tablewriter.NewTable(w,
tablewriter.WithRenderer(renderer.NewBlueprint(tw.Rendition{
Borders: tw.BorderNone,
Settings: tw.Settings{
Separators: tw.Separators{
BetweenColumns: tw.Off,
},
Lines: tw.Lines{
ShowHeaderLine: tw.Off,
},
},
})),
tablewriter.WithConfig(tablewriter.Config{
Header: tw.CellConfig{
Formatting: tw.CellFormatting{
AutoFormat: tw.Off,
},
Alignment: tw.CellAlignment{Global: tw.AlignLeft},
Padding: tw.CellPadding{Global: tw.Padding{Left: "", Right: " "}},
},
Row: tw.CellConfig{
Alignment: tw.CellAlignment{Global: tw.AlignLeft},
Padding: tw.CellPadding{Global: tw.Padding{Left: "", Right: " "}},
},
}),
)
}
func CheckBackendInstalled(backend string) (bool, error) {
status := desktopClient.Status()
if status.Error != nil {
return false, fmt.Errorf("failed to get backend status: %w", status.Error)
}
var backendStatus map[string]string
if err := json.Unmarshal(status.Status, &backendStatus); err != nil {
return false, fmt.Errorf("failed to parse backend status: %w", err)
}
backendState, exists := backendStatus[backend]
if !exists {
return false, nil
}
state := strings.TrimSpace(strings.ToLower(backendState))
if strings.HasPrefix(state, "not ") || strings.HasPrefix(state, "error") {
return false, nil
}
return strings.HasPrefix(state, "installed") || strings.HasPrefix(state, "running"), nil
}
func PromptInstallBackend(backend string, cmd *cobra.Command) (bool, error) {
fmt.Fprintf(cmd.OutOrStdout(), "Backend %q is not installed. Download and install it now? [Y/n]: ", backend)
reader := bufio.NewReader(os.Stdin)
input, err := reader.ReadString('\n')
if err != nil {
return false, fmt.Errorf("failed to read input: %w", err)
}
input = strings.TrimSpace(strings.ToLower(input))
return input == "" || input == "y" || input == "yes", nil
}
func InstallBackend(backend string, cmd *cobra.Command) error {
installCmd := newInstallRunner()
installCmd.SetArgs([]string{"--backend", backend})
if err := installCmd.Execute(); err != nil {
return fmt.Errorf("failed to install backend %s: %w", backend, err)
}
return nil
}
func EnsureBackendAvailable(backend string, cmd *cobra.Command) error {
installed, err := CheckBackendInstalled(backend)
if err != nil {
return err
}
if installed {
return nil
}
confirm, err := PromptInstallBackend(backend, cmd)
if err != nil {
return err
}
if !confirm {
cmd.Printf("Run 'docker model install-runner --backend %s' to install it manually.\n", backend)
return fmt.Errorf("backend installation cancelled")
}
if err := InstallBackend(backend, cmd); err != nil {
return err
}
installed, err = CheckBackendInstalled(backend)
if err != nil {
return err
}
if !installed {
return fmt.Errorf("backend %q is still not installed; run 'docker model install-runner --backend %s'", backend, backend)
}
cmd.Printf("Backend %q installed successfully.\n", backend)
return nil
}
func GetRequiredBackend(model string) (string, error) {
modelInfo, err := desktopClient.Inspect(model, false)
if err != nil {
return "", err
}
return GetRequiredBackendFromModelInfo(&modelInfo)
}
func GetRequiredBackendFromModelInfo(modelInfo *dmrm.Model) (string, error) {
config, ok := modelInfo.Config.(*types.Config)
if !ok {
return llamacpp.Name, nil
}
switch config.Format {
case types.FormatSafetensors:
return vllm.Name, nil
case types.FormatGGUF:
return llamacpp.Name, nil
case types.FormatDiffusers:
return "diffusers", nil
default:
return llamacpp.Name, nil
}
}