diff --git a/CHANGELOG.md b/CHANGELOG.md index f8f5e8448..1677d4374 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ New features: Bug fixes: - Fixed stale OCSP cache `.lck` directory permanently blocking cache writes (and forcing an online OCSP validation if OCSP is enabled, as is by default) by using `os.RemoveAll` instead of `os.Remove` for stale lock recovery (snowflakedb/gosnowflake#1793). +- Fixed regular chunk downloader reads so canceling the query context now interrupts stalled chunk downloads and wakes waiting row readers instead of hanging on the HTTP body read (snowflakedb/gosnowflake#1789). - Fixed `QueryArrowStream` chunk reads so canceling the query context now interrupts stalled Arrow stream downloads and reports the cancellation instead of hanging on the HTTP body read (snowflakedb/gosnowflake#1789). - Fixed `baseName` silently dropping files whose name ends with a dot (e.g. `myfile.txt.`), which caused PUT uploads to discard such files without error (snowflakedb/gosnowflake#1788). - Improved error message when `Host` is incorrectly configured with a URL scheme (e.g. `https://myorg-myaccount.snowflakecomputing.com`), previously this produced a cryptic `260004: failed to parse a port number` error (snowflakedb/gosnowflake#1784). diff --git a/chunk_downloader.go b/chunk_downloader.go index 2e6348b71..691ad60f2 100644 --- a/chunk_downloader.go +++ b/chunk_downloader.go @@ -425,14 +425,15 @@ func downloadChunkHelper(ctx context.Context, scd *snowflakeChunkDownloader, idx if err != nil { return fmt.Errorf("getting chunk: %w", err) } + body := newCancelableStream(ctx, resp.Body) defer func() { - if err = resp.Body.Close(); err != nil { + if err = body.Close(); err != nil { logger.Warnf("downloadChunkHelper: closing response body %v: %v", scd.ChunkMetas[idx].URL, err) } }() logger.WithContext(ctx).Debugf("response returned chunk: %v for URL: %v", idx+1, scd.ChunkMetas[idx].URL) if resp.StatusCode != http.StatusOK { - b, err := io.ReadAll(resp.Body) + b, err := io.ReadAll(body) if err != nil { logger.WithContext(ctx).Errorf("reading response body: %v", err) } @@ -445,7 +446,7 @@ func downloadChunkHelper(ctx context.Context, scd *snowflakeChunkDownloader, idx } } - bufStream := bufio.NewReader(resp.Body) + bufStream := bufio.NewReader(body) return decodeChunk(ctx, scd, idx, bufStream) } diff --git a/chunk_downloader_test.go b/chunk_downloader_test.go index 1d6a442a9..1229dd9af 100644 --- a/chunk_downloader_test.go +++ b/chunk_downloader_test.go @@ -1,13 +1,67 @@ package gosnowflake import ( + "bytes" "context" "database/sql/driver" + "errors" + "io" + "net/http" "testing" + "time" + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/ipc" + "github.com/apache/arrow-go/v18/arrow/memory" ia "github.com/snowflakedb/gosnowflake/v2/internal/arrow" + "github.com/snowflakedb/gosnowflake/v2/internal/query" ) +func newCancelableTestChunkDownloader(ctx context.Context, body io.ReadCloser) *snowflakeChunkDownloader { + one := "1" + return &snowflakeChunkDownloader{ + sc: &snowflakeConn{ + cfg: &Config{}, + rest: &snowflakeRestful{RequestTimeout: time.Second}, + syncParams: syncParams{params: map[string]*string{clientPrefetchThreadsKey: &one}}, + }, + ctx: ctx, + CellCount: 1, + ChunkMetas: []query.ExecResponseChunk{{URL: "https://example.com/chunk", RowCount: 1}}, + TotalRowIndex: int64(-1), + FuncDownload: downloadChunk, + FuncDownloadHelper: downloadChunkHelper, + FuncGet: func(context.Context, *snowflakeConn, string, map[string]string, time.Duration) (*http.Response, error) { + return &http.Response{StatusCode: http.StatusOK, Body: body}, nil + }, + QueryResultFormat: "json", + RowSet: rowSetType{ + RowType: []query.ExecResponseRowType{{Type: "TEXT"}}, + }, + } +} + +func buildArrowChunkBytes(t *testing.T) []byte { + t.Helper() + + pool := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer pool.AssertSize(t, 0) + + schema := arrow.NewSchema([]arrow.Field{{Name: "id", Type: arrow.PrimitiveTypes.Int64}}, nil) + bldr := array.NewRecordBuilder(pool, schema) + defer bldr.Release() + bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1}, nil) + rec := bldr.NewRecord() + defer rec.Release() + + var buf bytes.Buffer + w := ipc.NewWriter(&buf, ipc.WithSchema(schema), ipc.WithAllocator(pool)) + assertNilF(t, w.Write(rec), "writing arrow IPC stream should succeed") + assertNilF(t, w.Close(), "closing arrow IPC writer should succeed") + return append([]byte(nil), buf.Bytes()...) +} + func TestChunkDownloaderDoesNotStartWhenArrowParsingCausesError(t *testing.T) { tcs := []string{ "invalid base64", @@ -30,6 +84,119 @@ func TestChunkDownloaderDoesNotStartWhenArrowParsingCausesError(t *testing.T) { } } +func TestDownloadChunkHelperCancellationUnblocksStalledRead(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + terminalErr := errors.New("read from closed body") + body := newBlockingReadCloser([]byte(`["value"]`), terminalErr) + scd := newCancelableTestChunkDownloader(ctx, body) + + errCh := make(chan error, 1) + go func() { + errCh <- downloadChunkHelper(ctx, scd, 0) + }() + + readBlocked := false + select { + case <-body.blockingReadStarted: + readBlocked = true + case <-time.After(100 * time.Millisecond): + } + assertTrueF(t, readBlocked, "decode should block on the live response body before cancellation") + + cancel() + + var err error + readFinished := false + select { + case err = <-errCh: + readFinished = true + case <-time.After(200 * time.Millisecond): + } + assertTrueF(t, readFinished, "cancellation should unblock the chunk download promptly") + assertErrIsF(t, err, context.Canceled, "interrupted chunk downloads should surface context cancellation") + assertFalseF(t, errors.Is(err, terminalErr), "interrupted chunk downloads should not leak transport close errors") + assertTrueF(t, body.closed.Load(), "cancellation should close the underlying response body") +} + +func TestDownloadChunkHelperArrowCancellationUnblocksStalledRead(t *testing.T) { + ctx, cancel := context.WithCancel(ia.EnableArrowBatches(context.Background())) + defer cancel() + + terminalErr := errors.New("read from closed body") + body := newBlockingReadCloser(buildArrowChunkBytes(t), terminalErr) + body.firstChunkSize = len(body.payload) / 2 + scd := newCancelableTestChunkDownloader(ctx, body) + scd.QueryResultFormat = "arrow" + scd.pool = memory.DefaultAllocator + scd.rawBatches = []*rawArrowBatchData{{}} + + errCh := make(chan error, 1) + go func() { + errCh <- downloadChunkHelper(ctx, scd, 0) + }() + + readBlocked := false + select { + case <-body.blockingReadStarted: + readBlocked = true + case <-time.After(100 * time.Millisecond): + } + assertTrueF(t, readBlocked, "arrow decode should block on the live response body before cancellation") + + cancel() + + var err error + readFinished := false + select { + case err = <-errCh: + readFinished = true + case <-time.After(200 * time.Millisecond): + } + assertTrueF(t, readFinished, "cancellation should unblock the stalled Arrow chunk download promptly") + assertErrIsF(t, err, context.Canceled, "interrupted Arrow chunk downloads should surface context cancellation") + assertFalseF(t, errors.Is(err, terminalErr), "interrupted Arrow chunk downloads should not leak transport close errors") + assertTrueF(t, body.closed.Load(), "cancellation should close the underlying response body") +} + +func TestChunkDownloaderNextCancellationUnblocksWaitingRows(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + body := newBlockingReadCloser([]byte(`["value"]`), errors.New("read from closed body")) + scd := newCancelableTestChunkDownloader(ctx, body) + + assertNilF(t, scd.start(), "chunk downloader should start") + + nextErrCh := make(chan error, 1) + go func() { + _, err := scd.next() + nextErrCh <- err + }() + + readBlocked := false + select { + case <-body.blockingReadStarted: + readBlocked = true + case <-time.After(100 * time.Millisecond): + } + assertTrueF(t, readBlocked, "background chunk decode should block on the live response body") + + cancel() + + var err error + nextFinished := false + select { + case err = <-nextErrCh: + nextFinished = true + case <-time.After(200 * time.Millisecond): + } + assertTrueF(t, nextFinished, "cancellation should unblock waiting row reads promptly") + assertErrIsF(t, err, context.Canceled, "waiting row reads should surface context cancellation") + assertTrueF(t, body.closed.Load(), "cancellation should close the underlying response body") +} + func TestWithArrowBatchesWhenQueryReturnsNoRowsWhenUsingNativeGoSQLInterface(t *testing.T) { runDBTest(t, func(dbt *DBTest) { var rows driver.Rows