diff --git a/CHANGELOG.md b/CHANGELOG.md index e4818541f..972a40a87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ New features: +- Added `ArrowStreamBatch.Reset()` method that closes any existing stream and clears the cached reader, allowing callers to retry `GetStream` after a mid-stream failure (e.g. TCP RST) without re-executing the entire query. Inline (RowSetBase64) batches are restored from cached bytes on reset. + Bug fixes: - Fixed `baseName` silently dropping files whose name ends with a dot (e.g. `myfile.txt.`), which caused PUT uploads to discard such files without error (snowflakedb/gosnowflake#1788). diff --git a/arrow_stream.go b/arrow_stream.go index 2d84dc631..52b262664 100644 --- a/arrow_stream.go +++ b/arrow_stream.go @@ -55,21 +55,49 @@ type QueryResultFormatProvider interface { // QueryResultFormat: Arrow IPC record batches (use ipc.NewReader) when // the format is "arrow", or JSON (row fragments) when it is "json". type ArrowStreamBatch struct { - idx int - numrows int64 - scd *snowflakeArrowStreamChunkDownloader - Loc *time.Location - rr io.ReadCloser + idx int + numrows int64 + scd *snowflakeArrowStreamChunkDownloader + Loc *time.Location + rr io.ReadCloser + inlineData []byte // non-nil for inline (RowSetBase64) batches } // NumRows returns the total number of rows that the metadata stated should // be in this stream of record batches. func (asb *ArrowStreamBatch) NumRows() int64 { return asb.numrows } +// Reset closes any existing stream and clears the cached reader, allowing +// GetStream to re-download the chunk on the next call. This enables callers +// to retry after a mid-stream failure (e.g. TCP RST) without re-executing +// the entire query. +// +// For inline batches (those produced from RowSetBase64 in the initial +// query response), there is no remote chunk to re-download. In that +// case Reset rewinds the reader by re-wrapping the cached inline bytes +// so the batch can be read again from the beginning. This is done +// even when the underlying Close returns an error, so the batch is +// always left in a usable state for retry; the close error is still +// returned to the caller. +func (asb *ArrowStreamBatch) Reset() error { + var closeErr error + if asb.rr != nil { + closeErr = asb.rr.Close() + asb.rr = nil + } + if asb.inlineData != nil { + asb.rr = io.NopCloser(bytes.NewReader(asb.inlineData)) + } + return closeErr +} + // GetStream downloads the chunk (if not already cached) and returns a // stream of bytes. The content may be Arrow IPC or JSON (row fragments) // depending on the current QueryResultFormat. Close should be called // on the returned stream when done to ensure no leaked memory. +// +// If a previous stream failed mid-read, call Reset() first to clear the +// cached reader and allow re-download. func (asb *ArrowStreamBatch) GetStream(ctx context.Context) (io.ReadCloser, error) { if asb.rr == nil { if err := asb.downloadChunkStreamHelper(ctx); err != nil { @@ -220,9 +248,10 @@ func (scd *snowflakeArrowStreamChunkDownloader) GetBatches() (out []ArrowStreamB if len(rowSetBytes) > 0 { out = out[:chunkMetaLen+1] out[0] = ArrowStreamBatch{ - scd: scd, - Loc: loc, - rr: io.NopCloser(bytes.NewReader(rowSetBytes)), + scd: scd, + Loc: loc, + rr: io.NopCloser(bytes.NewReader(rowSetBytes)), + inlineData: rowSetBytes, } toFill = out[1:] } diff --git a/arrow_test.go b/arrow_test.go index 087983e12..136439ee6 100644 --- a/arrow_test.go +++ b/arrow_test.go @@ -2,9 +2,14 @@ package gosnowflake import ( "bytes" + "compress/gzip" "context" + "database/sql/driver" + "errors" "fmt" + "io" "math/big" + "net/http" "reflect" "strings" "testing" @@ -12,8 +17,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/memory" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" - - "database/sql/driver" + "github.com/snowflakedb/gosnowflake/v2/internal/query" ) func TestArrowBatchDataProvider(t *testing.T) { @@ -575,3 +579,356 @@ func TestArrowMemoryCleanedUp(t *testing.T) { assertFalseE(t, rows.Next()) }) } + +// errReadCloser is a ReadCloser whose Close returns a predetermined error. +type errReadCloser struct { + io.Reader + closeErr error + closed bool +} + +func (e *errReadCloser) Close() error { + e.closed = true + return e.closeErr +} + +func TestArrowStreamBatchResetNilReader(t *testing.T) { + // Reset on a batch with no reader should be a no-op and return nil. + batch := ArrowStreamBatch{} + err := batch.Reset() + assertNilF(t, err, "Reset on nil reader should return nil") +} + +func TestArrowStreamBatchResetClosesReader(t *testing.T) { + // Reset should close the underlying reader and nil it out. + rc := &errReadCloser{Reader: bytes.NewReader([]byte("data"))} + batch := ArrowStreamBatch{rr: rc} + + err := batch.Reset() + assertNilF(t, err, "Reset should not return error on successful close") + assertTrueF(t, rc.closed, "underlying reader should have been closed") + assertNilF(t, batch.rr, "rr should be nil after Reset") +} + +func TestArrowStreamBatchResetPropagatesCloseError(t *testing.T) { + // Reset should propagate the error from Close. + expected := errors.New("close failed") + rc := &errReadCloser{Reader: bytes.NewReader(nil), closeErr: expected} + batch := ArrowStreamBatch{rr: rc} + + err := batch.Reset() + assertErrIsF(t, err, expected, "Reset should propagate close error") + // rr should still be nilled out even when Close returns an error. + assertNilF(t, batch.rr, "rr should be nil after Reset even on error") +} + +func TestArrowStreamBatchResetClearsCachedReader(t *testing.T) { + // Reset should close and nil-out the cached reader so that a + // subsequent GetStream (tested separately with a full downloader + // in TestArrowStreamBatchResetThenGetStreamRedownloads) would + // attempt a fresh download. + rc := &errReadCloser{Reader: bytes.NewReader([]byte("stale"))} + batch := ArrowStreamBatch{rr: rc} + + // Confirm GetStream returns the cached reader before Reset. + stream, err := batch.GetStream(context.Background()) + assertNilF(t, err, "GetStream on cached reader should succeed") + assertEqualF(t, stream, io.ReadCloser(rc), "GetStream should return cached reader") + + // Reset clears the cache. + err = batch.Reset() + assertNilF(t, err, "Reset should succeed") + assertNilF(t, batch.rr, "rr should be nil after Reset") +} + +func TestArrowStreamBatchResetRestoresInlineReaderOnCloseError(t *testing.T) { + // For inline (RowSetBase64) batches, Reset must restore rr from + // inlineData even when the underlying Close returns an error, so + // that the batch remains usable for retry. Otherwise GetStream + // would fall through to downloadChunkStreamHelper, which has no + // chunk URL for inline batches. + expected := errors.New("close failed") + rc := &errReadCloser{Reader: bytes.NewReader(nil), closeErr: expected} + inline := []byte("inline payload") + batch := ArrowStreamBatch{rr: rc, inlineData: inline} + + err := batch.Reset() + assertErrIsF(t, err, expected, "Reset should propagate close error") + assertTrueF(t, rc.closed, "underlying reader should have been closed") + assertNotNilF(t, batch.rr, "rr should be restored from inlineData even after close error") + + // Verify the restored reader yields the inline payload. + got, readErr := io.ReadAll(batch.rr) + assertNilF(t, readErr, "reading restored inline reader should succeed") + assertEqualE(t, string(got), string(inline), "restored reader should yield inlineData") +} + +func TestArrowStreamBatchResetRestoresInlineReader(t *testing.T) { + // Reset on an inline batch should leave rr pointing at a fresh + // reader over inlineData (successful Close path). + rc := &errReadCloser{Reader: bytes.NewReader([]byte("consumed"))} + inline := []byte("inline payload") + batch := ArrowStreamBatch{rr: rc, inlineData: inline} + + err := batch.Reset() + assertNilF(t, err, "Reset should not return error on successful close") + assertTrueF(t, rc.closed, "underlying reader should have been closed") + assertNotNilF(t, batch.rr, "rr should be restored from inlineData") + + got, readErr := io.ReadAll(batch.rr) + assertNilF(t, readErr, "reading restored inline reader should succeed") + assertEqualE(t, string(got), string(inline), "restored reader should yield inlineData") +} + +func TestArrowStreamBatchDoubleResetIsIdempotent(t *testing.T) { + // Calling Reset twice should not error the second time. + rc := &errReadCloser{Reader: bytes.NewReader([]byte("data"))} + batch := ArrowStreamBatch{rr: rc} + + assertNilF(t, batch.Reset(), "first Reset should succeed") + assertNilF(t, batch.Reset(), "second Reset should be a no-op") +} + +// newTestArrowStreamBatch creates an ArrowStreamBatch wired to a mock FuncGet +// for unit testing without a live Snowflake connection. +func newTestArrowStreamBatch(funcGet func(context.Context, *snowflakeConn, string, map[string]string, time.Duration) (*http.Response, error)) ArrowStreamBatch { + sc := &snowflakeConn{rest: &snowflakeRestful{RequestTimeout: 0}} + scd := &snowflakeArrowStreamChunkDownloader{ + sc: sc, + ChunkMetas: []query.ExecResponseChunk{{URL: "http://fake/chunk0"}}, + Qrmk: "testQrmk", + FuncGet: funcGet, + } + return ArrowStreamBatch{idx: 0, scd: scd} +} + +func TestArrowStreamBatchResetThenGetStreamRedownloads(t *testing.T) { + // After Reset, calling GetStream should invoke FuncGet again, + // proving the chunk is actually re-downloaded. + callCount := 0 + mockGet := func(_ context.Context, _ *snowflakeConn, _ string, _ map[string]string, _ time.Duration) (*http.Response, error) { + callCount++ + body := fmt.Appendf(nil, "payload-%d", callCount) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(body)), + }, nil + } + + batch := newTestArrowStreamBatch(mockGet) + + // First download. + stream1, err := batch.GetStream(context.Background()) + assertNilF(t, err, "first GetStream should succeed") + defer stream1.Close() + data1, err := io.ReadAll(stream1) + assertNilF(t, err, "reading first stream should succeed") + assertEqualE(t, callCount, 1, "FuncGet should have been called once") + + // Reset clears cached reader. + assertNilF(t, batch.Reset(), "Reset after first read should succeed") + + // Second download — FuncGet is called again. + stream2, err := batch.GetStream(context.Background()) + assertNilF(t, err, "second GetStream should succeed") + defer stream2.Close() + data2, err := io.ReadAll(stream2) + assertNilF(t, err, "reading second stream should succeed") + assertEqualE(t, callCount, 2, "FuncGet should have been called twice after Reset") + + // Content should differ, proving it was re-downloaded. + assertFalseE(t, bytes.Equal(data1, data2), "re-downloaded data should differ from original") +} + +func TestArrowStreamBatchResetAfterPartialRead(t *testing.T) { + // Simulates the primary use case: a mid-stream failure followed by + // Reset + full re-download. + fullPayload := []byte("complete-data-here") + mockGet := func(_ context.Context, _ *snowflakeConn, _ string, _ map[string]string, _ time.Duration) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(fullPayload)), + }, nil + } + + batch := newTestArrowStreamBatch(mockGet) + + // Read only a few bytes (simulate partial/failed read). + stream, err := batch.GetStream(context.Background()) + assertNilF(t, err, "initial GetStream should succeed") + partial := make([]byte, 5) + _, err = stream.Read(partial) + assertNilF(t, err, "partial read should succeed") + + // Reset and re-download the full payload. + assertNilF(t, batch.Reset(), "Reset after partial read should succeed") + stream2, err := batch.GetStream(context.Background()) + assertNilF(t, err, "GetStream after Reset should succeed") + defer stream2.Close() + full, err := io.ReadAll(stream2) + assertNilF(t, err, "reading full payload should succeed") + assertEqualE(t, string(full), string(fullPayload), "re-downloaded payload should match original") +} + +func TestArrowStreamBatchGetStreamGzipped(t *testing.T) { + // Verify that gzip-compressed responses are transparently decompressed. + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + _, err := gw.Write([]byte("decompressed payload")) + assertNilF(t, err, "gzip Write should succeed") + assertNilF(t, gw.Close(), "gzip Close should succeed") + gzBytes := buf.Bytes() + + mockGet := func(_ context.Context, _ *snowflakeConn, _ string, _ map[string]string, _ time.Duration) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(gzBytes)), + }, nil + } + + batch := newTestArrowStreamBatch(mockGet) + stream, err := batch.GetStream(context.Background()) + assertNilF(t, err, "GetStream should succeed") + defer stream.Close() + data, err := io.ReadAll(stream) + assertNilF(t, err, "reading decompressed stream should succeed") + assertEqualE(t, string(data), "decompressed payload", "gzip-decoded payload should match") +} + +func TestArrowStreamBatchGetStreamNon200(t *testing.T) { + // Non-200 responses should surface as a SnowflakeError. + mockGet := func(_ context.Context, _ *snowflakeConn, _ string, _ map[string]string, _ time.Duration) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadGateway, + Body: io.NopCloser(bytes.NewReader([]byte("error"))), + }, nil + } + + batch := newTestArrowStreamBatch(mockGet) + _, err := batch.GetStream(context.Background()) + var sfErr *SnowflakeError + assertErrorsAsF(t, err, &sfErr, "should return SnowflakeError") + assertEqualF(t, sfErr.Number, ErrFailedToGetChunk, "error number should be ErrFailedToGetChunk") +} + +func TestArrowStreamBatchGetStreamFuncGetError(t *testing.T) { + // Network-level errors from FuncGet should propagate directly. + expected := errors.New("network failure") + mockGet := func(_ context.Context, _ *snowflakeConn, _ string, _ map[string]string, _ time.Duration) (*http.Response, error) { + return nil, expected + } + + batch := newTestArrowStreamBatch(mockGet) + _, err := batch.GetStream(context.Background()) + assertErrIsF(t, err, expected, "should propagate FuncGet error") +} + +func TestArrowStreamBatchResetThenGetStreamAfterError(t *testing.T) { + // If the first GetStream fails, Reset should allow a successful retry. + calls := 0 + mockGet := func(_ context.Context, _ *snowflakeConn, _ string, _ map[string]string, _ time.Duration) (*http.Response, error) { + calls++ + if calls == 1 { + return nil, errors.New("transient error") + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte("ok"))), + }, nil + } + + batch := newTestArrowStreamBatch(mockGet) + + // First attempt fails. + _, err := batch.GetStream(context.Background()) + assertNotNilE(t, err, "first GetStream should fail with transient error") + + // Reset (rr is still nil since download failed, but Reset is safe). + assertNilF(t, batch.Reset(), "Reset after failed download should succeed") + + // Retry succeeds. + stream, err := batch.GetStream(context.Background()) + assertNilF(t, err, "retry GetStream should succeed") + defer stream.Close() + data, err := io.ReadAll(stream) + assertNilF(t, err, "reading retried stream should succeed") + assertEqualE(t, string(data), "ok", "retried payload should match") +} + +func TestStreamWrapReaderCloseClosesInnerAndWrapped(t *testing.T) { + // When inner Reader implements io.ReadCloser, both inner and wrapped + // should be closed. + inner := &errReadCloser{Reader: bytes.NewReader(nil)} + wrapped := &errReadCloser{ + Reader: bytes.NewReader(nil), + closed: false, + } + w := &streamWrapReader{Reader: inner, wrapped: wrapped} + err := w.Close() + assertNilF(t, err, "Close should succeed when neither close errors") + assertTrueF(t, inner.closed, "inner ReadCloser should be closed") + assertTrueF(t, wrapped.closed, "wrapped body should be closed") +} + +func TestStreamWrapReaderCloseNonClosableInner(t *testing.T) { + // When inner Reader is not an io.ReadCloser, only wrapped is closed. + wrapped := &errReadCloser{Reader: bytes.NewReader(nil)} + w := &streamWrapReader{Reader: bytes.NewReader(nil), wrapped: wrapped} + err := w.Close() + assertNilF(t, err, "Close should succeed when only wrapped is closable") + assertTrueF(t, wrapped.closed, "wrapped body should be closed") +} + +func TestStreamWrapReaderClosePropagatesInnerError(t *testing.T) { + // If the inner ReadCloser's Close fails, the error should propagate. + expected := errors.New("inner close fail") + inner := &errReadCloser{Reader: bytes.NewReader(nil), closeErr: expected} + wrapped := &errReadCloser{Reader: bytes.NewReader(nil)} + w := &streamWrapReader{Reader: inner, wrapped: wrapped} + err := w.Close() + assertErrIsF(t, err, expected, "should propagate inner close error") + // wrapped should NOT have been closed because inner errored first. + assertFalseE(t, wrapped.closed, "wrapped should not close if inner errors") +} + +func TestArrowStreamBatchResetClosesStreamWrapReader(t *testing.T) { + // Reset should properly close a streamWrapReader, including the + // wrapped HTTP body. + wrappedClosed := false + body := &fakeResponseBody{ + body: []byte("data"), + onClose: func() { wrappedClosed = true }, + } + batch := ArrowStreamBatch{ + rr: &streamWrapReader{ + Reader: bytes.NewReader([]byte("inner")), + wrapped: body, + }, + } + assertNilF(t, batch.Reset(), "Reset should succeed") + assertNilF(t, batch.rr, "rr should be nil after Reset") + assertTrueF(t, wrappedClosed, "wrapped HTTP body should be closed") +} + +func TestArrowStreamBatchResetClosesGzipStreamWrapReader(t *testing.T) { + // Verify Reset properly tears down a gzip + streamWrapReader stack. + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + _, err := gw.Write([]byte("hello")) + assertNilF(t, err, "gzip Write should succeed") + assertNilF(t, gw.Close(), "gzip Close should succeed") + + wrappedClosed := false + body := &fakeResponseBody{ + body: buf.Bytes(), + onClose: func() { wrappedClosed = true }, + } + gr, err := gzip.NewReader(bytes.NewReader(buf.Bytes())) + assertNilF(t, err, "gzip.NewReader should succeed") + + batch := ArrowStreamBatch{ + rr: &streamWrapReader{Reader: gr, wrapped: body}, + } + assertNilF(t, batch.Reset(), "Reset should succeed") + assertTrueF(t, wrappedClosed, "wrapped HTTP body should be closed via gzip reader") +} diff --git a/chunk_test.go b/chunk_test.go index 6fb5b22d9..095679645 100644 --- a/chunk_test.go +++ b/chunk_test.go @@ -653,3 +653,44 @@ func TestQueryArrowStreamArrowResponseExposesFormat(t *testing.T) { assertTrueF(t, len(batches) > 0, "should have Arrow batches") }) } + +func TestQueryArrowStreamResetAndRereadBatch(t *testing.T) { + runSnowflakeConnTest(t, func(sct *SCTest) { + numrows := 50000 + query := fmt.Sprintf(selectRandomGenerator, numrows) + loader, err := sct.sc.QueryArrowStream(sct.sc.ctx, query) + assertNilF(t, err, "QueryArrowStream should succeed") + + batches, err := loader.GetBatches() + assertNilF(t, err, "GetBatches should succeed") + assertTrueF(t, len(batches) > 0, "should have at least one batch") + + // Pick a batch that requires a download (skip index 0 which may be inline). + idx := 0 + if len(batches) > 1 { + idx = 1 + } + + // First read. + stream1, err := batches[idx].GetStream(sct.sc.ctx) + assertNilF(t, err, "first GetStream should succeed") + defer stream1.Close() + data1, err := io.ReadAll(stream1) + assertNilF(t, err, "reading first stream should succeed") + assertTrueF(t, len(data1) > 0, "first read should return data") + + // Reset the batch. + assertNilF(t, batches[idx].Reset(), "Reset should succeed") + + // Re-read the same batch. + stream2, err := batches[idx].GetStream(sct.sc.ctx) + assertNilF(t, err, "second GetStream should succeed") + defer stream2.Close() + data2, err := io.ReadAll(stream2) + assertNilF(t, err, "reading second stream should succeed") + assertTrueF(t, len(data2) > 0, "re-read after Reset should return data") + + // Both reads should return the same content. + assertBytesEqualE(t, data1, data2, "data should match after Reset + re-download") + }) +}