diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 908542b..f99412c 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -532,8 +532,8 @@ func (c *Client) invokeHTTP(ctx context.Context, upstream Upstream, tool string, if err != nil { return InvokeResult{}, err } - var rpcResp rpcResponse - if err := json.Unmarshal(result.Body, &rpcResp); err != nil { + rpcResp, err := decodeRPCResponse(result) + if err != nil { return InvokeResult{}, err } if len(rpcResp.Error) > 0 && string(rpcResp.Error) != "null" { @@ -545,8 +545,8 @@ func (c *Client) invokeHTTP(ctx context.Context, upstream Upstream, tool string, if err != nil { return InvokeResult{}, err } - rpcResp = rpcResponse{} - if err := json.Unmarshal(result.Body, &rpcResp); err != nil { + rpcResp, err = decodeRPCResponse(result) + if err != nil { return InvokeResult{}, err } } @@ -575,8 +575,8 @@ func (c *Client) listToolsHTTP(ctx context.Context, upstream Upstream) ([]Tool, if err != nil { return nil, err } - var rpcResp rpcResponse - if err := json.Unmarshal(result.Body, &rpcResp); err != nil { + rpcResp, err := decodeRPCResponse(result) + if err != nil { return nil, err } if len(rpcResp.Error) > 0 && string(rpcResp.Error) != "null" { @@ -588,8 +588,8 @@ func (c *Client) listToolsHTTP(ctx context.Context, upstream Upstream) ([]Tool, if err != nil { return nil, err } - rpcResp = rpcResponse{} - if err := json.Unmarshal(result.Body, &rpcResp); err != nil { + rpcResp, err = decodeRPCResponse(result) + if err != nil { return nil, err } } @@ -664,13 +664,6 @@ func (c *Client) doHTTPEnvelope(ctx context.Context, upstream Upstream, body []b bodyBytes, _ := io.ReadAll(resp.Body) return ForwardResult{StatusCode: resp.StatusCode, Body: bodyBytes, ContentType: contentType, ProtocolVersion: resp.Header.Get("MCP-Protocol-Version"), SessionExpired: true, SessionID: sessionID}, nil } - if strings.Contains(contentType, "text/event-stream") { - data, sseErr := extractFirstSSEData(resp.Body) - if sseErr != nil { - return ForwardResult{}, sseErr - } - return ForwardResult{StatusCode: resp.StatusCode, Body: data, ContentType: "application/json", ProtocolVersion: resp.Header.Get("MCP-Protocol-Version"), SessionID: sessionID}, nil - } respBody := new(bytes.Buffer) _, err = respBody.ReadFrom(resp.Body) if err != nil { @@ -834,9 +827,14 @@ func (c *Client) initializeHTTPSession(ctx context.Context, upstream Upstream, p c.clearSession(upstream.Name) return false, fmt.Errorf("upstream initialize using MCP %s failed: http %d: %s", protocolVersion, result.StatusCode, extractErrorDetail(result.Body)) } - if len(bytes.TrimSpace(result.Body)) > 0 { + resultBody, err := decodeRPCPayload(result) + if err != nil { + c.clearSession(upstream.Name) + return false, fmt.Errorf("upstream initialize using MCP %s returned invalid SSE JSON-RPC: %w", protocolVersion, err) + } + if len(bytes.TrimSpace(resultBody)) > 0 { var rpcResp rpcResponse - if err := json.Unmarshal(result.Body, &rpcResp); err != nil { + if err := json.Unmarshal(resultBody, &rpcResp); err != nil { c.clearSession(upstream.Name) return false, fmt.Errorf("upstream initialize using MCP %s returned invalid JSON-RPC: %w", protocolVersion, err) } @@ -858,6 +856,25 @@ func (c *Client) initializeHTTPSession(ctx context.Context, upstream Upstream, p return hadSession, nil } +func decodeRPCResponse(result ForwardResult) (rpcResponse, error) { + body, err := decodeRPCPayload(result) + if err != nil { + return rpcResponse{}, err + } + var rpcResp rpcResponse + if err := json.Unmarshal(body, &rpcResp); err != nil { + return rpcResponse{}, err + } + return rpcResp, nil +} + +func decodeRPCPayload(result ForwardResult) ([]byte, error) { + if strings.Contains(strings.ToLower(result.ContentType), "text/event-stream") { + return extractFirstSSEData(bytes.NewReader(result.Body)) + } + return result.Body, nil +} + func (c *Client) invokeStdio(ctx context.Context, upstream Upstream, tool string, input map[string]any) (InvokeResult, error) { if upstream.Command == "" { return InvokeResult{}, fmt.Errorf("stdio upstream %q missing command", upstream.Name) diff --git a/internal/mcp/client_test.go b/internal/mcp/client_test.go index 997c0c0..3f497a1 100644 --- a/internal/mcp/client_test.go +++ b/internal/mcp/client_test.go @@ -132,6 +132,96 @@ func TestListToolsRetriesInitializeWithCompatibleProtocol(t *testing.T) { } } +func TestListToolsDecodesSSEResponseAfterMissingSessionReinitialize(t *testing.T) { + var toolsListCount int + var sessions []string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Envelope + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + switch req.Method { + case "initialize": + sessionID := "sid-1" + if len(sessions) > 0 { + sessionID = "sid-2" + } + sessions = append(sessions, sessionID) + w.Header().Set("Mcp-Session-Id", sessionID) + writeTestRPC(w, req.ID, map[string]any{"protocolVersion": r.Header.Get("MCP-Protocol-Version"), "capabilities": map[string]any{}}, nil) + case "notifications/initialized": + w.WriteHeader(http.StatusAccepted) + case "tools/list": + toolsListCount++ + if toolsListCount == 1 { + writeTestRPC(w, req.ID, nil, map[string]any{"code": -32000, "message": "No session ID provided for non-initialization request"}) + return + } + if got := r.Header.Get("Mcp-Session-Id"); got != "sid-2" { + t.Fatalf("retry tools/list used session %q, want sid-2", got) + } + writeTestRPCSSE(w, req.ID, map[string]any{"tools": []map[string]any{{"name": "stories.search"}}}, nil) + default: + t.Fatalf("unexpected method %q", req.Method) + } + })) + defer server.Close() + + client := NewHTTPClient() + client.httpClient = server.Client() + tools, err := client.ListTools(context.Background(), Upstream{Name: "shortcut", Mode: UpstreamModeHTTP, BaseURL: server.URL}) + if err != nil { + t.Fatalf("ListTools returned error: %v", err) + } + if len(tools) != 1 || tools[0].Name != "stories.search" { + t.Fatalf("unexpected tools: %#v", tools) + } + if toolsListCount != 2 { + t.Fatalf("tools/list count = %d, want 2", toolsListCount) + } + if len(sessions) != 2 { + t.Fatalf("initialize sessions = %#v, want two sessions", sessions) + } +} + +func TestInitializeDecodesSSEResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Envelope + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + switch req.Method { + case "initialize": + w.Header().Set("Mcp-Session-Id", "sid-sse") + writeTestRPCSSE(w, req.ID, map[string]any{"protocolVersion": r.Header.Get("MCP-Protocol-Version"), "capabilities": map[string]any{}}, nil) + case "notifications/initialized": + if got := r.Header.Get("Mcp-Session-Id"); got != "sid-sse" { + t.Fatalf("initialized notification missing session id: %q", got) + } + w.WriteHeader(http.StatusAccepted) + case "tools/list": + if got := r.Header.Get("Mcp-Session-Id"); got != "sid-sse" { + t.Fatalf("tools/list missing session id: %q", got) + } + writeTestRPC(w, req.ID, map[string]any{"tools": []map[string]any{{"name": "stories.search"}}}, nil) + default: + t.Fatalf("unexpected method %q", req.Method) + } + })) + defer server.Close() + + client := NewHTTPClient() + client.httpClient = server.Client() + tools, err := client.ListTools(context.Background(), Upstream{Name: "shortcut", Mode: UpstreamModeHTTP, BaseURL: server.URL}) + if err != nil { + t.Fatalf("ListTools returned error: %v", err) + } + if len(tools) != 1 || tools[0].Name != "stories.search" { + t.Fatalf("unexpected tools: %#v", tools) + } +} + func TestMissingSessionRPCErrorDetection(t *testing.T) { if !isMissingSessionRPCError(json.RawMessage(`{"code":-32000,"message":"No session ID provided for non-initialization request"}`)) { t.Fatal("expected missing session error to be detected") @@ -417,3 +507,19 @@ func writeTestRPC(w http.ResponseWriter, id json.RawMessage, result any, rpcErr panic(err) } } + +func writeTestRPCSSE(w http.ResponseWriter, id json.RawMessage, result any, rpcErr any) { + w.Header().Set("Content-Type", "text/event-stream") + resp := map[string]any{"jsonrpc": "2.0", "id": id} + if rpcErr != nil { + resp["error"] = rpcErr + } else { + resp["result"] = result + } + payload, err := json.Marshal(resp) + if err != nil { + panic(err) + } + _, _ = w.Write([]byte("event: message\n")) + _, _ = w.Write([]byte("data: " + string(payload) + "\n\n")) +}