Skip to content
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

New features:

- Added `ArrowStreamBatch.Reset()` method that closes any existing stream and clears the cached reader, allowing callers to retry `GetStream` after a mid-stream failure (e.g. TCP RST) without re-executing the entire query. Inline (RowSetBase64) batches are restored from cached bytes on reset.

Bug fixes:

- Fixed stale OCSP cache `.lck` directory permanently blocking cache writes (and forcing an online OCSP validation if OCSP is enabled, as is by default) by using `os.RemoveAll` instead of `os.Remove` for stale lock recovery (snowflakedb/gosnowflake#1793).
Expand Down
45 changes: 37 additions & 8 deletions arrow_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,49 @@ type QueryResultFormatProvider interface {
// QueryResultFormat: Arrow IPC record batches (use ipc.NewReader) when
// the format is "arrow", or JSON (row fragments) when it is "json".
type ArrowStreamBatch struct {
idx int
numrows int64
scd *snowflakeArrowStreamChunkDownloader
Loc *time.Location
rr io.ReadCloser
idx int
numrows int64
scd *snowflakeArrowStreamChunkDownloader
Loc *time.Location
rr io.ReadCloser
inlineData []byte // non-nil for inline (RowSetBase64) batches
}

// NumRows returns the total number of rows that the metadata stated should
// be in this stream of record batches.
func (asb *ArrowStreamBatch) NumRows() int64 { return asb.numrows }

// Reset closes any existing stream and clears the cached reader, allowing
// GetStream to re-download the chunk on the next call. This enables callers
// to retry after a mid-stream failure (e.g. TCP RST) without re-executing
// the entire query.
//
// For inline batches (those produced from RowSetBase64 in the initial
// query response), there is no remote chunk to re-download. In that
// case Reset rewinds the reader by re-wrapping the cached inline bytes
// so the batch can be read again from the beginning. This is done
// even when the underlying Close returns an error, so the batch is
// always left in a usable state for retry; the close error is still
// returned to the caller.
func (asb *ArrowStreamBatch) Reset() error {
var closeErr error
if asb.rr != nil {
closeErr = asb.rr.Close()
asb.rr = nil
}
if asb.inlineData != nil {
asb.rr = io.NopCloser(bytes.NewReader(asb.inlineData))
}
return closeErr
}

// GetStream downloads the chunk (if not already cached) and returns a
// stream of bytes. The content may be Arrow IPC or JSON (row fragments)
// depending on the current QueryResultFormat. Close should be called
// on the returned stream when done to ensure no leaked memory.
//
// If a previous stream failed mid-read, call Reset() first to clear the
// cached reader and allow re-download.
func (asb *ArrowStreamBatch) GetStream(ctx context.Context) (io.ReadCloser, error) {
if asb.rr == nil {
if err := asb.downloadChunkStreamHelper(ctx); err != nil {
Expand Down Expand Up @@ -334,9 +362,10 @@ func (scd *snowflakeArrowStreamChunkDownloader) GetBatches() (out []ArrowStreamB
if len(rowSetBytes) > 0 {
out = out[:chunkMetaLen+1]
out[0] = ArrowStreamBatch{
scd: scd,
Loc: loc,
rr: io.NopCloser(bytes.NewReader(rowSetBytes)),
scd: scd,
Loc: loc,
rr: io.NopCloser(bytes.NewReader(rowSetBytes)),
inlineData: rowSetBytes,
}
toFill = out[1:]
}
Expand Down
29 changes: 4 additions & 25 deletions arrow_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,27 +71,6 @@ func (c *closeCountingReadCloser) Close() error {
return nil
}

func newTestArrowStreamBatch(body io.ReadCloser) ArrowStreamBatch {
return ArrowStreamBatch{
idx: 0,
scd: &snowflakeArrowStreamChunkDownloader{
sc: &snowflakeConn{
rest: &snowflakeRestful{RequestTimeout: time.Second},
},
ChunkMetas: []query.ExecResponseChunk{{
URL: "https://example.com/chunk",
RowCount: 1,
}},
FuncGet: func(context.Context, *snowflakeConn, string, map[string]string, time.Duration) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: body,
}, nil
},
},
}
}

func gzipBody(payload []byte) io.ReadCloser {
return io.NopCloser(bytes.NewReader(gzipBytes(payload)))
}
Expand All @@ -109,7 +88,7 @@ func TestArrowStreamBatchGetStreamCancellationUnblocksStalledRead(t *testing.T)
defer cancel()

body := newBlockingReadCloser([]byte("ok"), io.EOF)
batch := newTestArrowStreamBatch(body)
batch := newTestArrowStreamBatch(static200OK(body))

stream, err := batch.GetStream(ctx)
assertNilF(t, err, "GetStream should succeed")
Expand Down Expand Up @@ -157,7 +136,7 @@ func TestArrowStreamBatchGetStreamCancellationNormalizesCloseErrors(t *testing.T

terminalErr := errors.New("read from closed body")
body := newBlockingReadCloser([]byte("ok"), terminalErr)
batch := newTestArrowStreamBatch(body)
batch := newTestArrowStreamBatch(static200OK(body))

stream, err := batch.GetStream(ctx)
assertNilF(t, err, "GetStream should succeed")
Expand Down Expand Up @@ -199,7 +178,7 @@ func TestArrowStreamBatchGetStreamCancellationNormalizesGzipBlockedRead(t *testi
terminalErr := errors.New("read from closed body")
body := newBlockingReadCloser(gzipBytes([]byte("ok")), terminalErr)
body.firstChunkSize = len(body.payload) / 2
batch := newTestArrowStreamBatch(body)
batch := newTestArrowStreamBatch(static200OK(body))

stream, err := batch.GetStream(ctx)
assertNilF(t, err, "GetStream should succeed")
Expand Down Expand Up @@ -336,7 +315,7 @@ func TestArrowStreamBatchGetStreamPreservesCompletedEOF(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
batch := newTestArrowStreamBatch(tc.body)
batch := newTestArrowStreamBatch(static200OK(tc.body))
stream, err := batch.GetStream(context.Background())
assertNilF(t, err, "GetStream should succeed")
defer func() {
Expand Down
Loading
Loading