diff --git a/examples/main.go b/examples/main.go index a3d07076b..44cbcf91a 100644 --- a/examples/main.go +++ b/examples/main.go @@ -27,6 +27,7 @@ func main() { // FuturesOrder() // DeliveryOrder() // WalletBalance() - WatchMiniMarketsStat() + // WatchMiniMarketsStat() // RunOrderListExamples() + MultiplexedFuturesWebSocketExample() } diff --git a/examples/multiplexed_websocket.go b/examples/multiplexed_websocket.go new file mode 100644 index 000000000..02c2a7979 --- /dev/null +++ b/examples/multiplexed_websocket.go @@ -0,0 +1,530 @@ +package main + +import ( + "context" + "fmt" + "log/slog" + "os" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/adshao/go-binance/v2/common/websocket" + "github.com/adshao/go-binance/v2/futures" +) + +const ( + requestsCount = 100 + responseTimeout = 30 * time.Second + requestRateLimit = 100 * time.Millisecond + maxConcurrency = 10 + channelBufferSize = 10 +) + +// MultiplexedFuturesWebSocketExample demonstrates concurrent WebSocket requests +// with waiter channels for 1:1 request/response matching +func MultiplexedFuturesWebSocketExample() { + manager, err := NewWebSocketManager() + if err != nil { + slog.Error("Failed to create WebSocket manager", "error", err) + return + } + + manager.logger.Info("=== Enhanced Multiplexed Futures WebSocket Example ===") + manager.logger.Info("Testing request/response tracking and waiter channels") + + // Create context for this operation + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var wg sync.WaitGroup + wg.Add(2) // 2 services will run + + // Error channel to collect errors from goroutines + errorChan := make(chan error, 3) // Buffer for all potential errors + + // Start monitoring and services + go manager.startFallbackMonitoring(ctx) + go manager.startOrderPlaceService(ctx, &wg, errorChan) + go manager.startOrderCancelService(ctx, &wg, errorChan) + + // Monitor for errors in a separate goroutine + go func() { + for err := range errorChan { + if err != nil { + manager.logger.Error("Service error", "error", err) + } + } + }() + + wg.Wait() + + manager.logger.Info("All services completed") + manager.printFinalStats() +} + +// createWebSocketClient creates a new WebSocket client +func createWebSocketClient(logger *slog.Logger) (websocket.Client, error) { + conn, err := websocket.NewConnection( + futures.WsApiInitReadWriteConn, + futures.WebsocketKeepalive, + futures.WebsocketTimeoutReadWriteConnection, + ) + if err != nil { + return nil, fmt.Errorf("failed to create WebSocket connection: %w", err) + } + + client, err := websocket.NewClient(conn) + if err != nil { + return nil, fmt.Errorf("failed to create WebSocket client: %w", err) + } + + logger.Info("Shared futures WebSocket client created successfully") + return client, nil +} + +// startFallbackMonitoring monitors the shared channel for fallback messages +func (wm *WebSocketManager) startFallbackMonitoring(ctx context.Context) { + wm.logger.Info("Starting fallback message monitoring") + + for { + select { + case data := <-wm.client.GetReadChannel(): + wm.logger.Error("FALLBACK: Message received on shared channel (should be routed to waiter)", + "data", string(data)) + case err := <-wm.client.GetReadErrorChannel(): + if err != nil { + wm.logger.Error("Shared channel error", "error", err) + return + } + case <-ctx.Done(): + wm.logger.Info("Fallback monitoring stopped") + return + } + } +} + +// startOrderPlaceService handles order placement requests +func (wm *WebSocketManager) startOrderPlaceService(ctx context.Context, wg *sync.WaitGroup, errorChan chan<- error) { + defer wg.Done() + + // Create order place service + orderPlaceService, err := futures.NewOrderPlaceWsService( + AppConfig.APIKey, + AppConfig.SecretKey, + websocket.WithWebSocketClient(wm.client), + ) + if err != nil { + errorChan <- fmt.Errorf("failed to create order place service: %w", err) + return + } + + // Create dedicated waiter channel with proper buffer + waiterChannel := make(chan []byte, channelBufferSize) + defer close(waiterChannel) + + // Initialize stats + wm.placeStats.SetStartTime() + + // Start request sender and response handler concurrently + var serviceWg sync.WaitGroup + serviceWg.Add(2) + + requestErrors := make(chan error, requestsCount) // Buffer for all possible request errors + + go func() { + defer serviceWg.Done() + wm.sendOrderPlaceRequests(ctx, orderPlaceService, waiterChannel, requestErrors) + close(requestErrors) + }() + + go func() { + defer serviceWg.Done() + wm.handleOrderPlaceResponses(ctx, waiterChannel, requestErrors, errorChan) + }() + + serviceWg.Wait() +} + +// handleOrderPlaceResponses handles responses from order place service +func (wm *WebSocketManager) handleOrderPlaceResponses(ctx context.Context, waiterChannel <-chan []byte, requestErrors <-chan error, errorChan chan<- error) { + serviceLogger := wm.logger.With("service", "order_place") + timeout := time.NewTimer(responseTimeout) + defer timeout.Stop() + + // Monitor request errors + go func() { + for err := range requestErrors { + if err != nil { + serviceLogger.Error("Request error", "error", err) + } + } + }() + + for { + select { + case data, ok := <-waiterChannel: + if !ok { + serviceLogger.Info("WaiterChannel closed") + wm.placeStats.SetEndTime() + return + } + response := string(data) + serviceLogger.Debug("Response received", "response", response) + + if !strings.Contains(response, "place_") { + err := fmt.Errorf("response does not contain expected prefix 'place_': %s", response) + errorChan <- err + wm.placeStats.SetEndTime() + return + } + + wm.placeStats.IncrementResponses() + sent, received := wm.placeStats.GetCounts() + serviceLogger.Info("Progress", "responses_received", received, "requests_sent", sent) + + if received >= requestsCount { + serviceLogger.Info("All responses received successfully", "total", received) + wm.placeStats.SetEndTime() + return + } + + case <-timeout.C: + sent, received := wm.placeStats.GetCounts() + err := fmt.Errorf("timeout reached: sent=%d, received=%d, missing=%d", sent, received, sent-received) + errorChan <- err + wm.placeStats.SetEndTime() + return + + case <-ctx.Done(): + serviceLogger.Info("Response handler stopped") + wm.placeStats.SetEndTime() + return + } + } +} + +// sendOrderPlaceRequests sends multiple order placement requests concurrently +func (wm *WebSocketManager) sendOrderPlaceRequests(ctx context.Context, orderPlaceService *futures.OrderPlaceWsService, waiterChannel chan []byte, errorChan chan<- error) { + serviceLogger := wm.logger.With("service", "order_place") + + // Build request + request := futures.NewOrderPlaceWsRequest(). + Symbol("BTCUSDT"). + Side(futures.SideTypeBuy). + Type(futures.OrderTypeMarket). + TimeInForce(futures.TimeInForceTypeFOK). + Quantity("0.00001") + + var wg sync.WaitGroup + semaphore := make(chan struct{}, maxConcurrency) + + for i := 0; i < requestsCount; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + semaphore <- struct{}{} // Acquire semaphore + defer func() { <-semaphore }() // Release semaphore + + requestID := fmt.Sprintf("place_%d_%d", time.Now().UnixNano(), index) + serviceLogger.Debug("Sending request", "request_id", requestID, "index", index) + + if err := orderPlaceService.Do(requestID, request, websocket.WithWaiter(waiterChannel)); err != nil { + errorChan <- fmt.Errorf("failed to send request %s: %w", requestID, err) + return + } + wm.placeStats.IncrementRequests() + }(i) + + // Rate limiting + time.Sleep(requestRateLimit) + } + + wg.Wait() + serviceLogger.Info("All order place requests sent") +} + +// startOrderCancelService handles order cancellation requests +func (wm *WebSocketManager) startOrderCancelService(ctx context.Context, wg *sync.WaitGroup, errorChan chan<- error) { + defer wg.Done() + + // Create order cancel service + orderCancelService, err := futures.NewOrderCancelWsService( + AppConfig.APIKey, + AppConfig.SecretKey, + websocket.WithWebSocketClient(wm.client), + ) + if err != nil { + errorChan <- fmt.Errorf("failed to create order cancel service: %w", err) + return + } + + // Create dedicated waiter channel with proper buffer + waiterChannel := make(chan []byte, channelBufferSize) + defer close(waiterChannel) + + // Initialize stats + wm.cancelStats.SetStartTime() + + // Start request sender and response handler concurrently + var serviceWg sync.WaitGroup + serviceWg.Add(2) + + requestErrors := make(chan error, requestsCount) // Buffer for all possible request errors + + go func() { + defer serviceWg.Done() + wm.sendOrderCancelRequests(ctx, orderCancelService, waiterChannel, requestErrors) + close(requestErrors) + }() + + go func() { + defer serviceWg.Done() + wm.handleOrderCancelResponses(ctx, waiterChannel, requestErrors, errorChan) + }() + + serviceWg.Wait() +} + +// handleOrderCancelResponses handles responses from order cancel service +func (wm *WebSocketManager) handleOrderCancelResponses(ctx context.Context, waiterChannel <-chan []byte, requestErrors <-chan error, errorChan chan<- error) { + serviceLogger := wm.logger.With("service", "order_cancel") + timeout := time.NewTimer(responseTimeout) + defer timeout.Stop() + + // Monitor request errors + go func() { + for err := range requestErrors { + if err != nil { + serviceLogger.Error("Request error", "error", err) + } + } + }() + + for { + select { + case data, ok := <-waiterChannel: + if !ok { + serviceLogger.Info("WaiterChannel closed") + wm.cancelStats.SetEndTime() + return + } + response := string(data) + serviceLogger.Debug("Response received", "response", response) + + if !strings.Contains(response, "cancel_") { + err := fmt.Errorf("response does not contain expected prefix 'cancel_': %s", response) + errorChan <- err + wm.cancelStats.SetEndTime() + return + } + + wm.cancelStats.IncrementResponses() + sent, received := wm.cancelStats.GetCounts() + serviceLogger.Info("Progress", "responses_received", received, "requests_sent", sent) + + if received >= requestsCount { + serviceLogger.Info("All responses received successfully", "total", received) + wm.cancelStats.SetEndTime() + return + } + + case <-timeout.C: + sent, received := wm.cancelStats.GetCounts() + err := fmt.Errorf("timeout reached: sent=%d, received=%d, missing=%d", sent, received, sent-received) + errorChan <- err + wm.cancelStats.SetEndTime() + return + + case <-ctx.Done(): + serviceLogger.Info("Response handler stopped") + wm.cancelStats.SetEndTime() + return + } + } +} + +// sendOrderCancelRequests sends multiple order cancellation requests concurrently +func (wm *WebSocketManager) sendOrderCancelRequests(ctx context.Context, orderCancelService *futures.OrderCancelWsService, waiterChannel chan []byte, errorChan chan<- error) { + serviceLogger := wm.logger.With("service", "order_cancel") + + // Build request + request := futures.NewOrderCancelRequest(). + Symbol("BTCUSDT"). + OrderID(123123) // Non-existing order for testing + + var wg sync.WaitGroup + semaphore := make(chan struct{}, maxConcurrency) + + for i := 0; i < requestsCount; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + semaphore <- struct{}{} // Acquire semaphore + defer func() { <-semaphore }() // Release semaphore + + requestID := fmt.Sprintf("cancel_%d_%d", time.Now().UnixNano(), index) + serviceLogger.Debug("Sending request", "request_id", requestID, "index", index) + + if err := orderCancelService.Do(requestID, request, websocket.WithWaiter(waiterChannel)); err != nil { + errorChan <- fmt.Errorf("failed to send request %s: %w", requestID, err) + return + } + wm.cancelStats.IncrementRequests() + }(i) + + // Rate limiting + time.Sleep(requestRateLimit) + } + + wg.Wait() + serviceLogger.Info("All order cancel requests sent") +} + +// printFinalStats prints comprehensive statistics for both services +func (wm *WebSocketManager) printFinalStats() { + wm.logger.Info("=== FINAL STATISTICS ===") + + // Get counts atomically to avoid race conditions + placeSent, placeReceived := wm.placeStats.GetCounts() + placeDuration := wm.placeStats.GetDuration() + placeSuccess := placeSent == placeReceived && placeSent == requestsCount + + // Calculate average response time safely + var placeAvgTime time.Duration + if placeReceived > 0 { + placeAvgTime = placeDuration / time.Duration(placeReceived) + } + + wm.logger.Info("Order Place Service Statistics", + "requests_sent", placeSent, + "responses_received", placeReceived, + "expected", requestsCount, + "missing", placeSent-placeReceived, + "success", placeSuccess, + "duration", placeDuration, + "avg_response_time", placeAvgTime, + ) + + // Get counts atomically to avoid race conditions + cancelSent, cancelReceived := wm.cancelStats.GetCounts() + cancelDuration := wm.cancelStats.GetDuration() + cancelSuccess := cancelSent == cancelReceived && cancelSent == requestsCount + + // Calculate average response time safely + var cancelAvgTime time.Duration + if cancelReceived > 0 { + cancelAvgTime = cancelDuration / time.Duration(cancelReceived) + } + + wm.logger.Info("Order Cancel Service Statistics", + "requests_sent", cancelSent, + "responses_received", cancelReceived, + "expected", requestsCount, + "missing", cancelSent-cancelReceived, + "success", cancelSuccess, + "duration", cancelDuration, + "avg_response_time", cancelAvgTime, + ) + + // Overall stats + totalSent := placeSent + cancelSent + totalReceived := placeReceived + cancelReceived + totalExpected := int64(requestsCount * 2) + overallSuccess := totalSent == totalReceived && totalSent == totalExpected + + // Calculate success rate safely + var successRate float64 + if totalSent > 0 { + successRate = float64(totalReceived) / float64(totalSent) * 100 + } + + wm.logger.Info("Overall Statistics", + "total_requests_sent", totalSent, + "total_responses_received", totalReceived, + "total_expected", totalExpected, + "total_missing", totalSent-totalReceived, + "overall_success", overallSuccess, + "success_rate", fmt.Sprintf("%.2f%%", successRate), + ) + + if !overallSuccess { + wm.logger.Error("VERIFICATION FAILED: Request/Response counts do not match!", + "sent", totalSent, + "received", totalReceived, + "expected", totalExpected, + ) + } else { + wm.logger.Info("✅ VERIFICATION PASSED: All requests received responses!") + } +} + +// WebSocketManager encapsulates the WebSocket client and statistics +type WebSocketManager struct { + logger *slog.Logger + placeStats *Stats + cancelStats *Stats + client websocket.Client +} + +// NewWebSocketManager creates a new manager with proper initialization +func NewWebSocketManager() (*WebSocketManager, error) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelInfo, + })) + + client, err := createWebSocketClient(logger) + if err != nil { + return nil, fmt.Errorf("failed to create WebSocket client: %w", err) + } + + return &WebSocketManager{ + logger: logger, + placeStats: &Stats{}, + cancelStats: &Stats{}, + client: client, + }, nil +} + +// Stats tracks request/response statistics using atomic operations +type Stats struct { + requestsSent int64 + responsesReceived int64 + startTime time.Time + endTime time.Time + mu sync.RWMutex // Only for time operations +} + +func (s *Stats) IncrementRequests() { + atomic.AddInt64(&s.requestsSent, 1) +} + +func (s *Stats) IncrementResponses() { + atomic.AddInt64(&s.responsesReceived, 1) +} + +func (s *Stats) GetCounts() (int64, int64) { + return atomic.LoadInt64(&s.requestsSent), atomic.LoadInt64(&s.responsesReceived) +} + +func (s *Stats) SetStartTime() { + s.mu.Lock() + defer s.mu.Unlock() + s.startTime = time.Now() +} + +func (s *Stats) SetEndTime() { + s.mu.Lock() + defer s.mu.Unlock() + s.endTime = time.Now() +} + +func (s *Stats) GetDuration() time.Duration { + s.mu.RLock() + defer s.mu.RUnlock() + if s.endTime.IsZero() { + return time.Since(s.startTime) + } + return s.endTime.Sub(s.startTime) +} diff --git a/v2/common/websocket/client.go b/v2/common/websocket/client.go index e21f54ba6..8c408cf08 100644 --- a/v2/common/websocket/client.go +++ b/v2/common/websocket/client.go @@ -87,8 +87,22 @@ func NewClient(conn Connection) (Client, error) { return client, nil } +type request struct { + waiter chan []byte +} + +// RequestOption define option type for request +type RequestOption func(*request) + +// WithWaiter set waiter channel param for the request +func WithWaiter(waiter chan []byte) RequestOption { + return func(r *request) { + r.waiter = waiter + } +} + type Client interface { - Write(id string, data []byte) error + Write(id string, data []byte, opts ...RequestOption) error WriteSync(id string, data []byte, timeout time.Duration) ([]byte, error) GetReadChannel() <-chan []byte GetReadErrorChannel() <-chan error @@ -98,7 +112,7 @@ type Client interface { } // Write sends data into websocket connection -func (c *client) Write(id string, data []byte) error { +func (c *client) Write(id string, data []byte, opts ...RequestOption) error { c.connMu.Lock() defer c.connMu.Unlock() @@ -106,12 +120,17 @@ func (c *client) Write(id string, data []byte) error { return ErrorWsIdAlreadySent } + req := &request{} + for _, opt := range opts { + opt(req) + } + if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil { c.debug("write: unable to write message into websocket conn '%v'", err) return err } - c.requestsList.Add(id) + c.requestsList.Add(id, req.waiter) return nil } @@ -203,8 +222,13 @@ func (c *client) read() { continue } - c.debug("read: sending message into read channel '%v'", msg) - c.readC <- message + if waiter := c.requestsList.Get(msg.Id); waiter != nil { + c.debug("read: send message into waiter channel '%v'", msg.Id) + waiter <- message + } else { + c.debug("read: sending message into read channel '%v'", msg) + c.readC <- message + } c.debug("read: remove message from request list '%v'", msg) c.requestsList.Remove(msg.Id) @@ -282,28 +306,34 @@ func (c *client) GetReconnectCount() int64 { func NewRequestList() RequestList { return RequestList{ mu: sync.Mutex{}, - requests: make(map[string]struct{}), // TODO preallocate buckets + requests: make(map[string]chan []byte), // TODO preallocate buckets } } -// RequestList state of requests that was sent/received +// RequestList state of requests that was sent/received with or without waiter channel type RequestList struct { mu sync.Mutex - requests map[string]struct{} + requests map[string]chan []byte } // Add adds request into list -func (l *RequestList) Add(id string) { +func (l *RequestList) Add(id string, waiterChan chan []byte) { + l.mu.Lock() + defer l.mu.Unlock() + l.requests[id] = waiterChan +} + +func (l *RequestList) Get(id string) chan []byte { l.mu.Lock() defer l.mu.Unlock() - l.requests[id] = struct{}{} + return l.requests[id] } // RecreateList creates new request list func (l *RequestList) RecreateList() { l.mu.Lock() defer l.mu.Unlock() - l.requests = make(map[string]struct{}) + l.requests = make(map[string]chan []byte) } // Remove adds request from list diff --git a/v2/common/websocket/client_test.go b/v2/common/websocket/client_test.go index 62853a268..79110ff5b 100644 --- a/v2/common/websocket/client_test.go +++ b/v2/common/websocket/client_test.go @@ -4,7 +4,9 @@ import ( "context" "encoding/json" "errors" + "fmt" "log" + "net" "net/http" "testing" "time" @@ -31,19 +33,43 @@ type clientTestSuite struct { secretKey string } +// findAvailablePort finds and returns an available port on localhost +func findAvailablePort() (int, error) { + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + return 0, fmt.Errorf("failed to find available port: %w", err) + } + defer listener.Close() + + addr := listener.Addr().(*net.TCPAddr) + return addr.Port, nil +} + +// createTestWebSocketURL creates a websocket URL for the given port +func createTestWebSocketURL(port int) string { + return fmt.Sprintf("ws://localhost:%d/ws", port) +} + func TestClient(t *testing.T) { suite.Run(t, new(clientTestSuite)) } func (s *clientTestSuite) TestReadWriteSync() { + // Find an available port + port, err := findAvailablePort() + s.Require().NoError(err) + stopCh := make(chan struct{}) go func() { - startWsTestServer(stopCh) + startWsTestServer(port, stopCh) }() defer func() { stopCh <- struct{}{} }() + // Give server time to start + time.Sleep(100 * time.Millisecond) + conn, err := NewConnection(func() (*websocket.Conn, error) { Dialer := websocket.Dialer{ Proxy: http.ProxyFromEnvironment, @@ -51,7 +77,8 @@ func (s *clientTestSuite) TestReadWriteSync() { EnableCompression: false, } - c, _, err := Dialer.Dial("ws://localhost:8080/ws", nil) + wsURL := createTestWebSocketURL(port) + c, _, err := Dialer.Dial(wsURL, nil) if err != nil { return nil, err } @@ -171,6 +198,39 @@ func (s *clientTestSuite) TestReadWriteSync() { } }, }, + { + name: "WriteAsync success with waiter channel", + testCallback: func() { + id, err := uuid.NewRandom() + s.Require().NoError(err) + requestID := id.String() + + req := testApiRequest{ + Id: requestID, + Method: "some-method-with-waiter", + Params: map[string]interface{}{}, + } + reqRaw, err := json.Marshal(req) + s.Require().NoError(err) + + waiter := make(chan []byte) + + err = client.Write(requestID, reqRaw, WithWaiter(waiter)) + s.Require().NoError(err) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + select { + case <-ctx.Done(): + s.T().Fatal("timeout waiting for write") + case responseRaw := <-waiter: + s.Require().Equal(reqRaw, responseRaw) + case err := <-client.GetReadErrorChannel(): + s.T().Fatalf("unexpected error: '%v'", err) + } + }, + }, } for _, tt := range tests { @@ -231,13 +291,14 @@ func wsHandler(w http.ResponseWriter, r *http.Request) { } } -func startWsTestServer(stopCh chan struct{}) { +func startWsTestServer(port int, stopCh chan struct{}) { + addr := fmt.Sprintf("localhost:%d", port) server := &http.Server{ - Addr: "localhost:8080", + Addr: addr, } http.HandleFunc("/ws", wsHandler) - log.Println("WebSocket server started on :8080") + log.Printf("WebSocket server started on :%d", port) go func() { if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { @@ -256,3 +317,101 @@ func startWsTestServer(stopCh chan struct{}) { } log.Println("Graceful shutdown complete.") } + +// Memory benchmark tests for channel references in RequestList +func BenchmarkRequestList_ChannelMemory(b *testing.B) { + tests := []struct { + name string + mapSize int + channelType string + }{ + {"SameChannel_1000entries", 1000, "same"}, + {"SameChannel_10000entries", 10000, "same"}, + {"DifferentChannels_1000entries", 1000, "different"}, + {"DifferentChannels_10000entries", 10000, "different"}, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + b.ReportAllocs() + + for n := 0; n < b.N; n++ { + requestList := NewRequestList() + + if tt.channelType == "same" { + // Test: Multiple map entries with SAME channel reference + sharedChannel := make(chan []byte, 1) + + for i := 0; i < tt.mapSize; i++ { + requestID := fmt.Sprintf("req_%d_%d", n, i) + requestList.Add(requestID, sharedChannel) + } + + } else { + // Test: Multiple map entries with DIFFERENT channels + for i := 0; i < tt.mapSize; i++ { + requestID := fmt.Sprintf("req_%d_%d", n, i) + uniqueChannel := make(chan []byte, 1) + requestList.Add(requestID, uniqueChannel) + } + } + + // Verify the map size + if requestList.Len() != tt.mapSize { + b.Fatalf("Expected %d entries, got %d", tt.mapSize, requestList.Len()) + } + } + }) + } +} + +// Benchmark to demonstrate memory efficiency of channel references +func BenchmarkChannelReference_vs_ChannelCopy(b *testing.B) { + const numEntries = 1000 + + b.Run("ChannelReferences", func(b *testing.B) { + b.ReportAllocs() + + for n := 0; n < b.N; n++ { + // Simulate the actual RequestList behavior + requests := make(map[string]chan []byte, numEntries) + sharedChannel := make(chan []byte, 1) + + for i := 0; i < numEntries; i++ { + requestID := fmt.Sprintf("req_%d", i) + requests[requestID] = sharedChannel // Store reference + } + + // Verify all entries point to the same channel + firstChan := requests["req_0"] + for i := 1; i < numEntries; i++ { + requestID := fmt.Sprintf("req_%d", i) + if requests[requestID] != firstChan { + b.Fatal("Channels should be identical references") + } + } + } + }) + + b.Run("UniqueChannels", func(b *testing.B) { + b.ReportAllocs() + + for n := 0; n < b.N; n++ { + requests := make(map[string]chan []byte, numEntries) + + for i := 0; i < numEntries; i++ { + requestID := fmt.Sprintf("req_%d", i) + requests[requestID] = make(chan []byte, 1) // Create unique channel + } + + // Verify all channels are different + for i := 1; i < numEntries; i++ { + req0 := fmt.Sprintf("req_%d", 0) + reqI := fmt.Sprintf("req_%d", i) + if requests[req0] == requests[reqI] { + b.Fatal("Channels should be unique") + } + } + } + }) +} diff --git a/v2/common/websocket/mock/client.go b/v2/common/websocket/mock/client.go index 688a0fedf..5001fa5e8 100644 --- a/v2/common/websocket/mock/client.go +++ b/v2/common/websocket/mock/client.go @@ -35,6 +35,20 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder { return m.recorder } +// Close mocks base method. +func (m *MockClient) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockClientMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockClient)(nil).Close)) +} + // GetReadChannel mocks base method. func (m *MockClient) GetReadChannel() <-chan []byte { m.ctrl.T.Helper() @@ -90,17 +104,22 @@ func (mr *MockClientMockRecorder) Wait(timeout interface{}) *gomock.Call { } // Write mocks base method. -func (m *MockClient) Write(id string, data []byte) error { +func (m *MockClient) Write(id string, data []byte, opts ...websocket.RequestOption) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", id, data) + varargs := []interface{}{id, data} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Write", varargs...) ret0, _ := ret[0].(error) return ret0 } // Write indicates an expected call of Write. -func (mr *MockClientMockRecorder) Write(id, data interface{}) *gomock.Call { +func (mr *MockClientMockRecorder) Write(id, data interface{}, opts ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockClient)(nil).Write), id, data) + varargs := append([]interface{}{id, data}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockClient)(nil).Write), varargs...) } // WriteSync mocks base method. @@ -112,13 +131,6 @@ func (m *MockClient) WriteSync(id string, data []byte, timeout time.Duration) ([ return ret0, ret1 } -func (m *MockClient) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - // WriteSync indicates an expected call of WriteSync. func (mr *MockClientMockRecorder) WriteSync(id, data, timeout interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() @@ -148,6 +160,20 @@ func (m *MockConnection) EXPECT() *MockConnectionMockRecorder { return m.recorder } +// Close mocks base method. +func (m *MockConnection) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockConnectionMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnection)(nil).Close)) +} + // ReadMessage mocks base method. func (m *MockConnection) ReadMessage() (int, []byte, error) { m.ctrl.T.Helper() diff --git a/v2/common/websocket/service_options.go b/v2/common/websocket/service_options.go new file mode 100644 index 000000000..8288167a5 --- /dev/null +++ b/v2/common/websocket/service_options.go @@ -0,0 +1,23 @@ +package websocket + +// WebSocketServiceOption represents a functional option for WebSocket services +type WebSocketServiceOption func(serviceOpt *WebSocketServiceCreateOption) + +type WebSocketServiceCreateOption struct { + Client Client + RecvWindow int64 +} + +// WithWebSocketClient creates an option to set the websocket Client for any WebSocket service +func WithWebSocketClient(client Client) WebSocketServiceOption { + return func(opt *WebSocketServiceCreateOption) { + opt.Client = client + } +} + +// WithRecvWindow creates an option to set the receive window for WebSocket services +func WithRecvWindow(recvWindow int64) WebSocketServiceOption { + return func(opt *WebSocketServiceCreateOption) { + opt.RecvWindow = recvWindow + } +} diff --git a/v2/futures/account_service_ws.go b/v2/futures/account_service_ws.go index 21421299a..fe09c0944 100644 --- a/v2/futures/account_service_ws.go +++ b/v2/futures/account_service_ws.go @@ -19,28 +19,46 @@ type WsAccountService struct { } func NewWsAccountService(apiKey, secretKey string, recvWindow ...int64) (*WsAccountService, error) { - conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) - if err != nil { - return nil, err + opts := []websocket.WebSocketServiceOption{} + if len(recvWindow) > 0 { + opts = append(opts, websocket.WithRecvWindow(recvWindow[0])) } + return NewWsAccountServiceWithOptions(apiKey, secretKey, opts...) +} - client, err := websocket.NewClient(conn) - if err != nil { - return nil, err +func NewWsAccountServiceWithOptions(apiKey, secretKey string, opts ...websocket.WebSocketServiceOption) (*WsAccountService, error) { + createOpts := &websocket.WebSocketServiceCreateOption{} + for _, opt := range opts { + opt(createOpts) } window := int64(5000) - if len(recvWindow) > 0 { - window = recvWindow[0] + if createOpts.RecvWindow > 0 { + window = createOpts.RecvWindow } - return &WsAccountService{ - c: client, + service := &WsAccountService{ ApiKey: apiKey, SecretKey: secretKey, KeyType: common.KeyTypeHmac, RecvWindow: window, - }, nil + } + + if createOpts.Client != nil { + service.c = createOpts.Client + } else { + conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) + if err != nil { + return nil, err + } + client, err := websocket.NewClient(conn) + if err != nil { + return nil, err + } + service.c = client + } + + return service, nil } type WsAccountV2InfoResponse struct { diff --git a/v2/futures/order_cancel_service_ws.go b/v2/futures/order_cancel_service_ws.go index f47bfae88..6f591ca04 100644 --- a/v2/futures/order_cancel_service_ws.go +++ b/v2/futures/order_cancel_service_ws.go @@ -74,7 +74,6 @@ type OrderCancelWsResponse struct { Error *common.APIError `json:"error,omitempty"` } -// OrderCancelWsService cancel order type OrderCancelWsService struct { c websocket.Client ApiKey string @@ -84,27 +83,37 @@ type OrderCancelWsService struct { } // NewOrderCancelWsService init OrderCancelWsService -func NewOrderCancelWsService(apiKey, secretKey string) (*OrderCancelWsService, error) { - conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) - if err != nil { - return nil, err +func NewOrderCancelWsService(apiKey, secretKey string, opts ...websocket.WebSocketServiceOption) (*OrderCancelWsService, error) { + service := &OrderCancelWsService{ + ApiKey: apiKey, + SecretKey: secretKey, + KeyType: common.KeyTypeHmac, } - client, err := websocket.NewClient(conn) - if err != nil { - return nil, err + createOpts := &websocket.WebSocketServiceCreateOption{} + for _, opt := range opts { + opt(createOpts) } - return &OrderCancelWsService{ - c: client, - ApiKey: apiKey, - SecretKey: secretKey, - KeyType: common.KeyTypeHmac, - }, nil + if createOpts.Client != nil { + service.c = createOpts.Client + } else { + conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) + if err != nil { + return nil, err + } + client, err := websocket.NewClient(conn) + if err != nil { + return nil, err + } + service.c = client + } + + return service, nil } // Do - sends 'order.cancel' request -func (s *OrderCancelWsService) Do(requestID string, request *OrderCancelRequest) error { +func (s *OrderCancelWsService) Do(requestID string, request *OrderCancelRequest, opts ...websocket.RequestOption) error { rawData, err := websocket.CreateRequest( websocket.NewRequestData( requestID, @@ -120,7 +129,7 @@ func (s *OrderCancelWsService) Do(requestID string, request *OrderCancelRequest) return err } - if err := s.c.Write(requestID, rawData); err != nil { + if err := s.c.Write(requestID, rawData, opts...); err != nil { return err } diff --git a/v2/futures/order_place_service_ws.go b/v2/futures/order_place_service_ws.go index 1155cc039..290f5a2c5 100644 --- a/v2/futures/order_place_service_ws.go +++ b/v2/futures/order_place_service_ws.go @@ -8,7 +8,6 @@ import ( "github.com/adshao/go-binance/v2/common/websocket" ) -// OrderPlaceWsService creates order type OrderPlaceWsService struct { c websocket.Client ApiKey string @@ -18,23 +17,33 @@ type OrderPlaceWsService struct { } // NewOrderPlaceWsService init OrderPlaceWsService -func NewOrderPlaceWsService(apiKey, secretKey string) (*OrderPlaceWsService, error) { - conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) - if err != nil { - return nil, err +func NewOrderPlaceWsService(apiKey, secretKey string, opts ...websocket.WebSocketServiceOption) (*OrderPlaceWsService, error) { + service := &OrderPlaceWsService{ + ApiKey: apiKey, + SecretKey: secretKey, + KeyType: common.KeyTypeHmac, } - client, err := websocket.NewClient(conn) - if err != nil { - return nil, err + createOpts := &websocket.WebSocketServiceCreateOption{} + for _, opt := range opts { + opt(createOpts) } - return &OrderPlaceWsService{ - c: client, - ApiKey: apiKey, - SecretKey: secretKey, - KeyType: common.KeyTypeHmac, - }, nil + if createOpts.Client != nil { + service.c = createOpts.Client + } else { + conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) + if err != nil { + return nil, err + } + client, err := websocket.NewClient(conn) + if err != nil { + return nil, err + } + service.c = client + } + + return service, nil } // OrderPlaceWsRequest parameters for 'order.place' websocket API @@ -238,7 +247,7 @@ func (s *OrderPlaceWsRequest) buildParams() params { } // Do - sends 'order.place' request -func (s *OrderPlaceWsService) Do(requestID string, request *OrderPlaceWsRequest) error { +func (s *OrderPlaceWsService) Do(requestID string, request *OrderPlaceWsRequest, opts ...websocket.RequestOption) error { rawData, err := websocket.CreateRequest( websocket.NewRequestData( requestID, @@ -254,7 +263,7 @@ func (s *OrderPlaceWsService) Do(requestID string, request *OrderPlaceWsRequest) return err } - if err := s.c.Write(requestID, rawData); err != nil { + if err := s.c.Write(requestID, rawData, opts...); err != nil { return err } diff --git a/v2/futures/order_status_service_ws.go b/v2/futures/order_status_service_ws.go index 3c98aae62..5e801dd50 100644 --- a/v2/futures/order_status_service_ws.go +++ b/v2/futures/order_status_service_ws.go @@ -18,23 +18,33 @@ type OrderStatusWsService struct { } // NewOrderStatusWsService init OrderStatusWsService -func NewOrderStatusWsService(apiKey, secretKey string) (*OrderStatusWsService, error) { - conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) - if err != nil { - return nil, err +func NewOrderStatusWsService(apiKey, secretKey string, opts ...websocket.WebSocketServiceOption) (*OrderStatusWsService, error) { + service := &OrderStatusWsService{ + ApiKey: apiKey, + SecretKey: secretKey, + KeyType: common.KeyTypeHmac, } - client, err := websocket.NewClient(conn) - if err != nil { - return nil, err + createOpts := &websocket.WebSocketServiceCreateOption{} + for _, opt := range opts { + opt(createOpts) } - return &OrderStatusWsService{ - c: client, - ApiKey: apiKey, - SecretKey: secretKey, - KeyType: common.KeyTypeHmac, - }, nil + if createOpts.Client != nil { + service.c = createOpts.Client + } else { + conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) + if err != nil { + return nil, err + } + client, err := websocket.NewClient(conn) + if err != nil { + return nil, err + } + service.c = client + } + + return service, nil } // OrderStatusWsRequest parameters for 'order.status' websocket API @@ -128,7 +138,7 @@ func (s *OrderStatusWsRequest) buildParams() params { } // Do - sends 'order.status' request -func (s *OrderStatusWsService) Do(requestID string, request *OrderStatusWsRequest) error { +func (s *OrderStatusWsService) Do(requestID string, request *OrderStatusWsRequest, opts ...websocket.RequestOption) error { rawData, err := websocket.CreateRequest( websocket.NewRequestData( requestID, @@ -144,7 +154,7 @@ func (s *OrderStatusWsService) Do(requestID string, request *OrderStatusWsReques return err } - if err := s.c.Write(requestID, rawData); err != nil { + if err := s.c.Write(requestID, rawData, opts...); err != nil { return err } diff --git a/v2/order_list_cancel_service_ws.go b/v2/order_list_cancel_service_ws.go index 44cab9a0d..ee7f28228 100644 --- a/v2/order_list_cancel_service_ws.go +++ b/v2/order_list_cancel_service_ws.go @@ -18,23 +18,33 @@ type OrderListCancelWsService struct { } // NewOrderListCancelWsService init OrderListCancelWsService -func NewOrderListCancelWsService(apiKey, secretKey string) (*OrderListCancelWsService, error) { - conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) - if err != nil { - return nil, err +func NewOrderListCancelWsService(apiKey, secretKey string, opts ...websocket.WebSocketServiceOption) (*OrderListCancelWsService, error) { + service := &OrderListCancelWsService{ + ApiKey: apiKey, + SecretKey: secretKey, + KeyType: common.KeyTypeHmac, } - client, err := websocket.NewClient(conn) - if err != nil { - return nil, err + createOpts := &websocket.WebSocketServiceCreateOption{} + for _, opt := range opts { + opt(createOpts) } - return &OrderListCancelWsService{ - c: client, - ApiKey: apiKey, - SecretKey: secretKey, - KeyType: common.KeyTypeHmac, - }, nil + if createOpts.Client != nil { + service.c = createOpts.Client + } else { + conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) + if err != nil { + return nil, err + } + client, err := websocket.NewClient(conn) + if err != nil { + return nil, err + } + service.c = client + } + + return service, nil } // OrderListCancelWsRequest parameters for 'orderList.cancel' websocket API @@ -76,7 +86,7 @@ func (s *OrderListCancelWsRequest) buildParams() params { } // Do - sends 'orderList.cancel' request -func (s *OrderListCancelWsService) Do(requestID string, request *OrderListCancelWsRequest) error { +func (s *OrderListCancelWsService) Do(requestID string, request *OrderListCancelWsRequest, opts ...websocket.RequestOption) error { rawData, err := websocket.CreateRequest( websocket.NewRequestData( requestID, @@ -92,7 +102,7 @@ func (s *OrderListCancelWsService) Do(requestID string, request *OrderListCancel return err } - if err := s.c.Write(requestID, rawData); err != nil { + if err := s.c.Write(requestID, rawData, opts...); err != nil { return err } diff --git a/v2/order_list_place_oto_service_ws.go b/v2/order_list_place_oto_service_ws.go index 00edc12ba..1a39abf22 100644 --- a/v2/order_list_place_oto_service_ws.go +++ b/v2/order_list_place_oto_service_ws.go @@ -18,23 +18,33 @@ type OrderListPlaceOtoWsService struct { } // NewOrderListPlaceOtoWsService init OrderListPlaceOtoWsService -func NewOrderListPlaceOtoWsService(apiKey, secretKey string) (*OrderListPlaceOtoWsService, error) { - conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) - if err != nil { - return nil, err +func NewOrderListPlaceOtoWsService(apiKey, secretKey string, opts ...websocket.WebSocketServiceOption) (*OrderListPlaceOtoWsService, error) { + service := &OrderListPlaceOtoWsService{ + ApiKey: apiKey, + SecretKey: secretKey, + KeyType: common.KeyTypeHmac, } - client, err := websocket.NewClient(conn) - if err != nil { - return nil, err + createOpts := &websocket.WebSocketServiceCreateOption{} + for _, opt := range opts { + opt(createOpts) } - return &OrderListPlaceOtoWsService{ - c: client, - ApiKey: apiKey, - SecretKey: secretKey, - KeyType: common.KeyTypeHmac, - }, nil + if createOpts.Client != nil { + service.c = createOpts.Client + } else { + conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) + if err != nil { + return nil, err + } + client, err := websocket.NewClient(conn) + if err != nil { + return nil, err + } + service.c = client + } + + return service, nil } // OrderListPlaceOtoWsRequest parameters for 'orderList.place.oto' websocket API @@ -148,7 +158,7 @@ func (s *OrderListPlaceOtoWsRequest) buildParams() params { } // Do - sends 'orderList.place.oto' request -func (s *OrderListPlaceOtoWsService) Do(requestID string, request *OrderListPlaceOtoWsRequest) error { +func (s *OrderListPlaceOtoWsService) Do(requestID string, request *OrderListPlaceOtoWsRequest, opts ...websocket.RequestOption) error { rawData, err := websocket.CreateRequest( websocket.NewRequestData( requestID, @@ -164,7 +174,7 @@ func (s *OrderListPlaceOtoWsService) Do(requestID string, request *OrderListPlac return err } - if err := s.c.Write(requestID, rawData); err != nil { + if err := s.c.Write(requestID, rawData, opts...); err != nil { return err } diff --git a/v2/order_list_place_otoco_service_ws.go b/v2/order_list_place_otoco_service_ws.go index dbe4a52da..a916d23ea 100644 --- a/v2/order_list_place_otoco_service_ws.go +++ b/v2/order_list_place_otoco_service_ws.go @@ -18,23 +18,33 @@ type OrderListPlaceOtocoWsService struct { } // NewOrderListPlaceOtocoWsService init OrderListPlaceOtocoWsService -func NewOrderListPlaceOtocoWsService(apiKey, secretKey string) (*OrderListPlaceOtocoWsService, error) { - conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) - if err != nil { - return nil, err +func NewOrderListPlaceOtocoWsService(apiKey, secretKey string, opts ...websocket.WebSocketServiceOption) (*OrderListPlaceOtocoWsService, error) { + service := &OrderListPlaceOtocoWsService{ + ApiKey: apiKey, + SecretKey: secretKey, + KeyType: common.KeyTypeHmac, } - client, err := websocket.NewClient(conn) - if err != nil { - return nil, err + createOpts := &websocket.WebSocketServiceCreateOption{} + for _, opt := range opts { + opt(createOpts) } - return &OrderListPlaceOtocoWsService{ - c: client, - ApiKey: apiKey, - SecretKey: secretKey, - KeyType: common.KeyTypeHmac, - }, nil + if createOpts.Client != nil { + service.c = createOpts.Client + } else { + conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) + if err != nil { + return nil, err + } + client, err := websocket.NewClient(conn) + if err != nil { + return nil, err + } + service.c = client + } + + return service, nil } // OrderListPlaceOtocoWsRequest parameters for 'orderList.place.otoco' websocket API @@ -186,7 +196,7 @@ func (s *OrderListPlaceOtocoWsRequest) buildParams() params { } // Do - sends 'orderList.place.otoco' request -func (s *OrderListPlaceOtocoWsService) Do(requestID string, request *OrderListPlaceOtocoWsRequest) error { +func (s *OrderListPlaceOtocoWsService) Do(requestID string, request *OrderListPlaceOtocoWsRequest, opts ...websocket.RequestOption) error { rawData, err := websocket.CreateRequest( websocket.NewRequestData( requestID, @@ -202,7 +212,7 @@ func (s *OrderListPlaceOtocoWsService) Do(requestID string, request *OrderListPl return err } - if err := s.c.Write(requestID, rawData); err != nil { + if err := s.c.Write(requestID, rawData, opts...); err != nil { return err } diff --git a/v2/order_list_place_service_ws.go b/v2/order_list_place_service_ws.go index 9ae3a0d8f..d46559b98 100644 --- a/v2/order_list_place_service_ws.go +++ b/v2/order_list_place_service_ws.go @@ -18,23 +18,33 @@ type OrderListPlaceWsService struct { } // NewOrderListPlaceWsService init OrderListPlaceWsService -func NewOrderListPlaceWsService(apiKey, secretKey string) (*OrderListPlaceWsService, error) { - conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) - if err != nil { - return nil, err +func NewOrderListPlaceWsService(apiKey, secretKey string, opts ...websocket.WebSocketServiceOption) (*OrderListPlaceWsService, error) { + service := &OrderListPlaceWsService{ + ApiKey: apiKey, + SecretKey: secretKey, + KeyType: common.KeyTypeHmac, } - client, err := websocket.NewClient(conn) - if err != nil { - return nil, err + createOpts := &websocket.WebSocketServiceCreateOption{} + for _, opt := range opts { + opt(createOpts) } - return &OrderListPlaceWsService{ - c: client, - ApiKey: apiKey, - SecretKey: secretKey, - KeyType: common.KeyTypeHmac, - }, nil + if createOpts.Client != nil { + service.c = createOpts.Client + } else { + conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) + if err != nil { + return nil, err + } + client, err := websocket.NewClient(conn) + if err != nil { + return nil, err + } + service.c = client + } + + return service, nil } // OrderListPlaceWsRequest parameters for 'orderList.place' websocket API (deprecated OCO) @@ -136,7 +146,7 @@ func (s *OrderListPlaceWsRequest) buildParams() params { } // Do - sends 'orderList.place' request -func (s *OrderListPlaceWsService) Do(requestID string, request *OrderListPlaceWsRequest) error { +func (s *OrderListPlaceWsService) Do(requestID string, request *OrderListPlaceWsRequest, opts ...websocket.RequestOption) error { rawData, err := websocket.CreateRequest( websocket.NewRequestData( requestID, @@ -152,7 +162,7 @@ func (s *OrderListPlaceWsService) Do(requestID string, request *OrderListPlaceWs return err } - if err := s.c.Write(requestID, rawData); err != nil { + if err := s.c.Write(requestID, rawData, opts...); err != nil { return err } diff --git a/v2/order_list_service_ws_create.go b/v2/order_list_service_ws_create.go index 58d34c8da..bb319976f 100644 --- a/v2/order_list_service_ws_create.go +++ b/v2/order_list_service_ws_create.go @@ -18,23 +18,33 @@ type OrderListCreateWsService struct { } // NewOrderListCreateWsService init OrderListCreateWsService -func NewOrderListCreateWsService(apiKey, secretKey string) (*OrderListCreateWsService, error) { - conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) - if err != nil { - return nil, err +func NewOrderListCreateWsService(apiKey, secretKey string, opts ...websocket.WebSocketServiceOption) (*OrderListCreateWsService, error) { + service := &OrderListCreateWsService{ + ApiKey: apiKey, + SecretKey: secretKey, + KeyType: common.KeyTypeHmac, } - client, err := websocket.NewClient(conn) - if err != nil { - return nil, err + createOpts := &websocket.WebSocketServiceCreateOption{} + for _, opt := range opts { + opt(createOpts) } - return &OrderListCreateWsService{ - c: client, - ApiKey: apiKey, - SecretKey: secretKey, - KeyType: common.KeyTypeHmac, - }, nil + if createOpts.Client != nil { + service.c = createOpts.Client + } else { + conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) + if err != nil { + return nil, err + } + client, err := websocket.NewClient(conn) + if err != nil { + return nil, err + } + service.c = client + } + + return service, nil } // OrderListCreateWsRequest parameters for 'orderList.place.oco' websocket API @@ -152,7 +162,7 @@ func (s *OrderListCreateWsRequest) buildParams() params { } // Do - sends 'orderList.place.oco' request -func (s *OrderListCreateWsService) Do(requestID string, request *OrderListCreateWsRequest) error { +func (s *OrderListCreateWsService) Do(requestID string, request *OrderListCreateWsRequest, opts ...websocket.RequestOption) error { rawData, err := websocket.CreateRequest( websocket.NewRequestData( requestID, @@ -168,7 +178,7 @@ func (s *OrderListCreateWsService) Do(requestID string, request *OrderListCreate return err } - if err := s.c.Write(requestID, rawData); err != nil { + if err := s.c.Write(requestID, rawData, opts...); err != nil { return err } diff --git a/v2/order_service_ws_create.go b/v2/order_service_ws_create.go index e62c220c0..668bd7f13 100644 --- a/v2/order_service_ws_create.go +++ b/v2/order_service_ws_create.go @@ -18,23 +18,33 @@ type OrderCreateWsService struct { } // NewOrderCreateWsService init OrderCreateWsService -func NewOrderCreateWsService(apiKey, secretKey string) (*OrderCreateWsService, error) { - conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) - if err != nil { - return nil, err +func NewOrderCreateWsService(apiKey, secretKey string, opts ...websocket.WebSocketServiceOption) (*OrderCreateWsService, error) { + service := &OrderCreateWsService{ + ApiKey: apiKey, + SecretKey: secretKey, + KeyType: common.KeyTypeHmac, } - client, err := websocket.NewClient(conn) - if err != nil { - return nil, err + createOpts := &websocket.WebSocketServiceCreateOption{} + for _, opt := range opts { + opt(createOpts) } - return &OrderCreateWsService{ - c: client, - ApiKey: apiKey, - SecretKey: secretKey, - KeyType: common.KeyTypeHmac, - }, nil + if createOpts.Client != nil { + service.c = createOpts.Client + } else { + conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) + if err != nil { + return nil, err + } + client, err := websocket.NewClient(conn) + if err != nil { + return nil, err + } + service.c = client + } + + return service, nil } // OrderCreateWsRequest parameters for 'order.place' websocket API @@ -112,7 +122,7 @@ func (s *OrderCreateWsRequest) buildParams() params { } // Do - sends 'order.place' request -func (s *OrderCreateWsService) Do(requestID string, request *OrderCreateWsRequest) error { +func (s *OrderCreateWsService) Do(requestID string, request *OrderCreateWsRequest, opts ...websocket.RequestOption) error { rawData, err := websocket.CreateRequest( websocket.NewRequestData( requestID, @@ -128,7 +138,7 @@ func (s *OrderCreateWsService) Do(requestID string, request *OrderCreateWsReques return err } - if err := s.c.Write(requestID, rawData); err != nil { + if err := s.c.Write(requestID, rawData, opts...); err != nil { return err } diff --git a/v2/sor_order_place_service_ws.go b/v2/sor_order_place_service_ws.go index 5c3bc5d70..868e3abbf 100644 --- a/v2/sor_order_place_service_ws.go +++ b/v2/sor_order_place_service_ws.go @@ -18,23 +18,33 @@ type SorOrderPlaceWsService struct { } // NewSorOrderPlaceWsService init SorOrderPlaceWsService -func NewSorOrderPlaceWsService(apiKey, secretKey string) (*SorOrderPlaceWsService, error) { - conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) - if err != nil { - return nil, err +func NewSorOrderPlaceWsService(apiKey, secretKey string, opts ...websocket.WebSocketServiceOption) (*SorOrderPlaceWsService, error) { + service := &SorOrderPlaceWsService{ + ApiKey: apiKey, + SecretKey: secretKey, + KeyType: common.KeyTypeHmac, } - client, err := websocket.NewClient(conn) - if err != nil { - return nil, err + createOpts := &websocket.WebSocketServiceCreateOption{} + for _, opt := range opts { + opt(createOpts) } - return &SorOrderPlaceWsService{ - c: client, - ApiKey: apiKey, - SecretKey: secretKey, - KeyType: common.KeyTypeHmac, - }, nil + if createOpts.Client != nil { + service.c = createOpts.Client + } else { + conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) + if err != nil { + return nil, err + } + client, err := websocket.NewClient(conn) + if err != nil { + return nil, err + } + service.c = client + } + + return service, nil } // SorOrderPlaceWsRequest parameters for 'sor.order.place' websocket API @@ -104,7 +114,7 @@ func (s *SorOrderPlaceWsRequest) buildParams() params { } // Do - sends 'sor.order.place' request -func (s *SorOrderPlaceWsService) Do(requestID string, request *SorOrderPlaceWsRequest) error { +func (s *SorOrderPlaceWsService) Do(requestID string, request *SorOrderPlaceWsRequest, opts ...websocket.RequestOption) error { rawData, err := websocket.CreateRequest( websocket.NewRequestData( requestID, @@ -120,7 +130,7 @@ func (s *SorOrderPlaceWsService) Do(requestID string, request *SorOrderPlaceWsRe return err } - if err := s.c.Write(requestID, rawData); err != nil { + if err := s.c.Write(requestID, rawData, opts...); err != nil { return err } diff --git a/v2/sor_order_test_service_ws.go b/v2/sor_order_test_service_ws.go index 7e343ad4d..30c014f54 100644 --- a/v2/sor_order_test_service_ws.go +++ b/v2/sor_order_test_service_ws.go @@ -18,23 +18,33 @@ type SorOrderTestWsService struct { } // NewSorOrderTestWsService init SorOrderTestWsService -func NewSorOrderTestWsService(apiKey, secretKey string) (*SorOrderTestWsService, error) { - conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) - if err != nil { - return nil, err +func NewSorOrderTestWsService(apiKey, secretKey string, opts ...websocket.WebSocketServiceOption) (*SorOrderTestWsService, error) { + service := &SorOrderTestWsService{ + ApiKey: apiKey, + SecretKey: secretKey, + KeyType: common.KeyTypeHmac, } - client, err := websocket.NewClient(conn) - if err != nil { - return nil, err + createOpts := &websocket.WebSocketServiceCreateOption{} + for _, opt := range opts { + opt(createOpts) } - return &SorOrderTestWsService{ - c: client, - ApiKey: apiKey, - SecretKey: secretKey, - KeyType: common.KeyTypeHmac, - }, nil + if createOpts.Client != nil { + service.c = createOpts.Client + } else { + conn, err := websocket.NewConnection(WsApiInitReadWriteConn, WebsocketKeepalive, WebsocketTimeoutReadWriteConnection) + if err != nil { + return nil, err + } + client, err := websocket.NewClient(conn) + if err != nil { + return nil, err + } + service.c = client + } + + return service, nil } // SorOrderTestWsRequest parameters for 'sor.order.test' websocket API @@ -102,7 +112,7 @@ func (s *SorOrderTestWsRequest) buildParams() params { } // Do - sends 'sor.order.test' request -func (s *SorOrderTestWsService) Do(requestID string, request *SorOrderTestWsRequest) error { +func (s *SorOrderTestWsService) Do(requestID string, request *SorOrderTestWsRequest, opts ...websocket.RequestOption) error { rawData, err := websocket.CreateRequest( websocket.NewRequestData( requestID, @@ -118,7 +128,7 @@ func (s *SorOrderTestWsService) Do(requestID string, request *SorOrderTestWsRequ return err } - if err := s.c.Write(requestID, rawData); err != nil { + if err := s.c.Write(requestID, rawData, opts...); err != nil { return err }