diff --git a/CHANGELOG.md b/CHANGELOG.md index b24c25998..f8f5e8448 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ New features: Bug fixes: - Fixed stale OCSP cache `.lck` directory permanently blocking cache writes (and forcing an online OCSP validation if OCSP is enabled, as is by default) by using `os.RemoveAll` instead of `os.Remove` for stale lock recovery (snowflakedb/gosnowflake#1793). +- Fixed `QueryArrowStream` chunk reads so canceling the query context now interrupts stalled Arrow stream downloads and reports the cancellation instead of hanging on the HTTP body read (snowflakedb/gosnowflake#1789). - Fixed `baseName` silently dropping files whose name ends with a dot (e.g. `myfile.txt.`), which caused PUT uploads to discard such files without error (snowflakedb/gosnowflake#1788). - Improved error message when `Host` is incorrectly configured with a URL scheme (e.g. `https://myorg-myaccount.snowflakecomputing.com`), previously this produced a cryptic `260004: failed to parse a port number` error (snowflakedb/gosnowflake#1784). - Fixed minicore build on OpenBSD by skipping the `-ldl` linker flag, since libdl is not a separate library on OpenBSD (`dlopen`/`dlsym` are provided by libc) (snowflakedb/gosnowflake#1791). diff --git a/arrow_stream.go b/arrow_stream.go index 2d84dc631..b13801d86 100644 --- a/arrow_stream.go +++ b/arrow_stream.go @@ -11,6 +11,8 @@ import ( "maps" "net/http" "strconv" + "sync" + "sync/atomic" "time" "github.com/apache/arrow-go/v18/arrow/ipc" @@ -76,7 +78,7 @@ func (asb *ArrowStreamBatch) GetStream(ctx context.Context) (io.ReadCloser, erro return nil, err } } - return asb.rr, nil + return newCancelableStream(ctx, asb.rr), nil } // streamWrapReader wraps an io.Reader so that Close closes the underlying body. @@ -94,6 +96,118 @@ func (w *streamWrapReader) Close() error { return w.wrapped.Close() } +// interrupt closes the raw response body directly. This bypasses gzip.Reader +// cleanup because the gzip layer may still be blocked waiting for more bytes, +// while closing the transport body is what unblocks the in-flight read. +func (w *streamWrapReader) interrupt() error { + return w.wrapped.Close() +} + +// interruptibleReader lets cancelableStream bypass wrapper cleanup and close +// the raw transport body that actually unblocks a stalled read. +type interruptibleReader interface { + interrupt() error +} + +// cancelableStream watches the stream context so cancellation can interrupt a +// stalled read, normalizes terminal transport errors caused by that interrupt +// to ctx.Err(), and still preserves true successful completion and EOF. +type cancelableStream struct { + ctx context.Context + + inner io.ReadCloser + interrupt func() error + + done chan struct{} + stopOnce sync.Once + interruptOnce sync.Once + completed atomic.Bool +} + +func newCancelableStream(ctx context.Context, inner io.ReadCloser) io.ReadCloser { + out := &cancelableStream{ + ctx: ctx, + inner: inner, + interrupt: inner.Close, + done: make(chan struct{}), + } + if interruptible, ok := inner.(interruptibleReader); ok { + out.interrupt = interruptible.interrupt + } + if ctx.Done() == nil { + return out + } + + go func() { + select { + case <-ctx.Done(): + if out.watchingStopped() || out.completed.Load() { + return + } + out.interruptOnce.Do(func() { + _ = out.interrupt() + }) + case <-out.done: + } + }() + return out +} + +func (r *cancelableStream) stopWatching() { + r.stopOnce.Do(func() { + close(r.done) + }) +} + +func (r *cancelableStream) watchingStopped() bool { + select { + case <-r.done: + return true + default: + return false + } +} + +func (r *cancelableStream) Read(p []byte) (int, error) { + n, err := r.inner.Read(p) + if err == nil { + return n, nil + } + defer r.stopWatching() + + // Preserve the successful final read where io.Reader returns payload and EOF + // together. Only terminal reads with no delivered bytes are candidates for + // cancellation rewriting. + if err == io.EOF && n > 0 { + r.completed.Store(true) + return n, err + } + + if normalized := r.canceledReadErr(); normalized != nil { + return n, normalized + } + + if err == io.EOF { + r.completed.Store(true) + } + return n, err +} + +func (r *cancelableStream) canceledReadErr() error { + ctxErr := r.ctx.Err() + if ctxErr == nil || r.completed.Load() { + return nil + } + // A canceled context wins over terminal transport errors even if the watcher + // goroutine has not yet unblocked the read before the terminal error arrives. + return ctxErr +} + +func (r *cancelableStream) Close() error { + r.stopWatching() + return r.inner.Close() +} + func (asb *ArrowStreamBatch) downloadChunkStreamHelper(ctx context.Context) error { headers := make(map[string]string) if len(asb.scd.ChunkHeader) > 0 { diff --git a/arrow_stream_test.go b/arrow_stream_test.go new file mode 100644 index 000000000..983d2df31 --- /dev/null +++ b/arrow_stream_test.go @@ -0,0 +1,355 @@ +package gosnowflake + +import ( + "bytes" + "compress/gzip" + "context" + "errors" + "io" + "net/http" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/snowflakedb/gosnowflake/v2/internal/query" +) + +type finalEOFStreamBody struct { + blockingReadStarted chan struct{} + unblockRead chan struct{} + + blockOnce sync.Once + closeOnce sync.Once + closed atomic.Bool + delivered bool + payload []byte +} + +func newFinalEOFStreamBody(payload []byte) *finalEOFStreamBody { + return &finalEOFStreamBody{ + blockingReadStarted: make(chan struct{}), + unblockRead: make(chan struct{}), + payload: append([]byte(nil), payload...), + } +} + +func (b *finalEOFStreamBody) Read(p []byte) (int, error) { + if b.delivered { + return 0, io.EOF + } + b.blockOnce.Do(func() { + close(b.blockingReadStarted) + }) + <-b.unblockRead + b.delivered = true + copy(p, b.payload) + return len(b.payload), io.EOF +} + +func (b *finalEOFStreamBody) Close() error { + b.closed.Store(true) + b.closeOnce.Do(func() { + close(b.unblockRead) + }) + return nil +} + +type closeCountingReadCloser struct { + closeCalls atomic.Int32 + secondClose chan struct{} +} + +func (c *closeCountingReadCloser) Read(_ []byte) (int, error) { + return 0, io.EOF +} + +func (c *closeCountingReadCloser) Close() error { + if c.closeCalls.Add(1) == 2 && c.secondClose != nil { + close(c.secondClose) + } + return nil +} + +func newTestArrowStreamBatch(body io.ReadCloser) ArrowStreamBatch { + return ArrowStreamBatch{ + idx: 0, + scd: &snowflakeArrowStreamChunkDownloader{ + sc: &snowflakeConn{ + rest: &snowflakeRestful{RequestTimeout: time.Second}, + }, + ChunkMetas: []query.ExecResponseChunk{{ + URL: "https://example.com/chunk", + RowCount: 1, + }}, + FuncGet: func(context.Context, *snowflakeConn, string, map[string]string, time.Duration) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: body, + }, nil + }, + }, + } +} + +func gzipBody(payload []byte) io.ReadCloser { + return io.NopCloser(bytes.NewReader(gzipBytes(payload))) +} + +func gzipBytes(payload []byte) []byte { + var buf bytes.Buffer + zw := gzip.NewWriter(&buf) + _, _ = zw.Write(payload) + _ = zw.Close() + return buf.Bytes() +} + +func TestArrowStreamBatchGetStreamCancellationUnblocksStalledRead(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + body := newBlockingReadCloser([]byte("ok"), io.EOF) + batch := newTestArrowStreamBatch(body) + + stream, err := batch.GetStream(ctx) + assertNilF(t, err, "GetStream should succeed") + defer func() { + _ = stream.Close() + }() + + type readResult struct { + data []byte + err error + } + + readResultCh := make(chan readResult, 1) + go func() { + data, readErr := io.ReadAll(stream) + readResultCh <- readResult{data: data, err: readErr} + }() + + readBlocked := false + select { + case <-body.blockingReadStarted: + readBlocked = true + case <-time.After(100 * time.Millisecond): + } + assertTrueF(t, readBlocked, "read should block after the stream opens") + + cancel() + + var result readResult + readFinished := false + select { + case result = <-readResultCh: + readFinished = true + case <-time.After(200 * time.Millisecond): + } + assertTrueF(t, readFinished, "cancellation should unblock a stalled read promptly") + assertEqualF(t, string(result.data), "ok", "buffered bytes should still be returned before cancellation") + assertErrIsF(t, result.err, context.Canceled, "stalled read should surface cancellation") + assertTrueF(t, body.closed.Load(), "cancellation should close the underlying body") +} + +func TestArrowStreamBatchGetStreamCancellationNormalizesCloseErrors(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + terminalErr := errors.New("read from closed body") + body := newBlockingReadCloser([]byte("ok"), terminalErr) + batch := newTestArrowStreamBatch(body) + + stream, err := batch.GetStream(ctx) + assertNilF(t, err, "GetStream should succeed") + defer func() { + _ = stream.Close() + }() + + type readResult struct { + data []byte + err error + } + + readResultCh := make(chan readResult, 1) + go func() { + data, readErr := io.ReadAll(stream) + readResultCh <- readResult{data: data, err: readErr} + }() + + readBlocked := false + select { + case <-body.blockingReadStarted: + readBlocked = true + case <-time.After(100 * time.Millisecond): + } + assertTrueF(t, readBlocked, "read should block after the stream opens") + + cancel() + + result := <-readResultCh + assertEqualF(t, string(result.data), "ok", "buffered bytes should still be returned before cancellation") + assertErrIsF(t, result.err, context.Canceled, "canceled reads should surface context cancellation") + assertFalseF(t, errors.Is(result.err, terminalErr), "canceled reads should not leak the transport close error") +} + +func TestArrowStreamBatchGetStreamCancellationNormalizesGzipBlockedRead(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + terminalErr := errors.New("read from closed body") + body := newBlockingReadCloser(gzipBytes([]byte("ok")), terminalErr) + body.firstChunkSize = len(body.payload) / 2 + batch := newTestArrowStreamBatch(body) + + stream, err := batch.GetStream(ctx) + assertNilF(t, err, "GetStream should succeed") + defer func() { + _ = stream.Close() + }() + + type readResult struct { + data []byte + err error + } + + readResultCh := make(chan readResult, 1) + go func() { + data, readErr := io.ReadAll(stream) + readResultCh <- readResult{data: data, err: readErr} + }() + + readBlocked := false + select { + case <-body.blockingReadStarted: + readBlocked = true + case <-time.After(100 * time.Millisecond): + } + assertTrueF(t, readBlocked, "gzip read should block waiting for more compressed bytes") + + cancel() + + result := <-readResultCh + assertEqualF(t, string(result.data), "ok", "buffered gzip payload should still be returned before cancellation") + assertErrIsF(t, result.err, context.Canceled, "canceled gzip reads should surface context cancellation") + assertFalseF(t, errors.Is(result.err, terminalErr), "canceled gzip reads should not leak the transport close error") + assertTrueF(t, body.closed.Load(), "cancellation should close the underlying body") +} + +func TestArrowStreamBatchGetStreamAlreadyCanceledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + batch := ArrowStreamBatch{ + idx: 0, + scd: &snowflakeArrowStreamChunkDownloader{ + sc: &snowflakeConn{ + rest: &snowflakeRestful{RequestTimeout: time.Second}, + }, + ChunkMetas: []query.ExecResponseChunk{{ + URL: "https://example.com/chunk", + RowCount: 1, + }}, + FuncGet: func(ctx context.Context, _ *snowflakeConn, _ string, _ map[string]string, _ time.Duration) (*http.Response, error) { + return nil, ctx.Err() + }, + }, + } + + stream, err := batch.GetStream(ctx) + assertNilF(t, stream, "GetStream should not return a stream for an already-canceled context") + assertErrIsF(t, err, context.Canceled, "GetStream should surface the canceled context before opening the stream") +} + +func TestCancelableStreamCancellationAfterFinalPayloadPreservesCompletion(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + body := newFinalEOFStreamBody([]byte("ok")) + stream := newCancelableStream(ctx, body) + defer func() { + assertNilF(t, stream.Close(), "closing stream should succeed") + }() + + type readResult struct { + data []byte + err error + } + + readResultCh := make(chan readResult, 1) + go func() { + data, readErr := io.ReadAll(stream) + readResultCh <- readResult{data: data, err: readErr} + }() + + readBlocked := false + select { + case <-body.blockingReadStarted: + readBlocked = true + case <-time.After(100 * time.Millisecond): + } + assertTrueF(t, readBlocked, "read should block before the final payload is delivered") + + cancel() + + result := <-readResultCh + assertEqualF(t, string(result.data), "ok", "final payload should still be returned") + assertNilF(t, result.err, "final payload delivery should still complete successfully") +} + +func TestCancelableStreamCloseStopsWatcherBeforeCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + body := &closeCountingReadCloser{secondClose: make(chan struct{})} + stream := newCancelableStream(ctx, body) + + assertNilF(t, stream.Close(), "closing stream should succeed") + assertEqualF(t, body.closeCalls.Load(), int32(1), "closing the stream should close the body exactly once") + + cancel() + + secondClose := false + select { + case <-body.secondClose: + secondClose = true + case <-time.After(50 * time.Millisecond): + } + + assertFalseF(t, secondClose, "canceling after close should not re-close the body") + assertEqualF(t, body.closeCalls.Load(), int32(1), "canceling after close should not re-close the body") +} + +func TestArrowStreamBatchGetStreamPreservesCompletedEOF(t *testing.T) { + payload := []byte(`{"key":"value"}`) + + testCases := []struct { + name string + body io.ReadCloser + }{ + { + name: "plain", + body: io.NopCloser(bytes.NewReader(payload)), + }, + { + name: "gzip", + body: gzipBody(payload), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + batch := newTestArrowStreamBatch(tc.body) + stream, err := batch.GetStream(context.Background()) + assertNilF(t, err, "GetStream should succeed") + defer func() { + assertNilF(t, stream.Close(), "closing stream should succeed") + }() + + data, err := io.ReadAll(stream) + assertNilF(t, err, "completed streams should read successfully") + assertEqualF(t, string(data), string(payload), "stream content should match the response body") + + n, err := stream.Read(make([]byte, 1)) + assertEqualF(t, n, 0, "completed streams should not return extra bytes") + assertErrIsF(t, err, io.EOF, "completed streams should preserve EOF") + }) + } +} diff --git a/blocking_read_closer_test.go b/blocking_read_closer_test.go new file mode 100644 index 000000000..f0a85e3e9 --- /dev/null +++ b/blocking_read_closer_test.go @@ -0,0 +1,54 @@ +package gosnowflake + +import ( + "sync" + "sync/atomic" +) + +type blockingReadCloser struct { + payload []byte + firstChunkSize int + blockingReadStarted chan struct{} + unblockRead chan struct{} + terminalErr error + + blockOnce sync.Once + closeOnce sync.Once + closed atomic.Bool + sentBody bool +} + +func newBlockingReadCloser(payload []byte, terminalErr error) *blockingReadCloser { + return &blockingReadCloser{ + payload: append([]byte(nil), payload...), + firstChunkSize: len(payload), + blockingReadStarted: make(chan struct{}), + unblockRead: make(chan struct{}), + terminalErr: terminalErr, + } +} + +func (b *blockingReadCloser) Read(p []byte) (int, error) { + if !b.sentBody { + b.sentBody = true + chunk := b.payload + if b.firstChunkSize >= 0 && b.firstChunkSize < len(chunk) { + chunk = chunk[:b.firstChunkSize] + } + copy(p, chunk) + return len(chunk), nil + } + b.blockOnce.Do(func() { + close(b.blockingReadStarted) + }) + <-b.unblockRead + return 0, b.terminalErr +} + +func (b *blockingReadCloser) Close() error { + b.closed.Store(true) + b.closeOnce.Do(func() { + close(b.unblockRead) + }) + return nil +}