diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index c3ecb20a6..0a09d32b7 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -1528,6 +1528,40 @@ public class ProviderConfig /// [JsonPropertyName("headers")] public IDictionary? Headers { get; set; } + + /// + /// Well-known model name used by the runtime to look up agent configuration + /// (tools, prompts, reasoning behavior) and default token limits. Also used + /// as the wire model when is not set. + /// Falls back to . + /// + [JsonPropertyName("modelId")] + public string? ModelId { get; set; } + + /// + /// Model name sent to the provider API for inference. Use this when the + /// provider's model name (e.g. an Azure deployment name or a custom + /// fine-tune name) differs from . + /// Falls back to , then . + /// + [JsonPropertyName("wireModel")] + public string? WireModel { get; set; } + + /// + /// Overrides the resolved model's default max prompt tokens. The runtime + /// triggers conversation compaction before sending a request when the + /// prompt (system message, history, tool definitions, user message) would + /// exceed this limit. + /// + [JsonPropertyName("maxPromptTokens")] + public int? MaxInputTokens { get; set; } + + /// + /// Overrides the resolved model's default max output tokens. When hit, the + /// model stops generating and returns a truncated response. + /// + [JsonPropertyName("maxOutputTokens")] + public int? MaxOutputTokens { get; set; } } /// diff --git a/dotnet/test/E2E/SessionConfigE2ETests.cs b/dotnet/test/E2E/SessionConfigE2ETests.cs index 4ba42ec50..ddd44ea0d 100644 --- a/dotnet/test/E2E/SessionConfigE2ETests.cs +++ b/dotnet/test/E2E/SessionConfigE2ETests.cs @@ -238,6 +238,62 @@ public async Task Should_Forward_Custom_Provider_Headers_On_Resume() await session2.DisposeAsync(); } + [Fact] + public async Task Should_Forward_Provider_Wire_Model() + { + // Verifies that ProviderConfig.WireModel overrides the model name sent to + // the provider API, while SessionConfig.Model still drives runtime + // configuration lookup (capabilities, prompts, reasoning behavior). + // MaxOutputTokens is also set here to confirm the SDK accepts it without + // serialization errors; the CLI does not echo it as `max_tokens` on the + // OpenAI-style wire request, so we don't assert on it directly (see unit + // tests for serialization coverage). + var session = await CreateSessionAsync(new SessionConfig + { + Model = "claude-sonnet-4.5", + Provider = new ProviderConfig + { + Type = "openai", + BaseUrl = Ctx.ProxyUrl, + ApiKey = "test-provider-key", + WireModel = "test-wire-model", + MaxOutputTokens = 1024, + }, + }); + + await session.SendAndWaitAsync(new MessageOptions { Prompt = "What is 1+1?" }); + + var exchange = Assert.Single(await Ctx.GetExchangesAsync()); + Assert.Equal("test-wire-model", exchange.Request.Model); + + await session.DisposeAsync(); + } + + [Fact] + public async Task Should_Use_Provider_Model_Id_As_Wire_Model() + { + // ProviderConfig.ModelId drives both the runtime resolved model AND the wire model + // when WireModel is not specified. Here SessionConfig.Model is intentionally omitted + // so that ModelId is the only model source. + var session = await CreateSessionAsync(new SessionConfig + { + Provider = new ProviderConfig + { + Type = "openai", + BaseUrl = Ctx.ProxyUrl, + ApiKey = "test-provider-key", + ModelId = "claude-sonnet-4.5", + }, + }); + + await session.SendAndWaitAsync(new MessageOptions { Prompt = "What is 1+1?" }); + + var exchange = Assert.Single(await Ctx.GetExchangesAsync()); + Assert.Equal("claude-sonnet-4.5", exchange.Request.Model); + + await session.DisposeAsync(); + } + [Fact] public async Task Should_Use_WorkingDirectory_For_Tool_Execution() { diff --git a/dotnet/test/Unit/SerializationTests.cs b/dotnet/test/Unit/SerializationTests.cs index bfdf8db6a..e58b256f4 100644 --- a/dotnet/test/Unit/SerializationTests.cs +++ b/dotnet/test/Unit/SerializationTests.cs @@ -20,7 +20,11 @@ public void ProviderConfig_CanSerializeHeaders_WithSdkOptions() var original = new ProviderConfig { BaseUrl = "https://example.com/provider", - Headers = new Dictionary { ["Authorization"] = "Bearer provider-token" } + Headers = new Dictionary { ["Authorization"] = "Bearer provider-token" }, + ModelId = "gpt-4o", + WireModel = "my-finetune-v3", + MaxInputTokens = 100_000, + MaxOutputTokens = 4096 }; var json = JsonSerializer.Serialize(original, options); @@ -28,11 +32,19 @@ public void ProviderConfig_CanSerializeHeaders_WithSdkOptions() var root = document.RootElement; Assert.Equal("https://example.com/provider", root.GetProperty("baseUrl").GetString()); Assert.Equal("Bearer provider-token", root.GetProperty("headers").GetProperty("Authorization").GetString()); + Assert.Equal("gpt-4o", root.GetProperty("modelId").GetString()); + Assert.Equal("my-finetune-v3", root.GetProperty("wireModel").GetString()); + Assert.Equal(100_000, root.GetProperty("maxPromptTokens").GetInt32()); + Assert.Equal(4096, root.GetProperty("maxOutputTokens").GetInt32()); var deserialized = JsonSerializer.Deserialize(json, options); Assert.NotNull(deserialized); Assert.Equal("https://example.com/provider", deserialized.BaseUrl); Assert.Equal("Bearer provider-token", deserialized.Headers!["Authorization"]); + Assert.Equal("gpt-4o", deserialized.ModelId); + Assert.Equal("my-finetune-v3", deserialized.WireModel); + Assert.Equal(100_000, deserialized.MaxInputTokens); + Assert.Equal(4096, deserialized.MaxOutputTokens); } [Fact] diff --git a/go/internal/e2e/session_config_e2e_test.go b/go/internal/e2e/session_config_e2e_test.go index a62866cee..d3af7f6c0 100644 --- a/go/internal/e2e/session_config_e2e_test.go +++ b/go/internal/e2e/session_config_e2e_test.go @@ -323,6 +323,85 @@ func TestSessionConfigExtrasE2E(t *testing.T) { } }) + t.Run("should forward provider wire model", func(t *testing.T) { + // Verifies that ProviderConfig.WireModel overrides the model name sent to + // the provider API, while SessionConfig.Model still drives runtime + // configuration lookup (capabilities, prompts, reasoning behavior). + // MaxOutputTokens is also set here to confirm the SDK accepts it without + // serialization errors; the CLI does not echo it as `max_tokens` on the + // OpenAI-style wire request, so we don't assert on it directly (see unit + // tests for serialization coverage). + ctx.ConfigureForTest(t) + + maxOutputTokens := 1024 + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Model: "claude-sonnet-4.5", + Provider: &copilot.ProviderConfig{ + Type: "openai", + BaseURL: ctx.ProxyURL, + APIKey: "test-provider-key", + WireModel: "test-wire-model", + MaxOutputTokens: maxOutputTokens, + }, + }) + if err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + + _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + if err != nil { + t.Fatalf("SendAndWait failed: %v", err) + } + + exchanges, err := ctx.GetExchanges() + if err != nil { + t.Fatalf("GetExchanges failed: %v", err) + } + if len(exchanges) != 1 { + t.Fatalf("Expected exactly 1 exchange, got %d", len(exchanges)) + } + if exchanges[0].Request.Model != "test-wire-model" { + t.Errorf("Expected request model to be 'test-wire-model', got %q", exchanges[0].Request.Model) + } + }) + + t.Run("should use provider model id as wire model", func(t *testing.T) { + // ProviderConfig.ModelID drives both the runtime resolved model AND the wire + // model when WireModel is not specified. SessionConfig.Model is intentionally + // omitted so that ModelID is the only model source. + ctx.ConfigureForTest(t) + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Provider: &copilot.ProviderConfig{ + Type: "openai", + BaseURL: ctx.ProxyURL, + APIKey: "test-provider-key", + ModelID: "claude-sonnet-4.5", + }, + }) + if err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + + _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + if err != nil { + t.Fatalf("SendAndWait failed: %v", err) + } + + exchanges, err := ctx.GetExchanges() + if err != nil { + t.Fatalf("GetExchanges failed: %v", err) + } + if len(exchanges) != 1 { + t.Fatalf("Expected exactly 1 exchange, got %d", len(exchanges)) + } + if exchanges[0].Request.Model != "claude-sonnet-4.5" { + t.Errorf("Expected request model to be 'claude-sonnet-4.5', got %q", exchanges[0].Request.Model) + } + }) + t.Run("should use workingDirectory for tool execution", func(t *testing.T) { ctx.ConfigureForTest(t) diff --git a/go/types.go b/go/types.go index 2c1f6b67e..dd3ffbbe3 100644 --- a/go/types.go +++ b/go/types.go @@ -859,6 +859,25 @@ type ProviderConfig struct { Azure *AzureProviderOptions `json:"azure,omitempty"` // Headers are custom HTTP headers included in outbound provider requests. Headers map[string]string `json:"headers,omitempty"` + // ModelID is the well-known model name used by the runtime to look up + // agent configuration (tools, prompts, reasoning behavior) and default + // token limits. Also used as the wire model when WireModel is not set. + // Falls back to SessionConfig.Model. + ModelID string `json:"modelId,omitempty"` + // WireModel is the model name sent to the provider API for inference. Use + // this when the provider's model name (e.g. an Azure deployment name or a + // custom fine-tune name) differs from ModelID. + // Falls back to ModelID, then SessionConfig.Model. + WireModel string `json:"wireModel,omitempty"` + // MaxInputTokens overrides the resolved model's default max prompt tokens. + // The runtime triggers conversation compaction before sending a request + // when the prompt (system message, history, tool definitions, user + // message) would exceed this limit. + MaxInputTokens int `json:"maxPromptTokens,omitempty"` + // MaxOutputTokens overrides the resolved model's default max output + // tokens. When hit, the model stops generating and returns a truncated + // response. + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` } // AzureProviderOptions contains Azure-specific provider configuration diff --git a/go/types_test.go b/go/types_test.go index ef02424a3..d24e6342f 100644 --- a/go/types_test.go +++ b/go/types_test.go @@ -151,3 +151,68 @@ func TestSessionSendRequest_JSONIncludesRequestHeaders(t *testing.T) { t.Fatalf("expected Authorization header, got %v", headers["Authorization"]) } } + +func TestProviderConfig_JSONIncludesAllFields(t *testing.T) { + cfg := ProviderConfig{ + BaseURL: "https://example.com/provider", + APIKey: "test-key", + Headers: map[string]string{"Authorization": "Bearer provider-token"}, + ModelID: "gpt-4o", + WireModel: "my-finetune-v3", + MaxInputTokens: 100000, + MaxOutputTokens: 4096, + } + + data, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal ProviderConfig: %v", err) + } + + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal ProviderConfig: %v", err) + } + + if decoded["baseUrl"] != "https://example.com/provider" { + t.Errorf("expected baseUrl to round-trip, got %v", decoded["baseUrl"]) + } + if decoded["modelId"] != "gpt-4o" { + t.Errorf("expected modelId 'gpt-4o', got %v", decoded["modelId"]) + } + if decoded["wireModel"] != "my-finetune-v3" { + t.Errorf("expected wireModel 'my-finetune-v3', got %v", decoded["wireModel"]) + } + if decoded["maxPromptTokens"] != float64(100000) { + t.Errorf("expected maxPromptTokens 100000, got %v", decoded["maxPromptTokens"]) + } + if decoded["maxOutputTokens"] != float64(4096) { + t.Errorf("expected maxOutputTokens 4096, got %v", decoded["maxOutputTokens"]) + } + headers, ok := decoded["headers"].(map[string]any) + if !ok { + t.Fatalf("expected headers object, got %T", decoded["headers"]) + } + if headers["Authorization"] != "Bearer provider-token" { + t.Errorf("expected Authorization header, got %v", headers["Authorization"]) + } +} + +func TestProviderConfig_JSONOmitsUnsetTokenFields(t *testing.T) { + cfg := ProviderConfig{BaseURL: "https://example.com/provider"} + + data, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal ProviderConfig: %v", err) + } + + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal ProviderConfig: %v", err) + } + + for _, field := range []string{"modelId", "wireModel", "maxPromptTokens", "maxOutputTokens", "headers"} { + if _, present := decoded[field]; present { + t.Errorf("expected %q to be omitted when unset, got %v", field, decoded[field]) + } + } +} diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 9b6939489..b1b6b4f46 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -42,6 +42,7 @@ import type { GetAuthStatusResponse, GetStatusResponse, ModelInfo, + ProviderConfig, ResumeSessionConfig, SectionTransformFn, SessionConfig, @@ -64,6 +65,17 @@ import type { } from "./types.js"; import { defaultJoinSessionPermissionHandler } from "./types.js"; +/** + * Convert a {@link ProviderConfig} to its JSON-RPC wire shape, remapping + * camelCase SDK property names to the wire keys expected by the runtime + * (e.g. `maxInputTokens` → `maxPromptTokens`). + */ +function toWireProviderConfig(provider: ProviderConfig): Record { + const { maxInputTokens, ...rest } = provider; + if (maxInputTokens === undefined) return rest; + return { ...rest, maxPromptTokens: maxInputTokens }; +} + /** * Minimum protocol version this SDK can communicate with. * Servers reporting a version below this are rejected. @@ -788,7 +800,7 @@ export class CopilotClient { systemMessage: wireSystemMessage, availableTools: config.availableTools, excludedTools: config.excludedTools, - provider: config.provider, + provider: config.provider ? toWireProviderConfig(config.provider) : undefined, modelCapabilities: config.modelCapabilities, requestPermission: true, requestUserInput: !!config.onUserInputRequest, @@ -931,7 +943,7 @@ export class CopilotClient { name: cmd.name, description: cmd.description, })), - provider: config.provider, + provider: config.provider ? toWireProviderConfig(config.provider) : undefined, modelCapabilities: config.modelCapabilities, requestPermission: config.onPermissionRequest !== defaultJoinSessionPermissionHandler, diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index 960d398a9..59dff3d82 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -1503,6 +1503,36 @@ export interface ProviderConfig { * Custom HTTP headers to include in outbound provider requests. */ headers?: Record; + + /** + * Well-known model name used by the runtime to look up agent configuration + * (tools, prompts, reasoning behavior) and default token limits. Also used + * as the wire model when {@link wireModel} is not set. + * Falls back to {@link SessionConfig.model}. + */ + modelId?: string; + + /** + * Model name sent to the provider API for inference. Use this when the + * provider's model name (e.g. an Azure deployment name or a custom + * fine-tune name) differs from {@link modelId}. + * Falls back to {@link modelId}, then {@link SessionConfig.model}. + */ + wireModel?: string; + + /** + * Overrides the resolved model's default max prompt tokens. The runtime + * triggers conversation compaction before sending a request when the + * prompt (system message, history, tool definitions, user message) would + * exceed this limit. + */ + maxInputTokens?: number; + + /** + * Overrides the resolved model's default max output tokens. When hit, the + * model stops generating and returns a truncated response. + */ + maxOutputTokens?: number; } /** diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index 880646f0d..b2fe998ee 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -224,6 +224,10 @@ describe("CopilotClient", () => { provider: { baseUrl: "https://example.com/provider", headers: { Authorization: "Bearer provider-token" }, + modelId: "gpt-4o", + wireModel: "my-finetune-v3", + maxInputTokens: 100_000, + maxOutputTokens: 4096, }, }); @@ -232,6 +236,10 @@ describe("CopilotClient", () => { expect.objectContaining({ baseUrl: "https://example.com/provider", headers: { Authorization: "Bearer provider-token" }, + modelId: "gpt-4o", + wireModel: "my-finetune-v3", + maxPromptTokens: 100_000, + maxOutputTokens: 4096, }) ); spy.mockRestore(); @@ -255,6 +263,10 @@ describe("CopilotClient", () => { provider: { baseUrl: "https://example.com/provider", headers: { Authorization: "Bearer resume-token" }, + modelId: "gpt-4o", + wireModel: "my-finetune-v3", + maxInputTokens: 100_000, + maxOutputTokens: 4096, }, }); @@ -263,6 +275,10 @@ describe("CopilotClient", () => { expect.objectContaining({ baseUrl: "https://example.com/provider", headers: { Authorization: "Bearer resume-token" }, + modelId: "gpt-4o", + wireModel: "my-finetune-v3", + maxPromptTokens: 100_000, + maxOutputTokens: 4096, }) ); spy.mockRestore(); diff --git a/nodejs/test/e2e/session_config.e2e.test.ts b/nodejs/test/e2e/session_config.e2e.test.ts index d288835db..b86c3fa51 100644 --- a/nodejs/test/e2e/session_config.e2e.test.ts +++ b/nodejs/test/e2e/session_config.e2e.test.ts @@ -325,6 +325,58 @@ describe("Session Configuration", async () => { await session2.disconnect(); }); + it("should forward provider wire model", async () => { + // Verifies that ProviderConfig.wireModel overrides the model name sent to + // the provider API, while SessionConfig.model still drives runtime + // configuration lookup (capabilities, prompts, reasoning behavior). + // maxOutputTokens is also set here to confirm the SDK accepts it without + // serialization errors; the CLI does not echo it as `max_tokens` on the + // OpenAI-style wire request, so we don't assert on it directly (see unit + // tests for serialization coverage). + const session = await client.createSession({ + onPermissionRequest: approveAll, + model: "claude-sonnet-4.5", + provider: { + type: "openai", + baseUrl: openAiEndpoint.url, + apiKey: "test-provider-key", + wireModel: "test-wire-model", + maxOutputTokens: 1024, + }, + }); + + await session.sendAndWait({ prompt: "What is 1+1?" }); + + const exchanges = await openAiEndpoint.getExchanges(); + expect(exchanges.length).toBe(1); + expect(exchanges[0].request.model).toBe("test-wire-model"); + + await session.disconnect(); + }); + + it("should use provider model id as wire model", async () => { + // ProviderConfig.modelId drives both the runtime resolved model AND the wire + // model when wireModel is not specified. SessionConfig.model is intentionally + // omitted so that modelId is the only model source. + const session = await client.createSession({ + onPermissionRequest: approveAll, + provider: { + type: "openai", + baseUrl: openAiEndpoint.url, + apiKey: "test-provider-key", + modelId: "claude-sonnet-4.5", + }, + }); + + await session.sendAndWait({ prompt: "What is 1+1?" }); + + const exchanges = await openAiEndpoint.getExchanges(); + expect(exchanges.length).toBe(1); + expect(exchanges[0].request.model).toBe("claude-sonnet-4.5"); + + await session.disconnect(); + }); + it("should apply workingDirectory on session resume", async () => { const subDir = join(workDir, "resume-subproject"); await mkdir(subDir, { recursive: true }); diff --git a/python/copilot/client.py b/python/copilot/client.py index 56653e2b7..0e03dcbf7 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -2275,6 +2275,14 @@ def _convert_provider_to_wire_format( wire_provider["bearerToken"] = provider["bearer_token"] if "headers" in provider: wire_provider["headers"] = provider["headers"] + if "model_id" in provider: + wire_provider["modelId"] = provider["model_id"] + if "wire_model" in provider: + wire_provider["wireModel"] = provider["wire_model"] + if "max_input_tokens" in provider: + wire_provider["maxPromptTokens"] = provider["max_input_tokens"] + if "max_output_tokens" in provider: + wire_provider["maxOutputTokens"] = provider["max_output_tokens"] if "azure" in provider: azure = provider["azure"] wire_azure: dict[str, Any] = {} diff --git a/python/copilot/session.py b/python/copilot/session.py index 980aa70df..97a505c25 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -832,6 +832,24 @@ class ProviderConfig(TypedDict, total=False): bearer_token: str azure: AzureProviderOptions # Azure-specific options headers: dict[str, str] + # Well-known model name used by the runtime to look up agent configuration + # (tools, prompts, reasoning behavior) and default token limits. Also used + # as the wire model when wire_model is not set. + # Falls back to SessionConfig.model. + model_id: str + # Model name sent to the provider API for inference. Use this when the + # provider's model name (e.g. an Azure deployment name or a custom + # fine-tune name) differs from model_id. + # Falls back to model_id, then SessionConfig.model. + wire_model: str + # Overrides the resolved model's default max prompt tokens. The runtime + # triggers conversation compaction before sending a request when the prompt + # (system message, history, tool definitions, user message) would exceed + # this limit. + max_input_tokens: int + # Overrides the resolved model's default max output tokens. When hit, the + # model stops generating and returns a truncated response. + max_output_tokens: int class SessionConfig(TypedDict, total=False): diff --git a/python/e2e/test_session_config_e2e.py b/python/e2e/test_session_config_e2e.py index 825afe8a1..1fd2cd0a2 100644 --- a/python/e2e/test_session_config_e2e.py +++ b/python/e2e/test_session_config_e2e.py @@ -236,6 +236,57 @@ async def test_should_forward_custom_provider_headers_on_resume(self, ctx: E2ETe await session2.disconnect() await session1.disconnect() + async def test_should_forward_provider_wire_model(self, ctx: E2ETestContext): + # Verifies that ProviderConfig.wire_model overrides the model name sent + # to the provider API, while SessionConfig.model still drives runtime + # configuration lookup (capabilities, prompts, reasoning behavior). + # max_output_tokens is also set here to confirm the SDK accepts it + # without serialization errors; the CLI does not echo it as + # `max_tokens` on the OpenAI-style wire request, so we don't assert on + # it directly (see unit tests for serialization coverage). + session = await ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all, + model="claude-sonnet-4.5", + provider={ + "type": "openai", + "base_url": ctx.proxy_url, + "api_key": "test-provider-key", + "wire_model": "test-wire-model", + "max_output_tokens": 1024, + }, + ) + + await session.send_and_wait("What is 1+1?") + + exchanges = await ctx.get_exchanges() + assert len(exchanges) == 1 + request = exchanges[0]["request"] + assert request["model"] == "test-wire-model" + + await session.disconnect() + + async def test_should_use_provider_model_id_as_wire_model(self, ctx: E2ETestContext): + # ProviderConfig.model_id drives both the runtime resolved model AND the wire + # model when wire_model is not specified. SessionConfig.model is intentionally + # omitted so that model_id is the only model source. + session = await ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all, + provider={ + "type": "openai", + "base_url": ctx.proxy_url, + "api_key": "test-provider-key", + "model_id": "claude-sonnet-4.5", + }, + ) + + await session.send_and_wait("What is 1+1?") + + exchanges = await ctx.get_exchanges() + assert len(exchanges) == 1 + assert exchanges[0]["request"]["model"] == "claude-sonnet-4.5" + + await session.disconnect() + async def test_should_use_workingdirectory_for_tool_execution(self, ctx: E2ETestContext): sub_dir = os.path.join(ctx.work_dir, "subproject") os.makedirs(sub_dir, exist_ok=True) diff --git a/python/test_client.py b/python/test_client.py index c11ba708c..a890ca12e 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -564,12 +564,20 @@ async def mock_request(method, params): provider={ "base_url": "https://example.com/provider", "headers": {"Authorization": "Bearer provider-token"}, + "model_id": "gpt-4o", + "wire_model": "my-finetune-v3", + "max_input_tokens": 100_000, + "max_output_tokens": 4096, }, ) provider = captured["session.create"]["provider"] assert provider["baseUrl"] == "https://example.com/provider" assert provider["headers"] == {"Authorization": "Bearer provider-token"} + assert provider["modelId"] == "gpt-4o" + assert provider["wireModel"] == "my-finetune-v3" + assert provider["maxPromptTokens"] == 100_000 + assert provider["maxOutputTokens"] == 4096 finally: await client.force_stop() @@ -599,12 +607,20 @@ async def mock_request(method, params): provider={ "base_url": "https://example.com/provider", "headers": {"Authorization": "Bearer resume-token"}, + "model_id": "gpt-4o", + "wire_model": "my-finetune-v3", + "max_input_tokens": 100_000, + "max_output_tokens": 4096, }, ) provider = captured["session.resume"]["provider"] assert provider["baseUrl"] == "https://example.com/provider" assert provider["headers"] == {"Authorization": "Bearer resume-token"} + assert provider["modelId"] == "gpt-4o" + assert provider["wireModel"] == "my-finetune-v3" + assert provider["maxPromptTokens"] == 100_000 + assert provider["maxOutputTokens"] == 4096 finally: await client.force_stop() diff --git a/test/snapshots/session_config/should_forward_provider_wire_model.yaml b/test/snapshots/session_config/should_forward_provider_wire_model.yaml new file mode 100644 index 000000000..6d25ae167 --- /dev/null +++ b/test/snapshots/session_config/should_forward_provider_wire_model.yaml @@ -0,0 +1,11 @@ +models: + - claude-sonnet-4.5 + - test-wire-model +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: What is 1+1? + - role: assistant + content: 1 + 1 = 2 diff --git a/test/snapshots/session_config/should_use_provider_model_id_as_wire_model.yaml b/test/snapshots/session_config/should_use_provider_model_id_as_wire_model.yaml new file mode 100644 index 000000000..250402101 --- /dev/null +++ b/test/snapshots/session_config/should_use_provider_model_id_as_wire_model.yaml @@ -0,0 +1,10 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: What is 1+1? + - role: assistant + content: 1 + 1 = 2