Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 34 additions & 17 deletions internal/mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ func (c *Client) invokeHTTP(ctx context.Context, upstream Upstream, tool string,
if err != nil {
return InvokeResult{}, err
}
var rpcResp rpcResponse
if err := json.Unmarshal(result.Body, &rpcResp); err != nil {
rpcResp, err := decodeRPCResponse(result)
if err != nil {
return InvokeResult{}, err
}
if len(rpcResp.Error) > 0 && string(rpcResp.Error) != "null" {
Expand All @@ -545,8 +545,8 @@ func (c *Client) invokeHTTP(ctx context.Context, upstream Upstream, tool string,
if err != nil {
return InvokeResult{}, err
}
rpcResp = rpcResponse{}
if err := json.Unmarshal(result.Body, &rpcResp); err != nil {
rpcResp, err = decodeRPCResponse(result)
if err != nil {
return InvokeResult{}, err
}
}
Expand Down Expand Up @@ -575,8 +575,8 @@ func (c *Client) listToolsHTTP(ctx context.Context, upstream Upstream) ([]Tool,
if err != nil {
return nil, err
}
var rpcResp rpcResponse
if err := json.Unmarshal(result.Body, &rpcResp); err != nil {
rpcResp, err := decodeRPCResponse(result)
if err != nil {
return nil, err
}
if len(rpcResp.Error) > 0 && string(rpcResp.Error) != "null" {
Expand All @@ -588,8 +588,8 @@ func (c *Client) listToolsHTTP(ctx context.Context, upstream Upstream) ([]Tool,
if err != nil {
return nil, err
}
rpcResp = rpcResponse{}
if err := json.Unmarshal(result.Body, &rpcResp); err != nil {
rpcResp, err = decodeRPCResponse(result)
if err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -664,13 +664,6 @@ func (c *Client) doHTTPEnvelope(ctx context.Context, upstream Upstream, body []b
bodyBytes, _ := io.ReadAll(resp.Body)
return ForwardResult{StatusCode: resp.StatusCode, Body: bodyBytes, ContentType: contentType, ProtocolVersion: resp.Header.Get("MCP-Protocol-Version"), SessionExpired: true, SessionID: sessionID}, nil
}
if strings.Contains(contentType, "text/event-stream") {
data, sseErr := extractFirstSSEData(resp.Body)
if sseErr != nil {
return ForwardResult{}, sseErr
}
return ForwardResult{StatusCode: resp.StatusCode, Body: data, ContentType: "application/json", ProtocolVersion: resp.Header.Get("MCP-Protocol-Version"), SessionID: sessionID}, nil
}
respBody := new(bytes.Buffer)
_, err = respBody.ReadFrom(resp.Body)
if err != nil {
Expand Down Expand Up @@ -834,9 +827,14 @@ func (c *Client) initializeHTTPSession(ctx context.Context, upstream Upstream, p
c.clearSession(upstream.Name)
return false, fmt.Errorf("upstream initialize using MCP %s failed: http %d: %s", protocolVersion, result.StatusCode, extractErrorDetail(result.Body))
}
if len(bytes.TrimSpace(result.Body)) > 0 {
resultBody, err := decodeRPCPayload(result)
if err != nil {
c.clearSession(upstream.Name)
return false, fmt.Errorf("upstream initialize using MCP %s returned invalid SSE JSON-RPC: %w", protocolVersion, err)
}
if len(bytes.TrimSpace(resultBody)) > 0 {
var rpcResp rpcResponse
if err := json.Unmarshal(result.Body, &rpcResp); err != nil {
if err := json.Unmarshal(resultBody, &rpcResp); err != nil {
c.clearSession(upstream.Name)
return false, fmt.Errorf("upstream initialize using MCP %s returned invalid JSON-RPC: %w", protocolVersion, err)
}
Expand All @@ -858,6 +856,25 @@ func (c *Client) initializeHTTPSession(ctx context.Context, upstream Upstream, p
return hadSession, nil
}

func decodeRPCResponse(result ForwardResult) (rpcResponse, error) {
body, err := decodeRPCPayload(result)
if err != nil {
return rpcResponse{}, err
}
var rpcResp rpcResponse
if err := json.Unmarshal(body, &rpcResp); err != nil {
return rpcResponse{}, err
}
return rpcResp, nil
}

func decodeRPCPayload(result ForwardResult) ([]byte, error) {
if strings.Contains(strings.ToLower(result.ContentType), "text/event-stream") {
return extractFirstSSEData(bytes.NewReader(result.Body))
}
return result.Body, nil
}

func (c *Client) invokeStdio(ctx context.Context, upstream Upstream, tool string, input map[string]any) (InvokeResult, error) {
if upstream.Command == "" {
return InvokeResult{}, fmt.Errorf("stdio upstream %q missing command", upstream.Name)
Expand Down
106 changes: 106 additions & 0 deletions internal/mcp/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,96 @@ func TestListToolsRetriesInitializeWithCompatibleProtocol(t *testing.T) {
}
}

func TestListToolsDecodesSSEResponseAfterMissingSessionReinitialize(t *testing.T) {
var toolsListCount int
var sessions []string

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req Envelope
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
switch req.Method {
case "initialize":
sessionID := "sid-1"
if len(sessions) > 0 {
sessionID = "sid-2"
}
sessions = append(sessions, sessionID)
w.Header().Set("Mcp-Session-Id", sessionID)
writeTestRPC(w, req.ID, map[string]any{"protocolVersion": r.Header.Get("MCP-Protocol-Version"), "capabilities": map[string]any{}}, nil)
case "notifications/initialized":
w.WriteHeader(http.StatusAccepted)
case "tools/list":
toolsListCount++
if toolsListCount == 1 {
writeTestRPC(w, req.ID, nil, map[string]any{"code": -32000, "message": "No session ID provided for non-initialization request"})
return
}
if got := r.Header.Get("Mcp-Session-Id"); got != "sid-2" {
t.Fatalf("retry tools/list used session %q, want sid-2", got)
}
writeTestRPCSSE(w, req.ID, map[string]any{"tools": []map[string]any{{"name": "stories.search"}}}, nil)
default:
t.Fatalf("unexpected method %q", req.Method)
}
}))
defer server.Close()

client := NewHTTPClient()
client.httpClient = server.Client()
tools, err := client.ListTools(context.Background(), Upstream{Name: "shortcut", Mode: UpstreamModeHTTP, BaseURL: server.URL})
if err != nil {
t.Fatalf("ListTools returned error: %v", err)
}
if len(tools) != 1 || tools[0].Name != "stories.search" {
t.Fatalf("unexpected tools: %#v", tools)
}
if toolsListCount != 2 {
t.Fatalf("tools/list count = %d, want 2", toolsListCount)
}
if len(sessions) != 2 {
t.Fatalf("initialize sessions = %#v, want two sessions", sessions)
}
}

func TestInitializeDecodesSSEResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req Envelope
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
switch req.Method {
case "initialize":
w.Header().Set("Mcp-Session-Id", "sid-sse")
writeTestRPCSSE(w, req.ID, map[string]any{"protocolVersion": r.Header.Get("MCP-Protocol-Version"), "capabilities": map[string]any{}}, nil)
case "notifications/initialized":
if got := r.Header.Get("Mcp-Session-Id"); got != "sid-sse" {
t.Fatalf("initialized notification missing session id: %q", got)
}
w.WriteHeader(http.StatusAccepted)
case "tools/list":
if got := r.Header.Get("Mcp-Session-Id"); got != "sid-sse" {
t.Fatalf("tools/list missing session id: %q", got)
}
writeTestRPC(w, req.ID, map[string]any{"tools": []map[string]any{{"name": "stories.search"}}}, nil)
default:
t.Fatalf("unexpected method %q", req.Method)
}
}))
defer server.Close()

client := NewHTTPClient()
client.httpClient = server.Client()
tools, err := client.ListTools(context.Background(), Upstream{Name: "shortcut", Mode: UpstreamModeHTTP, BaseURL: server.URL})
if err != nil {
t.Fatalf("ListTools returned error: %v", err)
}
if len(tools) != 1 || tools[0].Name != "stories.search" {
t.Fatalf("unexpected tools: %#v", tools)
}
}

func TestMissingSessionRPCErrorDetection(t *testing.T) {
if !isMissingSessionRPCError(json.RawMessage(`{"code":-32000,"message":"No session ID provided for non-initialization request"}`)) {
t.Fatal("expected missing session error to be detected")
Expand Down Expand Up @@ -417,3 +507,19 @@ func writeTestRPC(w http.ResponseWriter, id json.RawMessage, result any, rpcErr
panic(err)
}
}

func writeTestRPCSSE(w http.ResponseWriter, id json.RawMessage, result any, rpcErr any) {
w.Header().Set("Content-Type", "text/event-stream")
resp := map[string]any{"jsonrpc": "2.0", "id": id}
if rpcErr != nil {
resp["error"] = rpcErr
} else {
resp["result"] = result
}
payload, err := json.Marshal(resp)
if err != nil {
panic(err)
}
_, _ = w.Write([]byte("event: message\n"))
_, _ = w.Write([]byte("data: " + string(payload) + "\n\n"))
}
Loading