From a23f120750cd104fabf57b2e4910418626399b1c Mon Sep 17 00:00:00 2001 From: Chris Randles Date: Sat, 18 Apr 2026 10:47:00 -0400 Subject: [PATCH 1/8] feat: add ClickHouse native binary protocol support Signed-off-by: Chris Randles --- go.mod | 8 + go.sum | 35 ++ integration/clickhouse_native_test.go | 72 +++ integration/harness_test.go | 29 ++ pkg/backends/clickhouse/clickhouse.go | 32 ++ pkg/backends/clickhouse/native/client.go | 166 +++++++ .../clickhouse/native/server/fuzz_test.go | 186 ++++++++ .../clickhouse/native/server/handler.go | 249 ++++++++++ .../clickhouse/native/server/handler_test.go | 297 ++++++++++++ .../clickhouse/native/server/handshake.go | 104 +++++ .../clickhouse/native/server/protocol.go | 168 +++++++ .../clickhouse/native/server/protocol_test.go | 424 +++++++++++++++++ .../clickhouse/native/server/query.go | 265 +++++++++++ .../clickhouse/native/server/response.go | 431 ++++++++++++++++++ .../clickhouse/native/server/types.go | 304 ++++++++++++ pkg/backends/options/options.go | 14 + pkg/backends/protocol.go | 28 ++ pkg/daemon/setup/listeners.go | 40 ++ pkg/frontend/options/options.go | 29 +- pkg/frontend/options/protocol_listener.go | 57 +++ pkg/proxy/connhandler/connhandler.go | 31 ++ pkg/proxy/engines/httpproxy.go | 8 +- pkg/proxy/listener/listener.go | 32 +- pkg/proxy/listener/protocol_listener.go | 185 ++++++++ 24 files changed, 3186 insertions(+), 8 deletions(-) create mode 100644 pkg/backends/clickhouse/native/client.go create mode 100644 pkg/backends/clickhouse/native/server/fuzz_test.go create mode 100644 pkg/backends/clickhouse/native/server/handler.go create mode 100644 pkg/backends/clickhouse/native/server/handler_test.go create mode 100644 pkg/backends/clickhouse/native/server/handshake.go create mode 100644 pkg/backends/clickhouse/native/server/protocol.go create mode 100644 pkg/backends/clickhouse/native/server/protocol_test.go create mode 100644 pkg/backends/clickhouse/native/server/query.go create mode 100644 pkg/backends/clickhouse/native/server/response.go create mode 100644 pkg/backends/clickhouse/native/server/types.go create mode 100644 pkg/backends/protocol.go create mode 100644 pkg/frontend/options/protocol_listener.go create mode 100644 pkg/proxy/connhandler/connhandler.go create mode 100644 pkg/proxy/listener/protocol_listener.go diff --git a/go.mod b/go.mod index 731a95825..d67fa3fe1 100644 --- a/go.mod +++ b/go.mod @@ -47,6 +47,8 @@ require ( github.com/Antonboom/nilnil v1.1.1 // indirect github.com/Antonboom/testifylint v1.6.4 // indirect github.com/BurntSushi/toml v1.6.0 // indirect + github.com/ClickHouse/ch-go v0.71.0 // indirect + github.com/ClickHouse/clickhouse-go/v2 v2.45.0 // indirect github.com/Djarvur/go-err113 v0.1.1 // indirect github.com/Masterminds/semver/v3 v3.4.0 // indirect github.com/MirrexOne/unqueryvet v1.5.4 // indirect @@ -98,6 +100,8 @@ require ( github.com/fzipp/gocyclo v0.6.0 // indirect github.com/ghostiam/protogetter v0.3.20 // indirect github.com/go-critic/go-critic v0.14.3 // indirect + github.com/go-faster/city v1.0.1 // indirect + github.com/go-faster/errors v0.7.1 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-toolsmith/astcast v1.1.0 // indirect @@ -180,9 +184,11 @@ require ( github.com/nishanths/exhaustive v0.12.0 // indirect github.com/nishanths/predeclared v0.2.2 // indirect github.com/nunnatsa/ginkgolinter v0.23.0 // indirect + github.com/paulmach/orb v0.12.0 // indirect github.com/pelletier/go-toml v1.9.5 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/philhofer/fwd v1.2.0 // indirect + github.com/pierrec/lz4/v4 v4.1.25 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/procfs v0.20.1 // indirect @@ -201,6 +207,8 @@ require ( github.com/sashamelentyev/interfacebloat v1.1.0 // indirect github.com/sashamelentyev/usestdlibvars v1.29.0 // indirect github.com/securego/gosec/v2 v2.24.8-0.20260309165252-619ce2117e08 // indirect + github.com/segmentio/asm v1.2.1 // indirect + github.com/shopspring/decimal v1.4.0 // indirect github.com/sirupsen/logrus v1.9.4 // indirect github.com/sivchari/containedctx v1.0.3 // indirect github.com/sonatard/noctx v0.5.1 // indirect diff --git a/go.sum b/go.sum index 78e02198c..096867737 100644 --- a/go.sum +++ b/go.sum @@ -61,6 +61,10 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/ClickHouse/ch-go v0.71.0 h1:bUdZ/EZj/LcVHsMqaRUP2holqygrPWQKeMjc6nZoyRM= +github.com/ClickHouse/ch-go v0.71.0/go.mod h1:NwbNc+7jaqfY58dmdDUbG4Jl22vThgx1cYjBw0vtgXw= +github.com/ClickHouse/clickhouse-go/v2 v2.45.0 h1:iHt15nA4iYhfde5bDQAcLAat9BAh7B5ksPRNRa4UI7s= +github.com/ClickHouse/clickhouse-go/v2 v2.45.0/go.mod h1:giJfUVlMkcfUEPVfRpt51zZaGEx9i17gCos8gBl392c= github.com/Djarvur/go-err113 v0.1.1 h1:eHfopDqXRwAi+YmCUas75ZE0+hoBHJ2GQNLYRSxao4g= github.com/Djarvur/go-err113 v0.1.1/go.mod h1:IaWJdYFLg76t2ihfflPZnM1LIQszWOsFDh2hhhAVF6k= github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0= @@ -208,6 +212,10 @@ github.com/ghostiam/protogetter v0.3.20 h1:oW7OPFit2FxZOpmMRPP9FffU4uUpfeE/rEdE1 github.com/ghostiam/protogetter v0.3.20/go.mod h1:FjIu5Yfs6FT391m+Fjp3fbAYJ6rkL/J6ySpZBfnODuI= github.com/go-critic/go-critic v0.14.3 h1:5R1qH2iFeo4I/RJU8vTezdqs08Egi4u5p6vOESA0pog= github.com/go-critic/go-critic v0.14.3/go.mod h1:xwntfW6SYAd7h1OqDzmN6hBX/JxsEKl5up/Y2bsxgVQ= +github.com/go-faster/city v1.0.1 h1:4WAxSZ3V2Ws4QRDrscLEDcibJY8uf41H6AhXDrNDcGw= +github.com/go-faster/city v1.0.1/go.mod h1:jKcUJId49qdW3L1qKHH/3wPeUstCVpVSXTM6vO3VcTw= +github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg= +github.com/go-faster/errors v0.7.1/go.mod h1:5ySTjWFiphBs07IKuiL69nxdfd5+fzh1u7FPGZP2quo= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -257,6 +265,7 @@ github.com/godoc-lint/godoc-lint v0.11.2/go.mod h1:iVpGdL1JCikNH2gGeAn3Hh+AgN5Gx github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -284,8 +293,10 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golangci/asciicheck v0.5.0 h1:jczN/BorERZwK8oiFBOGvlGPknhvq0bjnysTj4nUfo0= @@ -408,11 +419,13 @@ github.com/julz/importas v0.2.0 h1:y+MJN/UdL63QbFJHws9BVC5RpA2iq0kpjrFajTGivjQ= github.com/julz/importas v0.2.0/go.mod h1:pThlt589EnCYtMnmhmRYY/qn9lCf/frPOK+WMx3xiJY= github.com/karamaru-alpha/copyloopvar v1.2.2 h1:yfNQvP9YaGQR7VaWLYcfZUlRP2eo2vhExWKxD/fP6q0= github.com/karamaru-alpha/copyloopvar v1.2.2/go.mod h1:oY4rGZqZ879JkJMtX3RRkcXRkmUvH0x35ykgaKgsgJY= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/errcheck v1.10.0 h1:Lvs/YAHP24YKg08LA8oDw2z9fJVme090RAXd90S+rrw= github.com/kisielk/errcheck v1.10.0/go.mod h1:kQxWMMVZgIkDq7U8xtG/n2juOjbLgZtedi0D+/VL/i8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kkHAIKE/contextcheck v1.1.6 h1:7HIyRcnyzxL9Lz06NGhiKvenXq7Zw6Q0UQu/ttjfJCE= github.com/kkHAIKE/contextcheck v1.1.6/go.mod h1:3dDbMRNBFaq8HFXWC1JyvDSPm43CmE6IuHam8Wr0rkg= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= @@ -484,6 +497,7 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/moricho/tparallel v0.3.2 h1:odr8aZVFA3NZrNybggMkYO3rgPRcqjeQUlBBFVxKHTI= github.com/moricho/tparallel v0.3.2/go.mod h1:OQ+K3b4Ln3l2TZveGCywybl68glfLEwFGqvnjok8b+U= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= @@ -512,12 +526,17 @@ github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJ github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs= github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT91xUo= github.com/otiai10/mint v1.3.1/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc= +github.com/paulmach/orb v0.12.0 h1:z+zOwjmG3MyEEqzv92UN49Lg1JFYx0L9GpGKNVDKk1s= +github.com/paulmach/orb v0.12.0/go.mod h1:5mULz1xQfs3bmQm63QEJA6lNGujuRafwA5S/EnuLaLU= +github.com/paulmach/protoscan v0.2.1/go.mod h1:SpcSwydNLrxUGSDvXvO0P7g7AuhJ7lcKfDlhJCDw2gY= github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM= github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= +github.com/pierrec/lz4/v4 v4.1.25 h1:kocOqRffaIbU5djlIBr7Wh+cx82C0vtFb0fOurZHqD0= +github.com/pierrec/lz4/v4 v4.1.25/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -584,8 +603,12 @@ github.com/sashamelentyev/usestdlibvars v1.29.0 h1:8J0MoRrw4/NAXtjQqTHrbW9NN+3iM github.com/sashamelentyev/usestdlibvars v1.29.0/go.mod h1:8PpnjHMk5VdeWlVb4wCdrB8PNbLqZ3wBZTZWkrpZZL8= github.com/securego/gosec/v2 v2.24.8-0.20260309165252-619ce2117e08 h1:AoLtJX4WUtZkhhUUMFy3GgecAALp/Mb4S1iyQOA2s0U= github.com/securego/gosec/v2 v2.24.8-0.20260309165252-619ce2117e08/go.mod h1:+XLCJiRE95ga77XInNELh2M6zQP+PdqiT9Zpm0D9Wpk= +github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0= +github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= @@ -624,6 +647,7 @@ github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= @@ -635,6 +659,7 @@ github.com/tenntenn/text/transform v0.0.0-20200319021203-7eef512accb3 h1:f+jULpR github.com/tenntenn/text/transform v0.0.0-20200319021203-7eef512accb3/go.mod h1:ON8b8w4BN/kE1EOhwT0o+d62W65a6aPw1nouo9LMgyY= github.com/tetafro/godot v1.5.4 h1:u1ww+gqpRLiIA16yF2PV1CV1n/X3zhyezbNXC3E14Sg= github.com/tetafro/godot v1.5.4/go.mod h1:eOkMrVQurDui411nBY2FA05EYH01r14LuWY/NrVDVcU= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/timakin/bodyclose v0.0.0-20241222091800-1db5c5ca4d67 h1:9LPGD+jzxMlnk5r6+hJnar67cgpDIz/iyD+rfl5r2Vk= github.com/timakin/bodyclose v0.0.0-20241222091800-1db5c5ca4d67/go.mod h1:mkjARE7Yr8qU23YcGMSALbIxTQ9r9QBVahQOBRfU460= github.com/timonwong/loggercheck v0.11.0 h1:jdaMpYBl+Uq9mWPXv1r8jc5fC3gyXx4/WGwTnnNKn4M= @@ -655,6 +680,9 @@ github.com/uudashr/gocognit v1.2.1 h1:CSJynt5txTnORn/DkhiB4mZjwPuifyASC8/6Q0I/QS github.com/uudashr/gocognit v1.2.1/go.mod h1:acaubQc6xYlXFEMb9nWX2dYBzJ/bIjEkc1zzvyIZg5Q= github.com/uudashr/iface v1.4.1 h1:J16Xl1wyNX9ofhpHmQ9h9gk5rnv2A6lX/2+APLTo0zU= github.com/uudashr/iface v1.4.1/go.mod h1:pbeBPlbuU2qkNDn0mmfrxP2X+wjPMIQAy+r1MBXSXtg= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g= +github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8= github.com/xen0n/gosmopolitan v1.3.0 h1:zAZI1zefvo7gcpbCOrPSHJZJYA9ZgLfJqtKzZ5pHqQM= github.com/xen0n/gosmopolitan v1.3.0/go.mod h1:rckfr5T6o4lBtM1ga7mLGKZmLxswUoH1zxHgNXOsEt4= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= @@ -667,6 +695,7 @@ github.com/yeya24/promlinter v0.3.0 h1:JVDbMp08lVCP7Y6NP3qHroGAO6z2yGKQtS5Jsjqto github.com/yeya24/promlinter v0.3.0/go.mod h1:cDfJQQYv9uYciW60QT0eeHlFodotkYZlL+YcPQN+mW4= github.com/ykadowak/zerologlint v0.1.5 h1:Gy/fMz1dFQN9JZTPjv1hxEk+sRWm05row04Yoolgdiw= github.com/ykadowak/zerologlint v0.1.5/go.mod h1:KaUskqF3e/v59oPmdq1U1DnKcuHokl2/K1U4pmIELKg= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -692,6 +721,7 @@ go.augendre.info/fatcontext v0.9.0 h1:Gt5jGD4Zcj8CDMVzjOJITlSb9cEch54hjRRlN3qDoj go.augendre.info/fatcontext v0.9.0/go.mod h1:L94brOAT1OOUNue6ph/2HnwxoNlds9aXDF2FcUntbNw= go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo= go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E= +go.mongodb.org/mongo-driver v1.11.4/go.mod h1:PTSz5yu21bkT/wXpkS7WR5f0ddqw5quethTUn9WM+2g= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= @@ -740,6 +770,7 @@ golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= @@ -823,6 +854,7 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= @@ -967,10 +999,12 @@ golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200724022722-7017fd6b1305/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.1-0.20210205202024-ef80cdb6ec6d/go.mod h1:9bzcO0MWcOuT0tm1iBGzDVPshzfwoVvREIui8C+MHqU= golang.org/x/tools v0.1.1-0.20210302220138-2ac05c832e1a/go.mod h1:9bzcO0MWcOuT0tm1iBGzDVPshzfwoVvREIui8C+MHqU= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= @@ -1073,6 +1107,7 @@ google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpAD google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= diff --git a/integration/clickhouse_native_test.go b/integration/clickhouse_native_test.go index 25b26f352..853954492 100644 --- a/integration/clickhouse_native_test.go +++ b/integration/clickhouse_native_test.go @@ -71,3 +71,75 @@ func TestClickHouseNativeSDK(t *testing.T) { t.Logf("%d typed rows", count) }) } + +// TestClickHouseNativeProtocolListener tests Flow 1: client speaks native +// protocol to Trickster's protocol listener, which proxies through the +// caching engine to ClickHouse's HTTP port. +func TestClickHouseNativeProtocolListener(t *testing.T) { + cfg := writeNativeListenerConfig(t, 8582, 8583, 8588, 9100) + h := tricksterHarness{ConfigPath: cfg, BaseAddr: "127.0.0.1:8582", MetricsAddr: "127.0.0.1:8583"} + h.start(t) + waitForClickHouseData(t, "127.0.0.1:8123") + + // Connect via native protocol to Trickster's protocol listener + conn, err := clickhouse.Open(&clickhouse.Options{ + Addr: []string{"127.0.0.1:9100"}, + Protocol: clickhouse.Native, + }) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + t.Cleanup(cancel) + + t.Run("ping", func(t *testing.T) { + require.NoError(t, conn.Ping(ctx)) + }) + + t.Run("select", func(t *testing.T) { + rows, err := conn.Query(ctx, "SELECT count() FROM trips") + require.NoError(t, err) + defer rows.Close() + require.True(t, rows.Next()) + }) + + t.Run("nullable", func(t *testing.T) { + rows, err := conn.Query(ctx, "SELECT toNullable(1) AS n, NULL AS m") + require.NoError(t, err) + defer rows.Close() + require.True(t, rows.Next()) + }) + + t.Run("array", func(t *testing.T) { + rows, err := conn.Query(ctx, "SELECT [1,2,3] AS arr") + require.NoError(t, err) + defer rows.Close() + require.True(t, rows.Next()) + }) + + t.Run("tuple", func(t *testing.T) { + rows, err := conn.Query(ctx, "SELECT tuple(1, 'hello') AS t") + require.NoError(t, err) + defer rows.Close() + require.True(t, rows.Next()) + }) + + t.Run("map", func(t *testing.T) { + rows, err := conn.Query(ctx, "SELECT map('key', 1) AS m") + require.NoError(t, err) + defer rows.Close() + require.True(t, rows.Next()) + }) + + t.Run("cache_hit", func(t *testing.T) { + q := "SELECT count() FROM trips" + rows1, err := conn.Query(ctx, q) + require.NoError(t, err) + rows1.Close() + + rows2, err := conn.Query(ctx, q) + require.NoError(t, err) + defer rows2.Close() + require.True(t, rows2.Next()) + }) +} diff --git a/integration/harness_test.go b/integration/harness_test.go index cf513b48e..ecd97c31b 100644 --- a/integration/harness_test.go +++ b/integration/harness_test.go @@ -30,6 +30,7 @@ import ( "github.com/stretchr/testify/require" tkconfig "github.com/trickstercache/trickster/v2/pkg/config" "github.com/trickstercache/trickster/v2/pkg/config/mgmt" + fropt "github.com/trickstercache/trickster/v2/pkg/frontend/options" "gopkg.in/yaml.v2" ) @@ -183,6 +184,34 @@ func writeTestConfig(t *testing.T, frontPort, metricsPort, mgmtPort int) string return path } +func writeNativeListenerConfig(t *testing.T, frontPort, metricsPort, mgmtPort, nativePort int) string { + t.Helper() + b, err := os.ReadFile("../docs/developer/environment/trickster-config/trickster.yaml") + require.NoError(t, err) + var c tkconfig.Config + require.NoError(t, yaml.Unmarshal(b, &c)) + c.Frontend.ListenPort = frontPort + c.Metrics.ListenPort = metricsPort + if c.MgmtConfig == nil { + c.MgmtConfig = mgmt.New() + } + c.MgmtConfig.ListenPort = mgmtPort + // Add protocol listener for ClickHouse native protocol + c.Frontend.ProtocolListeners = append(c.Frontend.ProtocolListeners, + &fropt.ProtocolListenerOptions{ + Name: "clickhouse-native", + Protocol: "clickhouse-native", + ListenPort: nativePort, + Backend: "click1", + }, + ) + out, err := yaml.Marshal(&c) + require.NoError(t, err) + path := filepath.Join(t.TempDir(), "trickster.yaml") + require.NoError(t, os.WriteFile(path, out, 0644)) + return path +} + func defaultCacheProviders() []cacheProviderCase { return []cacheProviderCase{ {Name: "memory", Backend: "prom1"}, diff --git a/pkg/backends/clickhouse/clickhouse.go b/pkg/backends/clickhouse/clickhouse.go index 7cb16c11d..9f1acaaf6 100644 --- a/pkg/backends/clickhouse/clickhouse.go +++ b/pkg/backends/clickhouse/clickhouse.go @@ -20,13 +20,19 @@ package clickhouse import ( "net/http" "net/url" + "strings" "time" "github.com/trickstercache/trickster/v2/pkg/backends" modelch "github.com/trickstercache/trickster/v2/pkg/backends/clickhouse/model" + chnative "github.com/trickstercache/trickster/v2/pkg/backends/clickhouse/native" + chserver "github.com/trickstercache/trickster/v2/pkg/backends/clickhouse/native/server" bo "github.com/trickstercache/trickster/v2/pkg/backends/options" "github.com/trickstercache/trickster/v2/pkg/backends/providers/registry/types" "github.com/trickstercache/trickster/v2/pkg/cache" + "github.com/trickstercache/trickster/v2/pkg/observability/logging" + "github.com/trickstercache/trickster/v2/pkg/observability/logging/logger" + "github.com/trickstercache/trickster/v2/pkg/proxy/connhandler" "github.com/trickstercache/trickster/v2/pkg/proxy/errors" "github.com/trickstercache/trickster/v2/pkg/proxy/methods" "github.com/trickstercache/trickster/v2/pkg/proxy/request" @@ -39,6 +45,7 @@ var _ backends.TimeseriesBackend = (*Client)(nil) // Client Implements the Proxy Client Interface type Client struct { backends.TimeseriesBackend + nativeClient *chnative.NativeClient } var _ types.NewBackendClientFunc = NewClient @@ -54,9 +61,29 @@ func NewClient(name string, o *bo.Options, router http.Handler, c := &Client{} b, err := backends.NewTimeseriesBackend(name, o, c.RegisterHandlers, router, cache, modelch.NewModeler()) c.TimeseriesBackend = b + if err == nil && o != nil && o.Protocol == "native" { + nc, ncErr := chnative.NewNativeClient(o) + if ncErr != nil { + logger.Error("clickhouse native client setup failed", + logging.Pairs{"backend": name, "detail": ncErr.Error()}) + } else { + c.nativeClient = nc + o.Fetcher = nc.Fetch + } + } return c, err } +// ConnectionHandler implements backends.ConnectionHandlerProvider. +func (c *Client) ConnectionHandler(protocol string) connhandler.ConnectionHandler { + if protocol != "clickhouse-native" { + return nil + } + return &chserver.Handler{ + QueryHandler: c.Router(), + } +} + // ParseTimeRangeQuery parses the key parts of a TimeRangeQuery from the inbound HTTP Request func (c *Client) ParseTimeRangeQuery(r *http.Request) (*timeseries.TimeRangeQuery, *timeseries.RequestOptions, bool, error) { var sqlQuery string @@ -83,6 +110,11 @@ func (c *Client) ParseTimeRangeQuery(r *http.Request) (*timeseries.TimeRangeQuer if err != nil { return trq, ro, canOPC, err } + // When the client requests Native format via URL params (official Grafana + // plugin), set the output format so the marshaler returns Native binary. + if ro != nil && strings.EqualFold(r.URL.Query().Get("default_format"), "Native") { + ro.OutputFormat = modelch.OutputFormatNative + } if isBody && trq != nil { trq.OriginalBody = originalBody } diff --git a/pkg/backends/clickhouse/native/client.go b/pkg/backends/clickhouse/native/client.go new file mode 100644 index 000000000..fb9370667 --- /dev/null +++ b/pkg/backends/clickhouse/native/client.go @@ -0,0 +1,166 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package native provides a ClickHouse native protocol (port 9000) egress +// adapter. It translates HTTP-shaped proxy requests into native protocol +// queries via clickhouse-go and returns HTTP-shaped responses with JSON bodies +// that the existing ClickHouse Modeler can consume. +package native + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/ClickHouse/clickhouse-go/v2" + "github.com/ClickHouse/clickhouse-go/v2/lib/driver" + bo "github.com/trickstercache/trickster/v2/pkg/backends/options" + "github.com/trickstercache/trickster/v2/pkg/proxy/request" +) + +// NativeClient wraps a clickhouse-go native connection pool and exposes a +// Fetcher compatible with bo.Options.Fetcher. +type NativeClient struct { + conn driver.Conn +} + +// NewNativeClient creates a NativeClient from the backend options. The +// origin URL host:port is used as the native protocol address. +func NewNativeClient(o *bo.Options) (*NativeClient, error) { + addr := o.Host + if addr == "" { + return nil, errors.New("clickhouse native: origin host is empty") + } + if !strings.Contains(addr, ":") { + addr += ":9000" + } + conn, err := clickhouse.Open(&clickhouse.Options{ + Addr: []string{addr}, + Protocol: clickhouse.Native, + }) + if err != nil { + return nil, fmt.Errorf("clickhouse native: open: %w", err) + } + return &NativeClient{conn: conn}, nil +} + +// Close closes the underlying connection pool. +func (nc *NativeClient) Close() error { + return nc.conn.Close() +} + +// Fetch executes the SQL query embedded in the *http.Request (body or query +// param) against ClickHouse using the native protocol, and returns a +// synthetic *http.Response with a JSON body in ClickHouse WFDocument format: +// +// {"meta":[{"name":"col","type":"Type"},...], "data":[{"col":"val",...},...], "rows":N} +// +// This matches what the existing HTTP path returns when FORMAT JSON is used. +func (nc *NativeClient) Fetch(r *http.Request) (*http.Response, error) { + sql, err := extractSQL(r) + if err != nil { + return syntheticErrorResponse(http.StatusBadRequest, err), nil + } + + rows, err := nc.conn.Query(r.Context(), sql) + if err != nil { + return syntheticErrorResponse(http.StatusBadGateway, err), nil + } + defer rows.Close() + + colTypes := rows.ColumnTypes() + colNames := rows.Columns() + + meta := make([]map[string]string, len(colNames)) + for i, name := range colNames { + meta[i] = map[string]string{ + "name": name, + "type": colTypes[i].DatabaseTypeName(), + } + } + + data := make([]map[string]any, 0, 64) + scanDest := make([]any, len(colNames)) + for i := range scanDest { + scanDest[i] = new(any) + } + + for rows.Next() { + if err := rows.Scan(scanDest...); err != nil { + return syntheticErrorResponse(http.StatusBadGateway, err), nil + } + row := make(map[string]any, len(colNames)) + for i, name := range colNames { + row[name] = *(scanDest[i].(*any)) + } + data = append(data, row) + } + if err := rows.Err(); err != nil { + return syntheticErrorResponse(http.StatusBadGateway, err), nil + } + + rowCount := len(data) + doc := map[string]any{ + "meta": meta, + "data": data, + "rows": rowCount, + } + + body, err := json.Marshal(doc) + if err != nil { + return syntheticErrorResponse(http.StatusInternalServerError, err), nil + } + + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Header: http.Header{"Content-Type": {"application/json"}}, + Body: io.NopCloser(bytes.NewReader(body)), + ContentLength: int64(len(body)), + Request: r, + }, nil +} + +func extractSQL(r *http.Request) (string, error) { + if r.Body != nil && r.Body != http.NoBody { + b, err := request.GetBody(r) + if err != nil { + return "", err + } + if len(b) > 0 { + return string(b), nil + } + } + if q := r.URL.Query().Get("query"); q != "" { + return q, nil + } + return "", errors.New("no SQL query found in request") +} + +func syntheticErrorResponse(code int, err error) *http.Response { + body := []byte(err.Error()) + return &http.Response{ + StatusCode: code, + Status: fmt.Sprintf("%d %s", code, http.StatusText(code)), + Header: http.Header{"Content-Type": {"text/plain"}}, + Body: io.NopCloser(bytes.NewReader(body)), + ContentLength: int64(len(body)), + } +} diff --git a/pkg/backends/clickhouse/native/server/fuzz_test.go b/pkg/backends/clickhouse/native/server/fuzz_test.go new file mode 100644 index 000000000..f75cbf5f5 --- /dev/null +++ b/pkg/backends/clickhouse/native/server/fuzz_test.go @@ -0,0 +1,186 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "bytes" + "encoding/binary" + "testing" +) + +// FuzzReadClientHello feeds arbitrary bytes into readClientHello and verifies +// it never panics. The parser must gracefully handle truncated, oversized, or +// malformed input. +func FuzzReadClientHello(f *testing.F) { + // Seed with a valid ClientHello payload (without the leading packet-type byte). + var seed bytes.Buffer + w := newProtoWriter(&seed) + w.putStr("clickhouse-go") + w.putUvarint(2) + w.putUvarint(0) + w.putUvarint(ServerRevision) + w.putStr("default") + w.putStr("user") + w.putStr("pass") + f.Add(seed.Bytes()) + + // Minimal valid payload. + var minimal bytes.Buffer + wm := newProtoWriter(&minimal) + wm.putStr("") + wm.putUvarint(0) + wm.putUvarint(0) + wm.putUvarint(0) + wm.putStr("") + wm.putStr("") + wm.putStr("") + f.Add(minimal.Bytes()) + + // Empty and tiny inputs. + f.Add([]byte{}) + f.Add([]byte{0}) + f.Add([]byte{0xff, 0xff, 0xff, 0xff}) + + f.Fuzz(func(t *testing.T, data []byte) { + r := newProtoReader(bytes.NewReader(data)) + readClientHello(r) // must not panic + }) +} + +// FuzzReadClientQuery feeds arbitrary bytes into readClientQuery and verifies +// it never panics. This is the most complex parser — it reads client info, +// settings, interserver secret, compression flag, SQL, and parameters. +func FuzzReadClientQuery(f *testing.F) { + // Seed with a valid ClientQuery payload built by sendTestQuery. + var seed bytes.Buffer + w := newProtoWriter(&seed) + w.putStr("") // query ID + w.putByte(1) // query kind + w.putStr("") // initial user + w.putStr("") // initial query ID + w.putStr("127.0.0.1") // initial address + w.putInt64(0) // initial query start time + w.putByte(1) // interface = TCP + w.putStr("") // os user + w.putStr("") // os hostname + w.putStr("test") // client name + w.putUvarint(1) // major + w.putUvarint(0) // minor + w.putUvarint(ServerRevision) + w.putStr("") // quota key + w.putUvarint(0) // distributed depth + w.putUvarint(0) // version patch + w.putByte(0) // no otel span + w.putUvarint(0) // parallel replicas fields + w.putUvarint(0) + w.putUvarint(0) + w.putStr("") // end settings + w.putStr("") // interserver secret + w.putUvarint(2) // state complete + w.putBool(false) // no compression + w.putStr("SELECT 1") + w.putStr("") // end parameters + f.Add(seed.Bytes(), uint64(ServerRevision)) + + // Seed with an older revision. + f.Add(seed.Bytes(), uint64(54060)) + + // Empty and tiny inputs at various revisions. + f.Add([]byte{}, uint64(ServerRevision)) + f.Add([]byte{0}, uint64(ServerRevision)) + f.Add([]byte{0xff, 0xff, 0xff}, uint64(54060)) + + f.Fuzz(func(t *testing.T, data []byte, revision uint64) { + // Cap revision to a reasonable range to avoid nonsensical code paths. + if revision > 100000 { + revision = 100000 + } + r := newProtoReader(bytes.NewReader(data)) + readClientQuery(r, revision) // must not panic + }) +} + +// FuzzWriteJSONAsNativeBlocks feeds arbitrary bytes (treated as a JSON body) +// into writeJSONAsNativeBlocks and verifies it never panics. +func FuzzWriteJSONAsNativeBlocks(f *testing.F) { + f.Add([]byte(`{"meta":[{"name":"x","type":"UInt32"}],"data":[{"x":1}],"rows":1}`)) + f.Add([]byte(`{"meta":[],"data":[]}`)) + f.Add([]byte(`not json at all`)) + f.Add([]byte(`{}`)) + f.Add([]byte{}) + f.Add([]byte(`{"meta":[{"name":"s","type":"String"}],"data":[{"s":"hello"}],"rows":1}`)) + f.Add([]byte(`{"meta":[{"name":"d","type":"DateTime"}],"data":[{"d":1700000000}],"rows":1}`)) + + f.Fuzz(func(t *testing.T, data []byte) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + writeJSONAsNativeBlocks(w, data, false) // must not panic + }) +} + +// FuzzSkipClientData feeds arbitrary bytes into skipClientData and verifies +// it never panics. +func FuzzSkipClientData(f *testing.F) { + // Seed with a valid empty data block (no leading packet type byte). + var seed bytes.Buffer + w := newProtoWriter(&seed) + w.putStr("") // block name + w.putUvarint(0) // is_overflows + w.putBool(false) // bucket_num + w.putUvarint(2) // bucket_size + w.putInt32(-1) // reserved + w.putUvarint(0) // reserved + w.putUvarint(0) // num columns + w.putUvarint(0) // num rows + f.Add(seed.Bytes(), uint64(ServerRevision)) + + f.Add([]byte{}, uint64(ServerRevision)) + f.Add([]byte{0, 0}, uint64(0)) + + f.Fuzz(func(t *testing.T, data []byte, revision uint64) { + if revision > 100000 { + revision = 100000 + } + r := newProtoReader(bytes.NewReader(data)) + skipClientData(r, revision) // must not panic + }) +} + +// FuzzProtoReaderStr feeds arbitrary bytes into the string reader to verify +// it handles oversized length prefixes and truncated data without panicking. +func FuzzProtoReaderStr(f *testing.F) { + // Valid string. + var valid bytes.Buffer + w := newProtoWriter(&valid) + w.putStr("hello") + f.Add(valid.Bytes()) + + // Length prefix claiming huge size with no data. + var huge bytes.Buffer + buf := make([]byte, binary.MaxVarintLen64) + n := binary.PutUvarint(buf, 1<<30) + huge.Write(buf[:n]) + f.Add(huge.Bytes()) + + f.Add([]byte{}) + f.Add([]byte{0xff, 0xff, 0xff, 0xff, 0x0f}) + + f.Fuzz(func(t *testing.T, data []byte) { + r := newProtoReader(bytes.NewReader(data)) + r.str() // must not panic + }) +} diff --git a/pkg/backends/clickhouse/native/server/handler.go b/pkg/backends/clickhouse/native/server/handler.go new file mode 100644 index 000000000..2aea2aa90 --- /dev/null +++ b/pkg/backends/clickhouse/native/server/handler.go @@ -0,0 +1,249 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" +) + +// Handler implements connhandler.ConnectionHandler for the ClickHouse native +// protocol. It accepts native connections, extracts SQL queries, routes them +// through a standard http.Handler (which is the ClickHouse backend's query +// handler), and transcodes the JSON response back to native data blocks. +type Handler struct { + // QueryHandler is the HTTP handler that processes ClickHouse queries + // (typically the backend's router or QueryHandler). + QueryHandler http.Handler +} + +// HandleConnection implements connhandler.ConnectionHandler. +func (h *Handler) HandleConnection(ctx context.Context, conn net.Conn) error { + r := newProtoReader(conn) + bw := bufio.NewWriterSize(conn, 128*1024) + w := newProtoWriter(bw) + + // --- handshake --- + pktType, err := r.ReadByte() + if err != nil { + return fmt.Errorf("read hello packet type: %w", err) + } + if pktType != ClientHello { + return fmt.Errorf("expected ClientHello (0), got %d", pktType) + } + hello, err := readClientHello(r) + if err != nil { + return fmt.Errorf("read client hello: %w", err) + } + + if err := writeServerHello(w); err != nil { + return fmt.Errorf("write server hello: %w", err) + } + if err := bw.Flush(); err != nil { + return fmt.Errorf("flush server hello: %w", err) + } + + clientRevision := hello.ProtoRevision + + // After receiving ServerHello, clients with revision >= 54458 send an + // addendum (currently just a quota key string). Read and discard it. + if clientRevision >= RevisionAddendum { + if _, err := r.str(); err != nil { + return fmt.Errorf("read client addendum: %w", err) + } + } + + // --- connection loop --- + for { + if ctx.Err() != nil { + return nil + } + + pktType, err = r.ReadByte() + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return fmt.Errorf("read packet type: %w", err) + } + + switch pktType { + case ClientPing: + if err := writePong(w); err != nil { + return err + } + if err := bw.Flush(); err != nil { + return err + } + + case ClientQuery: + q, err := readClientQuery(r, clientRevision) + if err != nil { + _ = writeException(w, 62, "DB::Exception", err.Error()) + _ = bw.Flush() + return err + } + + // After the query, the client sends an empty data block. + // For SELECT this is always empty; INSERT with inline data + // is not yet supported via the native protocol proxy. + dataPkt, err := r.ReadByte() + if err != nil { + return fmt.Errorf("read post-query data: %w", err) + } + if dataPkt == ClientData { + if err := skipClientData(r, clientRevision); err != nil { + return fmt.Errorf("skip post-query data block: %w", err) + } + } + + if err := h.handleQuery(ctx, w, bw, q, hello.Database, q.Compression); err != nil { + return err + } + + case ClientCancel: + // nothing to cancel in the proxy case + continue + + case ClientData: + // unexpected data outside query context, skip it + if err := skipClientData(r, clientRevision); err != nil { + return fmt.Errorf("skip unexpected data: %w", err) + } + + default: + return fmt.Errorf("unknown packet type: %d", pktType) + } + } +} + +// handleQuery creates a synthetic HTTP request, invokes the QueryHandler, +// parses the JSON response, and writes native protocol data blocks. +func (h *Handler) handleQuery(ctx context.Context, w *protoWriter, bw *bufio.Writer, q *ClientQueryMsg, database string, compressed bool) error { + sql := q.SQL + upper := strings.ToUpper(strings.TrimSpace(sql)) + isSelect := strings.HasPrefix(upper, "SELECT") || strings.HasPrefix(upper, "WITH") + // Ensure SELECT queries return JSON so the response can be transcoded + // to native data blocks. Non-SELECT (INSERT, CREATE, etc.) are proxied + // as-is without format modification. + if isSelect && !strings.Contains(upper, "FORMAT ") { + sql = strings.TrimRight(sql, "; \t\n") + " FORMAT JSON" + } + + // Use "/" as the path — the backend router handles paths relative to + // the backend, not the full frontend path (e.g. "/click1/"). + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/", strings.NewReader(sql)) + if err != nil { + return writeQueryError(w, bw, err) + } + if database != "" { + qp := req.URL.Query() + qp.Set("database", database) + req.URL.RawQuery = qp.Encode() + } + req.Header.Set("Content-Type", "text/plain") + + rec := httptest.NewRecorder() + h.QueryHandler.ServeHTTP(rec, req) + resp := rec.Result() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + msg := string(body) + if msg == "" { + msg = resp.Status + } + return writeQueryError(w, bw, fmt.Errorf("%s", msg)) + } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return writeQueryError(w, bw, err) + } + + if err := writeJSONAsNativeBlocks(w, body, compressed); err != nil { + return writeQueryError(w, bw, err) + } + + if err := writeEndOfStream(w); err != nil { + return err + } + return bw.Flush() +} + +func writeQueryError(w *protoWriter, bw *bufio.Writer, err error) error { + _ = writeException(w, 62, "DB::Exception", err.Error()) + _ = bw.Flush() + return nil +} + +type wfMetaItem struct { + Name string `json:"name"` + Type string `json:"type"` +} + +type wfDocument struct { + Meta []wfMetaItem `json:"meta"` + Data []map[string]any `json:"data"` + Rows *int `json:"rows"` +} + +// writeJSONAsNativeBlocks parses a ClickHouse JSON response and writes it as +// native protocol data blocks. +func writeJSONAsNativeBlocks(w *protoWriter, body []byte, compressed bool) error { + var doc wfDocument + if err := json.Unmarshal(body, &doc); err != nil { + // Not JSON — might be a non-SELECT result. Send empty block. + return writeEmptyBlock(w) + } + + if len(doc.Meta) == 0 || len(doc.Data) == 0 { + return writeEmptyBlock(w) + } + + columns := make([]Column, len(doc.Meta)) + for i, m := range doc.Meta { + columns[i] = Column(m) + } + + numRows := uint64(len(doc.Data)) + // Convert row-major data to column-major + colValues := make([][]any, len(columns)) + for i := range colValues { + colValues[i] = make([]any, numRows) + } + for rowIdx, row := range doc.Data { + for colIdx, col := range columns { + colValues[colIdx][rowIdx] = row[col.Name] + } + } + + if compressed { + return writeCompressedDataBlock(w, columns, colValues, numRows) + } + return writeDataBlock(w, columns, colValues, numRows) +} diff --git a/pkg/backends/clickhouse/native/server/handler_test.go b/pkg/backends/clickhouse/native/server/handler_test.go new file mode 100644 index 000000000..1b555dbf3 --- /dev/null +++ b/pkg/backends/clickhouse/native/server/handler_test.go @@ -0,0 +1,297 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net" + "net/http" + "testing" + "time" +) + +// echoJSONHandler returns a ClickHouse-shaped JSON response with the SQL +// query echoed back as a single-row result. +func echoJSONHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + doc := map[string]any{ + "meta": []map[string]string{ + {"name": "query", "type": "String"}, + }, + "data": []map[string]any{ + {"query": string(body)}, + }, + "rows": 1, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(doc) + }) +} + +func TestHandlerPingPong(t *testing.T) { + h := &Handler{ + QueryHandler: echoJSONHandler(), + } + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + errCh := make(chan error, 1) + go func() { + errCh <- h.HandleConnection(ctx, serverConn) + }() + + cw := newProtoWriter(clientConn) + cr := newProtoReader(clientConn) + + // Send ClientHello + cw.putByte(ClientHello) + cw.putStr("test-client") + cw.putUvarint(1) + cw.putUvarint(0) + cw.putUvarint(ServerRevision) + cw.putStr("default") + cw.putStr("") + cw.putStr("") + + // Read ServerHello + pkt, err := cr.ReadByte() + if err != nil { + t.Fatal(err) + } + if pkt != ServerHello { + t.Fatalf("expected ServerHello, got %d", pkt) + } + // skip rest of hello + cr.str() // name + cr.uvarint() + cr.uvarint() + cr.uvarint() + cr.str() // tz + cr.str() // display + cr.uvarint() + + // Send addendum (quota key, required for revision >= 54458) + cw.putStr("") + + // Send Ping + cw.putByte(ClientPing) + + // Read Pong + pkt, err = cr.ReadByte() + if err != nil { + t.Fatal(err) + } + if pkt != ServerPong { + t.Fatalf("expected ServerPong (%d), got %d", ServerPong, pkt) + } + + // Close client side + clientConn.Close() + cancel() +} + +func TestHandlerQuery(t *testing.T) { + h := &Handler{ + QueryHandler: echoJSONHandler(), + } + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + errCh := make(chan error, 1) + go func() { + errCh <- h.HandleConnection(ctx, serverConn) + }() + + cw := newProtoWriter(clientConn) + cr := newProtoReader(clientConn) + + // Send ClientHello + cw.putByte(ClientHello) + cw.putStr("test-client") + cw.putUvarint(1) + cw.putUvarint(0) + cw.putUvarint(ServerRevision) + cw.putStr("default") + cw.putStr("") + cw.putStr("") + + // Read ServerHello + pkt, _ := cr.ReadByte() + if pkt != ServerHello { + t.Fatalf("expected ServerHello, got %d", pkt) + } + cr.str() + cr.uvarint() + cr.uvarint() + cr.uvarint() + cr.str() + cr.str() + cr.uvarint() + + // Send ClientQuery + sendTestQuery(t, cw, "SELECT 1") + + // Send empty ClientData block (required after query) + sendEmptyDataBlock(cw) + + // Read response — should be ServerData + pkt, err := cr.ReadByte() + if err != nil { + t.Fatal(err) + } + if pkt != ServerData { + t.Fatalf("expected ServerData (%d), got %d", ServerData, pkt) + } + + // Read block name + blockName, _ := cr.str() + if blockName != "" { + t.Fatalf("expected empty block name, got %q", blockName) + } + + // Skip block info + cr.uvarint() + cr.boolean() + cr.uvarint() + cr.int32() + cr.uvarint() + + numCols, _ := cr.uvarint() + numRows, _ := cr.uvarint() + if numCols != 1 { + t.Fatalf("expected 1 column, got %d", numCols) + } + if numRows != 1 { + t.Fatalf("expected 1 row, got %d", numRows) + } + + colName, _ := cr.str() + if colName != "query" { + t.Fatalf("expected column name %q, got %q", "query", colName) + } + + // Read EndOfStream + colType, _ := cr.str() + _ = colType + cr.boolean() // custom serialization flag + val, _ := cr.str() + if val != "SELECT 1 FORMAT JSON" { + t.Fatalf("expected query value %q, got %q", "SELECT 1 FORMAT JSON", val) + } + + pkt, _ = cr.ReadByte() + if pkt != ServerEndOfStream { + t.Fatalf("expected ServerEndOfStream (%d), got %d", ServerEndOfStream, pkt) + } + + clientConn.Close() + cancel() +} + +func TestWriteJSONAsNativeBlocks(t *testing.T) { + doc := map[string]any{ + "meta": []map[string]string{ + {"name": "n", "type": "UInt32"}, + }, + "data": []map[string]any{ + {"n": float64(42)}, + }, + "rows": 1, + } + body, _ := json.Marshal(doc) + + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeJSONAsNativeBlocks(w, body, false); err != nil { + t.Fatal(err) + } + + r := newProtoReader(&buf) + pkt, _ := r.ReadByte() + if pkt != ServerData { + t.Fatalf("expected ServerData, got %d", pkt) + } +} + +// sendTestQuery writes a minimal ClientQuery with the given SQL. +func sendTestQuery(t *testing.T, w *protoWriter, sql string) { + t.Helper() + w.putByte(ClientQuery) + w.putStr("") // query ID + w.putByte(1) // query kind = initial + w.putStr("") // initial user + w.putStr("") // initial query ID + w.putStr("127.0.0.1") // initial address + // initial query start time (8 bytes, revision >= 54449) + w.putInt64(0) + w.putByte(1) // interface = TCP + w.putStr("") // os user + w.putStr("") // os hostname + w.putStr("t") // client name + w.putUvarint(1) + w.putUvarint(0) + w.putUvarint(ServerRevision) // client protocol revision + w.putStr("") // quota key (revision >= 54060) + w.putUvarint(0) // distributed depth (revision >= 54448) + w.putUvarint(0) // version patch (revision >= 54401) + w.putByte(0) // no otel span (revision >= 54442) + // parallel replicas (revision >= 54453) + w.putUvarint(0) + w.putUvarint(0) + w.putUvarint(0) + // end of client info + + // settings (empty) + w.putStr("") + // interserver secret (revision >= 54441) + w.putStr("") + // state + w.putUvarint(2) // StateComplete + // compression + w.putBool(false) + // SQL + w.putStr(sql) + // parameters (revision >= 54459) + w.putStr("") +} + +// sendEmptyDataBlock writes an empty ClientData block. +func sendEmptyDataBlock(w *protoWriter) { + w.putByte(ClientData) + w.putStr("") // block name + // block info + w.putUvarint(0) // is_overflows + w.putBool(false) // bucket_num + w.putUvarint(2) // bucket_size + w.putInt32(-1) // reserved + w.putUvarint(0) // reserved + // num columns, num rows + w.putUvarint(0) + w.putUvarint(0) +} diff --git a/pkg/backends/clickhouse/native/server/handshake.go b/pkg/backends/clickhouse/native/server/handshake.go new file mode 100644 index 000000000..138383cb6 --- /dev/null +++ b/pkg/backends/clickhouse/native/server/handshake.go @@ -0,0 +1,104 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import "fmt" + +// ClientHelloMsg contains the fields sent by the client in the Hello packet. +type ClientHelloMsg struct { + ClientName string + VersionMajor uint64 + VersionMinor uint64 + ProtoRevision uint64 + Database string + Username string + Password string +} + +// readClientHello reads a ClientHello packet from the wire. The leading packet +// type byte must already have been consumed. +func readClientHello(r *protoReader) (*ClientHelloMsg, error) { + name, err := r.str() + if err != nil { + return nil, fmt.Errorf("read client name: %w", err) + } + major, err := r.uvarint() + if err != nil { + return nil, fmt.Errorf("read major version: %w", err) + } + minor, err := r.uvarint() + if err != nil { + return nil, fmt.Errorf("read minor version: %w", err) + } + revision, err := r.uvarint() + if err != nil { + return nil, fmt.Errorf("read protocol revision: %w", err) + } + db, err := r.str() + if err != nil { + return nil, fmt.Errorf("read database: %w", err) + } + user, err := r.str() + if err != nil { + return nil, fmt.Errorf("read username: %w", err) + } + pass, err := r.str() + if err != nil { + return nil, fmt.Errorf("read password: %w", err) + } + return &ClientHelloMsg{ + ClientName: name, + VersionMajor: major, + VersionMinor: minor, + ProtoRevision: revision, + Database: db, + Username: user, + Password: pass, + }, nil +} + +// writeServerHello sends a ServerHello response to the client. +func writeServerHello(w *protoWriter) error { + if err := w.putByte(ServerHello); err != nil { + return err + } + // Server name + if err := w.putStr("Trickster"); err != nil { + return err + } + // Version major/minor + if err := w.putUvarint(2); err != nil { + return err + } + if err := w.putUvarint(0); err != nil { + return err + } + // Revision + if err := w.putUvarint(ServerRevision); err != nil { + return err + } + // Timezone (revision >= 54058) + if err := w.putStr("UTC"); err != nil { + return err + } + // Display name (revision >= 54372) + if err := w.putStr("Trickster"); err != nil { + return err + } + // Version patch (revision >= 54401) + return w.putUvarint(0) +} diff --git a/pkg/backends/clickhouse/native/server/protocol.go b/pkg/backends/clickhouse/native/server/protocol.go new file mode 100644 index 000000000..a391f6289 --- /dev/null +++ b/pkg/backends/clickhouse/native/server/protocol.go @@ -0,0 +1,168 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package server implements the server side of the ClickHouse native TCP +// wire protocol. It speaks enough of the protocol to accept connections from +// clickhouse-go (or the ClickHouse CLI) and proxy queries through Trickster's +// caching engine. +package server + +import ( + "bufio" + "encoding/binary" + "fmt" + "io" + "math" +) + +// Client-to-server packet types. +const ( + ClientHello = 0 + ClientQuery = 1 + ClientData = 2 + ClientCancel = 3 + ClientPing = 4 +) + +// Server-to-client packet types. +const ( + ServerHello = 0 + ServerData = 1 + ServerException = 2 + ServerProgress = 3 + ServerPong = 4 + ServerEndOfStream = 5 + ServerProfileInfo = 6 +) + +// Protocol revision thresholds used by the server handshake and query parsing. +const ( + RevisionTimezone = 54058 + RevisionQuotaKey = 54060 + RevisionDisplayName = 54372 + RevisionVersionPatch = 54401 + RevisionClientWriteInfo = 54420 + RevisionInterserverSecret = 54441 + RevisionOpenTelemetry = 54442 + RevisionDistributedDepth = 54448 + RevisionInitialQueryStart = 54449 + RevisionParallelReplicas = 54453 + RevisionCustomSerialization = 54454 + RevisionParameters = 54459 + RevisionAddendum = 54458 + RevisionServerQueryTime = 54460 +) + +// ServerRevision is the protocol revision we advertise to clients. +// This matches the DBMS_TCP_PROTOCOL_VERSION from clickhouse-go. +const ServerRevision = RevisionServerQueryTime + +// ---------- wire primitives ---------- + +type protoReader struct { + *bufio.Reader +} + +func newProtoReader(r io.Reader) *protoReader { + return &protoReader{bufio.NewReaderSize(r, 128*1024)} +} + +func (r *protoReader) uvarint() (uint64, error) { + return binary.ReadUvarint(r) +} + +func (r *protoReader) str() (string, error) { + n, err := r.uvarint() + if err != nil { + return "", err + } + if n > 10*1024*1024 { + return "", fmt.Errorf("string too long: %d bytes", n) + } + b := make([]byte, n) + if _, err := io.ReadFull(r, b); err != nil { + return "", err + } + return string(b), nil +} + +func (r *protoReader) int32() (int32, error) { + var b [4]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return 0, err + } + return int32(binary.LittleEndian.Uint32(b[:])), nil //nolint:gosec // wire protocol reinterpret +} + +func (r *protoReader) boolean() (bool, error) { + b, err := r.ReadByte() + return b != 0, err +} + +type protoWriter struct { + io.Writer + buf [binary.MaxVarintLen64]byte +} + +func newProtoWriter(w io.Writer) *protoWriter { + return &protoWriter{Writer: w} +} + +func (w *protoWriter) putByte(b byte) error { + w.buf[0] = b + _, err := w.Write(w.buf[:1]) + return err +} + +func (w *protoWriter) putUvarint(v uint64) error { + n := binary.PutUvarint(w.buf[:], v) + _, err := w.Write(w.buf[:n]) + return err +} + +func (w *protoWriter) putStr(s string) error { + if err := w.putUvarint(uint64(len(s))); err != nil { + return err + } + if len(s) > 0 { + _, err := w.Write([]byte(s)) + return err + } + return nil +} + +func (w *protoWriter) putInt32(v int32) error { + binary.LittleEndian.PutUint32(w.buf[:4], uint32(v)) //nolint:gosec // wire protocol reinterpret + _, err := w.Write(w.buf[:4]) + return err +} + +func (w *protoWriter) putInt64(v int64) error { + binary.LittleEndian.PutUint64(w.buf[:8], uint64(v)) //nolint:gosec // wire protocol reinterpret + _, err := w.Write(w.buf[:8]) + return err +} + +func (w *protoWriter) putFloat64(v float64) error { + return w.putInt64(int64(math.Float64bits(v))) //nolint:gosec // IEEE 754 bit cast +} + +func (w *protoWriter) putBool(v bool) error { + if v { + return w.putByte(1) + } + return w.putByte(0) +} diff --git a/pkg/backends/clickhouse/native/server/protocol_test.go b/pkg/backends/clickhouse/native/server/protocol_test.go new file mode 100644 index 000000000..64c53a3a0 --- /dev/null +++ b/pkg/backends/clickhouse/native/server/protocol_test.go @@ -0,0 +1,424 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "bytes" + "encoding/binary" + "testing" +) + +func TestProtoWriterReadRoundTrip(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + + if err := w.putByte(42); err != nil { + t.Fatal(err) + } + if err := w.putUvarint(12345); err != nil { + t.Fatal(err) + } + if err := w.putStr("hello world"); err != nil { + t.Fatal(err) + } + if err := w.putInt32(-99); err != nil { + t.Fatal(err) + } + if err := w.putBool(true); err != nil { + t.Fatal(err) + } + + r := newProtoReader(&buf) + + b, err := r.ReadByte() + if err != nil { + t.Fatal(err) + } + if b != 42 { + t.Fatalf("expected byte 42, got %d", b) + } + + uv, err := r.uvarint() + if err != nil { + t.Fatal(err) + } + if uv != 12345 { + t.Fatalf("expected uvarint 12345, got %d", uv) + } + + s, err := r.str() + if err != nil { + t.Fatal(err) + } + if s != "hello world" { + t.Fatalf("expected string %q, got %q", "hello world", s) + } + + i, err := r.int32() + if err != nil { + t.Fatal(err) + } + if i != -99 { + t.Fatalf("expected int32 -99, got %d", i) + } + + bv, err := r.boolean() + if err != nil { + t.Fatal(err) + } + if !bv { + t.Fatal("expected true, got false") + } +} + +func TestHandshakeRoundTrip(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + + // Write a ClientHello + w.putByte(ClientHello) + w.putStr("test-client") + w.putUvarint(1) // major + w.putUvarint(2) // minor + w.putUvarint(ServerRevision) + w.putStr("default") + w.putStr("user") + w.putStr("pass") + + r := newProtoReader(&buf) + pkt, _ := r.ReadByte() + if pkt != ClientHello { + t.Fatalf("expected ClientHello packet type 0, got %d", pkt) + } + + hello, err := readClientHello(r) + if err != nil { + t.Fatal(err) + } + if hello.ClientName != "test-client" { + t.Fatalf("expected client name %q, got %q", "test-client", hello.ClientName) + } + if hello.Database != "default" { + t.Fatalf("expected database %q, got %q", "default", hello.Database) + } + if hello.Username != "user" { + t.Fatalf("expected username %q, got %q", "user", hello.Username) + } + if hello.Password != "pass" { + t.Fatalf("expected password %q, got %q", "pass", hello.Password) + } + + // Write ServerHello and verify it starts with the correct packet type + var serverBuf bytes.Buffer + sw := newProtoWriter(&serverBuf) + if err := writeServerHello(sw); err != nil { + t.Fatal(err) + } + + sr := newProtoReader(&serverBuf) + pktType, _ := sr.ReadByte() + if pktType != ServerHello { + t.Fatalf("expected ServerHello packet type 0, got %d", pktType) + } + name, _ := sr.str() + if name != "Trickster" { + t.Fatalf("expected server name %q, got %q", "Trickster", name) + } +} + +func TestWriteException(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + + if err := writeException(w, 60, "DB::SyntaxException", "bad query"); err != nil { + t.Fatal(err) + } + + r := newProtoReader(&buf) + pkt, _ := r.ReadByte() + if pkt != ServerException { + t.Fatalf("expected ServerException packet type %d, got %d", ServerException, pkt) + } + + code, _ := r.int32() + if code != 60 { + t.Fatalf("expected error code 60, got %d", code) + } + + name, _ := r.str() + if name != "DB::SyntaxException" { + t.Fatalf("expected exception name %q, got %q", "DB::SyntaxException", name) + } + + msg, _ := r.str() + if msg != "bad query" { + t.Fatalf("expected message %q, got %q", "bad query", msg) + } +} + +func TestWriteEmptyBlock(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + + if err := writeEmptyBlock(w); err != nil { + t.Fatal(err) + } + + r := newProtoReader(&buf) + pkt, _ := r.ReadByte() + if pkt != ServerData { + t.Fatalf("expected ServerData packet type %d, got %d", ServerData, pkt) + } +} + +func TestWriteDataBlock(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + + columns := []Column{ + {Name: "id", Type: "UInt32"}, + {Name: "name", Type: "String"}, + } + values := [][]any{ + {uint32(1), uint32(2)}, + {"alice", "bob"}, + } + + if err := writeDataBlock(w, columns, values, 2); err != nil { + t.Fatal(err) + } + + r := newProtoReader(&buf) + pkt, _ := r.ReadByte() + if pkt != ServerData { + t.Fatalf("expected ServerData, got %d", pkt) + } + + // block name (empty) + blockName, _ := r.str() + if blockName != "" { + t.Fatalf("expected empty block name, got %q", blockName) + } + + // skip block info: is_overflows(uvarint), bucket_num(bool), bucket_size(uvarint), reserved(int32), reserved(uvarint) + r.uvarint() + r.boolean() + r.uvarint() + r.int32() + r.uvarint() + + numCols, _ := r.uvarint() + numRows, _ := r.uvarint() + if numCols != 2 { + t.Fatalf("expected 2 columns, got %d", numCols) + } + if numRows != 2 { + t.Fatalf("expected 2 rows, got %d", numRows) + } + + // First column: id UInt32 + colName, _ := r.str() + if colName != "id" { + t.Fatalf("expected column name %q, got %q", "id", colName) + } + colType, _ := r.str() + if colType != "UInt32" { + t.Fatalf("expected column type %q, got %q", "UInt32", colType) + } + // custom serialization flag + r.boolean() + // Two UInt32 values + var val [4]byte + r.Read(val[:]) + if binary.LittleEndian.Uint32(val[:]) != 1 { + t.Fatal("expected first value 1") + } + r.Read(val[:]) + if binary.LittleEndian.Uint32(val[:]) != 2 { + t.Fatal("expected second value 2") + } + + // Second column: name String + colName, _ = r.str() + if colName != "name" { + t.Fatalf("expected column name %q, got %q", "name", colName) + } + colType, _ = r.str() + if colType != "String" { + t.Fatalf("expected column type %q, got %q", "String", colType) + } + r.boolean() // custom serialization flag + s1, _ := r.str() + if s1 != "alice" { + t.Fatalf("expected %q, got %q", "alice", s1) + } + s2, _ := r.str() + if s2 != "bob" { + t.Fatalf("expected %q, got %q", "bob", s2) + } +} + +func TestWritePong(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writePong(w); err != nil { + t.Fatal(err) + } + if buf.Bytes()[0] != ServerPong { + t.Fatalf("expected ServerPong (%d), got %d", ServerPong, buf.Bytes()[0]) + } +} + +func TestWriteEndOfStream(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeEndOfStream(w); err != nil { + t.Fatal(err) + } + if buf.Bytes()[0] != ServerEndOfStream { + t.Fatalf("expected ServerEndOfStream (%d), got %d", ServerEndOfStream, buf.Bytes()[0]) + } +} + +func TestWriteColumnDataNullable(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + values := []any{"hello", nil, "world"} + if err := writeColumnData(w, "Nullable(String)", values); err != nil { + t.Fatal(err) + } + r := newProtoReader(&buf) + // Read null bitmap: 0, 1, 0 + b0, _ := r.ReadByte() + b1, _ := r.ReadByte() + b2, _ := r.ReadByte() + if b0 != 0 || b1 != 1 || b2 != 0 { + t.Fatalf("expected null bitmap [0,1,0], got [%d,%d,%d]", b0, b1, b2) + } + // Read values: "hello", "" (zero for null), "world" + s0, _ := r.str() + s1, _ := r.str() + s2, _ := r.str() + if s0 != "hello" || s1 != "" || s2 != "world" { + t.Fatalf("expected [hello,,world], got [%s,%s,%s]", s0, s1, s2) + } +} + +func TestWriteColumnDataArray(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + values := []any{ + []any{float64(1), float64(2)}, + []any{float64(3)}, + } + if err := writeColumnData(w, "Array(UInt32)", values); err != nil { + t.Fatal(err) + } + r := newProtoReader(&buf) + // Offsets: 2, 3 (cumulative) + var off [8]byte + r.Read(off[:]) + if binary.LittleEndian.Uint64(off[:]) != 2 { + t.Fatal("expected first offset 2") + } + r.Read(off[:]) + if binary.LittleEndian.Uint64(off[:]) != 3 { + t.Fatal("expected second offset 3") + } + // Inner data: 3 UInt32 values + var val [4]byte + r.Read(val[:]) + if binary.LittleEndian.Uint32(val[:]) != 1 { + t.Fatal("expected 1") + } + r.Read(val[:]) + if binary.LittleEndian.Uint32(val[:]) != 2 { + t.Fatal("expected 2") + } + r.Read(val[:]) + if binary.LittleEndian.Uint32(val[:]) != 3 { + t.Fatal("expected 3") + } +} + +func TestWriteColumnDataFixedString(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + values := []any{"ab", "abcdef", "x"} + if err := writeColumnData(w, "FixedString(4)", values); err != nil { + t.Fatal(err) + } + data := buf.Bytes() + if len(data) != 12 { // 3 * 4 bytes + t.Fatalf("expected 12 bytes, got %d", len(data)) + } + if string(data[0:4]) != "ab\x00\x00" { + t.Fatalf("expected 'ab\\0\\0', got %q", data[0:4]) + } + if string(data[4:8]) != "abcd" { + t.Fatalf("expected 'abcd' (truncated), got %q", data[4:8]) + } + if string(data[8:12]) != "x\x00\x00\x00" { + t.Fatalf("expected 'x\\0\\0\\0', got %q", data[8:12]) + } +} + +func TestWriteColumnDataTuple(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + values := []any{ + []any{"hello", float64(1)}, + []any{"world", float64(2)}, + } + if err := writeColumnData(w, "Tuple(String, UInt32)", values); err != nil { + t.Fatal(err) + } + r := newProtoReader(&buf) + // First sub-column: two strings + s0, _ := r.str() + s1, _ := r.str() + if s0 != "hello" || s1 != "world" { + t.Fatalf("expected [hello,world], got [%s,%s]", s0, s1) + } + // Second sub-column: two UInt32 + var val [4]byte + r.Read(val[:]) + if binary.LittleEndian.Uint32(val[:]) != 1 { + t.Fatal("expected 1") + } + r.Read(val[:]) + if binary.LittleEndian.Uint32(val[:]) != 2 { + t.Fatal("expected 2") + } +} + +func TestWriteCompressedDataBlock(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + columns := []Column{{Name: "x", Type: "UInt32"}} + values := [][]any{{uint32(42)}} + if err := writeCompressedDataBlock(w, columns, values, 1); err != nil { + t.Fatal(err) + } + if buf.Len() == 0 { + t.Fatal("expected non-empty compressed output") + } + // First byte should be ServerData + if buf.Bytes()[0] != ServerData { + t.Fatalf("expected ServerData, got %d", buf.Bytes()[0]) + } +} diff --git a/pkg/backends/clickhouse/native/server/query.go b/pkg/backends/clickhouse/native/server/query.go new file mode 100644 index 000000000..fa84bf8c3 --- /dev/null +++ b/pkg/backends/clickhouse/native/server/query.go @@ -0,0 +1,265 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import "fmt" + +// ClientQueryMsg contains the relevant fields from a ClientQuery packet. +// We parse enough to extract the SQL but skip over the full client info +// and settings blocks. +type ClientQueryMsg struct { + QueryID string + SQL string + Compression bool +} + +// readClientQuery reads a ClientQuery packet from the wire. The leading +// packet type byte must already have been consumed. revision is the +// protocol revision the client announced in its Hello. +func readClientQuery(r *protoReader, revision uint64) (*ClientQueryMsg, error) { + queryID, err := r.str() + if err != nil { + return nil, fmt.Errorf("read query id: %w", err) + } + + // --- client info --- + if err := skipClientInfo(r, revision); err != nil { + return nil, fmt.Errorf("skip client info: %w", err) + } + + // --- settings --- + if err := skipSettings(r, revision); err != nil { + return nil, fmt.Errorf("skip settings: %w", err) + } + + // --- interserver secret (revision >= 54441) --- + if revision >= RevisionInterserverSecret { + if _, err := r.str(); err != nil { + return nil, fmt.Errorf("read interserver secret: %w", err) + } + } + + // state + compression + state, err := r.uvarint() + if err != nil { + return nil, fmt.Errorf("read state: %w", err) + } + _ = state + + compBool, err := r.boolean() + if err != nil { + return nil, fmt.Errorf("read compression flag: %w", err) + } + + // query body + sql, err := r.str() + if err != nil { + return nil, fmt.Errorf("read query body: %w", err) + } + + // parameters (revision >= 54459) + if revision >= RevisionParameters { + if err := skipKeyValuePairs(r); err != nil { + return nil, fmt.Errorf("skip parameters: %w", err) + } + } + + return &ClientQueryMsg{ + QueryID: queryID, + SQL: sql, + Compression: compBool, + }, nil +} + +// skipClientInfo reads and discards the client info section. +func skipClientInfo(r *protoReader, revision uint64) error { + // query kind + if _, err := r.ReadByte(); err != nil { + return err + } + // initial user, initial query id, initial address + for range 3 { + if _, err := r.str(); err != nil { + return err + } + } + // initial query start time (revision >= 54449) + if revision >= RevisionInitialQueryStart { + var b [8]byte + if _, err := r.Read(b[:]); err != nil { + return err + } + } + // interface type + if _, err := r.ReadByte(); err != nil { + return err + } + // os user, os hostname, client name + for range 3 { + if _, err := r.str(); err != nil { + return err + } + } + // client version major, minor, protocol revision + for range 3 { + if _, err := r.uvarint(); err != nil { + return err + } + } + // quota key (revision >= 54060) + if revision >= RevisionQuotaKey { + if _, err := r.str(); err != nil { + return err + } + } + // distributed depth (revision >= 54448) + if revision >= RevisionDistributedDepth { + if _, err := r.uvarint(); err != nil { + return err + } + } + // client version patch (revision >= 54401) + if revision >= RevisionVersionPatch { + if _, err := r.uvarint(); err != nil { + return err + } + } + // OpenTelemetry (revision >= 54442) + if revision >= RevisionOpenTelemetry { + hasSpan, err := r.ReadByte() + if err != nil { + return err + } + if hasSpan != 0 { + // trace id (16) + span id (8) + trace state (string) + flags (1) + var skip [24]byte + if _, err := r.Read(skip[:]); err != nil { + return err + } + if _, err := r.str(); err != nil { + return err + } + if _, err := r.ReadByte(); err != nil { + return err + } + } + } + // parallel replicas (revision >= 54453) + if revision >= RevisionParallelReplicas { + for range 3 { + if _, err := r.uvarint(); err != nil { + return err + } + } + } + return nil +} + +// skipSettings reads and discards the settings key=value list. +func skipSettings(r *protoReader, revision uint64) error { + for { + name, err := r.str() + if err != nil { + return err + } + if name == "" { + break + } + if revision <= 54429 { + // old format: uvarint value + if _, err := r.uvarint(); err != nil { + return err + } + } else { + // new format: flags (uvarint) + string value + if _, err := r.uvarint(); err != nil { + return err + } + if _, err := r.str(); err != nil { + return err + } + } + } + return nil +} + +// skipKeyValuePairs reads string key/value pairs terminated by an empty key. +func skipKeyValuePairs(r *protoReader) error { + for { + name, err := r.str() + if err != nil { + return err + } + if name == "" { + break + } + // flags + value + if _, err := r.uvarint(); err != nil { + return err + } + if _, err := r.str(); err != nil { + return err + } + } + return nil +} + +// skipClientData reads and discards a ClientData block (the empty data block +// that clients send after a query). The leading packet type byte must already +// have been consumed. +func skipClientData(r *protoReader, revision uint64) error { + // block name (external table name, typically empty) + if _, err := r.str(); err != nil { + return err + } + // block info + if revision > 0 { + // is_overflows (uvarint), bucket_num (bool) + if _, err := r.uvarint(); err != nil { + return err + } + if _, err := r.boolean(); err != nil { + return err + } + // bucket_size (uvarint) + if _, err := r.uvarint(); err != nil { + return err + } + // reserved int32 + if _, err := r.int32(); err != nil { + return err + } + // reserved uvarint + if _, err := r.uvarint(); err != nil { + return err + } + } + // num columns, num rows + numCols, err := r.uvarint() + if err != nil { + return err + } + numRows, err := r.uvarint() + if err != nil { + return err + } + _ = numCols + _ = numRows + // For the empty data block after query, both should be 0. + // If not, we'd need to read column data, but we don't support INSERT. + return nil +} diff --git a/pkg/backends/clickhouse/native/server/response.go b/pkg/backends/clickhouse/native/server/response.go new file mode 100644 index 000000000..186a28258 --- /dev/null +++ b/pkg/backends/clickhouse/native/server/response.go @@ -0,0 +1,431 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "bytes" + "encoding/binary" + "fmt" + "strconv" + "strings" + + "github.com/ClickHouse/ch-go/compress" +) + +// writeException sends a ServerException packet. +func writeException(w *protoWriter, code int32, name, message string) error { + if err := w.putByte(ServerException); err != nil { + return err + } + if err := w.putInt32(code); err != nil { + return err + } + if err := w.putStr(name); err != nil { + return err + } + if err := w.putStr(message); err != nil { + return err + } + if err := w.putStr(""); err != nil { // stack trace + return err + } + return w.putBool(false) // has nested +} + +// writeEndOfStream sends a ServerEndOfStream packet. +func writeEndOfStream(w *protoWriter) error { + return w.putByte(ServerEndOfStream) +} + +// writePong sends a ServerPong packet. +func writePong(w *protoWriter) error { + return w.putByte(ServerPong) +} + +// Column represents a named and typed column for writing data blocks. +type Column struct { + Name string + Type string +} + +// writeDataBlock writes a ServerData packet containing a result block. +// values is a column-major 2D slice: values[col][row]. +func writeDataBlock(w *protoWriter, columns []Column, values [][]any, numRows uint64) error { + if err := w.putByte(ServerData); err != nil { + return err + } + if err := w.putStr(""); err != nil { + return err + } + return writeBlockContent(w, columns, values, numRows) +} + +// writeEmptyBlock sends an empty ServerData block (0 columns, 0 rows). +func writeEmptyBlock(w *protoWriter) error { + return writeDataBlock(w, nil, nil, 0) +} + +// writeCompressedDataBlock writes a ServerData packet with LZ4-compressed +// block content. +func writeCompressedDataBlock(w *protoWriter, columns []Column, values [][]any, numRows uint64) error { + if err := w.putByte(ServerData); err != nil { + return err + } + if err := w.putStr(""); err != nil { + return err + } + + var buf bytes.Buffer + bw := newProtoWriter(&buf) + if err := writeBlockContent(bw, columns, values, numRows); err != nil { + return err + } + + cw := compress.NewWriter(compress.LevelZero, compress.LZ4) + if err := cw.Compress(buf.Bytes()); err != nil { + return fmt.Errorf("compress block: %w", err) + } + _, err := w.Write(cw.Data) + return err +} + +// writeBlockContent writes block info + columns + data (no packet header). +func writeBlockContent(w *protoWriter, columns []Column, values [][]any, numRows uint64) error { + if err := w.putUvarint(0); err != nil { // is_overflows + return err + } + if err := w.putBool(false); err != nil { // bucket_num + return err + } + if err := w.putUvarint(2); err != nil { // bucket_size + return err + } + if err := w.putInt32(-1); err != nil { // reserved + return err + } + if err := w.putUvarint(0); err != nil { // reserved + return err + } + + numCols := uint64(len(columns)) + if err := w.putUvarint(numCols); err != nil { + return err + } + if err := w.putUvarint(numRows); err != nil { + return err + } + + for i, col := range columns { + if err := w.putStr(col.Name); err != nil { + return err + } + if err := w.putStr(col.Type); err != nil { + return err + } + if err := w.putBool(false); err != nil { // custom serialization flag + return err + } + if numRows > 0 { + if err := writeColumnData(w, col.Type, values[i]); err != nil { + return fmt.Errorf("write column %q data: %w", col.Name, err) + } + } + } + return nil +} + +// writeColumnData writes the raw column data for a single column. +// +//nolint:gosec // intentional wire protocol conversions +func writeColumnData(w *protoWriter, colType string, values []any) error { + // Nullable(T): null bitmap + inner column data + if inner, ok := strings.CutPrefix(colType, "Nullable("); ok { + inner = strings.TrimSuffix(inner, ")") + for _, v := range values { + if v == nil { + if err := w.putByte(1); err != nil { + return err + } + } else { + if err := w.putByte(0); err != nil { + return err + } + } + } + for _, v := range values { + if v == nil { + v = zeroForType(inner) + } + if err := writeValue(w, inner, v); err != nil { + return err + } + } + return nil + } + + // Array(T): cumulative offsets (UInt64) + flattened inner data + if inner, ok := strings.CutPrefix(colType, "Array("); ok { + inner = strings.TrimSuffix(inner, ")") + var offset uint64 + var flattened []any + for _, v := range values { + arr := toSlice(v) + offset += uint64(len(arr)) + var b [8]byte + binary.LittleEndian.PutUint64(b[:], offset) + if _, err := w.Write(b[:]); err != nil { + return err + } + flattened = append(flattened, arr...) + } + return writeColumnData(w, inner, flattened) + } + + // Map(K, V): offsets + keys + values + if rest, ok := strings.CutPrefix(colType, "Map("); ok { + rest = strings.TrimSuffix(rest, ")") + keyType, valType := splitMapTypes(rest) + var offset uint64 + var allKeys, allVals []any + for _, v := range values { + m := toMap(v) + offset += uint64(len(m)) + var b [8]byte + binary.LittleEndian.PutUint64(b[:], offset) + if _, err := w.Write(b[:]); err != nil { + return err + } + for mk, mv := range m { + allKeys = append(allKeys, mk) + allVals = append(allVals, mv) + } + } + if err := writeColumnData(w, keyType, allKeys); err != nil { + return err + } + return writeColumnData(w, valType, allVals) + } + + // Tuple(T1, T2, ...): each sub-type written as a column + if rest, ok := strings.CutPrefix(colType, "Tuple("); ok { + rest = strings.TrimSuffix(rest, ")") + subTypes := splitTupleTypes(rest) + for colIdx, st := range subTypes { + subVals := make([]any, len(values)) + for rowIdx, v := range values { + subVals[rowIdx] = tupleElement(v, colIdx) + } + if err := writeColumnData(w, st, subVals); err != nil { + return err + } + } + return nil + } + + // FixedString(N): exactly N bytes per row + if rest, ok := strings.CutPrefix(colType, "FixedString("); ok { + rest = strings.TrimSuffix(rest, ")") + n := parseInt(rest) + for _, v := range values { + b := []byte(toString(v)) + if len(b) >= n { + b = b[:n] + } else { + b = append(b, make([]byte, n-len(b))...) + } + if _, err := w.Write(b); err != nil { + return err + } + } + return nil + } + + // LowCardinality(T): dictionary + index + if inner, ok := strings.CutPrefix(colType, "LowCardinality("); ok { + inner = strings.TrimSuffix(inner, ")") + return writeLowCardinality(w, inner, values) + } + + // Default: per-value encoding + for _, v := range values { + if err := writeValue(w, colType, v); err != nil { + return err + } + } + return nil +} + +// zeroForType returns the zero value for a ClickHouse type. +func zeroForType(colType string) any { + switch colType { + case "UInt8", "Int8": + return uint8(0) + case "UInt16", "Int16": + return uint16(0) + case "UInt32", "Int32", "DateTime": + return uint32(0) + case "UInt64", "Int64", "DateTime64": + return uint64(0) + case "Float32": + return float32(0) + case "Float64": + return float64(0) + default: + return "" + } +} + +func toSlice(v any) []any { + if x, ok := v.([]any); ok { + return x + } + return nil +} + +func toMap(v any) map[string]any { + if x, ok := v.(map[string]any); ok { + return x + } + return nil +} + +func splitMapTypes(s string) (string, string) { + depth := 0 + for i, c := range s { + switch c { + case '(': + depth++ + case ')': + depth-- + case ',': + if depth == 0 { + return strings.TrimSpace(s[:i]), strings.TrimSpace(s[i+1:]) + } + } + } + return s, "" +} + +func splitTupleTypes(s string) []string { + var result []string + depth := 0 + start := 0 + for i, c := range s { + switch c { + case '(': + depth++ + case ')': + depth-- + case ',': + if depth == 0 { + result = append(result, strings.TrimSpace(s[start:i])) + start = i + 1 + } + } + } + if start < len(s) { + result = append(result, strings.TrimSpace(s[start:])) + } + return result +} + +func tupleElement(v any, idx int) any { + if x, ok := v.([]any); ok { + if idx < len(x) { + return x[idx] + } + } + return nil +} + +func parseInt(s string) int { + n, _ := strconv.Atoi(strings.TrimSpace(s)) + return n +} + +// writeLowCardinality writes a LowCardinality(T) column using dictionary encoding. +// +//nolint:gosec // intentional wire protocol conversions +func writeLowCardinality(w *protoWriter, innerType string, values []any) error { + dict := make(map[string]int) + var dictValues []any + indices := make([]int, len(values)) + for i, v := range values { + key := fmt.Sprint(v) + idx, ok := dict[key] + if !ok { + idx = len(dictValues) + dict[key] = idx + dictValues = append(dictValues, v) + } + indices[i] = idx + } + + if err := w.putInt64(1); err != nil { // serialization version + return err + } + + var indexType uint64 + switch { + case len(dictValues) <= 256: + indexType = 0 + case len(dictValues) <= 65536: + indexType = 1 + default: + indexType = 2 + } + + flags := indexType | (1 << 9) // need_update_dictionary=1 + if err := w.putInt64(int64(flags)); err != nil { + return err + } + + if err := w.putInt64(int64(len(dictValues))); err != nil { + return err + } + for _, dv := range dictValues { + if err := writeValue(w, innerType, dv); err != nil { + return err + } + } + + if err := w.putInt64(int64(len(indices))); err != nil { + return err + } + + for _, idx := range indices { + switch indexType { + case 0: + if err := w.putByte(byte(idx)); err != nil { + return err + } + case 1: + var b [2]byte + binary.LittleEndian.PutUint16(b[:], uint16(idx)) + if _, err := w.Write(b[:]); err != nil { + return err + } + default: + var b [4]byte + binary.LittleEndian.PutUint32(b[:], uint32(idx)) + if _, err := w.Write(b[:]); err != nil { + return err + } + } + } + return nil +} diff --git a/pkg/backends/clickhouse/native/server/types.go b/pkg/backends/clickhouse/native/server/types.go new file mode 100644 index 000000000..71ea44942 --- /dev/null +++ b/pkg/backends/clickhouse/native/server/types.go @@ -0,0 +1,304 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "math" + "math/big" + "net" + "strconv" + "strings" + "time" +) + +// writeValue encodes a single value according to its ClickHouse type. +// +//nolint:gosec // intentional wire protocol integer conversions throughout +func writeValue(w *protoWriter, colType string, v any) error { + switch { + case colType == "UInt8" || colType == "Bool": + return w.putByte(toNumeric[uint8](v)) + case colType == "UInt16": + var b [2]byte + binary.LittleEndian.PutUint16(b[:], toNumeric[uint16](v)) + _, err := w.Write(b[:]) + return err + case colType == "UInt32": + var b [4]byte + binary.LittleEndian.PutUint32(b[:], toNumeric[uint32](v)) + _, err := w.Write(b[:]) + return err + case colType == "UInt64": + var b [8]byte + binary.LittleEndian.PutUint64(b[:], toNumeric[uint64](v)) + _, err := w.Write(b[:]) + return err + case colType == "Int8": + return w.putByte(byte(toNumeric[int8](v))) + case colType == "Int16": + var b [2]byte + binary.LittleEndian.PutUint16(b[:], uint16(toNumeric[int16](v))) //nolint:gosec + _, err := w.Write(b[:]) + return err + case colType == "Int32": + return w.putInt32(toNumeric[int32](v)) + case colType == "Int64": + return w.putInt64(toNumeric[int64](v)) + case colType == "Float32": + var b [4]byte + binary.LittleEndian.PutUint32(b[:], math.Float32bits(toNumeric[float32](v))) + _, err := w.Write(b[:]) + return err + case colType == "Float64": + return w.putFloat64(toNumeric[float64](v)) + case colType == "DateTime": + return w.putInt32(int32(toDateTime(v))) //nolint:gosec + case strings.HasPrefix(colType, "DateTime64"): + return w.putInt64(toDateTime64(v)) + case colType == "Date": + var b [2]byte + binary.LittleEndian.PutUint16(b[:], toDate(v)) + _, err := w.Write(b[:]) + return err + case colType == "Date32": + return w.putInt32(toDate32(v)) + case colType == "UUID": + return writeUUID(w, v) + case colType == "IPv4": + return writeIPv4(w, v) + case colType == "IPv6": + return writeIPv6(w, v) + case colType == "Enum8" || strings.HasPrefix(colType, "Enum8("): + return w.putByte(byte(toNumeric[int8](v))) + case colType == "Enum16" || strings.HasPrefix(colType, "Enum16("): + var b [2]byte + binary.LittleEndian.PutUint16(b[:], uint16(toNumeric[int16](v))) //nolint:gosec + _, err := w.Write(b[:]) + return err + case colType == "Int128" || colType == "UInt128": + return writeBigInt(w, v, 16) + case colType == "Int256" || colType == "UInt256": + return writeBigInt(w, v, 32) + case strings.HasPrefix(colType, "Decimal32"): + return w.putInt32(toNumeric[int32](v)) + case strings.HasPrefix(colType, "Decimal64"): + return w.putInt64(toNumeric[int64](v)) + case strings.HasPrefix(colType, "Decimal128"): + return writeBigInt(w, v, 16) + case strings.HasPrefix(colType, "Decimal256"): + return writeBigInt(w, v, 32) + case strings.HasPrefix(colType, "Decimal("): + return w.putInt64(toNumeric[int64](v)) + default: + return w.putStr(toString(v)) + } +} + +// toNumeric converts an any value to a numeric type T. +// +//nolint:gosec // intentional wire protocol conversions +func toNumeric[T interface { + ~int8 | ~int16 | ~int32 | ~int64 | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 +}](v any) T { + switch x := v.(type) { + case float64: + return T(x) + case float32: + return T(x) + case int64: + return T(x) + case int32: + return T(x) + case int: + return T(x) + case uint64: + return T(x) + case uint32: + return T(x) + case uint16: + return T(x) + case uint8: + return T(x) + default: + var zero T + return zero + } +} + +func toString(v any) string { + switch x := v.(type) { + case string: + return x + case []byte: + return string(x) + default: + return fmt.Sprint(v) + } +} + +//nolint:gosec // intentional wire protocol conversions +func toDateTime(v any) uint32 { + switch x := v.(type) { + case time.Time: + return uint32(x.Unix()) + case uint32: + return x + case float64: + return uint32(x) + default: + return uint32(toNumeric[int64](v)) + } +} + +//nolint:gosec // intentional wire protocol conversions +func toDate(v any) uint16 { + switch x := v.(type) { + case time.Time: + epoch := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) + return uint16(x.Sub(epoch).Hours() / 24) + case uint16: + return x + default: + return uint16(toNumeric[int64](v)) + } +} + +//nolint:gosec // intentional wire protocol conversions +func toDate32(v any) int32 { + switch x := v.(type) { + case time.Time: + epoch := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) + return int32(x.Sub(epoch).Hours() / 24) + case int32: + return x + default: + return toNumeric[int32](v) + } +} + +//nolint:gosec // intentional wire protocol conversions +func toDateTime64(v any) int64 { + switch x := v.(type) { + case time.Time: + return x.UnixMilli() + default: + return toNumeric[int64](x) + } +} + +// writeBigInt writes an integer as a little-endian byte sequence of the given +// size (16 for Int128/UInt128, 32 for Int256/UInt256). +func writeBigInt(w *protoWriter, v any, size int) error { + bi := toBigInt(v) + b := make([]byte, size) + if bi.Sign() >= 0 { + raw := bi.Bytes() + for i, j := 0, len(raw)-1; j >= 0 && i < size; i, j = i+1, j-1 { + b[i] = raw[j] + } + } else { + tc := new(big.Int).Add(bi, new(big.Int).Lsh(big.NewInt(1), uint(size*8))) //nolint:gosec // size is always 16 or 32 + raw := tc.Bytes() + for i, j := 0, len(raw)-1; j >= 0 && i < size; i, j = i+1, j-1 { + b[i] = raw[j] + } + for i := len(raw); i < size; i++ { + b[i] = 0xFF + } + } + _, err := w.Write(b) + return err +} + +func toBigInt(v any) *big.Int { + switch x := v.(type) { + case *big.Int: + return x + case int64: + return big.NewInt(x) + case uint64: + return new(big.Int).SetUint64(x) + case float64: + bi, _ := new(big.Int).SetString(strconv.FormatFloat(x, 'f', 0, 64), 10) + if bi == nil { + return big.NewInt(0) + } + return bi + case string: + bi, ok := new(big.Int).SetString(x, 10) + if !ok { + return big.NewInt(0) + } + return bi + default: + return big.NewInt(0) + } +} + +// writeUUID writes a UUID as 16 bytes. ClickHouse stores UUIDs as two +// little-endian uint64s with the halves swapped relative to RFC 4122. +func writeUUID(w *protoWriter, v any) error { + s := toString(v) + s = strings.ReplaceAll(s, "-", "") + b, err := hex.DecodeString(s) + if err != nil || len(b) != 16 { + b = make([]byte, 16) + } + var buf [16]byte + binary.LittleEndian.PutUint64(buf[0:8], binary.BigEndian.Uint64(b[0:8])) + binary.LittleEndian.PutUint64(buf[8:16], binary.BigEndian.Uint64(b[8:16])) + _, err = w.Write(buf[:]) + return err +} + +// writeIPv4 writes an IPv4 address as a 4-byte little-endian uint32. +func writeIPv4(w *protoWriter, v any) error { + s := toString(v) + ip := net.ParseIP(s) + if ip == nil { + ip = net.IPv4zero + } + ip4 := ip.To4() + if ip4 == nil { + ip4 = net.IPv4zero.To4() + } + var b [4]byte + b[0] = ip4[3] + b[1] = ip4[2] + b[2] = ip4[1] + b[3] = ip4[0] + _, err := w.Write(b[:]) + return err +} + +// writeIPv6 writes an IPv6 address as 16 bytes in network byte order. +func writeIPv6(w *protoWriter, v any) error { + s := toString(v) + ip := net.ParseIP(s) + if ip == nil { + ip = net.IPv6zero + } + ip16 := ip.To16() + if ip16 == nil { + ip16 = net.IPv6zero + } + _, err := w.Write([]byte(ip16)) + return err +} diff --git a/pkg/backends/options/options.go b/pkg/backends/options/options.go index 589861e99..bf402267d 100644 --- a/pkg/backends/options/options.go +++ b/pkg/backends/options/options.go @@ -65,6 +65,10 @@ type Options struct { // OriginURL provides the base upstream URL for all proxied requests to this Backend. // it can be as simple as http://example.com or as complex as https://example.com:8443/path/prefix OriginURL string `yaml:"origin_url,omitempty"` + // Protocol selects the upstream wire protocol used to communicate with the origin. + // When empty, HTTP is used. Supported values are provider-specific (e.g., "native" + // for ClickHouse to use the binary protocol on port 9000). + Protocol string `yaml:"protocol,omitempty"` // Timeout defines how long the HTTP request will wait for a response before timing out Timeout time.Duration `yaml:"timeout,omitempty"` // KeepAliveTimeout defines how long an open keep-alive HTTP connection remains idle before closing @@ -153,6 +157,11 @@ type Options struct { // ForwardedHeaders indicates the class of 'Forwarded' header to attach to upstream requests ForwardedHeaders string `yaml:"forwarded_headers,omitempty"` + // DPCFallbackWarning, when true (default), logs a warning when a query cannot + // be parsed as a time range query and falls back from DPC to OPC. Set to false + // to suppress these warnings (they will still appear at debug level). + DPCFallbackWarning *bool `yaml:"dpc_fallback_warning,omitempty"` + // IsDefault indicates if this is the d.Default backend for any request not matching a configured route IsDefault bool `yaml:"is_default,omitempty"` // FastForwardDisable indicates whether the FastForward feature should be disabled for this backend @@ -223,6 +232,11 @@ type Options struct { // DoesShard is true when sharding will be used with this origin, based on how the // sharding options have been configured DoesShard bool `yaml:"-"` + // Fetcher, when non-nil, replaces HTTPClient.Do for upstream requests. + // This allows non-HTTP backends (e.g. ClickHouse native protocol) to + // intercept the fetch and translate between HTTP-shaped requests and the + // native wire format. When nil, HTTPClient.Do is used. + Fetcher func(*http.Request) (*http.Response, error) `yaml:"-"` } var _ types.ConfigOptions[Options] = &Options{} diff --git a/pkg/backends/protocol.go b/pkg/backends/protocol.go new file mode 100644 index 000000000..00f2eadf2 --- /dev/null +++ b/pkg/backends/protocol.go @@ -0,0 +1,28 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package backends + +import "github.com/trickstercache/trickster/v2/pkg/proxy/connhandler" + +// ConnectionHandlerProvider is an optional interface that backends may +// implement to accept raw TCP connections for non-HTTP protocols (e.g. +// ClickHouse native binary protocol, MySQL wire protocol). +type ConnectionHandlerProvider interface { + // ConnectionHandler returns a ConnectionHandler for the given protocol + // name, or nil if the backend does not support the requested protocol. + ConnectionHandler(protocol string) connhandler.ConnectionHandler +} diff --git a/pkg/daemon/setup/listeners.go b/pkg/daemon/setup/listeners.go index 00f8d8b37..24c7d7bcc 100644 --- a/pkg/daemon/setup/listeners.go +++ b/pkg/daemon/setup/listeners.go @@ -191,4 +191,44 @@ func applyListenerConfigs(conf, oldConf *config.Config, } else { lg.UpdateRouter("mgmtListener", mr) } + + applyProtocolListenerConfigs(conf, lg, o, errorFunc, drainTimeout) +} + +func applyProtocolListenerConfigs(conf *config.Config, lg *listener.Group, + o backends.Backends, errorFunc func(), drainTimeout time.Duration, +) { + if conf.Frontend == nil || len(conf.Frontend.ProtocolListeners) == 0 { + return + } + for _, plo := range conf.Frontend.ProtocolListeners { + if err := plo.Validate(); err != nil { + logger.Error("invalid protocol listener config", + logging.Pairs{"detail": err.Error()}) + continue + } + listenerName := "proto_" + plo.Name + bc := o.Get(plo.Backend) + if bc == nil { + logger.Error("protocol listener references unknown backend", + logging.Pairs{"listener": plo.Name, "backend": plo.Backend}) + continue + } + chp, ok := bc.(backends.ConnectionHandlerProvider) + if !ok { + logger.Error("backend does not support protocol listeners", + logging.Pairs{"listener": plo.Name, "backend": plo.Backend}) + continue + } + handler := chp.ConnectionHandler(plo.Protocol) + if handler == nil { + logger.Error("backend does not support requested protocol", + logging.Pairs{"listener": plo.Name, "protocol": plo.Protocol, "backend": plo.Backend}) + continue + } + lg.DrainAndCloseProtocol(listenerName, drainTimeout) + go lg.StartProtocolListener(listenerName, + plo.ListenAddress, plo.ListenPort, plo.ConnectionsLimit, + handler, errorFunc) + } } diff --git a/pkg/frontend/options/options.go b/pkg/frontend/options/options.go index ccb6daed7..0e8fdb113 100644 --- a/pkg/frontend/options/options.go +++ b/pkg/frontend/options/options.go @@ -19,6 +19,7 @@ package options import ( "crypto/tls" "errors" + "slices" "time" "github.com/trickstercache/trickster/v2/pkg/util/pointers" @@ -49,6 +50,9 @@ type Options struct { // ServeTLS indicates whether to listen and serve on the TLS port, meaning // at least one backend options has a valid certificate and key file configured. ServeTLS bool `yaml:"-"` + // ProtocolListeners configures raw TCP listeners for non-HTTP protocols + // (e.g. ClickHouse native binary, MySQL wire protocol). + ProtocolListeners []*ProtocolListenerOptions `yaml:"protocol_listeners,omitempty"` } type TLSConfigFunc func() (*tls.Config, error) @@ -67,7 +71,24 @@ func New() *Options { // Equal returns true if the FrontendConfigs are identical in value. func (o *Options) Equal(o2 *Options) bool { - return *o == *o2 + if o.ListenAddress != o2.ListenAddress || + o.ListenPort != o2.ListenPort || + o.TLSListenAddress != o2.TLSListenAddress || + o.TLSListenPort != o2.TLSListenPort || + o.ConnectionsLimit != o2.ConnectionsLimit || + o.TruncateRequestBodyTooLarge != o2.TruncateRequestBodyTooLarge || + o.ReadHeaderTimeout != o2.ReadHeaderTimeout || + o.ServeTLS != o2.ServeTLS { + return false + } + if (o.MaxRequestBodySizeBytes == nil) != (o2.MaxRequestBodySizeBytes == nil) { + return false + } + if o.MaxRequestBodySizeBytes != nil && *o.MaxRequestBodySizeBytes != *o2.MaxRequestBodySizeBytes { + return false + } + return slices.EqualFunc(o.ProtocolListeners, o2.ProtocolListeners, + func(a, b *ProtocolListenerOptions) bool { return *a == *b }) } // Clone returns a clone of the Options @@ -76,6 +97,12 @@ func (o *Options) Clone() *Options { if o.MaxRequestBodySizeBytes != nil { out.MaxRequestBodySizeBytes = new(*o.MaxRequestBodySizeBytes) } + if len(o.ProtocolListeners) > 0 { + out.ProtocolListeners = make([]*ProtocolListenerOptions, len(o.ProtocolListeners)) + for i, pl := range o.ProtocolListeners { + out.ProtocolListeners[i] = pointers.Clone(pl) + } + } return out } diff --git a/pkg/frontend/options/protocol_listener.go b/pkg/frontend/options/protocol_listener.go new file mode 100644 index 000000000..25110a4f6 --- /dev/null +++ b/pkg/frontend/options/protocol_listener.go @@ -0,0 +1,57 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package options + +import ( + "errors" + "fmt" +) + +// ProtocolListenerOptions configures a raw TCP listener for a non-HTTP +// protocol such as the ClickHouse native binary protocol or MySQL wire +// protocol. +type ProtocolListenerOptions struct { + // Name is a human-readable identifier for this listener (e.g. "clickhouse-native"). + Name string `yaml:"name"` + // Protocol selects the wire protocol handler (e.g. "clickhouse-native", "mysql"). + Protocol string `yaml:"protocol"` + // ListenAddress is the IP address to bind (default: all interfaces). + ListenAddress string `yaml:"listen_address,omitempty"` + // ListenPort is the TCP port to bind. + ListenPort int `yaml:"listen_port"` + // Backend is the name of the configured backend this listener routes to. + Backend string `yaml:"backend"` + // ConnectionsLimit is the maximum number of concurrent connections. + ConnectionsLimit int `yaml:"connections_limit,omitempty"` +} + +// Validate checks that required fields are present. +func (o *ProtocolListenerOptions) Validate() error { + if o.Name == "" { + return errors.New("protocol_listener: name is required") + } + if o.Protocol == "" { + return fmt.Errorf("protocol_listener %q: protocol is required", o.Name) + } + if o.ListenPort < 1 { + return fmt.Errorf("protocol_listener %q: listen_port must be > 0", o.Name) + } + if o.Backend == "" { + return fmt.Errorf("protocol_listener %q: backend is required", o.Name) + } + return nil +} diff --git a/pkg/proxy/connhandler/connhandler.go b/pkg/proxy/connhandler/connhandler.go new file mode 100644 index 000000000..2934cbdee --- /dev/null +++ b/pkg/proxy/connhandler/connhandler.go @@ -0,0 +1,31 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package connhandler defines the ConnectionHandler interface used by +// protocol listeners and backends to handle raw TCP connections for non-HTTP +// protocols. +package connhandler + +import ( + "context" + "net" +) + +// ConnectionHandler handles raw TCP connections for non-HTTP protocols +// such as the ClickHouse native binary protocol or MySQL wire protocol. +type ConnectionHandler interface { + HandleConnection(ctx context.Context, conn net.Conn) error +} diff --git a/pkg/proxy/engines/httpproxy.go b/pkg/proxy/engines/httpproxy.go index e622f8a86..90cc76943 100644 --- a/pkg/proxy/engines/httpproxy.go +++ b/pkg/proxy/engines/httpproxy.go @@ -215,7 +215,13 @@ func PrepareFetchReader(r *http.Request) (io.ReadCloser, *http.Response, int64) r.Header.Set(headers.NameAcceptEncoding, ep.SupportedHeaderVal) } - resp, err := o.HTTPClient.Do(r) + var resp *http.Response + var err error + if o.Fetcher != nil { + resp, err = o.Fetcher(r) + } else { + resp, err = o.HTTPClient.Do(r) + } if err != nil { if rsc == nil || !rsc.Cancelable || !errors.Is(err, context.Canceled) { logger.Error("error downloading url", diff --git a/pkg/proxy/listener/listener.go b/pkg/proxy/listener/listener.go index c5c507028..04c64490b 100644 --- a/pkg/proxy/listener/listener.go +++ b/pkg/proxy/listener/listener.go @@ -108,16 +108,18 @@ func (l *Listener) RouteSwapper() *switcher.SwitchHandler { // Group is a collection of listeners type Group struct { - members map[string]*Listener - listenersLock sync.Mutex - done chan struct{} + members map[string]*Listener + protocolMembers map[string]*ProtocolListener + listenersLock sync.Mutex + done chan struct{} } // NewGroup returns a new Group func NewGroup() *Group { return &Group{ - members: make(map[string]*Listener), - done: make(chan struct{}), + members: make(map[string]*Listener), + protocolMembers: make(map[string]*ProtocolListener), + done: make(chan struct{}), } } @@ -368,9 +370,15 @@ func (lg *Group) WaitForReady(timeout time.Duration) error { listeners = append(listeners, l) } } + protoListeners := make([]*ProtocolListener, 0, len(lg.protocolMembers)) + for _, pl := range lg.protocolMembers { + if pl != nil { + protoListeners = append(protoListeners, pl) + } + } lg.listenersLock.Unlock() - if len(listeners) == 0 { + if len(listeners) == 0 && len(protoListeners) == 0 { return nil } @@ -379,6 +387,9 @@ func (lg *Group) WaitForReady(timeout time.Duration) error { for _, l := range listeners { l.WaitForReady(0) } + for _, pl := range protoListeners { + pl.WaitForReady(0) + } close(done) }() @@ -404,6 +415,10 @@ func (lg *Group) Shutdown(drainWait time.Duration) error { for name := range lg.members { names = append(names, name) } + protoNames := make([]string, 0, len(lg.protocolMembers)) + for name := range lg.protocolMembers { + protoNames = append(protoNames, name) + } lg.listenersLock.Unlock() var firstErr error @@ -412,6 +427,11 @@ func (lg *Group) Shutdown(drainWait time.Duration) error { firstErr = err } } + for _, name := range protoNames { + if err := lg.DrainAndCloseProtocol(name, drainWait); err != nil && firstErr == nil { + firstErr = err + } + } select { case <-lg.done: diff --git a/pkg/proxy/listener/protocol_listener.go b/pkg/proxy/listener/protocol_listener.go new file mode 100644 index 000000000..51dac02ec --- /dev/null +++ b/pkg/proxy/listener/protocol_listener.go @@ -0,0 +1,185 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package listener + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/trickstercache/trickster/v2/pkg/observability/logging" + "github.com/trickstercache/trickster/v2/pkg/observability/logging/logger" + "github.com/trickstercache/trickster/v2/pkg/observability/metrics" + "github.com/trickstercache/trickster/v2/pkg/proxy/connhandler" + "golang.org/x/net/netutil" +) + +// ProtocolListener accepts raw TCP connections and dispatches them to a +// connhandler.ConnectionHandler. It mirrors the lifecycle semantics of the HTTP Listener +// (state machine, readiness signaling, graceful drain) without wrapping an +// http.Server. +type ProtocolListener struct { + listener net.Listener + handler connhandler.ConnectionHandler + + state atomic.Int32 + readyCh chan struct{} + readyOnce sync.Once + + activeConns sync.WaitGroup + cancel context.CancelFunc +} + +// StartProtocolListener creates a raw TCP listener, registers it with the +// Group, and runs the accept loop. It blocks until the listener is closed. +func (lg *Group) StartProtocolListener( + name, address string, port, connectionsLimit int, + handler connhandler.ConnectionHandler, errorFunc func(), +) error { + pl := &ProtocolListener{ + handler: handler, + readyCh: make(chan struct{}), + } + pl.setState(StateStarting) + + ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", address, port)) + if err != nil { + logger.ErrorSynchronous("protocol listener startup failed", + logging.Pairs{"listenerName": name, "detail": err}) + pl.setState(StateStopped) + if errorFunc != nil { + errorFunc() + } + return err + } + if connectionsLimit > 0 { + ln = netutil.LimitListener(ln, connectionsLimit) + } + pl.listener = ln + + logger.Info("protocol listener starting", + logging.Pairs{"listenerName": name, "port": port, "address": address}) + + lg.listenersLock.Lock() + lg.protocolMembers[name] = pl + lg.listenersLock.Unlock() + + pl.setState(StateReady) + pl.readyOnce.Do(func() { close(pl.readyCh) }) + + ctx, cancel := context.WithCancel(context.Background()) + pl.cancel = cancel + + defer func() { + cancel() + pl.activeConns.Wait() + pl.setState(StateStopped) + }() + + for { + conn, err := ln.Accept() + if err != nil { + if ctx.Err() != nil { + return nil + } + logger.Error("protocol listener accept error", + logging.Pairs{"listenerName": name, "detail": err}) + if errorFunc != nil { + errorFunc() + } + return err + } + metrics.ProxyConnectionRequested.Inc() + metrics.ProxyConnectionAccepted.Inc() + metrics.ProxyActiveConnections.Inc() + + pl.activeConns.Add(1) + go func() { + defer func() { + _ = conn.Close() + metrics.ProxyActiveConnections.Dec() + metrics.ProxyConnectionClosed.Inc() + pl.activeConns.Done() + }() + if err := handler.HandleConnection(ctx, conn); err != nil { + logger.Debug("protocol connection handler error", + logging.Pairs{"listenerName": name, "detail": err}) + } + }() + } +} + +// DrainAndCloseProtocol drains and closes the named protocol listener. +func (lg *Group) DrainAndCloseProtocol(name string, drainWait time.Duration) error { + lg.listenersLock.Lock() + pl, ok := lg.protocolMembers[name] + if !ok || pl == nil { + lg.listenersLock.Unlock() + return nil + } + delete(lg.protocolMembers, name) + lg.listenersLock.Unlock() + + pl.setState(StateStopping) + if pl.cancel != nil { + pl.cancel() + } + if pl.listener != nil { + _ = pl.listener.Close() + } + + done := make(chan struct{}) + go func() { + pl.activeConns.Wait() + close(done) + }() + + ctx, cancel := context.WithTimeout(context.Background(), drainWait) + defer cancel() + select { + case <-done: + case <-ctx.Done(): + } + pl.setState(StateStopped) + return nil +} + +func (pl *ProtocolListener) setState(state ListenerState) { + pl.state.Store(int32(state)) +} + +// WaitForReady waits for the protocol listener to become ready. +func (pl *ProtocolListener) WaitForReady(timeout time.Duration) bool { + if pl.readyCh == nil { + return ListenerState(pl.state.Load()) == StateReady + } + if timeout > 0 { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + select { + case <-pl.readyCh: + return true + case <-ctx.Done(): + return false + } + } + <-pl.readyCh + return true +} From 4bb260ab8b23484c2ee8c33cc2a4856756f20a3b Mon Sep 17 00:00:00 2001 From: Chris Randles Date: Sat, 18 Apr 2026 10:47:13 -0400 Subject: [PATCH 2/8] feat: extend ClickHouse parser and type coverage Signed-off-by: Chris Randles --- pkg/backends/clickhouse/lexing.go | 26 +++++ pkg/backends/clickhouse/model/datatypes.go | 18 ++++ pkg/backends/clickhouse/parsing.go | 118 +++++++++++++++++++-- pkg/backends/clickhouse/parsing_test.go | 69 ++++++++++++ pkg/backends/clickhouse/url.go | 27 ++++- pkg/backends/clickhouse/url_test.go | 64 +++++++++++ pkg/parsing/lex/sql/token.go | 8 +- 7 files changed, 318 insertions(+), 12 deletions(-) diff --git a/pkg/backends/clickhouse/lexing.go b/pkg/backends/clickhouse/lexing.go index a21c8a3c9..690deb8a6 100644 --- a/pkg/backends/clickhouse/lexing.go +++ b/pkg/backends/clickhouse/lexing.go @@ -38,12 +38,20 @@ const ( tokenToStartOfInterval tokenToStartOf tokenInterval + tokenDateTrunc + tokenTimeSlot tokenToDateFunc tokenWeek tokenDay tokenHour tokenMinute tokenSecond + tokenMonth + tokenQuarter + tokenYear + tokenMicrosecond + tokenNanosecond + tokenMillisecond ) // tokens for ClickHouse select @@ -69,6 +77,15 @@ var chKey = map[string]token.Typ{ "tostartoftenminutes": tokenToStartOf, "tostartoffifteenminutes": tokenToStartOf, "tostartofinterval": tokenToStartOfInterval, + "tostartofmonth": tokenToStartOf, + "tostartofquarter": tokenToStartOf, + "tostartofyear": tokenToStartOf, + "tostartofmillisecond": tokenToStartOf, + "tostartofsecond": tokenToStartOf, + "tostartofmicrosecond": tokenToStartOf, + "tostartofnanosecond": tokenToStartOf, + "date_trunc": tokenDateTrunc, + "timeslot": tokenTimeSlot, tokenValPreWhere: tokenPreWhere, tokenValFormat: tokenFormat, tokenValUnionAll: tokenUnionAll, @@ -82,7 +99,16 @@ var chKey = map[string]token.Typ{ "minute": tokenMinute, "second": tokenSecond, "todatetime": tokenToDateFunc, + "todatetime64": tokenToDateFunc, "todate": tokenToDateFunc, + "todate32": tokenToDateFunc, + "now64": lsql.TokenNowFunc, + "month": tokenMonth, + "quarter": tokenQuarter, + "year": tokenYear, + "microsecond": tokenMicrosecond, + "nanosecond": tokenNanosecond, + "millisecond": tokenMillisecond, } // LexerOptions returns a Clickhouse-crafted Lexer Options Pointer diff --git a/pkg/backends/clickhouse/model/datatypes.go b/pkg/backends/clickhouse/model/datatypes.go index 38892fcff..157c7d5bf 100644 --- a/pkg/backends/clickhouse/model/datatypes.go +++ b/pkg/backends/clickhouse/model/datatypes.go @@ -18,6 +18,24 @@ package model type DataType int +// ClickHouse type name strings used across marshal/unmarshal. +const ( + TypeUInt8 = "UInt8" + TypeUInt16 = "UInt16" + TypeUInt32 = "UInt32" + TypeUInt64 = "UInt64" + TypeInt8 = "Int8" + TypeInt16 = "Int16" + TypeInt32 = "Int32" + TypeInt64 = "Int64" + TypeFloat32 = "Float32" + TypeFloat64 = "Float64" + TypeDateTime = "DateTime" + TypeDate = "Date" + TypeString = "String" + TypeBool = "Bool" +) + const ( UInt8 = DataType(iota) UInt16 diff --git a/pkg/backends/clickhouse/parsing.go b/pkg/backends/clickhouse/parsing.go index 74a7082f0..3e5ab6b50 100644 --- a/pkg/backends/clickhouse/parsing.go +++ b/pkg/backends/clickhouse/parsing.go @@ -101,6 +101,12 @@ func parse(statement string) (*timeseries.TimeRangeQuery, *timeseries.RequestOpt if t, err = parseWhereTokens(results, trq, ro); err != nil { return returnWithKey(parsing.ParserError(err, t)) } + // Ensure the statement has a FORMAT token. Queries from the official Grafana + // ClickHouse plugin omit FORMAT and rely on default_format=Native. DPC needs + // to control the format via the SQL clause so default_format doesn't override. + if !strings.Contains(trq.Statement, tkFormat) { + trq.Statement = strings.TrimRight(trq.Statement, "; \t\n") + " FORMAT " + tkFormat + } return trq, ro, canObjectCache, nil } @@ -110,6 +116,10 @@ const ( tkTS1 = "<$TS1$>" tkTS2 = "<$TS2$>" tkFormat = "<$FORMAT$>" + + // sdtDate marks the timestamp field as returning ClickHouse Date type + // (not DateTime). Used to wrap interpolated boundaries in toDate(). + sdtDate = "Date" ) const ( @@ -126,14 +136,52 @@ var tokenToStartOfLookup = map[string]time.Duration{ "tostartoffiveminute": time.Minute * 5, "tostartoftenminutes": time.Minute * 10, "tostartoffifteenminutes": time.Minute * 15, + "tostartofmonth": day * 30, + "tostartofquarter": day * 90, + "tostartofyear": day * 365, + "tostartofmillisecond": time.Millisecond, + "tostartofsecond": time.Second, + "tostartofmicrosecond": time.Microsecond, + "tostartofnanosecond": time.Nanosecond, + "timeslot": time.Minute * 30, } var tokenDurations = map[token.Typ]time.Duration{ - tokenWeek: week, - tokenDay: day, - tokenHour: time.Hour, - tokenMinute: time.Minute, - tokenSecond: time.Second, + tokenWeek: week, + tokenDay: day, + tokenHour: time.Hour, + tokenMinute: time.Minute, + tokenSecond: time.Second, + tokenMonth: day * 30, + tokenQuarter: day * 90, + tokenYear: day * 365, + tokenMillisecond: time.Millisecond, + tokenMicrosecond: time.Microsecond, + tokenNanosecond: time.Nanosecond, +} + +var dateReturningFunctions = map[string]bool{ + "tostartofmonth": true, + "tostartofquarter": true, + "tostartofyear": true, +} + +var dateReturningTruncUnits = map[string]bool{ + "month": true, + "quarter": true, + "year": true, +} + +var dateTruncUnits = map[string]time.Duration{ + "year": day * 365, + "quarter": day * 90, + "month": day * 30, + "week": week, + "day": day, + "hour": time.Hour, + "minute": time.Minute, + "second": time.Second, + "millisecond": time.Millisecond, } var supportedFormats = map[string]byte{ @@ -300,7 +348,7 @@ func parseSelectTokens(results ts.Lookup, for _, fieldParts := range st { var prev *token.Token var isTimeSeries, isIntDiv, needStep, checkMultiplier, expectBaseTimeField, - isToStartOfInterval, expectAlias bool + isToStartOfInterval, isDateTrunc, expectAlias bool var d time.Duration var x int var err error @@ -344,6 +392,27 @@ func parseSelectTokens(results ts.Lookup, } goto nextIteration } + // date_trunc('unit', column) — unit is a string literal, column follows + if isDateTrunc { + if t.Typ == token.String { + unit := strings.ToLower(strings.Trim(t.Val, "'\"")) + d, ok := dateTruncUnits[unit] + if !ok { + return t, sqlparser.ErrStepParse + } + trq.Step = d + isTimeSeries = true + foundTimeSeries = true + checkMultiplier = true + if dateReturningTruncUnits[unit] { + trq.TimestampDefinition.SDataType = sdtDate + } + } else if !token.IsComma(t.Typ) && isTimeSeries && ro.BaseTimestampFieldName == "" { + ro.BaseTimestampFieldName = t.Val + isDateTrunc = false + } + goto nextIteration + } if isTimeSeries { if t.Typ == lsql.TokenAs { expectAlias = true @@ -409,6 +478,17 @@ func parseSelectTokens(results ts.Lookup, isToStartOfInterval = true goto nextIteration } + if t.Typ == tokenDateTrunc { + isDateTrunc = true + goto nextIteration + } + if t.Typ == tokenTimeSlot { + isTimeSeries = true + foundTimeSeries = true + expectBaseTimeField = true + trq.Step = time.Minute * 30 + goto nextIteration + } if (prev != nil && prev.Typ == tokenIntDiv && t.Typ == tokenToInt32) || ((prev == nil || prev.Typ == tokenToInt32) && t.Typ == tokenToStartOf) { isTimeSeries = true @@ -422,6 +502,9 @@ func parseSelectTokens(results ts.Lookup, return t, ErrUnsupportedToStartOfFunc } trq.Step = d + if dateReturningFunctions[t.Val] { + trq.TimestampDefinition.SDataType = sdtDate + } } } nextIteration: @@ -477,7 +560,7 @@ func SolveMathExpression(fieldParts token.Tokens, startValue int64, t.Typ == token.Space || t.Typ == lsql.TokenComment { continue } - if t.Typ.IsBreakable() || token.IsLogicalOperator(t.Typ) { + if t.Typ.IsBreakable() || token.IsLogicalOperator(t.Typ) || token.IsComma(t.Typ) { i-- break } @@ -563,14 +646,33 @@ func parseWhereTokens(results ts.Lookup, for n, fieldParts := range wt { var atLowerBound bool var state int + var dateFuncDepth int var i int // c-style iteration, rather than ranging, allows passing a subslice of fieldParts to other funcs for // processing, while also advancing the iterator lfp := len(fieldParts) for i = 0; i < lfp; i++ { t := fieldParts[i] + // Skip date function wrappers like toDateTime64(value, precision). + // Track paren depth so we can skip the trailing precision argument + // after the comma while still parsing the time value normally. + if t.Typ == tokenToDateFunc { + dateFuncDepth = 1 + continue + } if t.Typ == token.LeftParen || t.Typ == token.RightParen || - t.Typ == token.Space || t.Typ == lsql.TokenComment || t.Typ == tokenToDateFunc { + t.Typ == token.Space || t.Typ == lsql.TokenComment { + if dateFuncDepth > 0 && t.Typ == token.RightParen { + dateFuncDepth-- + } + continue + } + // Inside a date function after the time value, skip comma + precision args + if dateFuncDepth > 0 && token.IsComma(t.Typ) { + dateFuncDepth++ + continue + } + if dateFuncDepth > 1 { continue } if t.Typ.IsBreakable() { diff --git a/pkg/backends/clickhouse/parsing_test.go b/pkg/backends/clickhouse/parsing_test.go index 40ae9bf34..4b95de32f 100644 --- a/pkg/backends/clickhouse/parsing_test.go +++ b/pkg/backends/clickhouse/parsing_test.go @@ -617,6 +617,75 @@ func TestSolveMathExpression(t *testing.T) { } } +func TestNewTimeBucketingFunctions(t *testing.T) { + tests := []struct { + name string + query string + step time.Duration + }{ + {"toStartOfMonth", `SELECT toStartOfMonth(datetime) AS t, count() AS cnt FROM tbl WHERE datetime BETWEEN 1589904000 AND 1589997600 GROUP BY t FORMAT JSON`, 24 * time.Hour * 30}, + {"toStartOfQuarter", `SELECT toStartOfQuarter(datetime) AS t, count() AS cnt FROM tbl WHERE datetime BETWEEN 1589904000 AND 1589997600 GROUP BY t FORMAT JSON`, 24 * time.Hour * 90}, + {"toStartOfYear", `SELECT toStartOfYear(datetime) AS t, count() AS cnt FROM tbl WHERE datetime BETWEEN 1589904000 AND 1589997600 GROUP BY t FORMAT JSON`, 24 * time.Hour * 365}, + {"toStartOfSecond", `SELECT toStartOfSecond(datetime) AS t, count() AS cnt FROM tbl WHERE datetime BETWEEN 1589904000 AND 1589997600 GROUP BY t FORMAT JSON`, time.Second}, + {"toStartOfMillisecond", `SELECT toStartOfMillisecond(datetime) AS t, count() AS cnt FROM tbl WHERE datetime BETWEEN 1589904000 AND 1589997600 GROUP BY t FORMAT JSON`, time.Millisecond}, + {"toStartOfMicrosecond", `SELECT toStartOfMicrosecond(datetime) AS t, count() AS cnt FROM tbl WHERE datetime BETWEEN 1589904000 AND 1589997600 GROUP BY t FORMAT JSON`, time.Microsecond}, + {"toStartOfNanosecond", `SELECT toStartOfNanosecond(datetime) AS t, count() AS cnt FROM tbl WHERE datetime BETWEEN 1589904000 AND 1589997600 GROUP BY t FORMAT JSON`, time.Nanosecond}, + {"date_trunc_hour", `SELECT date_trunc('hour', datetime) AS t, count() AS cnt FROM tbl WHERE datetime BETWEEN 1589904000 AND 1589997600 GROUP BY t FORMAT JSON`, time.Hour}, + {"date_trunc_month", `SELECT date_trunc('month', datetime) AS t, count() AS cnt FROM tbl WHERE datetime BETWEEN 1589904000 AND 1589997600 GROUP BY t FORMAT JSON`, 24 * time.Hour * 30}, + {"date_trunc_quarter", `SELECT date_trunc('quarter', datetime) AS t, count() AS cnt FROM tbl WHERE datetime BETWEEN 1589904000 AND 1589997600 GROUP BY t FORMAT JSON`, 24 * time.Hour * 90}, + {"date_trunc_second", `SELECT date_trunc('second', datetime) AS t, count() AS cnt FROM tbl WHERE datetime BETWEEN 1589904000 AND 1589997600 GROUP BY t FORMAT JSON`, time.Second}, + {"toStartOfInterval_month", `SELECT toStartOfInterval(datetime, INTERVAL 1 month) AS t, count() AS cnt FROM tbl WHERE datetime BETWEEN 1589904000 AND 1589997600 GROUP BY t FORMAT JSON`, 24 * time.Hour * 30}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + trq, _, _, err := parse(tc.query) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if trq.Step != tc.step { + t.Errorf("expected step %v, got %v", tc.step, trq.Step) + } + }) + } +} + +func TestToDateTime64InWhere(t *testing.T) { + // Simulates the official Grafana ClickHouse plugin v4+ $__timeFilter expansion. + // Tests both integer and float epoch values inside toDateTime64(). + q := `SELECT toStartOfFiveMinute(datetime) AS t, count() AS cnt FROM tbl ` + + `WHERE datetime >= toDateTime64(1589904000, 3) AND datetime <= toDateTime64(1589997600, 3) ` + + `GROUP BY t FORMAT JSON` + trq, _, _, err := parse(q) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if trq.Step != 5*time.Minute { + t.Errorf("expected step 5m, got %v", trq.Step) + } + if trq.Extent.Start.Unix() != 1589904000 { + t.Errorf("expected start 1589904000, got %d", trq.Extent.Start.Unix()) + } + if trq.Extent.End.Unix() != 1589997600 { + t.Errorf("expected end 1589997600, got %d", trq.Extent.End.Unix()) + } +} + +func TestToDateTime64FloatEpoch(t *testing.T) { + q := `SELECT toStartOfFiveMinute(datetime) AS t, count() AS cnt FROM tbl ` + + `WHERE datetime >= toDateTime64(1589904000.123, 3) AND datetime <= toDateTime64(1589997600.456, 3) ` + + `GROUP BY t FORMAT JSON` + trq, _, _, err := parse(q) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if trq.Extent.Start.Unix() != 1589904000 { + t.Errorf("expected start 1589904000, got %d", trq.Extent.Start.Unix()) + } + if trq.Extent.End.Unix() != 1589997600 { + t.Errorf("expected end 1589997600, got %d", trq.Extent.End.Unix()) + } +} + func TestBadQueries(t *testing.T) { tests := []struct { query string diff --git a/pkg/backends/clickhouse/url.go b/pkg/backends/clickhouse/url.go index 4b4fca879..566d9474e 100644 --- a/pkg/backends/clickhouse/url.go +++ b/pkg/backends/clickhouse/url.go @@ -42,9 +42,10 @@ func (c *Client) SetExtent(r *http.Request, trq *timeseries.TimeRangeQuery, } qi := r.URL.Query() isBody := methods.HasBody(r.Method) - q := interpolateTimeQuery(trq.Statement, trq.TimestampDefinition, extent) + q := interpolateTimeQuery(trq.Statement, trq.TimestampDefinition, extent, qi.Get("default_format")) if isBody { request.SetBody(r, []byte(q)) + r.URL.RawQuery = qi.Encode() } else { qi.Set(upQuery, q) r.URL.RawQuery = qi.Encode() @@ -71,7 +72,7 @@ func formatTimestampValues(dataType timeseries.FieldDataType, extent *timeseries } func interpolateTimeQuery(template string, tfd timeseries.FieldDefinition, - extent *timeseries.Extent, + extent *timeseries.Extent, defaultFormat string, ) string { // tfd.DataType holds the database internal format for the timestamp used // when setting extents @@ -79,12 +80,32 @@ func interpolateTimeQuery(template string, tfd timeseries.FieldDefinition, // ProviderData1 holds the format of a secondary time field tStart, tEnd := formatTimestampValues(timeseries.FieldDataType(tfd.ProviderData1), extent) + // toStartOfMonth/Quarter/Year and toDate return Date type, not DateTime. + // ClickHouse requires Date comparisons to use toDate() wrapped values. + if tfd.SDataType == sdtDate { + start = "toDate(" + start + ")" + end = "toDate(" + end + ")" + } trange := fmt.Sprintf("%s BETWEEN %s AND %s", tfd.Name, start, end) + // When the client provides default_format=Native via URL params, ClickHouse + // uses that format regardless of any FORMAT clause. Remove the FORMAT token + // from the SQL so ClickHouse doesn't see a conflicting clause. Otherwise, + // inject TSVWithNamesAndTypes as the DPC wire format. + wireFormat := "TSVWithNamesAndTypes" + formatReplacement := wireFormat + if strings.EqualFold(defaultFormat, "Native") { + formatReplacement = "" + } + // Clean up "FORMAT " prefix if the replacement is empty + if formatReplacement == "" { + template = strings.Replace(template, " FORMAT "+tkFormat, "", 1) + template = strings.Replace(template, "FORMAT "+tkFormat, "", 1) + } out := strings.NewReplacer( tkRange, trange, tkTS1, tStart, tkTS2, tEnd, - tkFormat, "TSVWithNamesAndTypes", + tkFormat, formatReplacement, ).Replace(template) return out } diff --git a/pkg/backends/clickhouse/url_test.go b/pkg/backends/clickhouse/url_test.go index 9de7af0bb..365b234a3 100644 --- a/pkg/backends/clickhouse/url_test.go +++ b/pkg/backends/clickhouse/url_test.go @@ -20,6 +20,7 @@ import ( "fmt" "net/http" "net/url" + "strings" "testing" "time" @@ -56,3 +57,66 @@ func TestSetExtent(t *testing.T) { t.Errorf("\nexpected [%s]\ngot [%s]", expected, r.URL.RawQuery) } } + +func TestSetExtentNativeFormatPreserved(t *testing.T) { + client := &Client{} + start := time.Unix(1589904000, 0) + end := time.Unix(1589997600, 0) + e := ×eries.Extent{Start: start, End: end} + + trq := ×eries.TimeRangeQuery{ + Statement: `SELECT toStartOfFiveMinute(datetime) AS t, count() AS cnt FROM tbl WHERE datetime BETWEEN <$TS1$> AND <$TS2$> GROUP BY t ORDER BY t FORMAT <$FORMAT$>`, + } + + // GET request with default_format=Native (official Grafana plugin style) + tu := &url.URL{RawQuery: url.Values{ + "query": {trq.Statement}, + "default_format": {"Native"}, + "client_protocol_version": {"54460"}, + "database": {"default"}, + }.Encode()} + r, _ := http.NewRequest(http.MethodGet, tu.String(), nil) + r.URL = tu + + client.SetExtent(r, trq, e) + + q := r.URL.Query() + if !q.Has("default_format") { + t.Error("expected default_format to be preserved in URL params") + } + if !q.Has("database") { + t.Error("expected database param to be preserved") + } + // FORMAT clause should be removed from SQL when default_format=Native + sql := q.Get("query") + if strings.Contains(sql, "FORMAT") { + t.Errorf("expected FORMAT clause to be removed when default_format=Native, got: %s", sql) + } +} + +func TestSetExtentTSVFormatInjected(t *testing.T) { + client := &Client{} + start := time.Unix(1589904000, 0) + end := time.Unix(1589997600, 0) + e := ×eries.Extent{Start: start, End: end} + + trq := ×eries.TimeRangeQuery{ + Statement: `SELECT toStartOfFiveMinute(datetime) AS t, count() AS cnt FROM tbl WHERE datetime BETWEEN <$TS1$> AND <$TS2$> GROUP BY t ORDER BY t FORMAT <$FORMAT$>`, + } + + // GET request WITHOUT default_format (standard TSV path) + tu := &url.URL{RawQuery: url.Values{ + "query": {trq.Statement}, + "database": {"default"}, + }.Encode()} + r, _ := http.NewRequest(http.MethodGet, tu.String(), nil) + r.URL = tu + + client.SetExtent(r, trq, e) + + q := r.URL.Query() + sql := q.Get("query") + if !strings.Contains(sql, "TSVWithNamesAndTypes") { + t.Errorf("expected FORMAT TSVWithNamesAndTypes when no default_format, got: %s", sql) + } +} diff --git a/pkg/parsing/lex/sql/token.go b/pkg/parsing/lex/sql/token.go index b60be41cd..25311b9d4 100644 --- a/pkg/parsing/lex/sql/token.go +++ b/pkg/parsing/lex/sql/token.go @@ -17,6 +17,7 @@ package sql import ( + "strconv" "time" "github.com/trickstercache/trickster/v2/pkg/parsing/token" @@ -201,7 +202,12 @@ func TokenToTime(i *token.Token) (time.Time, byte, error) { case token.Number: n, err := i.Int64() if err != nil { - return t, 0, err + // Handle float epoch values like 1589904000.000 from toDateTime64() + f, ferr := strconv.ParseFloat(i.Val, 64) + if ferr != nil { + return t, 0, err + } + n = int64(f) } if n > 999999999999999999 { return time.Unix(0, n), TimeFormatUnixNano, nil From 73280711f98d2cc123046506ec3a61ff8336d58a Mon Sep 17 00:00:00 2001 From: Chris Randles Date: Sat, 18 Apr 2026 10:47:28 -0400 Subject: [PATCH 3/8] feat: add Native DPC modeler, DPC fallback warning, official Grafana ClickHouse plugin Signed-off-by: Chris Randles --- .../trickster-clickhouse-grafana.json | 183 +++++++++++ .../provisioning/datasources/datasource.yaml | 28 +- docs/developer/environment/docker-compose.yml | 2 +- integration/clickhouse_test.go | 56 ++++ pkg/backends/clickhouse/model/marshal.go | 2 + .../clickhouse/model/marshal_native.go | 285 +++++++++++++++++ pkg/backends/clickhouse/model/model.go | 7 +- pkg/backends/clickhouse/model/unmarshal.go | 16 + .../clickhouse/model/unmarshal_native.go | 297 ++++++++++++++++++ pkg/proxy/engines/deltaproxycache.go | 32 +- pkg/timeseries/format_hint.go | 32 ++ 11 files changed, 932 insertions(+), 8 deletions(-) create mode 100644 docs/developer/environment/docker-compose-data/dashboards/trickster-clickhouse-grafana.json create mode 100644 pkg/backends/clickhouse/model/marshal_native.go create mode 100644 pkg/backends/clickhouse/model/unmarshal_native.go create mode 100644 pkg/timeseries/format_hint.go diff --git a/docs/developer/environment/docker-compose-data/dashboards/trickster-clickhouse-grafana.json b/docs/developer/environment/docker-compose-data/dashboards/trickster-clickhouse-grafana.json new file mode 100644 index 000000000..33a0ef6ff --- /dev/null +++ b/docs/developer/environment/docker-compose-data/dashboards/trickster-clickhouse-grafana.json @@ -0,0 +1,183 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "links": [], + "panels": [ + { + "datasource": { + "type": "grafana-clickhouse-datasource", + "uid": "${datasource}" + }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { "legend": false, "tooltip": false, "viz": false }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { "type": "linear" }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { "group": "A", "mode": "none" }, + "thresholdsStyle": { "mode": "off" } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { "color": "green" }, + { "color": "red", "value": 80 } + ] + } + }, + "overrides": [] + }, + "gridPos": { "h": 22, "w": 12, "x": 0, "y": 0 }, + "id": 1, + "options": { + "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, + "tooltip": { "hideZeros": false, "mode": "single", "sort": "none" } + }, + "targets": [ + { + "datasource": { + "type": "grafana-clickhouse-datasource", + "uid": "${datasource}" + }, + "editorType": "sql", + "format": 1, + "meta": { "builderOptions": { "database": "default", "table": "trips" } }, + "queryType": "sql", + "rawSql": "SELECT toStartOfFiveMinute(pickup_datetime) AS t, count() AS Count FROM default.trips WHERE $__timeFilter(pickup_datetime) GROUP BY t ORDER BY t", + "refId": "A" + } + ], + "title": "# Trips Over Time", + "type": "timeseries" + }, + { + "datasource": { + "type": "grafana-clickhouse-datasource", + "uid": "${datasource}" + }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { "legend": false, "tooltip": false, "viz": false }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { "type": "linear" }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { "group": "A", "mode": "none" }, + "thresholdsStyle": { "mode": "off" } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { "color": "green" }, + { "color": "red", "value": 80 } + ] + } + }, + "overrides": [] + }, + "gridPos": { "h": 22, "w": 12, "x": 12, "y": 0 }, + "id": 2, + "options": { + "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, + "tooltip": { "hideZeros": false, "mode": "single", "sort": "none" } + }, + "targets": [ + { + "datasource": { + "type": "grafana-clickhouse-datasource", + "uid": "${datasource}" + }, + "editorType": "sql", + "format": 1, + "meta": { "builderOptions": { "database": "default", "table": "trips" } }, + "queryType": "sql", + "rawSql": "WITH sumIf(trip_id, payment_type = 'CSH') AS cash_trips, sumIf(trip_id, payment_type != 'CSH') AS non_cash_trips SELECT toStartOfFiveMinute(pickup_datetime) AS t, non_cash_trips / (cash_trips + non_cash_trips) * 100 AS \"Card Use Rate\" FROM default.trips WHERE $__timeFilter(pickup_datetime) GROUP BY t ORDER BY t", + "refId": "A" + } + ], + "title": "% Trips Paid w/ Credit Card", + "type": "timeseries" + } + ], + "preload": false, + "refresh": "5s", + "schemaVersion": 41, + "tags": [], + "templating": { + "list": [ + { + "allowCustomValue": false, + "current": { + "text": "clickhouse-grafana-direct", + "value": "clickhouse-grafana-direct" + }, + "description": "", + "name": "datasource", + "options": [], + "query": "grafana-clickhouse-datasource", + "refresh": 1, + "regex": "", + "type": "datasource" + } + ] + }, + "time": { + "from": "now-15m", + "to": "now" + }, + "timepicker": {}, + "timezone": "browser", + "title": "ClickHouse (Grafana Plugin)", + "uid": "clickhouse-grafana-plugin", + "version": 1 +} diff --git a/docs/developer/environment/docker-compose-data/grafana-config/provisioning/datasources/datasource.yaml b/docs/developer/environment/docker-compose-data/grafana-config/provisioning/datasources/datasource.yaml index e567c6f5c..e17302341 100644 --- a/docs/developer/environment/docker-compose-data/grafana-config/provisioning/datasources/datasource.yaml +++ b/docs/developer/environment/docker-compose-data/grafana-config/provisioning/datasources/datasource.yaml @@ -230,4 +230,30 @@ datasources: defaultDatabase: default username: default password: "" - usePOST: true \ No newline at end of file + usePOST: true + +# ClickHouse (Official Grafana Plugin) +- name: clickhouse-grafana-direct + type: grafana-clickhouse-datasource + access: proxy + editable: true + isDefault: false + jsonData: + host: clickhouse + port: 8123 + protocol: http + defaultDatabase: default + username: default + +- name: clickhouse-grafana-trickster + type: grafana-clickhouse-datasource + access: proxy + editable: true + isDefault: false + jsonData: + host: host.docker.internal + port: 8480 + path: /click1/ + protocol: http + defaultDatabase: default + username: default \ No newline at end of file diff --git a/docs/developer/environment/docker-compose.yml b/docs/developer/environment/docker-compose.yml index 5e91ee446..0b57e68a5 100644 --- a/docs/developer/environment/docker-compose.yml +++ b/docs/developer/environment/docker-compose.yml @@ -22,7 +22,7 @@ services: image: grafana/grafana:11.6.0 user: nobody environment: - - GF_INSTALL_PLUGINS=vertamedia-clickhouse-datasource + - GF_INSTALL_PLUGINS=vertamedia-clickhouse-datasource,grafana-clickhouse-datasource volumes: - ./docker-compose-data/grafana-config:/etc/grafana - ./docker-compose-data/dashboards:/var/lib/grafana/dashboards diff --git a/integration/clickhouse_test.go b/integration/clickhouse_test.go index 57f47dedc..5d6305ffc 100644 --- a/integration/clickhouse_test.go +++ b/integration/clickhouse_test.go @@ -98,6 +98,58 @@ func TestClickHouse(t *testing.T) { "multi-line SQL must reach DeltaProxyCache (issue #967)") }) + t.Run("grafana_official_plugin_native", func(t *testing.T) { + // The official Grafana ClickHouse plugin sends default_format=Native + // and client_protocol_version=54460 as URL params, with NO FORMAT in + // the SQL. This triggers TCP-style Native responses (block info + + // customSerialization flags). + now := time.Now().Unix() + q := fmt.Sprintf( + "SELECT toStartOfFiveMinute(pickup_datetime) AS t, count() AS cnt "+ + "FROM trips "+ + "WHERE pickup_datetime BETWEEN toDateTime(%d) AND toDateTime(%d) "+ + "GROUP BY t ORDER BY t", + now-3600, now, + ) + params := url.Values{ + "query": {q}, + "default_format": {"Native"}, + "client_protocol_version": {"54460"}, + "database": {"default"}, + } + u := "http://" + clickAddr + "/click1/?" + params.Encode() + + // First request — should go through DPC and return Native binary + resp, err := http.Get(u) + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode, "unexpected status: %s", string(body)) + require.Greater(t, len(body), 10, "expected Native binary response with data") + + hdr := parseTricksterResult(resp.Header.Get("X-Trickster-Result")) + require.Equal(t, "DeltaProxyCache", hdr["engine"]) + + // Verify the response is Native binary — first byte is numCols (uvarint), + // or block info field 1 (0x01) if TCP-style. Either way, it should be + // a small positive number, not a printable ASCII character. + require.Less(t, body[0], byte(0x20), "expected Native binary, got text (byte 0x%02x)", body[0]) + + // Second request — should hit DPC cache + resp2, err := http.Get(u) + require.NoError(t, err) + body2, err := io.ReadAll(resp2.Body) + resp2.Body.Close() + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp2.StatusCode) + require.Greater(t, len(body2), 10, "expected cached Native response with data") + + hdr2 := parseTricksterResult(resp2.Header.Get("X-Trickster-Result")) + require.Contains(t, []string{"hit", "phit"}, hdr2["status"], + "second request should hit the cache, got %s", hdr2["status"]) + }) + aggCases := []struct { name string group string @@ -105,6 +157,10 @@ func TestClickHouse(t *testing.T) { {"five_minute", "toStartOfFiveMinute(pickup_datetime)"}, {"fifteen_minute", "toStartOfInterval(pickup_datetime, INTERVAL 15 MINUTE)"}, {"one_hour", "toStartOfHour(pickup_datetime)"}, + {"one_day", "toStartOfDay(pickup_datetime)"}, + {"one_month", "toStartOfMonth(pickup_datetime)"}, + {"date_trunc_hour", "date_trunc('hour', pickup_datetime)"}, + {"date_trunc_day", "date_trunc('day', pickup_datetime)"}, } for _, tc := range aggCases { t.Run("aggregation_"+tc.name, func(t *testing.T) { diff --git a/pkg/backends/clickhouse/model/marshal.go b/pkg/backends/clickhouse/model/marshal.go index 4bbc0dd32..bfdc91f75 100644 --- a/pkg/backends/clickhouse/model/marshal.go +++ b/pkg/backends/clickhouse/model/marshal.go @@ -62,6 +62,8 @@ func MarshalTimeseriesWriter(ts timeseries.Timeseries, return marshalTimeseriesXSV(w, ds, rlo, true, false, '\t') case 5: return marshalTimeseriesXSV(w, ds, rlo, true, true, '\t') + case OutputFormatNative: + return marshalTimeseriesNative(w, ds, rlo) } return timeseries.ErrUnknownFormat } diff --git a/pkg/backends/clickhouse/model/marshal_native.go b/pkg/backends/clickhouse/model/marshal_native.go new file mode 100644 index 000000000..79f065e50 --- /dev/null +++ b/pkg/backends/clickhouse/model/marshal_native.go @@ -0,0 +1,285 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + "time" + + lsql "github.com/trickstercache/trickster/v2/pkg/parsing/lex/sql" + "github.com/trickstercache/trickster/v2/pkg/timeseries" + "github.com/trickstercache/trickster/v2/pkg/timeseries/dataset" + "github.com/trickstercache/trickster/v2/pkg/timeseries/epoch" +) + +// OutputFormatNative is the output format index for ClickHouse Native binary. +const OutputFormatNative byte = 6 + +// marshalTimeseriesNative writes a DataSet as a ClickHouse Native binary block. +func marshalTimeseriesNative(w io.Writer, ds *dataset.DataSet, rlo *timeseries.RequestOptions) error { + // Collect all points across all series into row-major format + type colDef struct { + name string + typ string + } + + if len(ds.Results) == 0 || len(ds.Results[0].SeriesList) == 0 { + return writeEmptyNativeBlock(w) + } + + res := ds.Results[0] + sh := res.SeriesList[0].Header + + // Build column definitions from the series header + var cols []colDef + cols = append(cols, colDef{name: sh.TimestampField.Name, typ: sh.TimestampField.SDataType}) + for _, tf := range sh.TagFieldsList { + cols = append(cols, colDef{name: tf.Name, typ: tf.SDataType}) + } + for _, vf := range sh.ValueFieldsList { + cols = append(cols, colDef{name: vf.Name, typ: vf.SDataType}) + } + + // Collect all rows + var rows [][]string + for _, series := range res.SeriesList { + for _, pt := range series.Points { + row := make([]string, len(cols)) + // Timestamp + row[0] = formatEpochForType(pt.Epoch, sh.TimestampField) + // Tags + for i, tf := range sh.TagFieldsList { + row[1+i] = series.Header.Tags[tf.Name] + } + // Values + for i, v := range pt.Values { + row[1+len(sh.TagFieldsList)+i] = fmt.Sprint(v) + } + rows = append(rows, row) + } + } + + numCols := uint64(len(cols)) + numRows := uint64(len(rows)) + + // Write TCP-style block info so clickhouse-go clients (including the + // official Grafana plugin) can parse the response. + if err := writeNativeBlockInfo(w); err != nil { + return err + } + + if err := writeUvarint(w, numCols); err != nil { + return err + } + if err := writeUvarint(w, numRows); err != nil { + return err + } + + for c, col := range cols { + if err := writeNativeString(w, col.name); err != nil { + return err + } + if err := writeNativeString(w, col.typ); err != nil { + return err + } + // customSerialization flag (matches TCP-style protocol) + if _, err := w.Write([]byte{0}); err != nil { + return err + } + for r := range numRows { + if err := writeNativeValue(w, col.typ, rows[r][c]); err != nil { + return fmt.Errorf("marshal native col %d row %d: %w", c, r, err) + } + } + } + + return nil +} + +func writeEmptyNativeBlock(w io.Writer) error { + if err := writeNativeBlockInfo(w); err != nil { + return err + } + if err := writeUvarint(w, 0); err != nil { + return err + } + return writeUvarint(w, 0) +} + +func writeNativeBlockInfo(w io.Writer) error { + // field 1: is_overflows = 0 + if err := writeUvarint(w, 1); err != nil { + return err + } + if _, err := w.Write([]byte{0}); err != nil { + return err + } + // field 2: bucket_num = -1 + if err := writeUvarint(w, 2); err != nil { + return err + } + var b [4]byte + binary.LittleEndian.PutUint32(b[:], uint32(0xFFFFFFFF)) //nolint:gosec + if _, err := w.Write(b[:]); err != nil { + return err + } + // end marker + return writeUvarint(w, 0) +} + +func writeUvarint(w io.Writer, v uint64) error { + var buf [binary.MaxVarintLen64]byte + n := binary.PutUvarint(buf[:], v) + _, err := w.Write(buf[:n]) + return err +} + +func writeNativeString(w io.Writer, s string) error { + if err := writeUvarint(w, uint64(len(s))); err != nil { + return err + } + if len(s) > 0 { + _, err := w.Write([]byte(s)) + return err + } + return nil +} + +//nolint:gosec // intentional wire protocol conversions +func writeNativeValue(w io.Writer, typ, val string) error { + switch typ { + case TypeUInt8, TypeBool: + v, _ := strconv.ParseUint(val, 10, 8) + _, err := w.Write([]byte{byte(v)}) + return err + case TypeUInt16: + v, _ := strconv.ParseUint(val, 10, 16) + var b [2]byte + binary.LittleEndian.PutUint16(b[:], uint16(v)) + _, err := w.Write(b[:]) + return err + case TypeUInt32: + v, _ := strconv.ParseUint(val, 10, 32) + var b [4]byte + binary.LittleEndian.PutUint32(b[:], uint32(v)) + _, err := w.Write(b[:]) + return err + case TypeUInt64: + v, _ := strconv.ParseUint(val, 10, 64) + var b [8]byte + binary.LittleEndian.PutUint64(b[:], v) + _, err := w.Write(b[:]) + return err + case TypeInt8: + v, _ := strconv.ParseInt(val, 10, 8) + _, err := w.Write([]byte{byte(v)}) + return err + case TypeInt16: + v, _ := strconv.ParseInt(val, 10, 16) + var b [2]byte + binary.LittleEndian.PutUint16(b[:], uint16(v)) + _, err := w.Write(b[:]) + return err + case TypeInt32: + v, _ := strconv.ParseInt(val, 10, 32) + var b [4]byte + binary.LittleEndian.PutUint32(b[:], uint32(v)) + _, err := w.Write(b[:]) + return err + case TypeInt64: + v, _ := strconv.ParseInt(val, 10, 64) + var b [8]byte + binary.LittleEndian.PutUint64(b[:], uint64(v)) + _, err := w.Write(b[:]) + return err + case TypeFloat32: + v, _ := strconv.ParseFloat(val, 32) + var b [4]byte + binary.LittleEndian.PutUint32(b[:], math.Float32bits(float32(v))) + _, err := w.Write(b[:]) + return err + case TypeFloat64: + v, _ := strconv.ParseFloat(val, 64) + var b [8]byte + binary.LittleEndian.PutUint64(b[:], math.Float64bits(v)) + _, err := w.Write(b[:]) + return err + case TypeDateTime: + t, err := time.Parse(lsql.SQLDateTimeLayout, val) + if err != nil { + // Try as epoch seconds + v, _ := strconv.ParseInt(val, 10, 64) + t = time.Unix(v, 0) + } + var b [4]byte + binary.LittleEndian.PutUint32(b[:], uint32(t.Unix())) + _, err = w.Write(b[:]) + return err + case TypeDate: + t, err := time.Parse("2006-01-02", val) + if err != nil { + return writeNativeString(w, val) + } + epoch := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) + days := uint16(t.Sub(epoch).Hours() / 24) + var b [2]byte + binary.LittleEndian.PutUint16(b[:], days) + _, err = w.Write(b[:]) + return err + case TypeString: + return writeNativeString(w, val) + default: + // DateTime64 variants + if strings.HasPrefix(typ, "DateTime64") { + t, err := parseClickHouseTimestamp("", val) + if err != nil { + v, _ := strconv.ParseInt(val, 10, 64) + var b [8]byte + binary.LittleEndian.PutUint64(b[:], uint64(v)) + _, err = w.Write(b[:]) + return err + } + var b [8]byte + binary.LittleEndian.PutUint64(b[:], uint64(t.UnixMilli())) + _, err = w.Write(b[:]) + return err + } + return writeNativeString(w, val) + } +} + +func formatEpochForType(ep epoch.Epoch, tfd timeseries.FieldDefinition) string { + nanos := int64(ep) + t := time.Unix(nanos/1e9, nanos%1e9).UTC() + switch tfd.SDataType { + case TypeDateTime: + return t.Format(lsql.SQLDateTimeLayout) + case TypeDate: + return t.Format("2006-01-02") + default: + if strings.HasPrefix(tfd.SDataType, "DateTime64") { + return t.Format("2006-01-02 15:04:05.000") + } + // Default: epoch seconds as string + return strconv.FormatInt(t.Unix(), 10) + } +} diff --git a/pkg/backends/clickhouse/model/model.go b/pkg/backends/clickhouse/model/model.go index 8378db5b0..af69293be 100644 --- a/pkg/backends/clickhouse/model/model.go +++ b/pkg/backends/clickhouse/model/model.go @@ -23,13 +23,14 @@ import ( const formatHeader = "X-Clickhouse-Format" -// NewModeler returns a collection of modeling functions for clickhouse interoperability +// NewModeler returns a collection of modeling functions for clickhouse interoperability. +// The wire unmarshaler auto-detects Native vs TSV format. func NewModeler() *timeseries.Modeler { return ×eries.Modeler{ - WireUnmarshalerReader: UnmarshalTimeseriesReader, + WireUnmarshalerReader: UnmarshalTimeseriesAutoReader, WireMarshaler: MarshalTimeseries, WireMarshalWriter: MarshalTimeseriesWriter, - WireUnmarshaler: UnmarshalTimeseries, + WireUnmarshaler: UnmarshalTimeseriesAuto, CacheMarshaler: dataset.MarshalDataSet, CacheUnmarshaler: dataset.UnmarshalDataSet, } diff --git a/pkg/backends/clickhouse/model/unmarshal.go b/pkg/backends/clickhouse/model/unmarshal.go index 4daf7f6f0..2f1646d8a 100644 --- a/pkg/backends/clickhouse/model/unmarshal.go +++ b/pkg/backends/clickhouse/model/unmarshal.go @@ -42,6 +42,22 @@ func UnmarshalTimeseries(data []byte, trq *timeseries.TimeRangeQuery) (timeserie return UnmarshalTimeseriesReader(buf, trq) } +// UnmarshalTimeseriesAuto dispatches to Native or TSV based on data content. +// Without header hints, falls back to TSV (the default DPC wire format). +func UnmarshalTimeseriesAuto(data []byte, trq *timeseries.TimeRangeQuery) (timeseries.Timeseries, error) { + return UnmarshalTimeseries(data, trq) +} + +// UnmarshalTimeseriesAutoReader auto-detects Native vs TSV from an io.Reader. +// It checks the FormatHintReader wrapper set by the DPC engine from the +// X-ClickHouse-Format response header. +func UnmarshalTimeseriesAutoReader(reader io.Reader, trq *timeseries.TimeRangeQuery) (timeseries.Timeseries, error) { + if hr, ok := reader.(*timeseries.FormatHintReader); ok && strings.EqualFold(hr.Format, "Native") { + return UnmarshalTimeseriesNativeReader(hr.Reader, trq) + } + return UnmarshalTimeseriesReader(reader, trq) +} + // parser is safe for concurrency var parser = dcsv.NewParserMust(buildFieldDefinitions, typeToFieldDataType, parseTimeField, dataStartRow) diff --git a/pkg/backends/clickhouse/model/unmarshal_native.go b/pkg/backends/clickhouse/model/unmarshal_native.go new file mode 100644 index 000000000..1e5abd704 --- /dev/null +++ b/pkg/backends/clickhouse/model/unmarshal_native.go @@ -0,0 +1,297 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "bufio" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "time" + + "github.com/trickstercache/trickster/v2/pkg/timeseries" + dcsv "github.com/trickstercache/trickster/v2/pkg/timeseries/dataset/csv" +) + +// nativeParser uses the same CSV→DataSet pipeline but with rows built from +// Native binary blocks instead of TSV text. +var nativeParser = dcsv.NewParserMust(buildFieldDefinitions, typeToFieldDataType, + parseTimeField, dataStartRow) + +// UnmarshalTimeseriesNative decodes a ClickHouse Native binary response into +// a Timeseries (DataSet). The Native format is: block info, then per-column +// (name, type, data[numRows]). We decode the columnar data into rows and +// feed the same CSV→DataSet pipeline used by TSV. +func UnmarshalTimeseriesNative(data []byte, trq *timeseries.TimeRangeQuery) (timeseries.Timeseries, error) { + return UnmarshalTimeseriesNativeReader(bufio.NewReader( + io.NopCloser(&nopReader{data, 0})), trq) +} + +type nopReader struct { + data []byte + pos int +} + +func (r *nopReader) Read(p []byte) (int, error) { + if r.pos >= len(r.data) { + return 0, io.EOF + } + n := copy(p, r.data[r.pos:]) + r.pos += n + return n, nil +} + +// UnmarshalTimeseriesNativeReader decodes a ClickHouse Native binary response +// from an io.Reader into a Timeseries (DataSet). ClickHouse sends multiple +// blocks for large result sets; we read all of them. +func UnmarshalTimeseriesNativeReader(reader io.Reader, trq *timeseries.TimeRangeQuery) (timeseries.Timeseries, error) { + br := bufio.NewReaderSize(reader, 128*1024) + + var colNames, colTypes []string + var allRows [][]string + var hasBlockInfo bool + + for { + // Peek to detect block info header. When client_protocol_version is + // present, ClickHouse includes TCP-style block info (starts with + // field_num 1 = is_overflows). Without it, the response starts + // directly with numCols (uvarint > 1 for any useful query). + peek, err := br.Peek(1) + if err != nil { + break + } + if peek[0] == 1 { + hasBlockInfo = true + if err := skipBlockInfo(br); err != nil { + break + } + } + + numCols, err := readUvarint(br) + if err != nil { + break + } + numRows, err := readUvarint(br) + if err != nil { + break + } + if numCols == 0 || numRows == 0 { + break + } + + blockColNames := make([]string, numCols) + blockColTypes := make([]string, numCols) + columns := make([][]string, numCols) + + for c := range numCols { + name, err := readString(br) + if err != nil { + return nil, fmt.Errorf("native: column %d name: %w", c, err) + } + blockColNames[c] = name + + typ, err := readString(br) + if err != nil { + return nil, fmt.Errorf("native: column %d type: %w", c, err) + } + blockColTypes[c] = typ + + // When block info is present (TCP-style format via + // client_protocol_version), a customSerialization bool follows + // each column type header. + if hasBlockInfo { + if _, err := br.ReadByte(); err != nil { + return nil, fmt.Errorf("native: column %d custom serialization flag: %w", c, err) + } + } + + vals := make([]string, numRows) + for r := range numRows { + vals[r], err = readValueAsString(br, typ) + if err != nil { + return nil, fmt.Errorf("native: column %d row %d: %w", c, r, err) + } + } + columns[c] = vals + } + + if colNames == nil { + colNames = blockColNames + colTypes = blockColTypes + } + + for r := range numRows { + row := make([]string, numCols) + for c := range numCols { + row[c] = columns[c][r] + } + allRows = append(allRows, row) + } + } + + if colNames == nil || len(allRows) == 0 { + return nil, timeseries.ErrInvalidBody + } + + rows := make([][]string, dataStartRow+len(allRows)) + rows[0] = colNames + rows[1] = colTypes + copy(rows[dataStartRow:], allRows) + + return nativeParser.ToDataSet(rows, trq) +} + +func skipBlockInfo(r *bufio.Reader) error { + for { + fieldNum, err := readUvarint(r) + if err != nil { + return err + } + if fieldNum == 0 { + break + } + switch fieldNum { + case 1: + _, err = r.ReadByte() + case 2: + _, err = readFixed(r, 4) + default: + return fmt.Errorf("unknown block info field: %d", fieldNum) + } + if err != nil { + return err + } + } + return nil +} + +// ---------- Native binary reading primitives ---------- + +func readUvarint(r io.ByteReader) (uint64, error) { + return binary.ReadUvarint(r) +} + +func readString(r *bufio.Reader) (string, error) { + n, err := readUvarint(r) + if err != nil { + return "", err + } + if n == 0 { + return "", nil + } + b := make([]byte, n) + _, err = io.ReadFull(r, b) + return string(b), err +} + +func readFixed(r io.Reader, n int) ([]byte, error) { + b := make([]byte, n) + _, err := io.ReadFull(r, b) + return b, err +} + +//nolint:gosec // intentional wire protocol conversions +func readValueAsString(r *bufio.Reader, typ string) (string, error) { + switch typ { + case TypeUInt8, TypeBool: + b, err := r.ReadByte() + return strconv.FormatUint(uint64(b), 10), err + case TypeUInt16: + b, err := readFixed(r, 2) + if err != nil { + return "", err + } + return strconv.FormatUint(uint64(binary.LittleEndian.Uint16(b)), 10), nil + case TypeUInt32: + b, err := readFixed(r, 4) + if err != nil { + return "", err + } + return strconv.FormatUint(uint64(binary.LittleEndian.Uint32(b)), 10), nil + case TypeUInt64: + b, err := readFixed(r, 8) + if err != nil { + return "", err + } + return strconv.FormatUint(binary.LittleEndian.Uint64(b), 10), nil + case TypeInt8: + b, err := r.ReadByte() + return strconv.Itoa(int(int8(b))), err + case TypeInt16: + b, err := readFixed(r, 2) + if err != nil { + return "", err + } + return strconv.Itoa(int(int16(binary.LittleEndian.Uint16(b)))), nil + case TypeInt32: + b, err := readFixed(r, 4) + if err != nil { + return "", err + } + return strconv.Itoa(int(int32(binary.LittleEndian.Uint32(b)))), nil + case TypeInt64: + b, err := readFixed(r, 8) + if err != nil { + return "", err + } + return strconv.FormatInt(int64(binary.LittleEndian.Uint64(b)), 10), nil + case TypeFloat32: + b, err := readFixed(r, 4) + if err != nil { + return "", err + } + return fmt.Sprint(math.Float32frombits(binary.LittleEndian.Uint32(b))), nil + case TypeFloat64: + b, err := readFixed(r, 8) + if err != nil { + return "", err + } + return fmt.Sprint(math.Float64frombits(binary.LittleEndian.Uint64(b))), nil + case TypeDateTime: + b, err := readFixed(r, 4) + if err != nil { + return "", err + } + ts := binary.LittleEndian.Uint32(b) + return time.Unix(int64(ts), 0).UTC().Format("2006-01-02 15:04:05"), nil + case TypeDate: + b, err := readFixed(r, 2) + if err != nil { + return "", err + } + days := binary.LittleEndian.Uint16(b) + t := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC).AddDate(0, 0, int(days)) + return t.Format("2006-01-02"), nil + case TypeString: + return readString(r) + default: + // For complex/unknown types, try reading as string + if len(typ) > 8 && typ[:8] == TypeDateTime { + // DateTime64, DateTime('tz'), etc. — 8 bytes + b, err := readFixed(r, 8) + if err != nil { + return "", err + } + v := int64(binary.LittleEndian.Uint64(b)) + // DateTime64 with default precision 3 (milliseconds) + return time.Unix(v/1000, (v%1000)*1000000).UTC().Format("2006-01-02 15:04:05.000"), nil + } + return readString(r) + } +} diff --git a/pkg/proxy/engines/deltaproxycache.go b/pkg/proxy/engines/deltaproxycache.go index a70640ed6..25713b71f 100644 --- a/pkg/proxy/engines/deltaproxycache.go +++ b/pkg/proxy/engines/deltaproxycache.go @@ -39,11 +39,14 @@ import ( "github.com/trickstercache/trickster/v2/pkg/observability/logging/logger" "github.com/trickstercache/trickster/v2/pkg/observability/metrics" tspan "github.com/trickstercache/trickster/v2/pkg/observability/tracing/span" + "github.com/trickstercache/trickster/v2/pkg/parsing" + "github.com/trickstercache/trickster/v2/pkg/parsing/sql" tctx "github.com/trickstercache/trickster/v2/pkg/proxy/context" tpe "github.com/trickstercache/trickster/v2/pkg/proxy/errors" "github.com/trickstercache/trickster/v2/pkg/proxy/headers" "github.com/trickstercache/trickster/v2/pkg/proxy/request" "github.com/trickstercache/trickster/v2/pkg/timeseries" + "github.com/trickstercache/trickster/v2/pkg/timeseries/sqlparser" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" @@ -208,8 +211,12 @@ func DeltaProxyCacheRequest(w http.ResponseWriter, r *http.Request, modeler *tim rsc.Unlock() if err != nil { if canOPC { - logger.Debug("could not parse time range query, using object proxy cache", - logging.Pairs{"error": err.Error()}) + pairs := logging.Pairs{"backendName": o.Name, "error": err.Error()} + if (o.DPCFallbackWarning == nil || *o.DPCFallbackWarning) && !isStructuralParseError(err) { + logger.Warn("query fell back from DPC to OPC (not cached as time series)", pairs) + } else { + logger.Debug("query fell back from DPC to OPC (not cached as time series)", pairs) + } rsc.AlternateCacheTTL = time.Minute ObjectProxyCacheRequest(w, r) return @@ -731,7 +738,13 @@ func fetchExtents(el timeseries.ExtentList, rsc *request.Resources, h http.Heade } if resp.StatusCode == http.StatusOK && len(body) > 0 { - nts, ferr := wur(getDecoderReader(resp), rsc.TimeRangeQuery) + var dr io.Reader + if chFmt := resp.Header.Get("X-ClickHouse-Format"); chFmt != "" { + dr = timeseries.NewFormatHintReader(bytes.NewReader(body), chFmt) + } else { + dr = getDecoderReader(resp) + } + nts, ferr := wur(dr, rsc.TimeRangeQuery) if ferr != nil { logger.Error("proxy object unmarshaling failed", logging.Pairs{"detail": ferr.Error()}) @@ -784,3 +797,16 @@ func fetchExtents(el timeseries.ExtentList, rsc *request.Resources, h http.Heade eg.Wait() return mts, uncachedValueCount.Load(), mresp, errors.Join(errs...) } + +// isStructuralParseError returns true if the error indicates the query is +// structurally not a time series query (no FROM, no SELECT, no WHERE, etc.) +// rather than a time series query that failed to parse. These are expected +// for health checks, version queries, and other non-timeseries SELECTs. +func isStructuralParseError(err error) bool { + return errors.Is(err, sql.ErrNotAtFrom) || + errors.Is(err, sql.ErrNotAtSelect) || + errors.Is(err, sql.ErrNotAtWhere) || + errors.Is(err, sql.ErrNotAtGroupBy) || + errors.Is(err, sqlparser.ErrNotTimeRangeQuery) || + errors.Is(err, parsing.ErrUnexpectedToken) +} diff --git a/pkg/timeseries/format_hint.go b/pkg/timeseries/format_hint.go new file mode 100644 index 000000000..7fbae39a0 --- /dev/null +++ b/pkg/timeseries/format_hint.go @@ -0,0 +1,32 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package timeseries + +import "io" + +// FormatHintReader wraps an io.Reader with a format hint from the upstream +// response (e.g., the X-ClickHouse-Format header). Provider modelers can +// type-assert on this to detect the wire format without byte-peeking. +type FormatHintReader struct { + io.Reader + Format string +} + +// NewFormatHintReader wraps a reader with a format hint. +func NewFormatHintReader(r io.Reader, format string) *FormatHintReader { + return &FormatHintReader{Reader: r, Format: format} +} From 9b1f828dd1c93b125605df5baa3274a2fa871ce0 Mon Sep 17 00:00:00 2001 From: Chris Randles Date: Sat, 18 Apr 2026 10:48:07 -0400 Subject: [PATCH 4/8] ci: skip CodeQL on forks + update ClickHouse docs Signed-off-by: Chris Randles --- .github/workflows/codeql-analysis.yml | 1 + docs/clickhouse.md | 91 ++++++++++++++++++++++----- 2 files changed, 77 insertions(+), 15 deletions(-) diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 3f44ee6ab..bd46781f4 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -21,6 +21,7 @@ on: jobs: analyze: name: Analyze + if: github.repository_owner == 'trickstercache' runs-on: ubuntu-latest permissions: actions: read diff --git a/docs/clickhouse.md b/docs/clickhouse.md index 8fc07881c..18000d547 100644 --- a/docs/clickhouse.md +++ b/docs/clickhouse.md @@ -4,9 +4,40 @@ Trickster will accelerate ClickHouse queries that return time series data normal ## Scope of Support -Trickster is tested with the [ClickHouse DataSource Plugin for Grafana](https://grafana.com/grafana/plugins/vertamedia-clickhouse-datasource) v1.9.3 by Vertamedia, and supports acceleration of queries constructed by this plugin using the plugin's built-in `$timeSeries` macro. Trickster also supports several other query formats that return "time series like" data. +Trickster is tested with the [ClickHouse DataSource Plugin for Grafana](https://grafana.com/grafana/plugins/vertamedia-clickhouse-datasource) v1.9.3 by Vertamedia, and supports acceleration of queries constructed by this plugin using the plugin's built-in `$timeSeries` macro. Trickster also supports the official [Grafana ClickHouse plugin](https://grafana.com/grafana/plugins/grafana-clickhouse-datasource/) (v4+), including `toDateTime64()` in WHERE clauses. Trickster also supports several other query formats that return "time series like" data. -Trickster also supports the ClickHouse Go SDK (`clickhouse-go/v2`). Queries made through the SDK's HTTP protocol — including those using `clickhouse.OpenDB` — are proxied and cached through Trickster. The SDK's Native binary protocol is supported via transparent proxying (OPC). +Trickster also supports the ClickHouse Go SDK (`clickhouse-go/v2`). Queries made through the SDK's HTTP protocol — including those using `clickhouse.OpenDB` — are proxied and cached through Trickster. + +### Native Binary Protocol Support + +Trickster supports the ClickHouse native binary protocol (port 9000) in two configurations: + +**Flow 1 — Native Protocol Listener:** Trickster can accept native protocol connections directly via a protocol listener. Clients connect using the native binary protocol, and Trickster proxies queries through its caching engine. Configure a protocol listener in the `frontend` section: + +```yaml +frontend: + protocol_listeners: + - name: clickhouse-native + protocol: clickhouse-native + listen_port: 9000 + backend: click1 +``` + +**Flow 2 — Native Upstream:** Trickster can speak native protocol to ClickHouse upstream while accepting HTTP from clients. Set `protocol: native` on the backend: + +```yaml +backends: + click1: + provider: clickhouse + origin_url: 'http://127.0.0.1:9000' + protocol: native +``` + +The native protocol implementation supports SELECT queries, ping/pong, handshake with addendum, and LZ4 block compression. INSERT operations over the native protocol are proxied as SQL but inline data blocks are not yet supported. + +Supported native protocol data types: all integer types (8–256 bit), Float32/64, String, FixedString(N), DateTime, DateTime64, Date, Date32, UUID, IPv4, IPv6, Enum8/16, Bool, Nullable(T), Array(T), Map(K,V), Tuple(T1,T2,...), LowCardinality(T), and Decimal. + +**Native Format in DPC:** When a client sends `default_format=Native` (as the official Grafana ClickHouse plugin does), Trickster's DPC automatically requests Native format from ClickHouse, deserializes it into the internal DataSet representation for caching and delta merging, and returns Native binary to the client. This is detected via the `X-ClickHouse-Format` response header — no configuration needed. Because ClickHouse does not ship an official Go query parser, Trickster includes its own custom SQL parser and lexer to deconstruct incoming ClickHouse queries, determine if they are cacheable and, if so, what elements are factored into the cache key derivation. Trickster also determines the requested time range and step based on the provided absolute values, in order to normalize the query before hashing the cache key. @@ -20,44 +51,74 @@ SELECT intDiv(toUInt32(time_col, 60) * 60) [* 1000] [as] [alias] ``` This is the approach used by the Grafana plugin. The time_col and/or alias is used to determine the requested time range from the WHERE or PREWHERE clause of the query. The argument to the ClickHouse intDiv function is the step value in seconds, since the toUInt32 function on a datetime column returns the Unix epoch seconds. -#### ClickHouse Time Grouping Function +#### ClickHouse Time Grouping Functions ```sql SELECT toStartOf[Period](time_col) [as] [alias] ``` -This is the approach that uses the following optimized ClickHouse functions to group timeseries queries: +Supported `toStartOf` functions: ``` -toStartOfMinute -toStartOfFiveMinute -toStartOfTenMinutes -toStartOfFifteenMinutes -toStartOfHour -toDate +toStartOfNanosecond toStartOfMicrosecond toStartOfMillisecond +toStartOfSecond toStartOfMinute toStartOfFiveMinute +toStartOfTenMinutes toStartOfFifteenMinutes toStartOfHour +toStartOfDay toStartOfWeek toStartOfMonth +toStartOfQuarter toStartOfYear toMonday ``` -Again the time_col and/or alias is used to determine the request time range from the WHERE or PREWHERE clause, and the step is derived from the function name. + +#### Custom Interval Grouping +```sql +SELECT toStartOfInterval(time_col, INTERVAL N unit) [as] [alias] +``` +Supported units: `second`, `minute`, `hour`, `day`, `week`, `month`, `quarter`, `year`, `millisecond`, `microsecond`, `nanosecond`. + +#### SQL-Standard date_trunc +```sql +SELECT date_trunc('unit', time_col) [as] [alias] +``` +Supported units: `second`, `minute`, `hour`, `day`, `week`, `month`, `quarter`, `year`, `millisecond`. + +#### timeSlot +```sql +SELECT timeSlot(time_col) [as] [alias] +``` +Rounds to 30-minute boundaries. + +The time_col and/or alias is used to determine the request time range from the WHERE or PREWHERE clause, and the step is derived from the function name or interval. #### Determining the requested time range Once the time column (or alias) and step are derived, Trickster parses each WHERE or PREWHERE clause to find comparison operations that mark the requested time range. To be cacheable, the WHERE clause must contain either a `[timecol|alias] BETWEEN` phrase or a `[time_col|alias] >[=]` phrase. The BETWEEN or >= arguments must be a parsable ClickHouse string date in the form `2006-01-02 15:04:05`, -a ten digit integer representing epoch seconds, or the `now()` ClickHouse function with optional subtraction. +a ten digit integer representing epoch seconds, or the `now()` / `now64()` ClickHouse functions with optional subtraction. If a `>` phrase is used, a similar `<` phrase can be used to specify the end of the time period. If none is found, Trickster will still cache results up to the current time, but future queries must also have no end time phrase, or Trickster will be unable to find the correct cache key. Examples of cacheable time range WHERE clauses: ```sql -WHERE t >= "2020-10-15 00:00:00" and t <= "2020-10-16 12:00:00" -WHERE t >= "2020-10-15 12:00:00" and t < now() - 60 * 60 +WHERE t >= '2020-10-15 00:00:00' AND t <= '2020-10-16 12:00:00' +WHERE t >= '2020-10-15 12:00:00' AND t < now() - 60 * 60 WHERE datetime BETWEEN 1574686300 AND 1574689900 +WHERE datetime >= toDateTime64(1589904000, 3) AND datetime <= toDateTime64(1589997600, 3) ``` -Note that these values can be wrapped in the ClickHouse toDateTime function, but ClickHouse will make that conversion implicitly and it is not required. All string times are assumed to be UTC. +These values can be wrapped in `toDateTime()` or `toDateTime64()` (including with a precision argument). All string times are assumed to be UTC. + +Queries using `toStartOfMonth`, `toStartOfQuarter`, or `toStartOfYear` (which return Date rather than DateTime) automatically wrap time range boundaries in `toDate()` during interpolation. ### Non-Time-Series Queries Queries that are not cacheable as time series — such as `LIMIT`-based queries, `SELECT 1` health checks, or SDK handshake requests — are transparently proxied to the upstream ClickHouse server. These requests are cached using the Object Proxy Cache (OPC) with per-query cache keys derived from the `query` and `database` URL parameters, ensuring that different SQL statements receive distinct cache entries. +When a query falls back from DPC to OPC, Trickster logs a warning at the `warn` level to help identify queries that aren't getting time series acceleration. This can be suppressed per-backend: + +```yaml +backends: + click1: + provider: clickhouse + dpc_fallback_warning: false +``` + ### Health and Ping Endpoint Trickster exposes a `/ping` endpoint that returns a health check response, matching the endpoint provided by ClickHouse itself. This enables compatibility with clients and SDKs that probe `/ping` during connection initialization. From bb8c3db4a72fe66bcfc1cb3848084f3d52fd8f6a Mon Sep 17 00:00:00 2001 From: Chris Randles Date: Sun, 19 Apr 2026 23:19:35 -0400 Subject: [PATCH 5/8] docs: [clickhouse] bump provider summary, document protocol + dpc_fallback_warning + protocol_listeners Signed-off-by: Chris Randles --- docs/supported-backend-providers.md | 2 +- examples/conf/example.full.yaml | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/supported-backend-providers.md b/docs/supported-backend-providers.md index 918188b0d..b3b6d6846 100644 --- a/docs/supported-backend-providers.md +++ b/docs/supported-backend-providers.md @@ -22,6 +22,6 @@ See the [InfluxDB Support Document](./influxdb.md) for more information. ### ClickHouse -Trickster supports accelerating ClickHouse time series. Specify `'clickhouse'` as the Provider when configuring Trickster. +Trickster supports accelerating ClickHouse time series over both HTTP and the ClickHouse native binary protocol (port 9000), and is tested against the Vertamedia and official Grafana ClickHouse (v4+) datasource plugins. Specify `'clickhouse'` as the Provider when configuring Trickster. See the [ClickHouse Support Document](./clickhouse.md) for more information. diff --git a/examples/conf/example.full.yaml b/examples/conf/example.full.yaml index 87d2a29c0..1822a3cc1 100644 --- a/examples/conf/example.full.yaml +++ b/examples/conf/example.full.yaml @@ -58,6 +58,18 @@ frontend: # # Default: false # # truncate_request_body_too_large: false +# # protocol_listeners configures raw TCP listeners for non-HTTP wire protocols +# # (e.g. the ClickHouse native binary protocol on port 9000). Each listener routes +# # all accepted connections to a single named backend. Supported protocols today: +# # clickhouse-native. See /docs/clickhouse.md for the ClickHouse flow. +# # protocol_listeners: +# # - name: clickhouse-native # human-readable identifier +# # protocol: clickhouse-native # wire protocol handler +# # listen_address: '' # bind all interfaces by default +# # listen_port: 9000 # TCP port +# # backend: click1 # name of a configured backend (below) +# # connections_limit: 0 # 0 = unlimited + # caches: # default: # # provider defines what kind of cache Trickster uses @@ -254,6 +266,12 @@ backends: # origin_url is a required configuration value origin_url: http://prometheus:9090 +# # protocol selects the upstream wire protocol used to communicate with the origin. +# # When empty (default), HTTP is used. Supported values are provider-specific: +# # clickhouse: 'native' — speak the native binary protocol to an upstream on port 9000 +# # (this is independent of frontend protocol_listeners, which accept non-HTTP clients.) +# protocol: '' + # is_default describes whether this backend is the default backend considered when routing http requests # it is false, by default; but if you only have a single backend configured, is_default will be true unless explicitly set to false is_default: true @@ -339,6 +357,12 @@ backends: # # These next 7 settings only apply to Time Series backends +# # dpc_fallback_warning, when true (default), logs a warning when a query cannot be parsed +# # as a time range query and falls back from DPC (Delta Proxy Cache) to OPC (Object Proxy Cache). +# # Set to false to suppress these warnings (they will still appear at debug level). +# # Applies to SQL-based TSDB backends like ClickHouse that may receive non-time-series queries. +# dpc_fallback_warning: true + # # backfill_tolerance prevents new datapoints that fall within the tolerance window (relative to time.Now) from being permanently # # cached. Think of it as "the newest N milliseconds of real-time data are preliminary and subject to updates, so refresh them periodically" # # default is 0ms From 1886ad8e16e74df50258f97552c3ecd37628187f Mon Sep 17 00:00:00 2001 From: Chris Randles Date: Mon, 20 Apr 2026 07:10:58 -0400 Subject: [PATCH 6/8] test/fix: [clickhouse] unit tests + native client scan fix Signed-off-by: Chris Randles --- pkg/backends/clickhouse/model/native_test.go | 597 ++++++++++++++++++ pkg/backends/clickhouse/native/client.go | 39 +- pkg/backends/clickhouse/native/client_test.go | 393 ++++++++++++ .../clickhouse/native/server/response_test.go | 241 +++++++ .../clickhouse/native/server/types_test.go | 477 ++++++++++++++ 5 files changed, 1729 insertions(+), 18 deletions(-) create mode 100644 pkg/backends/clickhouse/model/native_test.go create mode 100644 pkg/backends/clickhouse/native/client_test.go create mode 100644 pkg/backends/clickhouse/native/server/response_test.go create mode 100644 pkg/backends/clickhouse/native/server/types_test.go diff --git a/pkg/backends/clickhouse/model/native_test.go b/pkg/backends/clickhouse/model/native_test.go new file mode 100644 index 000000000..85a4b730b --- /dev/null +++ b/pkg/backends/clickhouse/model/native_test.go @@ -0,0 +1,597 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "bufio" + "bytes" + "errors" + "io" + "math" + "strconv" + "strings" + "testing" + + "github.com/trickstercache/trickster/v2/pkg/timeseries" + "github.com/trickstercache/trickster/v2/pkg/timeseries/dataset" + "github.com/trickstercache/trickster/v2/pkg/timeseries/epoch" +) + +func TestWriteUvarint(t *testing.T) { + cases := []uint64{0, 1, 127, 128, 16383, 16384, 1 << 32, math.MaxUint64} + for _, v := range cases { + b := new(bytes.Buffer) + if err := writeUvarint(b, v); err != nil { + t.Fatalf("writeUvarint(%d): %v", v, err) + } + got, err := readUvarint(bufio.NewReader(bytes.NewReader(b.Bytes()))) + if err != nil { + t.Fatalf("readUvarint for %d: %v", v, err) + } + if got != v { + t.Fatalf("uvarint roundtrip: want %d got %d", v, got) + } + } +} + +func TestWriteUvarint_WriterError(t *testing.T) { + if err := writeUvarint(errWriter{}, 12345); err == nil { + t.Fatal("expected error from failing writer") + } +} + +func TestWriteNativeString(t *testing.T) { + cases := []string{"", "a", "hello world", strings.Repeat("x", 300)} + for _, s := range cases { + b := new(bytes.Buffer) + if err := writeNativeString(b, s); err != nil { + t.Fatalf("writeNativeString(%q): %v", s, err) + } + got, err := readString(bufio.NewReader(bytes.NewReader(b.Bytes()))) + if err != nil { + t.Fatalf("readString: %v", err) + } + if got != s { + t.Fatalf("string roundtrip: want %q got %q", s, got) + } + } +} + +func TestWriteNativeString_WriterError(t *testing.T) { + if err := writeNativeString(errWriter{}, "anything"); err == nil { + t.Fatal("expected error writing length") + } + // empty string only writes the length; a writer that fails after first + // byte still surfaces the error for non-empty payloads. + if err := writeNativeString(shortWriter{max: 1}, "abc"); err == nil { + t.Fatal("expected error writing body") + } +} + +func TestReadString_EmptyAndError(t *testing.T) { + // length 0 + br := bufio.NewReader(bytes.NewReader([]byte{0})) + s, err := readString(br) + if err != nil || s != "" { + t.Fatalf("empty string: %v %q", err, s) + } + // length 5 but only 3 bytes available + br = bufio.NewReader(bytes.NewReader([]byte{5, 'a', 'b', 'c'})) + if _, err := readString(br); err == nil { + t.Fatal("expected short read error") + } + // uvarint read error + br = bufio.NewReader(bytes.NewReader(nil)) + if _, err := readString(br); err == nil { + t.Fatal("expected error on empty input") + } +} + +func TestReadUvarint_Error(t *testing.T) { + br := bufio.NewReader(bytes.NewReader([]byte{0xff})) + if _, err := readUvarint(br); err == nil { + t.Fatal("expected error on truncated uvarint") + } +} + +func TestReadFixed(t *testing.T) { + r := bytes.NewReader([]byte{1, 2, 3, 4, 5}) + b, err := readFixed(r, 3) + if err != nil { + t.Fatalf("readFixed: %v", err) + } + if !bytes.Equal(b, []byte{1, 2, 3}) { + t.Fatalf("want [1 2 3] got %v", b) + } + if _, err := readFixed(r, 10); err == nil { + t.Fatal("expected EOF") + } +} + +func TestWriteReadNativeValueRoundTrip(t *testing.T) { + cases := []struct { + typ, in, want string + }{ + {TypeUInt8, "200", "200"}, + {TypeBool, "1", "1"}, + {TypeUInt16, "65535", "65535"}, + {TypeUInt32, "4294967295", "4294967295"}, + {TypeUInt64, "18446744073709551615", "18446744073709551615"}, + {TypeInt8, "-5", "-5"}, + {TypeInt16, "-20000", "-20000"}, + {TypeInt32, "-2147483648", "-2147483648"}, + {TypeInt64, "-9223372036854775808", "-9223372036854775808"}, + {TypeString, "hello", "hello"}, + {TypeString, "", ""}, + {TypeDateTime, "2020-01-01 00:00:00", "2020-01-01 00:00:00"}, + {TypeDateTime, "1577836800", "2020-01-01 00:00:00"}, + {TypeDate, "2020-01-01", "2020-01-01"}, + } + for _, c := range cases { + b := new(bytes.Buffer) + if err := writeNativeValue(b, c.typ, c.in); err != nil { + t.Fatalf("writeNativeValue(%s,%q): %v", c.typ, c.in, err) + } + br := bufio.NewReader(bytes.NewReader(b.Bytes())) + got, err := readValueAsString(br, c.typ) + if err != nil { + t.Fatalf("readValueAsString(%s): %v", c.typ, err) + } + if got != c.want { + t.Fatalf("%s: want %q got %q", c.typ, c.want, got) + } + } +} + +func TestWriteReadFloatValues(t *testing.T) { + // floats don't exactly roundtrip string<->bits; verify the numeric value. + b := new(bytes.Buffer) + if err := writeNativeValue(b, TypeFloat32, "3.5"); err != nil { + t.Fatal(err) + } + got, err := readValueAsString(bufio.NewReader(bytes.NewReader(b.Bytes())), TypeFloat32) + if err != nil { + t.Fatal(err) + } + if got != "3.5" { + t.Fatalf("float32: want 3.5 got %q", got) + } + + b.Reset() + if err := writeNativeValue(b, TypeFloat64, "1.25"); err != nil { + t.Fatal(err) + } + got, err = readValueAsString(bufio.NewReader(bytes.NewReader(b.Bytes())), TypeFloat64) + if err != nil { + t.Fatal(err) + } + if got != "1.25" { + t.Fatalf("float64: want 1.25 got %q", got) + } +} + +func TestWriteNativeValue_DateTime64(t *testing.T) { + b := new(bytes.Buffer) + if err := writeNativeValue(b, "DateTime64(3)", "2020-01-01 00:00:00.000"); err != nil { + t.Fatal(err) + } + got, err := readValueAsString(bufio.NewReader(bytes.NewReader(b.Bytes())), "DateTime64(3)") + if err != nil { + t.Fatal(err) + } + if got != "2020-01-01 00:00:00.000" { + t.Fatalf("DateTime64 roundtrip: got %q", got) + } + + // epoch-ms fallback path: unparsable text falls through to ParseInt + b.Reset() + if err := writeNativeValue(b, "DateTime64(3)", "1577836800000"); err != nil { + t.Fatal(err) + } + got, err = readValueAsString(bufio.NewReader(bytes.NewReader(b.Bytes())), "DateTime64(3)") + if err != nil { + t.Fatal(err) + } + if got != "2020-01-01 00:00:00.000" { + t.Fatalf("DateTime64 epoch-ms: got %q", got) + } +} + +func TestWriteNativeValue_UnknownTypeAsString(t *testing.T) { + b := new(bytes.Buffer) + if err := writeNativeValue(b, "SomeUnknownType", "payload"); err != nil { + t.Fatal(err) + } + got, err := readValueAsString(bufio.NewReader(bytes.NewReader(b.Bytes())), "SomeUnknownType") + if err != nil { + t.Fatal(err) + } + if got != "payload" { + t.Fatalf("unknown type: got %q", got) + } +} + +func TestWriteNativeValue_DateUnparsable(t *testing.T) { + // Unparsable Date falls through to writeNativeString + b := new(bytes.Buffer) + if err := writeNativeValue(b, TypeDate, "not-a-date"); err != nil { + t.Fatal(err) + } + // reading as Date consumes 2 bytes — here the first byte is the uvarint + // length written by the string fallback. Just verify bytes were written. + if b.Len() == 0 { + t.Fatal("expected bytes written on Date fallback") + } +} + +func TestReadValueAsString_ShortBuffers(t *testing.T) { + types := []string{TypeUInt16, TypeUInt32, TypeUInt64, TypeInt16, + TypeInt32, TypeInt64, TypeFloat32, TypeFloat64, TypeDateTime, TypeDate, + "DateTime64(3)"} + for _, typ := range types { + br := bufio.NewReader(bytes.NewReader(nil)) + if _, err := readValueAsString(br, typ); err == nil { + t.Fatalf("%s: expected error on empty reader", typ) + } + } +} + +func TestReadValueAsString_Int8UInt8Empty(t *testing.T) { + // UInt8/Bool/Int8 call ReadByte which surfaces io.EOF. + for _, typ := range []string{TypeUInt8, TypeInt8, TypeBool} { + br := bufio.NewReader(bytes.NewReader(nil)) + if _, err := readValueAsString(br, typ); err == nil { + t.Fatalf("%s: expected EOF", typ) + } + } +} + +func TestFormatEpochForType(t *testing.T) { + ep := epoch.Epoch(int64(1577836800) * 1e9) // 2020-01-01 00:00:00 UTC + cases := []struct { + sdt, want string + }{ + {TypeDateTime, "2020-01-01 00:00:00"}, + {TypeDate, "2020-01-01"}, + {"DateTime64(3)", "2020-01-01 00:00:00.000"}, + {"UInt64", "1577836800"}, + } + for _, c := range cases { + got := formatEpochForType(ep, timeseries.FieldDefinition{SDataType: c.sdt}) + if got != c.want { + t.Fatalf("%s: want %q got %q", c.sdt, c.want, got) + } + } +} + +func TestWriteNativeBlockInfo(t *testing.T) { + b := new(bytes.Buffer) + if err := writeNativeBlockInfo(b); err != nil { + t.Fatal(err) + } + if err := skipBlockInfo(bufio.NewReader(bytes.NewReader(b.Bytes()))); err != nil { + t.Fatalf("skipBlockInfo: %v", err) + } +} + +func TestWriteEmptyNativeBlock(t *testing.T) { + b := new(bytes.Buffer) + if err := writeEmptyNativeBlock(b); err != nil { + t.Fatal(err) + } + br := bufio.NewReader(bytes.NewReader(b.Bytes())) + if err := skipBlockInfo(br); err != nil { + t.Fatal(err) + } + ncols, err := readUvarint(br) + if err != nil || ncols != 0 { + t.Fatalf("numCols: %d err=%v", ncols, err) + } + nrows, err := readUvarint(br) + if err != nil || nrows != 0 { + t.Fatalf("numRows: %d err=%v", nrows, err) + } +} + +func TestSkipBlockInfo_UnknownField(t *testing.T) { + // fieldNum=7 is unknown, handler returns error. + buf := &bytes.Buffer{} + _ = writeUvarint(buf, 7) + if err := skipBlockInfo(bufio.NewReader(bytes.NewReader(buf.Bytes()))); err == nil { + t.Fatal("expected unknown field error") + } +} + +func TestSkipBlockInfo_ErrorOnRead(t *testing.T) { + // truncated field-1 payload (need 1 byte after fieldNum) + buf := &bytes.Buffer{} + _ = writeUvarint(buf, 1) + if err := skipBlockInfo(bufio.NewReader(bytes.NewReader(buf.Bytes()))); err == nil { + t.Fatal("expected read error on field 1 body") + } + // truncated field-2 payload (need 4 bytes) + buf.Reset() + _ = writeUvarint(buf, 2) + if err := skipBlockInfo(bufio.NewReader(bytes.NewReader(buf.Bytes()))); err == nil { + t.Fatal("expected read error on field 2 body") + } + // no fieldNum at all + if err := skipBlockInfo(bufio.NewReader(bytes.NewReader(nil))); err == nil { + t.Fatal("expected error reading fieldNum") + } +} + +func TestMarshalTimeseriesNative_RoundTrip(t *testing.T) { + ds := testDataSet() + b := new(bytes.Buffer) + if err := marshalTimeseriesNative(b, ds, ×eries.RequestOptions{}); err != nil { + t.Fatalf("marshal: %v", err) + } + + ts, err := UnmarshalTimeseriesNative(b.Bytes(), testTRQ.Clone()) + if err != nil { + t.Fatalf("unmarshal: %v", err) + } + got, ok := ts.(*dataset.DataSet) + if !ok || got == nil { + t.Fatal("expected non-nil DataSet") + } + if len(got.Results) != 1 || len(got.Results[0].SeriesList) != 1 { + t.Fatalf("unexpected shape: results=%d", len(got.Results)) + } + gotSeries := got.Results[0].SeriesList[0] + wantSeries := ds.Results[0].SeriesList[0] + if len(gotSeries.Points) != len(wantSeries.Points) { + t.Fatalf("points: want %d got %d", len(wantSeries.Points), len(gotSeries.Points)) + } + for i := range wantSeries.Points { + if gotSeries.Points[i].Epoch != wantSeries.Points[i].Epoch { + t.Errorf("point %d epoch: want %d got %d", i, + wantSeries.Points[i].Epoch, gotSeries.Points[i].Epoch) + } + } +} + +func TestMarshalTimeseriesNative_EmptyDataSet(t *testing.T) { + b := new(bytes.Buffer) + ds := &dataset.DataSet{} + if err := marshalTimeseriesNative(b, ds, ×eries.RequestOptions{}); err != nil { + t.Fatal(err) + } + // empty block — UnmarshalTimeseriesNative should see no rows and return + // ErrInvalidBody. + if _, err := UnmarshalTimeseriesNative(b.Bytes(), testTRQ.Clone()); err == nil { + t.Fatal("expected error on empty block") + } +} + +func TestMarshalTimeseriesNative_WriterError(t *testing.T) { + ds := testDataSet() + for i := 0; i < 40; i++ { + err := marshalTimeseriesNative(shortWriter{max: i}, ds, ×eries.RequestOptions{}) + if err == nil { + // Once the short writer accepts enough bytes it should succeed. + // Just ensure small limits produce errors. + if i < 5 { + t.Fatalf("limit=%d: expected error", i) + } + } + } +} + +func TestUnmarshalTimeseriesNative_NoBlockInfo(t *testing.T) { + // Build a minimal block without block-info header. First byte (numCols + // uvarint) must NOT be 1 so the peek branch in the reader skips + // skipBlockInfo. Use 2 cols, 1 row. + b := new(bytes.Buffer) + _ = writeUvarint(b, 2) // numCols + _ = writeUvarint(b, 1) // numRows + _ = writeNativeString(b, "t") + _ = writeNativeString(b, "UInt64") + // no customSerialization byte (since no block info) + _ = writeNativeValue(b, "UInt64", "1577836800000") + _ = writeNativeString(b, "hostname") + _ = writeNativeString(b, "String") + _ = writeNativeValue(b, "String", "localhost") + + ts, err := UnmarshalTimeseriesNative(b.Bytes(), testTRQ.Clone()) + if err != nil { + t.Fatalf("unmarshal: %v", err) + } + ds, ok := ts.(*dataset.DataSet) + if !ok || ds == nil { + t.Fatal("expected non-nil DataSet") + } +} + +func TestUnmarshalTimeseriesNative_EmptyReaderError(t *testing.T) { + if _, err := UnmarshalTimeseriesNative(nil, testTRQ.Clone()); err == nil { + t.Fatal("expected error on empty input") + } +} + +func TestUnmarshalTimeseriesNative_ColumnNameReadError(t *testing.T) { + b := new(bytes.Buffer) + _ = writeNativeBlockInfo(b) + _ = writeUvarint(b, 2) // numCols + _ = writeUvarint(b, 1) // numRows + // truncate here: reader expects a column name but hits EOF + if _, err := UnmarshalTimeseriesNativeReader(bytes.NewReader(b.Bytes()), testTRQ.Clone()); err == nil { + t.Fatal("expected error on truncated column name") + } +} + +func TestUnmarshalTimeseriesNative_ColumnTypeReadError(t *testing.T) { + b := new(bytes.Buffer) + _ = writeNativeBlockInfo(b) + _ = writeUvarint(b, 1) // numCols + _ = writeUvarint(b, 1) // numRows + _ = writeNativeString(b, "t") + // truncate before type + if _, err := UnmarshalTimeseriesNativeReader(bytes.NewReader(b.Bytes()), testTRQ.Clone()); err == nil { + t.Fatal("expected error on truncated column type") + } +} + +func TestUnmarshalTimeseriesNative_ValueReadError(t *testing.T) { + b := new(bytes.Buffer) + _ = writeNativeBlockInfo(b) + _ = writeUvarint(b, 1) // numCols + _ = writeUvarint(b, 1) // numRows + _ = writeNativeString(b, "t") + _ = writeNativeString(b, "UInt64") + _, _ = b.Write([]byte{0}) // customSerialization flag + // truncate — UInt64 needs 8 bytes + if _, err := UnmarshalTimeseriesNativeReader(bytes.NewReader(b.Bytes()), testTRQ.Clone()); err == nil { + t.Fatal("expected error on truncated value") + } +} + +func TestUnmarshalTimeseriesNative_CustomSerializationEOF(t *testing.T) { + b := new(bytes.Buffer) + _ = writeNativeBlockInfo(b) + _ = writeUvarint(b, 1) // numCols + _ = writeUvarint(b, 1) // numRows + _ = writeNativeString(b, "t") + _ = writeNativeString(b, "UInt64") + // missing customSerialization flag byte + if _, err := UnmarshalTimeseriesNativeReader(bytes.NewReader(b.Bytes()), testTRQ.Clone()); err == nil { + t.Fatal("expected error on missing custom serialization byte") + } +} + +func TestUnmarshalTimeseriesNativeReader_MultiBlock(t *testing.T) { + // concatenate two blocks — unmarshal should aggregate rows from both. + b := new(bytes.Buffer) + writeBlock := func(rows int) { + _ = writeNativeBlockInfo(b) + _ = writeUvarint(b, 2) // numCols + _ = writeUvarint(b, uint64(rows)) // numRows + _ = writeNativeString(b, "t") + _ = writeNativeString(b, "UInt64") + _, _ = b.Write([]byte{0}) + for i := 0; i < rows; i++ { + _ = writeNativeValue(b, "UInt64", + strconv.FormatInt(1577836800000+int64(i)*60000, 10)) + } + _ = writeNativeString(b, "hostname") + _ = writeNativeString(b, "String") + _, _ = b.Write([]byte{0}) + for i := 0; i < rows; i++ { + _ = writeNativeValue(b, "String", "localhost") + } + } + writeBlock(2) + writeBlock(1) + + ts, err := UnmarshalTimeseriesNativeReader(bytes.NewReader(b.Bytes()), testTRQ.Clone()) + if err != nil { + t.Fatalf("unmarshal: %v", err) + } + ds := ts.(*dataset.DataSet) + total := 0 + for _, s := range ds.Results[0].SeriesList { + total += len(s.Points) + } + if total != 3 { + t.Fatalf("expected 3 total points across blocks, got %d", total) + } +} + +// Assemble a block with a terminating zero-cols block (as ClickHouse sends). +func TestUnmarshalTimeseriesNativeReader_ZeroColsTerminator(t *testing.T) { + b := new(bytes.Buffer) + _ = writeNativeBlockInfo(b) + _ = writeUvarint(b, 2) + _ = writeUvarint(b, 1) + _ = writeNativeString(b, "t") + _ = writeNativeString(b, "UInt64") + _, _ = b.Write([]byte{0}) + _ = writeNativeValue(b, "UInt64", "1577836800000") + _ = writeNativeString(b, "hostname") + _ = writeNativeString(b, "String") + _, _ = b.Write([]byte{0}) + _ = writeNativeValue(b, "String", "localhost") + // terminator block: info + 0 cols + 0 rows + _ = writeNativeBlockInfo(b) + _ = writeUvarint(b, 0) + _ = writeUvarint(b, 0) + + ts, err := UnmarshalTimeseriesNativeReader(bytes.NewReader(b.Bytes()), testTRQ.Clone()) + if err != nil { + t.Fatalf("unmarshal: %v", err) + } + if ts == nil { + t.Fatal("nil timeseries") + } +} + +func TestNopReaderBehavior(t *testing.T) { + r := &nopReader{data: []byte("world"), pos: 0} + buf := make([]byte, 3) + n, err := r.Read(buf) + if err != nil || n != 3 || string(buf) != "wor" { + t.Fatalf("read1: n=%d err=%v buf=%q", n, err, buf) + } + n, err = r.Read(buf) + if err != nil || n != 2 || string(buf[:n]) != "ld" { + t.Fatalf("read2: n=%d err=%v buf=%q", n, err, buf[:n]) + } + n, err = r.Read(buf) + if n != 0 || !errors.Is(err, io.EOF) { + t.Fatalf("read3 want EOF: n=%d err=%v", n, err) + } +} + +// Cover via the higher-level MarshalTimeseries dispatch for OutputFormatNative. +func TestMarshalTimeseries_NativeFormat(t *testing.T) { + out, err := MarshalTimeseries(testDataSet(), + ×eries.RequestOptions{OutputFormat: OutputFormatNative}, 200) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if len(out) == 0 { + t.Fatal("expected non-empty output") + } + // Verify it parses back. + if _, err := UnmarshalTimeseriesNative(out, testTRQ.Clone()); err != nil { + t.Fatalf("roundtrip unmarshal: %v", err) + } +} + +// --- test helpers --- + +type errWriter struct{} + +func (errWriter) Write(_ []byte) (int, error) { return 0, errors.New("boom") } + +type shortWriter struct { + max, written int +} + +func (s shortWriter) Write(p []byte) (int, error) { + remaining := s.max - s.written + if remaining <= 0 { + return 0, errors.New("write limit exceeded") + } + n := len(p) + if n > remaining { + return remaining, errors.New("write limit exceeded") + } + return n, nil +} + diff --git a/pkg/backends/clickhouse/native/client.go b/pkg/backends/clickhouse/native/client.go index fb9370667..8951930fb 100644 --- a/pkg/backends/clickhouse/native/client.go +++ b/pkg/backends/clickhouse/native/client.go @@ -22,6 +22,7 @@ package native import ( "bytes" + "database/sql" "encoding/json" "errors" "fmt" @@ -30,7 +31,6 @@ import ( "strings" "github.com/ClickHouse/clickhouse-go/v2" - "github.com/ClickHouse/clickhouse-go/v2/lib/driver" bo "github.com/trickstercache/trickster/v2/pkg/backends/options" "github.com/trickstercache/trickster/v2/pkg/proxy/request" ) @@ -38,7 +38,7 @@ import ( // NativeClient wraps a clickhouse-go native connection pool and exposes a // Fetcher compatible with bo.Options.Fetcher. type NativeClient struct { - conn driver.Conn + db *sql.DB } // NewNativeClient creates a NativeClient from the backend options. The @@ -51,19 +51,16 @@ func NewNativeClient(o *bo.Options) (*NativeClient, error) { if !strings.Contains(addr, ":") { addr += ":9000" } - conn, err := clickhouse.Open(&clickhouse.Options{ + db := clickhouse.OpenDB(&clickhouse.Options{ Addr: []string{addr}, Protocol: clickhouse.Native, }) - if err != nil { - return nil, fmt.Errorf("clickhouse native: open: %w", err) - } - return &NativeClient{conn: conn}, nil + return &NativeClient{db: db}, nil } // Close closes the underlying connection pool. func (nc *NativeClient) Close() error { - return nc.conn.Close() + return nc.db.Close() } // Fetch executes the SQL query embedded in the *http.Request (body or query @@ -79,14 +76,20 @@ func (nc *NativeClient) Fetch(r *http.Request) (*http.Response, error) { return syntheticErrorResponse(http.StatusBadRequest, err), nil } - rows, err := nc.conn.Query(r.Context(), sql) + rows, err := nc.db.QueryContext(r.Context(), sql) // #nosec G701 -- proxy passthrough; SQL is forwarded verbatim from client if err != nil { return syntheticErrorResponse(http.StatusBadGateway, err), nil } defer rows.Close() - colTypes := rows.ColumnTypes() - colNames := rows.Columns() + colNames, err := rows.Columns() + if err != nil { + return syntheticErrorResponse(http.StatusBadGateway, err), nil + } + colTypes, err := rows.ColumnTypes() + if err != nil { + return syntheticErrorResponse(http.StatusBadGateway, err), nil + } meta := make([]map[string]string, len(colNames)) for i, name := range colNames { @@ -97,18 +100,18 @@ func (nc *NativeClient) Fetch(r *http.Request) (*http.Response, error) { } data := make([]map[string]any, 0, 64) - scanDest := make([]any, len(colNames)) - for i := range scanDest { - scanDest[i] = new(any) - } - for rows.Next() { - if err := rows.Scan(scanDest...); err != nil { + scanDest := make([]any, len(colNames)) + ptrs := make([]any, len(colNames)) + for i := range scanDest { + ptrs[i] = &scanDest[i] + } + if err := rows.Scan(ptrs...); err != nil { return syntheticErrorResponse(http.StatusBadGateway, err), nil } row := make(map[string]any, len(colNames)) for i, name := range colNames { - row[name] = *(scanDest[i].(*any)) + row[name] = scanDest[i] } data = append(data, row) } diff --git a/pkg/backends/clickhouse/native/client_test.go b/pkg/backends/clickhouse/native/client_test.go new file mode 100644 index 000000000..811344d48 --- /dev/null +++ b/pkg/backends/clickhouse/native/client_test.go @@ -0,0 +1,393 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package native + +import ( + "context" + "encoding/json" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + "time" + + "github.com/trickstercache/trickster/v2/pkg/backends/clickhouse/native/server" + bo "github.com/trickstercache/trickster/v2/pkg/backends/options" +) + +func TestNewNativeClient(t *testing.T) { + tests := []struct { + name string + host string + wantErr bool + }{ + {"empty host", "", true}, + {"host with port", "localhost:9000", false}, + {"host without port", "localhost", false}, + {"ipv4 with port", "127.0.0.1:19000", false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + o := &bo.Options{Host: tc.host} + c, err := NewNativeClient(o) + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if c == nil || c.db == nil { + t.Fatal("expected non-nil client with non-nil db") + } + if err := c.Close(); err != nil { + t.Fatalf("close: %v", err) + } + }) + } +} + +func TestClose(t *testing.T) { + c, err := NewNativeClient(&bo.Options{Host: "127.0.0.1:19000"}) + if err != nil { + t.Fatalf("new: %v", err) + } + if err := c.Close(); err != nil { + t.Fatalf("close: %v", err) + } +} + +func TestExtractSQL(t *testing.T) { + tests := []struct { + name string + method string + body string + query string + want string + wantErr bool + }{ + { + name: "post body", + method: http.MethodPost, + body: "SELECT 1", + want: "SELECT 1", + }, + { + name: "get query param", + method: http.MethodGet, + query: "SELECT 2", + want: "SELECT 2", + }, + { + name: "post body preferred over query param", + method: http.MethodPost, + body: "SELECT body", + query: "SELECT param", + want: "SELECT body", + }, + { + name: "no sql", + method: http.MethodGet, + wantErr: true, + }, + { + name: "empty post body falls back to missing", + method: http.MethodPost, + wantErr: true, + }, + { + name: "empty post body falls back to query", + method: http.MethodPost, + query: "SELECT fallback", + want: "SELECT fallback", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + u := &url.URL{Path: "/"} + if tc.query != "" { + q := u.Query() + q.Set("query", tc.query) + u.RawQuery = q.Encode() + } + var body io.ReadCloser + if tc.body != "" { + body = io.NopCloser(strings.NewReader(tc.body)) + } + r := &http.Request{Method: tc.method, URL: u, Body: body} + got, err := extractSQL(r) + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tc.want { + t.Fatalf("got %q, want %q", got, tc.want) + } + }) + } +} + +func TestSyntheticErrorResponse(t *testing.T) { + err := errors.New("boom") + resp := syntheticErrorResponse(http.StatusBadRequest, err) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status: got %d", resp.StatusCode) + } + if resp.Status != "400 Bad Request" { + t.Fatalf("status text: got %q", resp.Status) + } + if ct := resp.Header.Get("Content-Type"); ct != "text/plain" { + t.Fatalf("content-type: got %q", ct) + } + b, readErr := io.ReadAll(resp.Body) + if readErr != nil { + t.Fatal(readErr) + } + if string(b) != "boom" { + t.Fatalf("body: got %q", string(b)) + } + if resp.ContentLength != int64(len("boom")) { + t.Fatalf("length: got %d", resp.ContentLength) + } +} + +func TestFetchExtractError(t *testing.T) { + c, err := NewNativeClient(&bo.Options{Host: "127.0.0.1:19000"}) + if err != nil { + t.Fatalf("new: %v", err) + } + defer c.Close() + + r := &http.Request{Method: http.MethodGet, URL: &url.URL{Path: "/"}} + resp, err := c.Fetch(r) + if err != nil { + t.Fatalf("fetch: %v", err) + } + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status: got %d, want %d", resp.StatusCode, http.StatusBadRequest) + } +} + +func TestFetchQueryError(t *testing.T) { + addr, stop := startNativeServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad query", http.StatusBadRequest) + })) + defer stop() + + c, err := NewNativeClient(&bo.Options{Host: addr}) + if err != nil { + t.Fatalf("new: %v", err) + } + defer c.Close() + + r := buildQueryRequest(t, "SELECT 1") + resp, err := c.Fetch(r) + if err != nil { + t.Fatalf("fetch: %v", err) + } + if resp.StatusCode != http.StatusBadGateway { + t.Fatalf("status: got %d, want %d", resp.StatusCode, http.StatusBadGateway) + } +} + +func TestFetchSuccess(t *testing.T) { + addr, stop := startNativeServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + doc := map[string]any{ + "meta": []map[string]string{ + {"name": "a", "type": "String"}, + {"name": "b", "type": "String"}, + }, + "data": []map[string]any{}, + "rows": 0, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(doc) + })) + defer stop() + + c, err := NewNativeClient(&bo.Options{Host: addr}) + if err != nil { + t.Fatalf("new: %v", err) + } + defer c.Close() + + r := buildQueryRequest(t, "SELECT a, b FROM t") + resp, err := c.Fetch(r) + if err != nil { + t.Fatalf("fetch: %v", err) + } + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("status: got %d body=%q", resp.StatusCode, body) + } + if ct := resp.Header.Get("Content-Type"); ct != "application/json" { + t.Fatalf("content-type: got %q", ct) + } + + b, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + var out struct { + Meta []map[string]string `json:"meta"` + Data []map[string]any `json:"data"` + Rows int `json:"rows"` + } + if err := json.Unmarshal(b, &out); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if out.Rows != 0 { + t.Fatalf("rows: got %d", out.Rows) + } +} + +func TestFetchSuccessWithRows(t *testing.T) { + addr, stop := startNativeServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + doc := map[string]any{ + "meta": []map[string]string{ + {"name": "id", "type": "Int64"}, + {"name": "label", "type": "String"}, + }, + "data": []map[string]any{ + {"id": 1, "label": "alpha"}, + {"id": 2, "label": "beta"}, + {"id": 3, "label": "gamma"}, + }, + "rows": 3, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(doc) + })) + defer stop() + + c, err := NewNativeClient(&bo.Options{Host: addr}) + if err != nil { + t.Fatalf("new: %v", err) + } + defer c.Close() + + r := buildQueryRequest(t, "SELECT id, label FROM t") + resp, err := c.Fetch(r) + if err != nil { + t.Fatalf("fetch: %v", err) + } + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("status: got %d body=%q", resp.StatusCode, body) + } + + b, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + var out struct { + Meta []map[string]string `json:"meta"` + Data []map[string]any `json:"data"` + Rows int `json:"rows"` + } + if err := json.Unmarshal(b, &out); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if out.Rows != 3 { + t.Fatalf("rows: got %d, want 3", out.Rows) + } + if len(out.Data) != 3 { + t.Fatalf("data len: got %d, want 3", len(out.Data)) + } + wantLabels := []string{"alpha", "beta", "gamma"} + for i, row := range out.Data { + id, ok := row["id"].(float64) + if !ok { + t.Fatalf("row %d id: got %T %v", i, row["id"], row["id"]) + } + if int64(id) != int64(i+1) { + t.Fatalf("row %d id: got %v, want %d", i, id, i+1) + } + label, ok := row["label"].(string) + if !ok { + t.Fatalf("row %d label: got %T %v", i, row["label"], row["label"]) + } + if label != wantLabels[i] { + t.Fatalf("row %d label: got %q, want %q", i, label, wantLabels[i]) + } + } +} + +func buildQueryRequest(t *testing.T, sql string) *http.Request { + t.Helper() + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(sql)) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + return r.WithContext(ctx) +} + +func startNativeServer(t *testing.T, handler http.Handler) (string, func()) { + t.Helper() + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + h := &server.Handler{QueryHandler: handler} + ctx, cancel := context.WithCancel(context.Background()) + var ( + wg sync.WaitGroup + mu sync.Mutex + conns []net.Conn + ) + wg.Add(1) + go func() { + defer wg.Done() + for { + conn, err := lis.Accept() + if err != nil { + return + } + mu.Lock() + conns = append(conns, conn) + mu.Unlock() + wg.Add(1) + go func(c net.Conn) { + defer wg.Done() + defer c.Close() + _ = h.HandleConnection(ctx, c) + }(conn) + } + }() + stop := func() { + cancel() + lis.Close() + mu.Lock() + for _, c := range conns { + c.Close() + } + mu.Unlock() + wg.Wait() + } + return lis.Addr().String(), stop +} diff --git a/pkg/backends/clickhouse/native/server/response_test.go b/pkg/backends/clickhouse/native/server/response_test.go new file mode 100644 index 000000000..3b63c5303 --- /dev/null +++ b/pkg/backends/clickhouse/native/server/response_test.go @@ -0,0 +1,241 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "bytes" + "encoding/binary" + "testing" +) + +func TestZeroForType(t *testing.T) { + cases := []struct { + in string + want any + }{ + {"UInt8", uint8(0)}, + {"Int8", uint8(0)}, + {"UInt16", uint16(0)}, + {"Int16", uint16(0)}, + {"UInt32", uint32(0)}, + {"Int32", uint32(0)}, + {"DateTime", uint32(0)}, + {"UInt64", uint64(0)}, + {"Int64", uint64(0)}, + {"DateTime64", uint64(0)}, + {"Float32", float32(0)}, + {"Float64", float64(0)}, + {"String", ""}, + {"UUID", ""}, + } + for _, c := range cases { + t.Run(c.in, func(t *testing.T) { + if got := zeroForType(c.in); got != c.want { + t.Fatalf("zeroForType(%q): want %v (%T), got %v (%T)", c.in, c.want, c.want, got, got) + } + }) + } +} + +func TestToMap(t *testing.T) { + m := map[string]any{"a": 1, "b": "two"} + if got := toMap(m); len(got) != 2 || got["a"] != 1 || got["b"] != "two" { + t.Fatalf("map passthrough failed: %v", got) + } + if got := toMap("not a map"); got != nil { + t.Fatalf("non-map should be nil, got %v", got) + } + if got := toMap(nil); got != nil { + t.Fatalf("nil should be nil, got %v", got) + } +} + +func TestSplitMapTypes(t *testing.T) { + cases := []struct { + in string + wantK, wantV string + }{ + {"String, UInt32", "String", "UInt32"}, + {"String,Array(UInt32)", "String", "Array(UInt32)"}, + {"Tuple(Int32, Int32), String", "Tuple(Int32, Int32)", "String"}, + {"OnlyOne", "OnlyOne", ""}, + } + for _, c := range cases { + t.Run(c.in, func(t *testing.T) { + k, v := splitMapTypes(c.in) + if k != c.wantK || v != c.wantV { + t.Fatalf("splitMapTypes(%q): want (%q,%q), got (%q,%q)", c.in, c.wantK, c.wantV, k, v) + } + }) + } +} + +func TestWriteLowCardinality(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + values := []any{"a", "b", "a", "c", "a"} + if err := writeLowCardinality(w, "String", values); err != nil { + t.Fatal(err) + } + + r := newProtoReader(&buf) + ver, _ := readInt64(r, t) + if ver != 1 { + t.Fatalf("serialization version: want 1, got %d", ver) + } + flags, _ := readInt64(r, t) + if flags&0x7 != 0 { // <=256 dict -> index type 0 (byte) + t.Fatalf("flags low bits should be 0, got %d", flags&0x7) + } + if flags&(1<<9) == 0 { + t.Fatal("need_update_dictionary flag should be set") + } + + dictLen, _ := readInt64(r, t) + if dictLen != 3 { + t.Fatalf("dict len: want 3, got %d", dictLen) + } + s0, _ := r.str() + s1, _ := r.str() + s2, _ := r.str() + if s0 != "a" || s1 != "b" || s2 != "c" { + t.Fatalf("dict values: got %q %q %q", s0, s1, s2) + } + + indicesLen, _ := readInt64(r, t) + if indicesLen != int64(len(values)) { + t.Fatalf("indices len: want %d, got %d", len(values), indicesLen) + } + for i, want := range []byte{0, 1, 0, 2, 0} { + got, err := r.ReadByte() + if err != nil { + t.Fatal(err) + } + if got != want { + t.Fatalf("index[%d]: want %d, got %d", i, want, got) + } + } +} + +func TestWriteLowCardinalityUInt16Index(t *testing.T) { + // >256 unique values triggers uint16 index encoding. + var buf bytes.Buffer + w := newProtoWriter(&buf) + values := make([]any, 0, 300) + for i := 0; i < 300; i++ { + values = append(values, int32(i)) + } + if err := writeLowCardinality(w, "Int32", values); err != nil { + t.Fatal(err) + } + + r := newProtoReader(&buf) + _, _ = readInt64(r, t) // version + flags, _ := readInt64(r, t) + if flags&0x7 != 1 { + t.Fatalf("expected index type 1 (uint16), got %d", flags&0x7) + } + dictLen, _ := readInt64(r, t) + if dictLen != 300 { + t.Fatalf("dict len: want 300, got %d", dictLen) + } + // skip 300 Int32 dict entries + for i := 0; i < 300; i++ { + if _, err := r.int32(); err != nil { + t.Fatal(err) + } + } + indicesLen, _ := readInt64(r, t) + if indicesLen != 300 { + t.Fatalf("indices len: want 300, got %d", indicesLen) + } + // first index should be 0 + var b [2]byte + r.Read(b[:]) + if binary.LittleEndian.Uint16(b[:]) != 0 { + t.Fatalf("first index: want 0, got %d", binary.LittleEndian.Uint16(b[:])) + } +} + +func TestWriteColumnDataMap(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + values := []any{ + map[string]any{"k1": uint32(1)}, + map[string]any{"k2": uint32(2), "k3": uint32(3)}, + } + if err := writeColumnData(w, "Map(String, UInt32)", values); err != nil { + t.Fatal(err) + } + + r := newProtoReader(&buf) + // offsets: 1, 3 + var off [8]byte + r.Read(off[:]) + if binary.LittleEndian.Uint64(off[:]) != 1 { + t.Fatalf("first offset: want 1, got %d", binary.LittleEndian.Uint64(off[:])) + } + r.Read(off[:]) + if binary.LittleEndian.Uint64(off[:]) != 3 { + t.Fatalf("second offset: want 3, got %d", binary.LittleEndian.Uint64(off[:])) + } + // keys: 3 strings (order within a map is not deterministic, so just count) + for i := 0; i < 3; i++ { + if _, err := r.str(); err != nil { + t.Fatalf("reading key %d: %v", i, err) + } + } + // values: 3 UInt32 + for i := 0; i < 3; i++ { + if _, err := r.int32(); err != nil { + t.Fatalf("reading value %d: %v", i, err) + } + } +} + +func TestWriteColumnDataLowCardinality(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeColumnData(w, "LowCardinality(String)", []any{"x", "y", "x"}); err != nil { + t.Fatal(err) + } + if buf.Len() == 0 { + t.Fatal("expected non-empty output") + } +} + +func TestWriteBlockContentNoRows(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + cols := []Column{{Name: "x", Type: "UInt32"}} + if err := writeBlockContent(w, cols, nil, 0); err != nil { + t.Fatal(err) + } + if buf.Len() == 0 { + t.Fatal("expected non-empty output") + } +} + +// readInt64 is a tiny helper for tests reading back putInt64-encoded values. +func readInt64(r *protoReader, t *testing.T) (int64, error) { + t.Helper() + var b [8]byte + if _, err := r.Read(b[:]); err != nil { + return 0, err + } + return int64(binary.LittleEndian.Uint64(b[:])), nil //nolint:gosec +} diff --git a/pkg/backends/clickhouse/native/server/types_test.go b/pkg/backends/clickhouse/native/server/types_test.go new file mode 100644 index 000000000..469e67d9a --- /dev/null +++ b/pkg/backends/clickhouse/native/server/types_test.go @@ -0,0 +1,477 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "bytes" + "encoding/binary" + "math" + "math/big" + "testing" + "time" +) + +func TestToNumeric(t *testing.T) { + if got := toNumeric[int32](float64(7.9)); got != 7 { + t.Fatalf("float64 -> int32: got %d", got) + } + if got := toNumeric[uint32](float32(3.5)); got != 3 { + t.Fatalf("float32 -> uint32: got %d", got) + } + if got := toNumeric[int64](int64(-42)); got != -42 { + t.Fatalf("int64 passthrough: got %d", got) + } + if got := toNumeric[uint16](int32(99)); got != 99 { + t.Fatalf("int32 -> uint16: got %d", got) + } + if got := toNumeric[int32](int(123)); got != 123 { + t.Fatalf("int -> int32: got %d", got) + } + if got := toNumeric[uint64](uint64(1 << 40)); got != 1<<40 { + t.Fatalf("uint64 passthrough: got %d", got) + } + if got := toNumeric[uint32](uint32(5)); got != 5 { + t.Fatalf("uint32 passthrough: got %d", got) + } + if got := toNumeric[uint16](uint16(6)); got != 6 { + t.Fatalf("uint16 passthrough: got %d", got) + } + if got := toNumeric[uint8](uint8(7)); got != 7 { + t.Fatalf("uint8 passthrough: got %d", got) + } + if got := toNumeric[float64]("not a number"); got != 0 { + t.Fatalf("unsupported -> zero: got %v", got) + } +} + +func TestToString(t *testing.T) { + if s := toString("hello"); s != "hello" { + t.Fatalf("string passthrough: got %q", s) + } + if s := toString([]byte("bytes")); s != "bytes" { + t.Fatalf("[]byte: got %q", s) + } + if s := toString(42); s != "42" { + t.Fatalf("int via fmt.Sprint: got %q", s) + } + if s := toString(nil); s != "" { + t.Fatalf("nil via fmt.Sprint: got %q", s) + } +} + +func TestToDateTime(t *testing.T) { + ts := time.Unix(1_700_000_000, 0).UTC() + if got := toDateTime(ts); got != 1_700_000_000 { + t.Fatalf("time.Time: got %d", got) + } + if got := toDateTime(uint32(42)); got != 42 { + t.Fatalf("uint32 passthrough: got %d", got) + } + if got := toDateTime(float64(12345.0)); got != 12345 { + t.Fatalf("float64: got %d", got) + } + if got := toDateTime(int64(99)); got != 99 { + t.Fatalf("int64 fallback: got %d", got) + } +} + +func TestToDate(t *testing.T) { + ts := time.Date(1970, 1, 11, 0, 0, 0, 0, time.UTC) + if got := toDate(ts); got != 10 { + t.Fatalf("time.Time: expected 10 days, got %d", got) + } + if got := toDate(uint16(7)); got != 7 { + t.Fatalf("uint16 passthrough: got %d", got) + } + if got := toDate(int64(3)); got != 3 { + t.Fatalf("int64 fallback: got %d", got) + } +} + +func TestToDate32(t *testing.T) { + ts := time.Date(1970, 1, 6, 0, 0, 0, 0, time.UTC) + if got := toDate32(ts); got != 5 { + t.Fatalf("time.Time: expected 5 days, got %d", got) + } + if got := toDate32(int32(-100)); got != -100 { + t.Fatalf("int32 passthrough: got %d", got) + } + if got := toDate32(int64(42)); got != 42 { + t.Fatalf("int64 fallback: got %d", got) + } +} + +func TestToDateTime64(t *testing.T) { + ts := time.Unix(1_700_000_000, 500_000_000).UTC() + want := ts.UnixMilli() + if got := toDateTime64(ts); got != want { + t.Fatalf("time.Time: want %d, got %d", want, got) + } + if got := toDateTime64(int64(12345)); got != 12345 { + t.Fatalf("int64 fallback: got %d", got) + } +} + +func TestToBigInt(t *testing.T) { + bi := big.NewInt(999) + if got := toBigInt(bi); got != bi { + t.Fatal("*big.Int passthrough failed") + } + if got := toBigInt(int64(-7)); got.Int64() != -7 { + t.Fatalf("int64: got %s", got.String()) + } + if got := toBigInt(uint64(1 << 60)); got.Uint64() != 1<<60 { + t.Fatalf("uint64: got %s", got.String()) + } + if got := toBigInt(float64(1234)); got.Int64() != 1234 { + t.Fatalf("float64: got %s", got.String()) + } + // non-numeric float falls back to 0 via the SetString("NaN") guard + if got := toBigInt(math.NaN()); got.Sign() != 0 { + t.Fatalf("NaN float: got %s", got.String()) + } + if got := toBigInt("12345678901234567890"); got.String() != "12345678901234567890" { + t.Fatalf("string: got %s", got.String()) + } + if got := toBigInt("not a number"); got.Sign() != 0 { + t.Fatalf("bad string should be 0: got %s", got.String()) + } + if got := toBigInt(struct{}{}); got.Sign() != 0 { + t.Fatalf("unknown type should be 0: got %s", got.String()) + } +} + +func TestWriteBigIntPositive(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeBigInt(w, int64(0x0102030405060708), 16); err != nil { + t.Fatal(err) + } + got := buf.Bytes() + if len(got) != 16 { + t.Fatalf("expected 16 bytes, got %d", len(got)) + } + // little-endian: low byte first + want := []byte{0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01} + if !bytes.Equal(got[:8], want) { + t.Fatalf("low 8 bytes: want %x, got %x", want, got[:8]) + } + for i := 8; i < 16; i++ { + if got[i] != 0 { + t.Fatalf("byte %d should be 0, got %x", i, got[i]) + } + } +} + +func TestWriteBigIntNegative(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeBigInt(w, int64(-1), 16); err != nil { + t.Fatal(err) + } + got := buf.Bytes() + if len(got) != 16 { + t.Fatalf("expected 16 bytes, got %d", len(got)) + } + for i, b := range got { + if b != 0xFF { + t.Fatalf("byte %d: two's complement -1 should be 0xFF, got %x", i, b) + } + } +} + +func TestWriteBigInt256(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + bi, _ := new(big.Int).SetString("340282366920938463463374607431768211456", 10) // 2^128 + if err := writeBigInt(w, bi, 32); err != nil { + t.Fatal(err) + } + if buf.Len() != 32 { + t.Fatalf("expected 32 bytes, got %d", buf.Len()) + } + // 2^128 has bit 128 set — byte index 16 should be 0x01, others 0. + got := buf.Bytes() + if got[16] != 0x01 { + t.Fatalf("byte 16 should be 0x01, got %x", got[16]) + } +} + +func TestWriteUUID(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeUUID(w, "00112233-4455-6677-8899-aabbccddeeff"); err != nil { + t.Fatal(err) + } + if buf.Len() != 16 { + t.Fatalf("expected 16 bytes, got %d", buf.Len()) + } + // first half: BE 0x0011223344556677 stored as LE + got := buf.Bytes() + if binary.LittleEndian.Uint64(got[0:8]) != 0x0011223344556677 { + t.Fatalf("first half wrong: %x", got[0:8]) + } + if binary.LittleEndian.Uint64(got[8:16]) != 0x8899aabbccddeeff { + t.Fatalf("second half wrong: %x", got[8:16]) + } +} + +func TestWriteUUIDInvalid(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeUUID(w, "not-a-uuid"); err != nil { + t.Fatal(err) + } + if buf.Len() != 16 { + t.Fatalf("expected 16 zero bytes, got %d", buf.Len()) + } + for i, b := range buf.Bytes() { + if b != 0 { + t.Fatalf("byte %d expected 0, got %x", i, b) + } + } +} + +func TestWriteIPv4(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeIPv4(w, "1.2.3.4"); err != nil { + t.Fatal(err) + } + got := buf.Bytes() + // little-endian order: [4,3,2,1] + want := []byte{4, 3, 2, 1} + if !bytes.Equal(got, want) { + t.Fatalf("want %v, got %v", want, got) + } +} + +func TestWriteIPv4Invalid(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeIPv4(w, "not-an-ip"); err != nil { + t.Fatal(err) + } + if buf.Len() != 4 { + t.Fatalf("expected 4 bytes, got %d", buf.Len()) + } +} + +func TestWriteIPv6(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeIPv6(w, "::1"); err != nil { + t.Fatal(err) + } + got := buf.Bytes() + if len(got) != 16 { + t.Fatalf("expected 16 bytes, got %d", len(got)) + } + if got[15] != 1 { + t.Fatalf("last byte of ::1 should be 1, got %x", got[15]) + } + for i := 0; i < 15; i++ { + if got[i] != 0 { + t.Fatalf("byte %d of ::1 should be 0, got %x", i, got[i]) + } + } +} + +func TestWriteIPv6Invalid(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeIPv6(w, "garbage"); err != nil { + t.Fatal(err) + } + if buf.Len() != 16 { + t.Fatalf("expected 16 bytes, got %d", buf.Len()) + } +} + +func TestWriteValueTypes(t *testing.T) { + cases := []struct { + colType string + v any + want []byte + }{ + {"UInt8", uint8(0xAB), []byte{0xAB}}, + {"Bool", uint8(1), []byte{0x01}}, + {"UInt16", uint16(0x1234), []byte{0x34, 0x12}}, + {"UInt32", uint32(0x01020304), []byte{0x04, 0x03, 0x02, 0x01}}, + {"UInt64", uint64(0x0102030405060708), []byte{0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01}}, + {"Int8", int64(-2), []byte{0xFE}}, + {"Int16", int64(-1), []byte{0xFF, 0xFF}}, + {"Int32", int64(-1), []byte{0xFF, 0xFF, 0xFF, 0xFF}}, + {"Int64", int64(-1), []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}}, + {"Enum8", int64(-2), []byte{0xFE}}, + {"Enum8('a'=1)", int64(1), []byte{0x01}}, + {"Enum16('a'=1)", int64(1), []byte{0x01, 0x00}}, + {"Decimal32(2)", int64(12345), []byte{0x39, 0x30, 0x00, 0x00}}, + {"Decimal64(2)", int64(1), []byte{0x01, 0, 0, 0, 0, 0, 0, 0}}, + {"Decimal(9,2)", int64(2), []byte{0x02, 0, 0, 0, 0, 0, 0, 0}}, + } + for _, c := range cases { + t.Run(c.colType, func(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeValue(w, c.colType, c.v); err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf.Bytes(), c.want) { + t.Fatalf("want %x, got %x", c.want, buf.Bytes()) + } + }) + } +} + +func TestWriteValueFloat(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeValue(w, "Float32", float64(1.5)); err != nil { + t.Fatal(err) + } + got := math.Float32frombits(binary.LittleEndian.Uint32(buf.Bytes())) + if got != 1.5 { + t.Fatalf("want 1.5, got %v", got) + } + + buf.Reset() + w = newProtoWriter(&buf) + if err := writeValue(w, "Float64", float64(2.5)); err != nil { + t.Fatal(err) + } + gotD := math.Float64frombits(binary.LittleEndian.Uint64(buf.Bytes())) + if gotD != 2.5 { + t.Fatalf("want 2.5, got %v", gotD) + } +} + +func TestWriteValueTemporal(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeValue(w, "DateTime", uint32(1_700_000_000)); err != nil { + t.Fatal(err) + } + if v := binary.LittleEndian.Uint32(buf.Bytes()); v != 1_700_000_000 { + t.Fatalf("DateTime: got %d", v) + } + + buf.Reset() + w = newProtoWriter(&buf) + if err := writeValue(w, "DateTime64(3)", int64(123456789)); err != nil { + t.Fatal(err) + } + if v := int64(binary.LittleEndian.Uint64(buf.Bytes())); v != 123456789 { + t.Fatalf("DateTime64: got %d", v) + } + + buf.Reset() + w = newProtoWriter(&buf) + if err := writeValue(w, "Date", uint16(100)); err != nil { + t.Fatal(err) + } + if v := binary.LittleEndian.Uint16(buf.Bytes()); v != 100 { + t.Fatalf("Date: got %d", v) + } + + buf.Reset() + w = newProtoWriter(&buf) + if err := writeValue(w, "Date32", int32(-365)); err != nil { + t.Fatal(err) + } + if v := int32(binary.LittleEndian.Uint32(buf.Bytes())); v != -365 { + t.Fatalf("Date32: got %d", v) + } +} + +func TestWriteValueBigAndAddress(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeValue(w, "Int128", int64(1)); err != nil { + t.Fatal(err) + } + if buf.Len() != 16 { + t.Fatalf("Int128 should be 16 bytes, got %d", buf.Len()) + } + + buf.Reset() + w = newProtoWriter(&buf) + if err := writeValue(w, "UInt256", uint64(1)); err != nil { + t.Fatal(err) + } + if buf.Len() != 32 { + t.Fatalf("UInt256 should be 32 bytes, got %d", buf.Len()) + } + + buf.Reset() + w = newProtoWriter(&buf) + if err := writeValue(w, "Decimal128(10)", int64(1)); err != nil { + t.Fatal(err) + } + if buf.Len() != 16 { + t.Fatalf("Decimal128 should be 16 bytes, got %d", buf.Len()) + } + + buf.Reset() + w = newProtoWriter(&buf) + if err := writeValue(w, "Decimal256(20)", int64(1)); err != nil { + t.Fatal(err) + } + if buf.Len() != 32 { + t.Fatalf("Decimal256 should be 32 bytes, got %d", buf.Len()) + } + + buf.Reset() + w = newProtoWriter(&buf) + if err := writeValue(w, "UUID", "00112233-4455-6677-8899-aabbccddeeff"); err != nil { + t.Fatal(err) + } + if buf.Len() != 16 { + t.Fatalf("UUID should be 16 bytes, got %d", buf.Len()) + } + + buf.Reset() + w = newProtoWriter(&buf) + if err := writeValue(w, "IPv4", "1.2.3.4"); err != nil { + t.Fatal(err) + } + if buf.Len() != 4 { + t.Fatalf("IPv4 should be 4 bytes, got %d", buf.Len()) + } + + buf.Reset() + w = newProtoWriter(&buf) + if err := writeValue(w, "IPv6", "::1"); err != nil { + t.Fatal(err) + } + if buf.Len() != 16 { + t.Fatalf("IPv6 should be 16 bytes, got %d", buf.Len()) + } +} + +func TestWriteValueDefaultString(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + if err := writeValue(w, "String", "hi"); err != nil { + t.Fatal(err) + } + r := newProtoReader(&buf) + s, _ := r.str() + if s != "hi" { + t.Fatalf("want %q, got %q", "hi", s) + } +} From fce45b0030d5938e486692f1420e5e4e49b903c7 Mon Sep 17 00:00:00 2001 From: Chris Randles Date: Mon, 20 Apr 2026 23:47:18 -0400 Subject: [PATCH 7/8] ci: suppress codeql go/sql-injection for proxy passthrough Signed-off-by: Chris Randles --- .github/codeql/codeql-config.yml | 11 +++++++++++ .github/workflows/codeql-analysis.yml | 5 +---- 2 files changed, 12 insertions(+), 4 deletions(-) create mode 100644 .github/codeql/codeql-config.yml diff --git a/.github/codeql/codeql-config.yml b/.github/codeql/codeql-config.yml new file mode 100644 index 000000000..eb2bfb465 --- /dev/null +++ b/.github/codeql/codeql-config.yml @@ -0,0 +1,11 @@ +name: "trickster CodeQL config" + +query-filters: + # trickster is a passthrough proxy; backends forward client SQL verbatim + # to the upstream. companion to the gosec G701 suppression in code. + - exclude: + id: go/sql-injection + +paths-ignore: + - vendor + - testdata diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index bd46781f4..0183408bb 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -45,10 +45,7 @@ jobs: uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} - # If you wish to specify custom queries, you can do so here or in a config file. - # By default, queries listed here will override any specified in a config file. - # Prefix the list here with "+" to use these queries and those in the config file. - # queries: ./path/to/local/query, your-org/your-repo/queries@main + config-file: ./.github/codeql/codeql-config.yml # ℹ️ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl From 623e3bc49e80ee4ffbde7a77a2f5e84a59bb9231 Mon Sep 17 00:00:00 2001 From: Chris Randles Date: Thu, 23 Apr 2026 07:18:29 -0400 Subject: [PATCH 8/8] fix: [clickhouse] datetime(tz) wire size and skipClientInfo short reads Signed-off-by: Chris Randles --- pkg/backends/clickhouse/model/native_test.go | 35 +++++ .../clickhouse/model/unmarshal_native.go | 14 +- .../clickhouse/native/server/handler_test.go | 3 + .../clickhouse/native/server/query.go | 9 +- .../clickhouse/native/server/query_test.go | 124 ++++++++++++++++++ 5 files changed, 178 insertions(+), 7 deletions(-) create mode 100644 pkg/backends/clickhouse/native/server/query_test.go diff --git a/pkg/backends/clickhouse/model/native_test.go b/pkg/backends/clickhouse/model/native_test.go index 85a4b730b..ae05bf2e4 100644 --- a/pkg/backends/clickhouse/model/native_test.go +++ b/pkg/backends/clickhouse/model/native_test.go @@ -19,6 +19,7 @@ package model import ( "bufio" "bytes" + "encoding/binary" "errors" "io" "math" @@ -238,6 +239,40 @@ func TestWriteNativeValue_DateUnparsable(t *testing.T) { } } +func TestReadValueAsString_DateTimeWithTimezone(t *testing.T) { + // DateTime('UTC') is a 4-byte UInt32 on the ClickHouse native wire; + // the timezone is metadata only. Reading 8 bytes (as DateTime64 would) + // desynchronizes the column stream and corrupts every subsequent value. + for _, typ := range []string{ + "DateTime('UTC')", + "DateTime('Asia/Tokyo')", + "DateTime('America/New_York')", + } { + b := new(bytes.Buffer) + // 1577836800 = 2020-01-01 00:00:00 UTC, written as 4-byte LE UInt32. + _ = binary.Write(b, binary.LittleEndian, uint32(1577836800)) + // Sentinel UInt32 so we can detect over-read. + _ = binary.Write(b, binary.LittleEndian, uint32(0xDEADBEEF)) + + br := bufio.NewReader(bytes.NewReader(b.Bytes())) + got, err := readValueAsString(br, typ) + if err != nil { + t.Fatalf("%s: readValueAsString: %v", typ, err) + } + if got != "2020-01-01 00:00:00" { + t.Fatalf("%s: want 2020-01-01 00:00:00, got %q", typ, got) + } + // If readValueAsString consumed more than 4 bytes, the sentinel is lost. + next, err := readFixed(br, 4) + if err != nil { + t.Fatalf("%s: sentinel read: %v", typ, err) + } + if binary.LittleEndian.Uint32(next) != 0xDEADBEEF { + t.Fatalf("%s: sentinel mismatch: got %#x", typ, binary.LittleEndian.Uint32(next)) + } + } +} + func TestReadValueAsString_ShortBuffers(t *testing.T) { types := []string{TypeUInt16, TypeUInt32, TypeUInt64, TypeInt16, TypeInt32, TypeInt64, TypeFloat32, TypeFloat64, TypeDateTime, TypeDate, diff --git a/pkg/backends/clickhouse/model/unmarshal_native.go b/pkg/backends/clickhouse/model/unmarshal_native.go index 1e5abd704..dc53db979 100644 --- a/pkg/backends/clickhouse/model/unmarshal_native.go +++ b/pkg/backends/clickhouse/model/unmarshal_native.go @@ -23,6 +23,7 @@ import ( "io" "math" "strconv" + "strings" "time" "github.com/trickstercache/trickster/v2/pkg/timeseries" @@ -281,17 +282,22 @@ func readValueAsString(r *bufio.Reader, typ string) (string, error) { case TypeString: return readString(r) default: - // For complex/unknown types, try reading as string - if len(typ) > 8 && typ[:8] == TypeDateTime { - // DateTime64, DateTime('tz'), etc. — 8 bytes + if strings.HasPrefix(typ, "DateTime64") { b, err := readFixed(r, 8) if err != nil { return "", err } v := int64(binary.LittleEndian.Uint64(b)) - // DateTime64 with default precision 3 (milliseconds) return time.Unix(v/1000, (v%1000)*1000000).UTC().Format("2006-01-02 15:04:05.000"), nil } + if strings.HasPrefix(typ, "DateTime(") { + b, err := readFixed(r, 4) + if err != nil { + return "", err + } + ts := binary.LittleEndian.Uint32(b) + return time.Unix(int64(ts), 0).UTC().Format("2006-01-02 15:04:05"), nil + } return readString(r) } } diff --git a/pkg/backends/clickhouse/native/server/handler_test.go b/pkg/backends/clickhouse/native/server/handler_test.go index 1b555dbf3..58be2431f 100644 --- a/pkg/backends/clickhouse/native/server/handler_test.go +++ b/pkg/backends/clickhouse/native/server/handler_test.go @@ -154,6 +154,9 @@ func TestHandlerQuery(t *testing.T) { cr.str() cr.uvarint() + // Send addendum (quota key, required for revision >= 54458) + cw.putStr("") + // Send ClientQuery sendTestQuery(t, cw, "SELECT 1") diff --git a/pkg/backends/clickhouse/native/server/query.go b/pkg/backends/clickhouse/native/server/query.go index fa84bf8c3..f6c93e713 100644 --- a/pkg/backends/clickhouse/native/server/query.go +++ b/pkg/backends/clickhouse/native/server/query.go @@ -16,7 +16,10 @@ package server -import "fmt" +import ( + "fmt" + "io" +) // ClientQueryMsg contains the relevant fields from a ClientQuery packet. // We parse enough to extract the SQL but skip over the full client info @@ -100,7 +103,7 @@ func skipClientInfo(r *protoReader, revision uint64) error { // initial query start time (revision >= 54449) if revision >= RevisionInitialQueryStart { var b [8]byte - if _, err := r.Read(b[:]); err != nil { + if _, err := io.ReadFull(r, b[:]); err != nil { return err } } @@ -147,7 +150,7 @@ func skipClientInfo(r *protoReader, revision uint64) error { if hasSpan != 0 { // trace id (16) + span id (8) + trace state (string) + flags (1) var skip [24]byte - if _, err := r.Read(skip[:]); err != nil { + if _, err := io.ReadFull(r, skip[:]); err != nil { return err } if _, err := r.str(); err != nil { diff --git a/pkg/backends/clickhouse/native/server/query_test.go b/pkg/backends/clickhouse/native/server/query_test.go new file mode 100644 index 000000000..9f34eb5f3 --- /dev/null +++ b/pkg/backends/clickhouse/native/server/query_test.go @@ -0,0 +1,124 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "bytes" + "io" + "testing" +) + +// oneByteReader returns at most one byte per Read, forcing the embedded +// bufio.Reader to short-read on multi-byte fills. This simulates the TCP +// fragmentation case that exposes bare r.Read calls on the wire. +type oneByteReader struct { + data []byte + pos int +} + +func (r *oneByteReader) Read(p []byte) (int, error) { + if r.pos >= len(r.data) { + return 0, io.EOF + } + if len(p) == 0 { + return 0, nil + } + p[0] = r.data[r.pos] + r.pos++ + return 1, nil +} + +// writeClientInfo writes a client-info block matching the field layout +// skipClientInfo reads at ServerRevision. hasSpan controls whether the +// OpenTelemetry 24-byte trace/span block is emitted. +func writeClientInfo(w *protoWriter, hasSpan bool) { + w.putByte(1) // query kind = initial + w.putStr("user") + w.putStr("iqid") + w.putStr("127.0.0.1") + w.putInt64(1700000000) // initial_query_start_time — 8 bytes + w.putByte(1) // interface type = TCP + w.putStr("os-user") + w.putStr("host") + w.putStr("client") + w.putUvarint(1) // version major + w.putUvarint(0) // version minor + w.putUvarint(ServerRevision) // client protocol revision + w.putStr("") // quota key + w.putUvarint(0) // distributed depth + w.putUvarint(0) // version patch + if hasSpan { + w.putByte(1) + // trace id (16) + span id (8) = 24 bytes + span := make([]byte, 24) + for i := range span { + span[i] = byte(i + 1) + } + w.Write(span) + w.putStr("trace-state") + w.putByte(0) // flags + } else { + w.putByte(0) + } + w.putUvarint(0) // parallel replicas + w.putUvarint(0) + w.putUvarint(0) +} + +// TestSkipClientInfo_ShortReadStartTime verifies that skipClientInfo +// correctly consumes the 8-byte initial_query_start_time even when the +// underlying reader delivers one byte at a time. With a bare r.Read, +// the bufio.Reader short-reads and desynchronizes the stream. +func TestSkipClientInfo_ShortReadStartTime(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + writeClientInfo(w, false) + w.putByte(0x5A) // sentinel immediately after client info + + r := newProtoReader(&oneByteReader{data: buf.Bytes()}) + if err := skipClientInfo(r, ServerRevision); err != nil { + t.Fatalf("skipClientInfo: %v", err) + } + sentinel, err := r.ReadByte() + if err != nil { + t.Fatalf("sentinel read: %v", err) + } + if sentinel != 0x5A { + t.Fatalf("stream desync: want sentinel %#x, got %#x", 0x5A, sentinel) + } +} + +// TestSkipClientInfo_ShortReadOpenTelemetry covers the 24-byte OTel +// trace/span block, which is also bare-Read. +func TestSkipClientInfo_ShortReadOpenTelemetry(t *testing.T) { + var buf bytes.Buffer + w := newProtoWriter(&buf) + writeClientInfo(w, true) + w.putByte(0x5A) + + r := newProtoReader(&oneByteReader{data: buf.Bytes()}) + if err := skipClientInfo(r, ServerRevision); err != nil { + t.Fatalf("skipClientInfo: %v", err) + } + sentinel, err := r.ReadByte() + if err != nil { + t.Fatalf("sentinel read: %v", err) + } + if sentinel != 0x5A { + t.Fatalf("stream desync: want sentinel %#x, got %#x", 0x5A, sentinel) + } +}