diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 184adfb8..08874b95 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -28,6 +28,9 @@ jobs: - name: Install tree-sitter-cli run: npm install -g tree-sitter-cli + - name: Install ODBC + run: sudo apt-get install -y unixodbc-dev + - name: Install Rust uses: dtolnay/rust-toolchain@stable diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 0937dc62..7d1aef25 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -32,6 +32,9 @@ jobs: - name: Install LLVM run: sudo apt-get install -y llvm + - name: Install ODBC + run: sudo apt-get install -y unixodbc-dev + - name: Install Rust uses: dtolnay/rust-toolchain@stable diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 5d85746a..a64c898b 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -33,6 +33,9 @@ jobs: - name: Install LLVM run: sudo apt-get install -y llvm + - name: Install ODBC + run: sudo apt-get install -y unixodbc-dev + - name: Install Rust uses: dtolnay/rust-toolchain@stable diff --git a/Cargo.lock b/Cargo.lock index 32029591..29f8c42a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -64,6 +64,31 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "android-activity" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f2a1bb052857d5dd49572219344a7332b31b76405648eabac5bc68978251bcd" +dependencies = [ + "android-properties", + "bitflags 2.11.0", + "cc", + "jni 0.22.4", + "libc", + "log", + "ndk", + "ndk-context", + "ndk-sys", + "num_enum", + "thiserror 2.0.18", +] + +[[package]] +name = "android-properties" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7eb209b1518d6bb87b283c20095f5228ecda460da70b44f0802523dea6da04" + [[package]] name = "android_system_properties" version = "0.1.5" @@ -557,6 +582,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block2" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c132eebf10f5cad5289222520a4a058514204aed6d791f1cf4fe8088b82d15f" +dependencies = [ + "objc2", +] + [[package]] name = "borrow-or-share" version = "0.2.4" @@ -688,6 +722,20 @@ dependencies = [ "serde", ] +[[package]] +name = "calloop" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b99da2f8558ca23c71f4fd15dc57c906239752dd27ff3c00a1d56b685b7cbfec" +dependencies = [ + "bitflags 2.11.0", + "log", + "polling", + "rustix 0.38.44", + "slab", + "thiserror 1.0.69", +] + [[package]] name = "cast" version = "0.3.0" @@ -1081,6 +1129,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "cursor-icon" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f27ae1dd37df86211c42e150270f82743308803d90a6f6e6651cd730d5e1732f" + [[package]] name = "dashmap" version = "5.5.3" @@ -1149,6 +1203,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "dispatch" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd0c93bb4b0c6d9b77f4435b0ae98c24d17f1c45b2ff844c6151a07256ca923b" + [[package]] name = "displaydoc" version = "0.2.5" @@ -1169,6 +1229,12 @@ dependencies = [ "libloading", ] +[[package]] +name = "dpi" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b14ccef22fc6f5a8f4d7d768562a182c04ce9a3b3157b91390b52ddfdf1a76" + [[package]] name = "duckdb" version = "1.4.4" @@ -1495,7 +1561,7 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4" dependencies = [ - "rustix", + "rustix 1.1.4", "windows-sys 0.59.0", ] @@ -1661,6 +1727,7 @@ dependencies = [ "csscolorparser", "duckdb", "jsonschema", + "odbc-api", "palette", "plotters", "polars", @@ -1674,7 +1741,9 @@ dependencies = [ "serde", "serde_json", "sprintf", + "tempfile", "thiserror 1.0.69", + "toml_edit 0.22.27", "tree-sitter", "tree-sitter-ggsql", "ureq", @@ -1851,6 +1920,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "hex" version = "0.4.3" @@ -2196,19 +2271,68 @@ dependencies = [ "cesu8", "cfg-if", "combine", - "jni-sys", + "jni-sys 0.3.0", "log", "thiserror 1.0.69", "walkdir", "windows-sys 0.45.0", ] +[[package]] +name = "jni" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efd9a482cf3a427f00d6b35f14332adc7902ce91efb778580e180ff90fa3498" +dependencies = [ + "cfg-if", + "combine", + "jni-macros", + "jni-sys 0.4.1", + "log", + "simd_cesu8", + "thiserror 2.0.18", + "walkdir", + "windows-link", +] + +[[package]] +name = "jni-macros" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a00109accc170f0bdb141fed3e393c565b6f5e072365c3bd58f5b062591560a3" +dependencies = [ + "proc-macro2", + "quote", + "rustc_version", + "simd_cesu8", + "syn 2.0.117", +] + [[package]] name = "jni-sys" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +[[package]] +name = "jni-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn 2.0.117", +] + [[package]] name = "jobserver" version = "0.1.34" @@ -2395,6 +2519,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + [[package]] name = "linux-raw-sys" version = "0.12.1" @@ -2511,6 +2641,36 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "ndk" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3f42e7bbe13d351b6bead8286a43aac9534b82bd3cc43e47037f012ebfd62d4" +dependencies = [ + "bitflags 2.11.0", + "jni-sys 0.3.0", + "log", + "ndk-sys", + "num_enum", + "raw-window-handle", + "thiserror 1.0.69", +] + +[[package]] +name = "ndk-context" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" + +[[package]] +name = "ndk-sys" +version = "0.6.0+11769913" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee6cda3051665f1fb8d9e08fc35c96d5a244fb1be711a03b71118828afc9a873" +dependencies = [ + "jni-sys 0.3.0", +] + [[package]] name = "now" version = "0.1.3" @@ -2609,6 +2769,96 @@ dependencies = [ "libm", ] +[[package]] +name = "num_enum" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0bca838442ec211fa11de3a8b0e0e8f3a4522575b5c4c06ed722e005036f26" +dependencies = [ + "num_enum_derive", + "rustversion", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "680998035259dcfcafe653688bf2aa6d3e2dc05e98be6ab46afb089dc84f1df8" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "objc-sys" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdb91bdd390c7ce1a8607f35f3ca7151b65afc0ff5ff3b34fa350f7d7c7e4310" + +[[package]] +name = "objc2" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46a785d4eeff09c14c487497c162e92766fbb3e4059a71840cecc03d9a50b804" +dependencies = [ + "objc-sys", + "objc2-encode", +] + +[[package]] +name = "objc2-app-kit" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4e89ad9e3d7d297152b17d39ed92cd50ca8063a89a9fa569046d41568891eff" +dependencies = [ + "bitflags 2.11.0", + "block2", + "libc", + "objc2", + "objc2-core-data", + "objc2-core-image", + "objc2-foundation", + "objc2-quartz-core", +] + +[[package]] +name = "objc2-cloud-kit" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74dd3b56391c7a0596a295029734d3c1c5e7e510a4cb30245f8221ccea96b009" +dependencies = [ + "bitflags 2.11.0", + "block2", + "objc2", + "objc2-core-location", + "objc2-foundation", +] + +[[package]] +name = "objc2-contacts" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5ff520e9c33812fd374d8deecef01d4a840e7b41862d849513de77e44aa4889" +dependencies = [ + "block2", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-core-data" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "617fbf49e071c178c0b24c080767db52958f716d9eabdf0890523aeae54773ef" +dependencies = [ + "bitflags 2.11.0", + "block2", + "objc2", + "objc2-foundation", +] + [[package]] name = "objc2-core-foundation" version = "0.3.2" @@ -2618,6 +2868,96 @@ dependencies = [ "bitflags 2.11.0", ] +[[package]] +name = "objc2-core-image" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55260963a527c99f1819c4f8e3b47fe04f9650694ef348ffd2227e8196d34c80" +dependencies = [ + "block2", + "objc2", + "objc2-foundation", + "objc2-metal", +] + +[[package]] +name = "objc2-core-location" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "000cfee34e683244f284252ee206a27953279d370e309649dc3ee317b37e5781" +dependencies = [ + "block2", + "objc2", + "objc2-contacts", + "objc2-foundation", +] + +[[package]] +name = "objc2-encode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" + +[[package]] +name = "objc2-foundation" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ee638a5da3799329310ad4cfa62fbf045d5f56e3ef5ba4149e7452dcf89d5a8" +dependencies = [ + "bitflags 2.11.0", + "block2", + "dispatch", + "libc", + "objc2", +] + +[[package]] +name = "objc2-link-presentation" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1a1ae721c5e35be65f01a03b6d2ac13a54cb4fa70d8a5da293d7b0020261398" +dependencies = [ + "block2", + "objc2", + "objc2-app-kit", + "objc2-foundation", +] + +[[package]] +name = "objc2-metal" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd0cba1276f6023976a406a14ffa85e1fdd19df6b0f737b063b95f6c8c7aadd6" +dependencies = [ + "bitflags 2.11.0", + "block2", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-quartz-core" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e42bee7bff906b14b167da2bac5efe6b6a07e6f7c0a21a7308d40c960242dc7a" +dependencies = [ + "bitflags 2.11.0", + "block2", + "objc2", + "objc2-foundation", + "objc2-metal", +] + +[[package]] +name = "objc2-symbols" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a684efe3dec1b305badae1a28f6555f6ddd3bb2c2267896782858d5a78404dc" +dependencies = [ + "objc2", + "objc2-foundation", +] + [[package]] name = "objc2-system-configuration" version = "0.3.2" @@ -2627,6 +2967,51 @@ dependencies = [ "objc2-core-foundation", ] +[[package]] +name = "objc2-ui-kit" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8bb46798b20cd6b91cbd113524c490f1686f4c4e8f49502431415f3512e2b6f" +dependencies = [ + "bitflags 2.11.0", + "block2", + "objc2", + "objc2-cloud-kit", + "objc2-core-data", + "objc2-core-image", + "objc2-core-location", + "objc2-foundation", + "objc2-link-presentation", + "objc2-quartz-core", + "objc2-symbols", + "objc2-uniform-type-identifiers", + "objc2-user-notifications", +] + +[[package]] +name = "objc2-uniform-type-identifiers" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44fa5f9748dbfe1ca6c0b79ad20725a11eca7c2218bceb4b005cb1be26273bfe" +dependencies = [ + "block2", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-user-notifications" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76cfcbf642358e8689af64cee815d139339f3ed8ad05103ed5eaf73db8d84cb3" +dependencies = [ + "bitflags 2.11.0", + "block2", + "objc2", + "objc2-core-location", + "objc2-foundation", +] + [[package]] name = "object" version = "0.37.3" @@ -2671,6 +3056,26 @@ dependencies = [ "web-time", ] +[[package]] +name = "odbc-api" +version = "13.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44e14665455e2817ac5b0dd9f65a3dc97e76e8f85eeac6e4301b7cf9da451884" +dependencies = [ + "atoi", + "log", + "odbc-sys", + "thiserror 2.0.18", + "widestring", + "winit", +] + +[[package]] +name = "odbc-sys" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ecdb20f7c165083ad1bc9f55122f677725e257716a5bc83e5413d5654b7d6f1" + [[package]] name = "once_cell" version = "1.21.4" @@ -2695,6 +3100,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "orbclient" +version = "0.3.51" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59aed3b33578edcfa1bc96a321d590d31832b6ad55a26f0313362ce687e9abd6" +dependencies = [ + "libc", + "libredox", +] + [[package]] name = "outref" version = "0.5.2" @@ -2884,6 +3299,26 @@ dependencies = [ "uncased", ] +[[package]] +name = "pin-project" +version = "1.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "pin-project-lite" version = "0.2.17" @@ -3517,6 +3952,20 @@ dependencies = [ "version_check", ] +[[package]] +name = "polling" +version = "3.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0e4f59085d47d8241c88ead0f274e8a0cb551f3625263c05eb8dd897c34218" +dependencies = [ + "cfg-if", + "concurrent-queue", + "hermit-abi", + "pin-project-lite", + "rustix 1.1.4", + "windows-sys 0.61.2", +] + [[package]] name = "portable-atomic" version = "1.13.1" @@ -3600,7 +4049,7 @@ version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" dependencies = [ - "toml_edit", + "toml_edit 0.25.4+spec-1.1.0", ] [[package]] @@ -3907,6 +4356,12 @@ dependencies = [ "bitflags 2.11.0", ] +[[package]] +name = "raw-window-handle" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" + [[package]] name = "rayon" version = "1.11.0" @@ -3947,6 +4402,15 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -4251,6 +4715,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags 2.11.0", + "errno", + "libc", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", +] + [[package]] name = "rustix" version = "1.1.4" @@ -4260,7 +4737,7 @@ dependencies = [ "bitflags 2.11.0", "errno", "libc", - "linux-raw-sys", + "linux-raw-sys 0.12.1", "windows-sys 0.61.2", ] @@ -4310,7 +4787,7 @@ checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" dependencies = [ "core-foundation 0.10.1", "core-foundation-sys", - "jni", + "jni 0.21.1", "log", "once_cell", "rustls", @@ -4559,6 +5036,16 @@ dependencies = [ "value-trait", ] +[[package]] +name = "simd_cesu8" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94f90157bb87cddf702797c5dadfa0be7d266cdf49e22da2fcaa32eff75b2c33" +dependencies = [ + "rustc_version", + "simdutf8", +] + [[package]] name = "simdutf8" version = "0.1.5" @@ -4592,6 +5079,15 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "smol_str" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd538fb6910ac1099850255cf94a94df6551fbdd602454387d0adb2d1ca6dead" +dependencies = [ + "serde", +] + [[package]] name = "snap" version = "1.1.1" @@ -4821,7 +5317,7 @@ dependencies = [ "fastrand", "getrandom 0.4.2", "once_cell", - "rustix", + "rustix 1.1.4", "windows-sys 0.61.2", ] @@ -5008,6 +5504,12 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml_datetime" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" + [[package]] name = "toml_datetime" version = "1.0.0+spec-1.1.0" @@ -5017,6 +5519,18 @@ dependencies = [ "serde_core", ] +[[package]] +name = "toml_edit" +version = "0.22.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +dependencies = [ + "indexmap", + "toml_datetime 0.6.11", + "toml_write", + "winnow", +] + [[package]] name = "toml_edit" version = "0.25.4+spec-1.1.0" @@ -5024,7 +5538,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7193cbd0ce53dc966037f54351dbbcf0d5a642c7f0038c382ef9e677ce8c13f2" dependencies = [ "indexmap", - "toml_datetime", + "toml_datetime 1.0.0+spec-1.1.0", "toml_parser", "winnow", ] @@ -5038,6 +5552,12 @@ dependencies = [ "winnow", ] +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + [[package]] name = "tower" version = "0.5.3" @@ -5639,6 +6159,12 @@ dependencies = [ "web-sys", ] +[[package]] +name = "widestring" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72069c3113ab32ab29e5584db3c6ec55d416895e60715417b5b883a357c3e471" + [[package]] name = "winapi" version = "0.3.9" @@ -5960,6 +6486,46 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" +[[package]] +name = "winit" +version = "0.30.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6755fa58a9f8350bd1e472d4c3fcc25f824ec358933bba33306d0b63df5978d" +dependencies = [ + "android-activity", + "atomic-waker", + "bitflags 2.11.0", + "block2", + "calloop", + "cfg_aliases", + "concurrent-queue", + "core-foundation 0.9.4", + "core-graphics", + "cursor-icon", + "dpi", + "js-sys", + "libc", + "ndk", + "objc2", + "objc2-app-kit", + "objc2-foundation", + "objc2-ui-kit", + "orbclient", + "pin-project", + "raw-window-handle", + "redox_syscall 0.4.1", + "rustix 0.38.44", + "smol_str", + "tracing", + "unicode-segmentation", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "web-time", + "windows-sys 0.52.0", + "xkbcommon-dl", +] + [[package]] name = "winnow" version = "0.7.15" @@ -6088,9 +6654,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" dependencies = [ "libc", - "rustix", + "rustix 1.1.4", +] + +[[package]] +name = "xkbcommon-dl" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039de8032a9a8856a6be89cea3e5d12fdd82306ab7c94d74e6deab2460651c5" +dependencies = [ + "bitflags 2.11.0", + "dlib", + "log", + "once_cell", + "xkeysym", ] +[[package]] +name = "xkeysym" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9cc00251562a284751c9973bace760d86c0276c471b4be569fe6b068ee97a56" + [[package]] name = "xxhash-rust" version = "0.8.15" diff --git a/Cargo.toml b/Cargo.toml index bba38f7f..a7320945 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,10 @@ arrow = { version = "56", default-features = false, features = ["ipc"] } postgres = "0.19" rusqlite = { version = "0.38", features = ["bundled", "chrono", "functions", "window"] } +# ODBC +odbc-api = "13" +toml_edit = "0.22" + # Writers plotters = "0.3" diff --git a/ggsql-jupyter/Cargo.toml b/ggsql-jupyter/Cargo.toml index 5b31ef8f..1bcc08d4 100644 --- a/ggsql-jupyter/Cargo.toml +++ b/ggsql-jupyter/Cargo.toml @@ -56,6 +56,13 @@ hex = "0.4" # UUID for message IDs uuid = { version = "1.0", features = ["v4"] } +[features] +default = ["all-readers"] +all-readers = ["sqlite", "odbc", "duckdb"] +odbc = ["ggsql/odbc"] +sqlite = ["ggsql/sqlite"] +duckdb = ["ggsql/duckdb"] + [dev-dependencies] # Test utilities tokio-test = "0.4" diff --git a/ggsql-jupyter/src/connection.rs b/ggsql-jupyter/src/connection.rs new file mode 100644 index 00000000..c55385ae --- /dev/null +++ b/ggsql-jupyter/src/connection.rs @@ -0,0 +1,209 @@ +//! Database schema introspection for the Positron Connections pane. +//! +//! Delegates introspection SQL to the reader's `SqlDialect`, which provides +//! backend-specific queries (e.g. `information_schema` for DuckDB/PostgreSQL, +//! `sqlite_master` / `PRAGMA` for SQLite). + +use crate::util::find_column; +use ggsql::reader::Reader; +use serde::Serialize; +use serde_json::Value; + +/// An object in the schema hierarchy (catalog, schema, table, or view). +#[derive(Debug, Serialize)] +pub struct ObjectSchema { + pub name: String, + pub kind: String, +} + +/// A field (column) in a table. +#[derive(Debug, Serialize)] +pub struct FieldSchema { + pub name: String, + pub dtype: String, +} + +/// List objects at the given path depth. +/// +/// Path semantics (catalog → schema → table): +/// - `[]` → list catalogs +/// - `[catalog]` → list schemas in that catalog +/// - `[catalog, schema]` → list tables and views +pub fn list_objects(reader: &dyn Reader, path: &[String]) -> Result, String> { + match path.len() { + 0 => list_catalogs(reader), + 1 => list_schemas(reader, &path[0]), + 2 => list_tables(reader, &path[0], &path[1]), + _ => Ok(vec![]), + } +} + +/// List fields (columns) for the object at the given path. +/// +/// - `[catalog, schema, table]` → list columns +pub fn list_fields(reader: &dyn Reader, path: &[String]) -> Result, String> { + if path.len() == 3 { + list_columns(reader, &path[0], &path[1], &path[2]) + } else { + Ok(vec![]) + } +} + +/// Whether the path points to an object that contains data (table or view). +pub fn contains_data(path: &[Value]) -> bool { + path.last() + .and_then(|v| v.get("kind")) + .and_then(|k| k.as_str()) + .map(|k| k == "table" || k == "view") + .unwrap_or(false) +} + +fn list_catalogs(reader: &dyn Reader) -> Result, String> { + let sql = reader.dialect().sql_list_catalogs(); + let df = reader + .execute_sql(&sql) + .map_err(|e| format!("Failed to list catalogs: {}", e))?; + + let col = find_column(&df, &["catalog_name", "name"]) + .map_err(|e| format!("Missing catalog_name/name column: {}", e))?; + + let mut catalogs = Vec::new(); + for i in 0..df.height() { + if let Ok(val) = col.get(i) { + let name = val.to_string().trim_matches('"').to_string(); + catalogs.push(ObjectSchema { + name, + kind: "catalog".to_string(), + }); + } + } + Ok(catalogs) +} + +fn list_schemas(reader: &dyn Reader, catalog: &str) -> Result, String> { + let sql = reader.dialect().sql_list_schemas(catalog); + let df = reader + .execute_sql(&sql) + .map_err(|e| format!("Failed to list schemas: {}", e))?; + + let col = find_column(&df, &["schema_name", "name"]) + .map_err(|e| format!("Missing schema_name/name column: {}", e))?; + + let mut schemas = Vec::new(); + for i in 0..df.height() { + if let Ok(val) = col.get(i) { + let name = val.to_string().trim_matches('"').to_string(); + schemas.push(ObjectSchema { + name, + kind: "schema".to_string(), + }); + } + } + Ok(schemas) +} + +fn list_tables( + reader: &dyn Reader, + catalog: &str, + schema: &str, +) -> Result, String> { + let sql = reader.dialect().sql_list_tables(catalog, schema); + let df = reader + .execute_sql(&sql) + .map_err(|e| format!("Failed to list tables: {}", e))?; + + let name_col = find_column(&df, &["table_name", "name"]) + .map_err(|e| format!("Missing table_name/name column: {}", e))?; + let type_col = find_column(&df, &["table_type", "kind"]) + .map_err(|e| format!("Missing table_type/kind column: {}", e))?; + + let mut objects = Vec::new(); + for i in 0..df.height() { + if let (Ok(name_val), Ok(type_val)) = (name_col.get(i), type_col.get(i)) { + let name = name_val.to_string().trim_matches('"').to_string(); + let table_type = type_val.to_string().trim_matches('"').to_uppercase(); + let kind = if table_type.contains("VIEW") { + "view" + } else if table_type == "TABLE" + || table_type == "BASE TABLE" + || table_type.contains("TABLE") + { + "table" + } else { + continue; // Skip non-table/view objects (stages, procedures, etc.) + }; + objects.push(ObjectSchema { + name, + kind: kind.to_string(), + }); + } + } + Ok(objects) +} + +fn list_columns( + reader: &dyn Reader, + catalog: &str, + schema: &str, + table: &str, +) -> Result, String> { + let sql = reader.dialect().sql_list_columns(catalog, schema, table); + let df = reader + .execute_sql(&sql) + .map_err(|e| format!("Failed to list columns: {}", e))?; + + let name_col = find_column(&df, &["column_name"]) + .map_err(|e| format!("Missing column_name column: {}", e))?; + let type_col = + find_column(&df, &["data_type"]).map_err(|e| format!("Missing data_type column: {}", e))?; + + let mut fields = Vec::new(); + for i in 0..df.height() { + if let (Ok(name_val), Ok(type_val)) = (name_col.get(i), type_col.get(i)) { + let name = name_val.to_string().trim_matches('"').to_string(); + let dtype = type_val.to_string().trim_matches('"').to_string(); + fields.push(FieldSchema { name, dtype }); + } + } + Ok(fields) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_contains_data_table() { + let path = vec![ + serde_json::json!({"name": "memory", "kind": "catalog"}), + serde_json::json!({"name": "main", "kind": "schema"}), + serde_json::json!({"name": "users", "kind": "table"}), + ]; + assert!(contains_data(&path)); + } + + #[test] + fn test_contains_data_schema() { + let path = vec![ + serde_json::json!({"name": "memory", "kind": "catalog"}), + serde_json::json!({"name": "main", "kind": "schema"}), + ]; + assert!(!contains_data(&path)); + } + + #[test] + fn test_contains_data_catalog() { + let path = vec![serde_json::json!({"name": "memory", "kind": "catalog"})]; + assert!(!contains_data(&path)); + } + + #[test] + fn test_contains_data_view() { + let path = vec![ + serde_json::json!({"name": "memory", "kind": "catalog"}), + serde_json::json!({"name": "main", "kind": "schema"}), + serde_json::json!({"name": "my_view", "kind": "view"}), + ]; + assert!(contains_data(&path)); + } +} diff --git a/ggsql-jupyter/src/data_explorer.rs b/ggsql-jupyter/src/data_explorer.rs new file mode 100644 index 00000000..306a729d --- /dev/null +++ b/ggsql-jupyter/src/data_explorer.rs @@ -0,0 +1,1014 @@ +//! Data explorer backend for the Positron data viewer. +//! +//! Implements the `positron.dataExplorer` comm protocol, providing SQL-backed +//! paginated data access. + +use crate::util::find_column; +use ggsql::reader::Reader; +use serde_json::{json, Value}; + +/// Result of handling an RPC call. +pub struct RpcResponse { + /// The JSON-RPC result to send as the reply. + pub result: Value, + /// An optional event to send on iopub (e.g. `return_column_profiles`). + pub event: Option, +} + +/// An asynchronous event to send back on the comm after the RPC reply. +pub struct RpcEvent { + pub method: String, + pub params: Value, +} + +impl RpcResponse { + /// Create a simple reply with no async event. + pub fn reply(result: Value) -> Self { + Self { + result, + event: None, + } + } +} + +/// Cached column metadata for a table. +#[derive(Debug, Clone)] +pub struct ColumnInfo { + pub name: String, + /// Backend-specific type name (e.g. "INTEGER", "VARCHAR"). + pub type_name: String, + /// Positron display type (e.g. "integer", "string"). + pub type_display: String, +} + +/// State for one open data explorer comm. +pub struct DataExplorerState { + /// Fully qualified and quoted table path, e.g. `"memory"."main"."users"`. + table_path: String, + /// Display title shown in the data viewer tab. + title: String, + /// Cached column schemas. + columns: Vec, + /// Cached total row count. + num_rows: usize, +} + +impl DataExplorerState { + /// Open a data explorer for a table at the given connection path. + /// + /// Runs `SELECT COUNT(*)` and a column metadata query to cache schema + /// information. Does **not** load the full table into memory. + pub fn open(reader: &dyn Reader, path: &[String]) -> Result { + if path.len() < 3 { + return Err(format!( + "Expected [catalog, schema, table] path, got {} elements", + path.len() + )); + } + + let catalog = &path[0]; + let schema = &path[1]; + let table = &path[2]; + + let table_path = format!( + "{}.{}.{}", + ggsql::naming::quote_ident(catalog), + ggsql::naming::quote_ident(schema), + ggsql::naming::quote_ident(table), + ); + + // Get row count + let count_sql = format!("SELECT COUNT(*) AS \"n\" FROM {}", table_path); + let count_df = reader + .execute_sql(&count_sql) + .map_err(|e| format!("Failed to count rows: {}", e))?; + let num_rows = count_df + .column("n") + .ok() + .and_then(|col| col.get(0).ok()) + .and_then(|val| { + // Polars AnyValue — try common integer representations + let s = format!("{}", val); + s.parse::().ok() + }) + .unwrap_or(0); + + // Get column metadata from information_schema + let columns_sql = reader.dialect().sql_list_columns(catalog, schema, table); + let columns_df = reader + .execute_sql(&columns_sql) + .map_err(|e| format!("Failed to list columns: {}", e))?; + + let name_col = find_column(&columns_df, &["column_name"]) + .map_err(|e| format!("Missing column_name: {}", e))?; + let type_col = find_column(&columns_df, &["data_type"]) + .map_err(|e| format!("Missing data_type: {}", e))?; + + let mut columns = Vec::new(); + for i in 0..columns_df.height() { + if let (Ok(name_val), Ok(type_val)) = (name_col.get(i), type_col.get(i)) { + let name = name_val.to_string().trim_matches('"').to_string(); + let raw_type = type_val.to_string().trim_matches('"').to_string(); + let type_display = sql_type_to_display(&raw_type).to_string(); + let type_name = clean_type_name(&raw_type); + columns.push(ColumnInfo { + name, + type_name, + type_display, + }); + } + } + + Ok(Self { + table_path, + title: table.clone(), + columns, + num_rows, + }) + } + + /// Dispatch a JSON-RPC method call. + /// + /// Returns the RPC result and an optional async event to send on iopub + /// (used by `get_column_profiles` to deliver results asynchronously). + pub fn handle_rpc(&self, method: &str, params: &Value, reader: &dyn Reader) -> RpcResponse { + match method { + "get_state" => RpcResponse::reply(self.get_state()), + "get_schema" => RpcResponse::reply(self.get_schema(params)), + "get_data_values" => RpcResponse::reply(self.get_data_values(params, reader)), + "get_column_profiles" => self.get_column_profiles(params, reader), + // TODO: Implement filters, sorting, and searching. + "set_row_filters" => { + // Stub: accept but ignore filters, return current shape + RpcResponse::reply(json!({ + "selected_num_rows": self.num_rows, + "had_errors": false + })) + } + "set_sort_columns" | "set_column_filters" | "search_schema" => { + RpcResponse::reply(json!(null)) + } + _ => { + tracing::warn!("Unhandled data explorer method: {}", method); + RpcResponse::reply(json!(null)) + } + } + } + + fn get_state(&self) -> Value { + let num_columns = self.columns.len(); + json!({ + "display_name": self.title, + "table_shape": { + "num_rows": self.num_rows, + "num_columns": num_columns + }, + "table_unfiltered_shape": { + "num_rows": self.num_rows, + "num_columns": num_columns + }, + "has_row_labels": false, + "column_filters": [], + "row_filters": [], + "sort_keys": [], + "supported_features": { + "search_schema": { + "support_status": "unsupported", + "supported_types": [] + }, + "set_column_filters": { + "support_status": "unsupported", + "supported_types": [] + }, + "set_row_filters": { + "support_status": "unsupported", + "supports_conditions": "unsupported", + "supported_types": [] + }, + "get_column_profiles": { + "support_status": "supported", + "supported_types": [ + {"profile_type": "null_count", "support_status": "supported"}, + {"profile_type": "summary_stats", "support_status": "supported"}, + {"profile_type": "small_histogram", "support_status": "supported"}, + {"profile_type": "small_frequency_table", "support_status": "supported"} + ] + }, + "set_sort_columns": { + "support_status": "unsupported" + }, + "export_data_selection": { + "support_status": "unsupported", + "supported_formats": [] + }, + "convert_to_code": { + "support_status": "unsupported" + } + } + }) + } + + fn get_schema(&self, params: &Value) -> Value { + let indices: Vec = params + .get("column_indices") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_u64().map(|n| n as usize)) + .collect() + }) + .unwrap_or_default(); + + let columns: Vec = indices + .iter() + .filter_map(|&idx| { + self.columns.get(idx).map(|col| { + json!({ + "column_name": col.name, + "column_index": idx, + "type_name": col.type_name, + "type_display": col.type_display + }) + }) + }) + .collect(); + + json!({ "columns": columns }) + } + + fn get_data_values(&self, params: &Value, reader: &dyn Reader) -> Value { + let selections = match params.get("columns").and_then(|v| v.as_array()) { + Some(arr) => arr, + None => return json!({ "columns": [] }), + }; + + // Determine the row range from the first selection's spec + let (first_index, last_index) = selections + .first() + .and_then(|sel| sel.get("spec")) + .map(|spec| { + let first = spec + .get("first_index") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + let last = spec.get("last_index").and_then(|v| v.as_u64()).unwrap_or(0) as usize; + (first, last) + }) + .unwrap_or((0, 0)); + + let limit = last_index.saturating_sub(first_index) + 1; + + // Collect requested column indices + let col_indices: Vec = selections + .iter() + .filter_map(|sel| { + sel.get("column_index") + .and_then(|v| v.as_u64()) + .map(|n| n as usize) + }) + .collect(); + + // Build column list for SELECT + let col_names: Vec = col_indices + .iter() + .filter_map(|&idx| { + self.columns + .get(idx) + .map(|col| ggsql::naming::quote_ident(&col.name)) + }) + .collect(); + + if col_names.is_empty() { + return json!({ "columns": [] }); + } + + let sql = format!( + "SELECT {} FROM {} LIMIT {} OFFSET {}", + col_names.join(", "), + self.table_path, + limit, + first_index, + ); + + let df = match reader.execute_sql(&sql) { + Ok(df) => df, + Err(e) => { + tracing::error!("get_data_values query failed: {}", e); + let empty: Vec> = col_indices.iter().map(|_| vec![]).collect(); + return json!({ "columns": empty }); + } + }; + + // Format each column's values as strings. + // Positron's ColumnValue is `number | string`: numbers are special + // value codes (0 = NULL, 1 = NA, 2 = NaN), strings are formatted data. + const SPECIAL_VALUE_NULL: i64 = 0; + + let columns: Vec> = (0..df.width()) + .map(|col_idx| { + let col = df.get_columns()[col_idx].clone(); + (0..df.height()) + .map(|row_idx| { + match col.get(row_idx) { + Ok(val) => { + if val.is_null() { + json!(SPECIAL_VALUE_NULL) + } else { + let s = format!("{}", val); + // Strip surrounding quotes from string values + let s = s.trim_matches('"'); + Value::String(s.to_string()) + } + } + Err(_) => json!(SPECIAL_VALUE_NULL), + } + }) + .collect() + }) + .collect(); + + json!({ "columns": columns }) + } + + /// Handle `get_column_profiles` — returns `{}` as the RPC result and sends + /// profile data back as an async `return_column_profiles` event. + fn get_column_profiles(&self, params: &Value, reader: &dyn Reader) -> RpcResponse { + let callback_id = params + .get("callback_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let requests = match params.get("profiles").and_then(|v| v.as_array()) { + Some(arr) => arr, + None => { + return RpcResponse { + result: json!({}), + event: Some(RpcEvent { + method: "return_column_profiles".into(), + params: json!({ + "callback_id": callback_id, + "profiles": [] + }), + }), + }; + } + }; + + let mut profiles = Vec::new(); + for req in requests { + let col_idx = req + .get("column_index") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + + let specs = req + .get("profiles") + .and_then(|v| v.as_array()) + .cloned() + .unwrap_or_default(); + + let profile = self.compute_column_profile(col_idx, &specs, reader); + profiles.push(profile); + } + + RpcResponse { + result: json!({}), + event: Some(RpcEvent { + method: "return_column_profiles".into(), + params: json!({ + "callback_id": callback_id, + "profiles": profiles + }), + }), + } + } + + /// Compute profile results for a single column. + fn compute_column_profile( + &self, + col_idx: usize, + specs: &[Value], + reader: &dyn Reader, + ) -> Value { + let col = match self.columns.get(col_idx) { + Some(c) => c, + None => return json!({}), + }; + + let mut wants_null_count = false; + let mut wants_summary = false; + let mut histogram_params: Option<&Value> = None; + let mut freq_table_params: Option<&Value> = None; + for spec in specs { + match spec + .get("profile_type") + .and_then(|v| v.as_str()) + .unwrap_or("") + { + "null_count" => wants_null_count = true, + "summary_stats" => wants_summary = true, + "small_histogram" => histogram_params = spec.get("params"), + "small_frequency_table" => freq_table_params = spec.get("params"), + _ => {} + } + } + + let dialect = reader.dialect(); + let quoted_col = ggsql::naming::quote_ident(&col.name); + let display = col.type_display.as_str(); + + // Build a single SQL query that computes all needed aggregates. + let mut select_parts = Vec::new(); + if wants_null_count { + select_parts.push(format!( + "SUM(CASE WHEN {} IS NULL THEN 1 ELSE 0 END) AS \"null_count\"", + quoted_col + )); + } + if wants_summary { + match display { + "integer" | "floating" => { + let float_type = dialect.number_type_name().unwrap_or("DOUBLE PRECISION"); + select_parts.push(format!("MIN({}) AS \"min_val\"", quoted_col)); + select_parts.push(format!("MAX({}) AS \"max_val\"", quoted_col)); + select_parts.push(format!( + "AVG(CAST({} AS {})) AS \"mean_val\"", + quoted_col, float_type + )); + // Stddev: fetch raw aggregates, compute in Rust + select_parts.push(format!( + "SUM(CAST({c} AS {t}) * CAST({c} AS {t})) AS \"sum_sq\"", + c = quoted_col, + t = float_type + )); + select_parts.push(format!( + "SUM(CAST({} AS {})) AS \"sum_val\"", + quoted_col, float_type + )); + select_parts.push(format!("COUNT({}) AS \"cnt\"", quoted_col)); + } + "boolean" => { + let true_lit = dialect.sql_boolean_literal(true); + let false_lit = dialect.sql_boolean_literal(false); + select_parts.push(format!( + "SUM(CASE WHEN {} = {} THEN 1 ELSE 0 END) AS \"true_count\"", + quoted_col, true_lit + )); + select_parts.push(format!( + "SUM(CASE WHEN {} = {} THEN 1 ELSE 0 END) AS \"false_count\"", + quoted_col, false_lit + )); + } + "string" => { + select_parts.push(format!("COUNT(DISTINCT {}) AS \"num_unique\"", quoted_col)); + select_parts.push(format!( + "SUM(CASE WHEN {} = '' THEN 1 ELSE 0 END) AS \"num_empty\"", + quoted_col + )); + } + "date" | "datetime" => { + select_parts.push(format!("MIN({}) AS \"min_val\"", quoted_col)); + select_parts.push(format!("MAX({}) AS \"max_val\"", quoted_col)); + select_parts.push(format!("COUNT(DISTINCT {}) AS \"num_unique\"", quoted_col)); + } + _ => {} + } + } + + if select_parts.is_empty() { + return json!({}); + } + + let sql = format!( + "SELECT {} FROM {}", + select_parts.join(", "), + self.table_path + ); + + let df = match reader.execute_sql(&sql) { + Ok(df) => df, + Err(e) => { + tracing::error!("Column profile query failed: {}", e); + return json!({}); + } + }; + + let get_str = |name: &str| -> Option { + df.column(name) + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| { + if v.is_null() { + None + } else { + Some(format!("{}", v).trim_matches('"').to_string()) + } + }) + }; + + let get_i64 = + |name: &str| -> Option { get_str(name).and_then(|s| s.parse::().ok()) }; + + let get_f64 = + |name: &str| -> Option { get_str(name).and_then(|s| s.parse::().ok()) }; + + let mut result = json!({}); + + if wants_null_count { + if let Some(n) = get_i64("null_count") { + result["null_count"] = json!(n); + } + } + + if wants_summary { + let stats = match display { + "integer" | "floating" => { + let mut number_stats = json!({}); + if let Some(v) = get_str("min_val") { + number_stats["min_value"] = json!(v); + } + if let Some(v) = get_str("max_val") { + number_stats["max_value"] = json!(v); + } + if let Some(v) = get_str("mean_val") { + number_stats["mean"] = json!(v); + } + // Compute sample stddev from raw aggregates + if let (Some(sum_sq), Some(sum_val), Some(cnt)) = + (get_f64("sum_sq"), get_f64("sum_val"), get_i64("cnt")) + { + if cnt > 1 { + let variance = + (sum_sq - sum_val * sum_val / cnt as f64) / (cnt - 1) as f64; + let stdev = variance.max(0.0).sqrt(); + number_stats["stdev"] = json!(format!("{}", stdev)); + } + } + // Median via dialect's sql_percentile + let col_name = col.name.replace('"', "\"\""); + let from_query = format!("SELECT * FROM {}", self.table_path); + let median_expr = dialect.sql_percentile(&col_name, 0.5, &from_query, &[]); + let median_sql = format!("SELECT {} AS \"median_val\"", median_expr); + if let Ok(median_df) = reader.execute_sql(&median_sql) { + if let Some(v) = median_df + .column("median_val") + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| { + if v.is_null() { + None + } else { + Some(format!("{}", v).trim_matches('"').to_string()) + } + }) + { + number_stats["median"] = json!(v); + } + } + json!({ + "type_display": display, + "number_stats": number_stats + }) + } + "boolean" => { + json!({ + "type_display": display, + "boolean_stats": { + "true_count": get_i64("true_count").unwrap_or(0), + "false_count": get_i64("false_count").unwrap_or(0) + } + }) + } + "string" => { + json!({ + "type_display": display, + "string_stats": { + "num_unique": get_i64("num_unique").unwrap_or(0), + "num_empty": get_i64("num_empty").unwrap_or(0) + } + }) + } + "date" => { + let mut date_stats = json!({}); + if let Some(v) = get_str("min_val") { + date_stats["min_date"] = json!(v); + } + if let Some(v) = get_str("max_val") { + date_stats["max_date"] = json!(v); + } + if let Some(n) = get_i64("num_unique") { + date_stats["num_unique"] = json!(n); + } + json!({ + "type_display": display, + "date_stats": date_stats + }) + } + "datetime" => { + let mut datetime_stats = json!({}); + if let Some(v) = get_str("min_val") { + datetime_stats["min_date"] = json!(v); + } + if let Some(v) = get_str("max_val") { + datetime_stats["max_date"] = json!(v); + } + if let Some(n) = get_i64("num_unique") { + datetime_stats["num_unique"] = json!(n); + } + json!({ + "type_display": display, + "datetime_stats": datetime_stats + }) + } + _ => json!({"type_display": display}), + }; + result["summary_stats"] = stats; + } + + // Compute histogram if requested (only for numeric types) + if let Some(params) = histogram_params { + if matches!(display, "integer" | "floating") { + if let Some(hist) = self.compute_histogram(col, params, reader) { + result["small_histogram"] = hist; + } + } + } + + // Compute frequency table if requested (for string/boolean types) + if let Some(params) = freq_table_params { + if matches!(display, "string" | "boolean") { + if let Some(ft) = self.compute_frequency_table(col, params, reader) { + result["small_frequency_table"] = ft; + } + } + } + + result + } + + /// Compute a histogram for a numeric column. + fn compute_histogram( + &self, + col: &ColumnInfo, + params: &Value, + reader: &dyn Reader, + ) -> Option { + let max_bins = params + .get("num_bins") + .and_then(|v| v.as_u64()) + .unwrap_or(20) as usize; + + if max_bins == 0 { + return None; + } + + let dialect = reader.dialect(); + let float_type = dialect.number_type_name().unwrap_or("DOUBLE PRECISION"); + let quoted_col = ggsql::naming::quote_ident(&col.name); + let is_integer = col.type_display == "integer"; + + // Get min, max, count in one query + let bounds_sql = format!( + "SELECT \ + MIN(CAST({c} AS {t})) AS \"min_val\", \ + MAX(CAST({c} AS {t})) AS \"max_val\", \ + COUNT({c}) AS \"cnt\" \ + FROM {table} WHERE {c} IS NOT NULL", + c = quoted_col, + t = float_type, + table = self.table_path, + ); + + let bounds_df = reader.execute_sql(&bounds_sql).ok()?; + let get_f64 = |name: &str| -> Option { + bounds_df + .column(name) + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| { + if v.is_null() { + None + } else { + format!("{}", v).trim_matches('"').parse::().ok() + } + }) + }; + + let min_val = get_f64("min_val")?; + let max_val = get_f64("max_val")?; + let count = get_f64("cnt").unwrap_or(0.0) as usize; + + // Handle edge case: all values identical + if (max_val - min_val).abs() < f64::EPSILON { + return Some(json!({ + "bin_edges": [format!("{}", min_val), format!("{}", max_val)], + "bin_counts": [count as i64], + "quantiles": [] + })); + } + + // Determine actual bin count using Sturges' formula, capped at max_bins. + // For integers, also cap at (max - min + 1) to avoid sub-unit bins. + let mut num_bins = if count > 1 { + ((count as f64).log2().ceil() as usize + 1).max(1) + } else { + 1 + }; + if is_integer { + let int_range = (max_val - min_val) as usize + 1; + num_bins = num_bins.min(int_range); + } + num_bins = num_bins.min(max_bins).max(1); + + let bin_width = (max_val - min_val) / num_bins as f64; + + // Bin the data using FLOOR. Clamp the last bin to num_bins-1 so + // max value doesn't create an extra bin. + let hist_sql = format!( + "SELECT \ + CASE \ + WHEN \"bin\" >= {num_bins} THEN {last_bin} \ + ELSE \"bin\" \ + END AS \"clamped_bin\", \ + COUNT(*) AS \"cnt\" \ + FROM ( \ + SELECT FLOOR((CAST({c} AS {t}) - {min}) / {width}) AS \"bin\" \ + FROM {table} \ + WHERE {c} IS NOT NULL \ + ) AS \"__bins__\" \ + GROUP BY \"clamped_bin\" \ + ORDER BY \"clamped_bin\"", + c = quoted_col, + t = float_type, + table = self.table_path, + min = min_val, + width = bin_width, + num_bins = num_bins, + last_bin = num_bins - 1, + ); + + let hist_df = reader.execute_sql(&hist_sql).ok()?; + + // Build bin_edges: num_bins + 1 edges + let bin_edges: Vec = (0..=num_bins) + .map(|i| format!("{}", min_val + i as f64 * bin_width)) + .collect(); + + // Build bin_counts: fill from query results (sparse bins get 0) + let mut bin_counts = vec![0i64; num_bins]; + let bin_col = hist_df.column("clamped_bin").ok()?; + let cnt_col = hist_df.column("cnt").ok()?; + for i in 0..hist_df.height() { + if let (Ok(bin_val), Ok(cnt_val)) = (bin_col.get(i), cnt_col.get(i)) { + let bin_str = format!("{}", bin_val); + // Parse bin index — may be float (e.g., "3.0") on some backends + if let Ok(bin_idx) = bin_str.parse::() { + let idx = bin_idx as usize; + if idx < num_bins { + let count_str = format!("{}", cnt_val); + bin_counts[idx] = count_str.parse::().unwrap_or(0); + } + } + } + } + + // Compute requested quantiles + let quantiles_param = params + .get("quantiles") + .and_then(|v| v.as_array()) + .cloned() + .unwrap_or_default(); + + let mut quantile_results = Vec::new(); + let from_query = format!("SELECT * FROM {}", self.table_path); + let col_name = col.name.replace('"', "\"\""); + for q in &quantiles_param { + if let Some(q_val) = q.as_f64() { + let expr = dialect.sql_percentile(&col_name, q_val, &from_query, &[]); + let q_sql = format!("SELECT {} AS \"q_val\"", expr); + if let Ok(q_df) = reader.execute_sql(&q_sql) { + if let Some(v) = q_df + .column("q_val") + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| { + if v.is_null() { + None + } else { + Some(format!("{}", v).trim_matches('"').to_string()) + } + }) + { + quantile_results.push(json!({"q": q_val, "value": v})); + } + } + } + } + + Some(json!({ + "bin_edges": bin_edges, + "bin_counts": bin_counts, + "quantiles": quantile_results + })) + } + + /// Compute a frequency table for a string or boolean column. + fn compute_frequency_table( + &self, + col: &ColumnInfo, + params: &Value, + reader: &dyn Reader, + ) -> Option { + let limit = params.get("limit").and_then(|v| v.as_u64()).unwrap_or(8) as usize; + + let quoted_col = ggsql::naming::quote_ident(&col.name); + + let sql = format!( + "SELECT {c} AS \"value\", COUNT(*) AS \"count\" \ + FROM {table} \ + WHERE {c} IS NOT NULL \ + GROUP BY {c} \ + ORDER BY COUNT(*) DESC \ + LIMIT {limit}", + c = quoted_col, + table = self.table_path, + limit = limit, + ); + + let df = reader.execute_sql(&sql).ok()?; + + let val_col = df.column("value").ok()?; + let cnt_col = df.column("count").ok()?; + + let mut values = Vec::new(); + let mut counts = Vec::new(); + let mut top_total: i64 = 0; + + for i in 0..df.height() { + if let (Ok(v), Ok(c)) = (val_col.get(i), cnt_col.get(i)) { + let val_str = format!("{}", v).trim_matches('"').to_string(); + let count: i64 = format!("{}", c).parse().unwrap_or(0); + values.push(Value::String(val_str)); + counts.push(count); + top_total += count; + } + } + + // Compute other_count: total non-null rows minus the top-K sum + let count_sql = format!( + "SELECT COUNT({c}) AS \"total\" FROM {table}", + c = quoted_col, + table = self.table_path, + ); + let other_count = reader + .execute_sql(&count_sql) + .ok() + .and_then(|df| { + df.column("total") + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| format!("{}", v).parse::().ok()) + }) + .map(|total| total - top_total) + .unwrap_or(0); + + Some(json!({ + "values": values, + "counts": counts, + "other_count": other_count + })) + } +} + +/// Map a SQL type name (from information_schema or SHOW COLUMNS) to a Positron display type. +/// +/// Handles both simple type names (e.g. "INTEGER", "VARCHAR") and Snowflake's +/// JSON format (e.g. `{"type":"FIXED","precision":38,"scale":0,...}`). +fn sql_type_to_display(type_name: &str) -> &'static str { + // Handle Snowflake JSON type format + if type_name.starts_with('{') { + if let Ok(obj) = serde_json::from_str::(type_name) { + if let Some(t) = obj.get("type").and_then(|v| v.as_str()) { + return match t { + "FIXED" => { + let scale = obj.get("scale").and_then(|v| v.as_i64()).unwrap_or(0); + if scale > 0 { + "floating" + } else { + "integer" + } + } + "REAL" | "FLOAT" => "floating", + "TEXT" => "string", + "BOOLEAN" => "boolean", + "DATE" => "date", + "TIMESTAMP_NTZ" | "TIMESTAMP_LTZ" | "TIMESTAMP_TZ" => "datetime", + "TIME" => "time", + "BINARY" => "string", + "VARIANT" | "OBJECT" | "ARRAY" => "string", + _ => "unknown", + }; + } + } + } + + // Simple type names (DuckDB, PostgreSQL, SQLite, etc.) + let upper = type_name.to_uppercase(); + let upper = upper.as_str(); + + if upper.contains("INT") { + return "integer"; + } + if upper.contains("FLOAT") + || upper.contains("DOUBLE") + || upper.contains("REAL") + || upper.contains("NUMERIC") + || upper.contains("DECIMAL") + { + return "floating"; + } + if upper.contains("BOOL") { + return "boolean"; + } + if upper.contains("TIMESTAMP") || upper.contains("DATETIME") { + return "datetime"; + } + if upper.contains("DATE") { + return "date"; + } + if upper.contains("TIME") { + return "time"; + } + if upper.contains("CHAR") + || upper.contains("TEXT") + || upper.contains("STRING") + || upper.contains("VARCHAR") + || upper.contains("CLOB") + { + return "string"; + } + if upper.contains("BLOB") || upper.contains("BINARY") || upper.contains("BYTE") { + return "string"; + } + + "unknown" +} + +/// Clean up a raw type name for display in the schema response. +/// +/// For Snowflake JSON types, extracts the `type` field (e.g. "NUMBER", "TEXT"). +/// For simple type names, returns as-is. +fn clean_type_name(type_name: &str) -> String { + if type_name.starts_with('{') { + if let Ok(obj) = serde_json::from_str::(type_name) { + if let Some(t) = obj.get("type").and_then(|v| v.as_str()) { + return match t { + "FIXED" => { + let scale = obj.get("scale").and_then(|v| v.as_i64()).unwrap_or(0); + if scale > 0 { + format!( + "NUMBER({},{})", + obj.get("precision").and_then(|v| v.as_i64()).unwrap_or(38), + scale + ) + } else { + "NUMBER".to_string() + } + } + other => other.to_string(), + }; + } + } + } + type_name.to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sql_type_to_display() { + assert_eq!(sql_type_to_display("INTEGER"), "integer"); + assert_eq!(sql_type_to_display("BIGINT"), "integer"); + assert_eq!(sql_type_to_display("SMALLINT"), "integer"); + assert_eq!(sql_type_to_display("TINYINT"), "integer"); + assert_eq!(sql_type_to_display("INT"), "integer"); + assert_eq!(sql_type_to_display("DOUBLE"), "floating"); + assert_eq!(sql_type_to_display("FLOAT"), "floating"); + assert_eq!(sql_type_to_display("REAL"), "floating"); + assert_eq!(sql_type_to_display("NUMERIC(10,2)"), "floating"); + assert_eq!(sql_type_to_display("DECIMAL(10,2)"), "floating"); + assert_eq!(sql_type_to_display("BOOLEAN"), "boolean"); + assert_eq!(sql_type_to_display("BOOL"), "boolean"); + assert_eq!(sql_type_to_display("VARCHAR"), "string"); + assert_eq!(sql_type_to_display("TEXT"), "string"); + assert_eq!(sql_type_to_display("DATE"), "date"); + assert_eq!(sql_type_to_display("TIMESTAMP"), "datetime"); + assert_eq!(sql_type_to_display("TIMESTAMP WITH TIME ZONE"), "datetime"); + assert_eq!(sql_type_to_display("TIME"), "time"); + assert_eq!(sql_type_to_display("BLOB"), "string"); + assert_eq!(sql_type_to_display("UNKNOWN_TYPE"), "unknown"); + } +} diff --git a/ggsql-jupyter/src/display.rs b/ggsql-jupyter/src/display.rs index 1e1dbe67..78823779 100644 --- a/ggsql-jupyter/src/display.rs +++ b/ggsql-jupyter/src/display.rs @@ -35,9 +35,24 @@ pub fn format_display_data(result: ExecutionResult) -> Option { Some(format_dataframe(df)) } } + ExecutionResult::ConnectionChanged { display_name, .. } => { + Some(format_connection_changed(&display_name)) + } } } +/// Format a connection-changed message +fn format_connection_changed(display_name: &str) -> Value { + let text = format!("Connected to {}", display_name); + json!({ + "data": { + "text/plain": text + }, + "metadata": {}, + "transient": {} + }) +} + /// Format Vega-Lite visualization as display_data fn format_vegalite(spec: String) -> Value { let spec_value: Value = serde_json::from_str(&spec).unwrap_or_else(|e| { diff --git a/ggsql-jupyter/src/executor.rs b/ggsql-jupyter/src/executor.rs index 1548f5e3..434345c3 100644 --- a/ggsql-jupyter/src/executor.rs +++ b/ggsql-jupyter/src/executor.rs @@ -2,10 +2,11 @@ //! //! This module handles the execution of ggsql queries using the existing //! ggsql library components (parser, DuckDB reader, Vega-Lite writer). +//! It supports dynamic reader switching via `-- @connect:` meta-commands. use anyhow::Result; use ggsql::{ - reader::{DuckDBReader, Reader}, + reader::{connection::parse_connection_string, DuckDBReader, Reader}, validate::validate, writer::{VegaLiteWriter, Writer}, }; @@ -20,39 +21,196 @@ pub enum ExecutionResult { Visualization { spec: String, // Vega-Lite JSON }, + /// Connection changed via meta-command + ConnectionChanged { uri: String, display_name: String }, } -/// Query executor maintaining persistent DuckDB connection +/// Create a reader from a connection URI string. +/// +/// Supported schemes: +/// - `duckdb://memory` or `duckdb://` (always available) +/// - `sqlite://` (requires `sqlite` feature) +/// - `odbc://...` (requires `odbc` feature) +pub fn create_reader(uri: &str) -> Result> { + use ggsql::reader::connection::ConnectionInfo; + + let info = parse_connection_string(uri)?; + match info { + ConnectionInfo::DuckDBMemory => { + let reader = DuckDBReader::from_connection_string("duckdb://memory")?; + Ok(Box::new(reader)) + } + ConnectionInfo::DuckDBFile(path) => { + let reader = DuckDBReader::from_connection_string(&format!("duckdb://{}", path))?; + Ok(Box::new(reader)) + } + #[cfg(feature = "odbc")] + ConnectionInfo::ODBC(conn_str) => { + let reader = + ggsql::reader::OdbcReader::from_connection_string(&format!("odbc://{}", conn_str))?; + Ok(Box::new(reader)) + } + #[cfg(feature = "sqlite")] + ConnectionInfo::SQLite(path) => { + let reader = + ggsql::reader::SqliteReader::from_connection_string(&format!("sqlite://{}", path))?; + Ok(Box::new(reader)) + } + _ => anyhow::bail!("Unsupported reader type for connection string: {}", uri), + } +} + +/// Generate a human-readable display name for a connection URI. +pub fn display_name_for_uri(uri: &str) -> String { + if uri == "duckdb://memory" { + return "DuckDB (memory)".to_string(); + } + if let Some(path) = uri.strip_prefix("duckdb://") { + return format!("DuckDB ({})", path); + } + if let Some(path) = uri.strip_prefix("sqlite://") { + if path.is_empty() { + return "SQLite (memory)".to_string(); + } + return format!("SQLite ({})", path); + } + if let Some(odbc) = uri.strip_prefix("odbc://") { + // Try to extract driver name from ODBC string + if let Some(driver_start) = odbc.to_lowercase().find("driver=") { + let rest = &odbc[driver_start + 7..]; + let driver = rest + .split(';') + .next() + .unwrap_or("ODBC") + .trim_matches(|c| c == '{' || c == '}'); + return format!("{} (ODBC)", driver); + } + return "ODBC".to_string(); + } + uri.to_string() +} + +/// Detect the database type name from a connection URI (e.g. "DuckDB", "Snowflake"). +pub fn type_name_for_uri(uri: &str) -> String { + if uri.starts_with("duckdb://") { + return "DuckDB".to_string(); + } + if uri.starts_with("sqlite://") { + return "SQLite".to_string(); + } + if let Some(odbc) = uri.strip_prefix("odbc://") { + if odbc.to_lowercase().contains("driver=snowflake") { + return "Snowflake".to_string(); + } + if odbc.to_lowercase().contains("driver={postgresql}") + || odbc.to_lowercase().contains("driver=postgresql") + { + return "PostgreSQL".to_string(); + } + return "ODBC".to_string(); + } + "Unknown".to_string() +} + +/// Extract the host portion from a connection URI. +pub fn host_for_uri(uri: &str) -> String { + if uri == "duckdb://memory" { + return "memory".to_string(); + } + if let Some(path) = uri.strip_prefix("duckdb://") { + return path.to_string(); + } + if let Some(path) = uri.strip_prefix("sqlite://") { + if path.is_empty() { + return "memory".to_string(); + } + return path.to_string(); + } + if let Some(odbc) = uri.strip_prefix("odbc://") { + // Try to extract server + if let Some(server_start) = odbc.to_lowercase().find("server=") { + let rest = &odbc[server_start + 7..]; + if let Some(host) = rest.split(';').next() { + return host.to_string(); + } + } + } + uri.to_string() +} + +/// The `-- @connect:` meta-command prefix. +const META_CONNECT_PREFIX: &str = "-- @connect:"; + +/// Parse a `-- @connect: ` meta-command, returning the URI if present. +pub fn parse_meta_command(code: &str) -> Option { + let trimmed = code.trim(); + trimmed + .strip_prefix(META_CONNECT_PREFIX) + .map(|rest| rest.trim().to_string()) +} + +/// Query executor maintaining persistent database connection pub struct QueryExecutor { - reader: DuckDBReader, + reader: Box, writer: VegaLiteWriter, + reader_uri: String, } impl QueryExecutor { - /// Create a new query executor with in-memory DuckDB database - pub fn new() -> Result { - tracing::info!("Initializing query executor with in-memory DuckDB"); - let reader = DuckDBReader::from_connection_string("duckdb://memory")?; + /// Create a new query executor with a given connection URI + pub fn new_with_uri(uri: &str) -> Result { + tracing::info!("Initializing query executor with reader: {}", uri); + let reader = create_reader(uri)?; let writer = VegaLiteWriter::new(); - Ok(Self { reader, writer }) + Ok(Self { + reader, + writer, + reader_uri: uri.to_string(), + }) } - /// Execute a ggsql query - /// - /// This handles both pure SQL queries and queries with VISUALISE clauses. - /// - /// # Arguments - /// - /// * `code` - The ggsql query to execute - /// - /// # Returns + /// Create a new query executor with the default in-memory DuckDB database + #[cfg(test)] + pub fn new() -> Result { + Self::new_with_uri("duckdb://memory") + } + + /// Get the current reader URI + pub fn reader_uri(&self) -> &str { + &self.reader_uri + } + + /// Get a reference to the current reader (for schema introspection) + pub fn reader(&self) -> &dyn Reader { + &*self.reader + } + + /// Swap the reader to a new connection, returning the old URI + pub fn swap_reader(&mut self, uri: &str) -> Result { + let new_reader = create_reader(uri)?; + self.reader = new_reader; + let old_uri = std::mem::replace(&mut self.reader_uri, uri.to_string()); + Ok(old_uri) + } + + /// Execute a ggsql query or meta-command /// - /// An ExecutionResult containing either a DataFrame (for pure SQL) or - /// a Visualization (for queries with VISUALISE clause) - pub fn execute(&self, code: &str) -> Result { + /// This handles: + /// - `-- @connect: ` meta-commands for switching readers + /// - Pure SQL queries (no VISUALISE) + /// - ggsql queries with VISUALISE clauses + pub fn execute(&mut self, code: &str) -> Result { tracing::debug!("Executing query: {} chars", code.len()); + // Check for meta-commands first + if let Some(uri) = parse_meta_command(code) { + tracing::info!("Meta-command: switching reader to {}", uri); + self.swap_reader(&uri)?; + let display_name = display_name_for_uri(&uri); + return Ok(ExecutionResult::ConnectionChanged { uri, display_name }); + } + // 1. Validate to check if there's a visualization let validated = validate(code)?; @@ -93,7 +251,7 @@ mod tests { #[test] fn test_simple_visualization() { - let executor = QueryExecutor::new().unwrap(); + let mut executor = QueryExecutor::new().unwrap(); let code = "SELECT 1 as x, 2 as y VISUALISE x, y DRAW point"; let result = executor.execute(code).unwrap(); @@ -102,7 +260,7 @@ mod tests { #[test] fn test_pure_sql() { - let executor = QueryExecutor::new().unwrap(); + let mut executor = QueryExecutor::new().unwrap(); let code = "SELECT 1 as x, 2 as y"; let result = executor.execute(code).unwrap(); @@ -111,10 +269,38 @@ mod tests { #[test] fn test_error_handling() { - let executor = QueryExecutor::new().unwrap(); + let mut executor = QueryExecutor::new().unwrap(); let code = "SELECT * FROM nonexistent_table"; let result = executor.execute(code); assert!(result.is_err()); } + + #[test] + fn test_parse_meta_command() { + assert_eq!( + parse_meta_command("-- @connect: duckdb://memory"), + Some("duckdb://memory".to_string()) + ); + assert_eq!( + parse_meta_command(" -- @connect: duckdb://my.db "), + Some("duckdb://my.db".to_string()) + ); + assert_eq!(parse_meta_command("SELECT 1"), None); + } + + #[test] + fn test_meta_command_switches_reader() { + let mut executor = QueryExecutor::new().unwrap(); + assert_eq!(executor.reader_uri(), "duckdb://memory"); + + let result = executor.execute("-- @connect: duckdb://memory").unwrap(); + assert!(matches!(result, ExecutionResult::ConnectionChanged { .. })); + } + + #[test] + fn test_display_name_for_uri() { + assert_eq!(display_name_for_uri("duckdb://memory"), "DuckDB (memory)"); + assert_eq!(display_name_for_uri("duckdb://my.db"), "DuckDB (my.db)"); + } } diff --git a/ggsql-jupyter/src/kernel.rs b/ggsql-jupyter/src/kernel.rs index 14c07340..439adabd 100644 --- a/ggsql-jupyter/src/kernel.rs +++ b/ggsql-jupyter/src/kernel.rs @@ -3,13 +3,16 @@ //! This module implements the Jupyter messaging protocol over ZeroMQ sockets, //! handling kernel_info, execute, and shutdown requests. +use crate::connection; +use crate::data_explorer::{DataExplorerState, RpcResponse}; use crate::display::format_display_data; -use crate::executor::QueryExecutor; +use crate::executor::{self, ExecutionResult, QueryExecutor}; use crate::message::{ConnectionInfo, JupyterMessage, MessageHeader}; use anyhow::Result; use hmac::{Hmac, Mac}; use serde_json::{json, Value}; use sha2::Sha256; +use std::collections::HashMap; use zeromq::{PubSocket, RepSocket, RouterSocket, Socket, SocketRecv, SocketSend}; type HmacSha256 = Hmac; @@ -32,11 +35,13 @@ pub struct KernelServer { variables_comm_id: Option, ui_comm_id: Option, plot_comm_id: Option, + connection_comm_id: Option, + data_explorer_comms: HashMap, } impl KernelServer { /// Create a new kernel server from connection info - pub async fn new(connection: ConnectionInfo) -> Result { + pub async fn new(connection: ConnectionInfo, reader_uri: &str) -> Result { tracing::info!("Initializing kernel server"); // Initialize sockets @@ -68,8 +73,8 @@ impl KernelServer { tracing::info!("Binding heartbeat socket to {}", hb_addr); heartbeat.bind(&hb_addr).await?; - // Create executor - let executor = QueryExecutor::new()?; + // Create executor with the specified reader + let executor = QueryExecutor::new_with_uri(reader_uri)?; // Generate session ID let session = uuid::Uuid::new_v4().to_string(); @@ -92,12 +97,17 @@ impl KernelServer { variables_comm_id: None, ui_comm_id: None, plot_comm_id: None, + connection_comm_id: None, + data_explorer_comms: HashMap::new(), }; // Send initial "starting" status on IOPub // This is required by Jupyter protocol - exactly once at process startup kernel.send_status_initial("starting").await?; + // Open initial connection comm so the Connections pane shows the database + kernel.open_connection_comm(reader_uri).await?; + Ok(kernel) } @@ -310,10 +320,17 @@ impl KernelServer { match result { Ok(exec_result) => { + // If the connection changed, open a new connection comm + let is_connection_changed = + matches!(&exec_result, ExecutionResult::ConnectionChanged { .. }); + if let ExecutionResult::ConnectionChanged { ref uri, .. } = &exec_result { + self.open_connection_comm(uri).await?; + } + // Send execute_result (not display_data) // Per Jupyter spec: execute_result includes execution_count // Only send if there's something to display (DDL returns None) - if !silent { + if !silent && !is_connection_changed { if let Some(display_data) = format_display_data(exec_result) { // Build message content, including output_location if present let mut content = json!({ @@ -498,6 +515,7 @@ impl KernelServer { self.send_status("busy", parent).await?; // Check if it's a JSON-RPC request + #[allow(clippy::if_same_then_else)] if let Some(method) = data["method"].as_str() { let rpc_id = &data["id"]; @@ -588,11 +606,47 @@ impl KernelServer { } // Handle positron.ui requests else if Some(comm_id.to_string()) == self.ui_comm_id { - tracing::info!("Received UI request: {} (ignoring)", method); + self.send_shell_reply( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "id": rpc_id, + "result": null + } + }), + parent, + identities, + ) + .await?; } // Handle positron.plot requests else if Some(comm_id.to_string()) == self.plot_comm_id { - tracing::info!("Received plot request: {} (ignoring)", method); + self.send_shell_reply( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "id": rpc_id, + "result": null + } + }), + parent, + identities, + ) + .await?; + } + // Handle positron.connection requests + else if Some(comm_id.to_string()) == self.connection_comm_id { + self.handle_connection_rpc(method, rpc_id, comm_id, parent, identities) + .await?; + } + // Handle positron.dataExplorer requests + else if self.data_explorer_comms.contains_key(comm_id) { + self.handle_data_explorer_rpc(method, rpc_id, comm_id, parent, identities) + .await?; } // Unknown comm else { @@ -634,6 +688,16 @@ impl KernelServer { comms[id] = json!({"target_name": "positron.plot"}); } } + if let Some(id) = &self.connection_comm_id { + if target_name.is_none() || target_name == Some("positron.connection") { + comms[id] = json!({"target_name": "positron.connection"}); + } + } + for id in self.data_explorer_comms.keys() { + if target_name.is_none() || target_name == Some("positron.dataExplorer") { + comms[id] = json!({"target_name": "positron.dataExplorer"}); + } + } tracing::info!( "Returning comms: {}", @@ -677,6 +741,11 @@ impl KernelServer { } else if Some(comm_id.to_string()) == self.plot_comm_id { tracing::info!("Closing positron.plot comm"); self.plot_comm_id = None; + } else if Some(comm_id.to_string()) == self.connection_comm_id { + tracing::info!("Closing positron.connection comm"); + self.connection_comm_id = None; + } else if self.data_explorer_comms.remove(comm_id).is_some() { + tracing::info!("Closing data explorer comm: {}", comm_id); } else { tracing::warn!("Close for unknown comm_id: {}", comm_id); } @@ -685,6 +754,248 @@ impl KernelServer { Ok(()) } + /// Open (or replace) a `positron.connection` comm for the current reader. + /// + /// The kernel initiates this comm (backend-initiated). If an existing + /// connection comm is open, it is closed first. + async fn open_connection_comm(&mut self, uri: &str) -> Result<()> { + // Close existing connection comm if any + if let Some(old_id) = self.connection_comm_id.take() { + tracing::info!("Closing old connection comm: {}", old_id); + let close_msg = self.create_message("comm_close", json!({ "comm_id": old_id }), None); + let zmq_msg = self.serialize_message_with_topic(&close_msg, "comm_close")?; + self.iopub.send(zmq_msg).await?; + } + + let comm_id = uuid::Uuid::new_v4().to_string(); + let display_name = executor::display_name_for_uri(uri); + let type_name = executor::type_name_for_uri(uri); + let host = executor::host_for_uri(uri); + let meta_command = format!("-- @connect: {}", uri); + + tracing::info!( + "Opening positron.connection comm: {} ({})", + comm_id, + display_name + ); + + let msg = self.create_message( + "comm_open", + json!({ + "comm_id": comm_id, + "target_name": "positron.connection", + "data": { + "name": display_name, + "language_id": "ggsql", + "host": host, + "type": type_name, + "code": meta_command + } + }), + None, + ); + let zmq_msg = self.serialize_message_with_topic(&msg, "comm_open")?; + self.iopub.send(zmq_msg).await?; + + self.connection_comm_id = Some(comm_id); + Ok(()) + } + + /// Handle JSON-RPC requests on the connection comm + async fn handle_connection_rpc( + &mut self, + method: &str, + rpc_id: &Value, + comm_id: &str, + parent: &JupyterMessage, + identities: &[Vec], + ) -> Result<()> { + tracing::info!("Connection RPC: {}", method); + + let params = &parent.content["data"]["params"]; + + let result = match method { + "list_objects" => { + let path: Vec = params["path"] + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|v| { + v.get("name") + .and_then(|n| n.as_str()) + .map(|s| s.to_string()) + }) + .collect() + }) + .unwrap_or_default(); + match connection::list_objects(self.executor.reader(), &path) { + Ok(objects) => json!(objects), + Err(e) => { + tracing::error!("list_objects failed: {}", e); + json!([]) + } + } + } + "list_fields" => { + let path: Vec = params["path"] + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|v| { + v.get("name") + .and_then(|n| n.as_str()) + .map(|s| s.to_string()) + }) + .collect() + }) + .unwrap_or_default(); + match connection::list_fields(self.executor.reader(), &path) { + Ok(fields) => json!(fields), + Err(e) => { + tracing::error!("list_fields failed: {}", e); + json!([]) + } + } + } + "contains_data" => { + let path: Vec = params["path"].as_array().cloned().unwrap_or_default(); + let has_data = connection::contains_data(&path); + json!(has_data) + } + "get_icon" => json!(""), + "preview_object" => { + let path: Vec = params["path"] + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|v| { + v.get("name") + .and_then(|n| n.as_str()) + .map(|s| s.to_string()) + }) + .collect() + }) + .unwrap_or_default(); + + match DataExplorerState::open(self.executor.reader(), &path) { + Ok(state) => { + let de_comm_id = uuid::Uuid::new_v4().to_string(); + let title = path.last().cloned().unwrap_or_default(); + + // Send comm_open on iopub to open the data viewer + let msg = self.create_message( + "comm_open", + json!({ + "comm_id": de_comm_id, + "target_name": "positron.dataExplorer", + "data": { + "title": title + } + }), + Some(parent), + ); + let zmq_msg = self.serialize_message_with_topic(&msg, "comm_open")?; + self.iopub.send(zmq_msg).await?; + + tracing::info!("Opened data explorer comm: {} for {}", de_comm_id, title); + self.data_explorer_comms.insert(de_comm_id, state); + } + Err(e) => { + tracing::error!("preview_object failed: {}", e); + } + } + json!(null) + } + "get_metadata" => { + let uri = self.executor.reader_uri(); + json!({ + "name": executor::display_name_for_uri(uri), + "language_id": "ggsql", + "host": executor::host_for_uri(uri), + "type": executor::type_name_for_uri(uri), + "code": format!("-- @connect: {}", uri) + }) + } + _ => { + tracing::warn!("Unknown connection method: {}", method); + json!(null) + } + }; + + self.send_shell_reply( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "id": rpc_id, + "result": result + } + }), + parent, + identities, + ) + .await?; + + Ok(()) + } + + /// Handle JSON-RPC requests on a data explorer comm + async fn handle_data_explorer_rpc( + &mut self, + method: &str, + rpc_id: &Value, + comm_id: &str, + parent: &JupyterMessage, + identities: &[Vec], + ) -> Result<()> { + tracing::info!("Data explorer RPC: {}", method); + + let params = &parent.content["data"]["params"]; + + let RpcResponse { result, event } = + if let Some(state) = self.data_explorer_comms.get(comm_id) { + state.handle_rpc(method, params, self.executor.reader()) + } else { + RpcResponse::reply(json!(null)) + }; + + // Send the RPC reply + self.send_shell_reply( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "id": rpc_id, + "result": result + } + }), + parent, + identities, + ) + .await?; + + // Send async event on iopub if present (e.g. return_column_profiles) + if let Some(evt) = event { + self.send_iopub( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "method": evt.method, + "params": evt.params + } + }), + parent, + ) + .await?; + } + + Ok(()) + } + /// Send a message on the IOPub channel async fn send_iopub( &mut self, diff --git a/ggsql-jupyter/src/lib.rs b/ggsql-jupyter/src/lib.rs index 40861748..e7c56850 100644 --- a/ggsql-jupyter/src/lib.rs +++ b/ggsql-jupyter/src/lib.rs @@ -2,9 +2,12 @@ //! //! This module exposes the internal components for testing. +pub mod connection; +pub mod data_explorer; pub mod display; pub mod executor; pub mod message; +pub mod util; // Re-export commonly used types pub use display::format_display_data; diff --git a/ggsql-jupyter/src/main.rs b/ggsql-jupyter/src/main.rs index 432c57ed..73ba93d8 100644 --- a/ggsql-jupyter/src/main.rs +++ b/ggsql-jupyter/src/main.rs @@ -2,10 +2,13 @@ //! //! A Jupyter kernel for executing ggsql queries with rich Vega-Lite visualizations. +mod connection; +mod data_explorer; mod display; mod executor; mod kernel; mod message; +mod util; use anyhow::{Context, Result}; use clap::Parser; @@ -22,6 +25,10 @@ struct Args { #[arg(short = 'f', long = "connection-file")] connection_file: Option, + /// Database connection URI (e.g. "duckdb://memory") + #[arg(long, default_value = "duckdb://memory")] + reader: String, + /// Install the kernel spec #[arg(long)] install: bool, @@ -69,7 +76,7 @@ async fn main() -> Result<()> { tracing::info!("Creating kernel server"); // Create and run kernel - let mut kernel = kernel::KernelServer::new(connection).await?; + let mut kernel = kernel::KernelServer::new(connection, &args.reader).await?; tracing::info!("Kernel ready, starting event loop"); diff --git a/ggsql-jupyter/src/util.rs b/ggsql-jupyter/src/util.rs new file mode 100644 index 00000000..1e1bf152 --- /dev/null +++ b/ggsql-jupyter/src/util.rs @@ -0,0 +1,24 @@ +use polars::prelude::{Column, DataFrame}; + +/// Find a DataFrame column by name, trying multiple names and falling back to +/// case-insensitive matching. This handles ODBC drivers that return uppercase +/// column names (e.g. `TABLE_NAME` instead of `table_name`). +pub fn find_column<'a>(df: &'a DataFrame, names: &[&str]) -> Result<&'a Column, String> { + // Try exact match first + for name in names { + if let Ok(col) = df.column(name) { + return Ok(col); + } + } + // Fall back to case-insensitive match + let col_names = df.get_column_names(); + for name in names { + let lower = name.to_lowercase(); + for cn in &col_names { + if cn.to_lowercase() == lower { + return df.column(cn).map_err(|e| e.to_string()); + } + } + } + Err(format!("Missing column (tried: {:?})", names)) +} diff --git a/ggsql-python/src/lib.rs b/ggsql-python/src/lib.rs index d477275e..6e5c543c 100644 --- a/ggsql-python/src/lib.rs +++ b/ggsql-python/src/lib.rs @@ -164,6 +164,10 @@ impl Reader for PyReaderBridge { }) } + fn execute(&self, query: &str) -> ggsql::Result { + ggsql::reader::execute_with_reader(self, query) + } + fn dialect(&self) -> &dyn ggsql::reader::SqlDialect { &ANSI_DIALECT } diff --git a/ggsql-vscode/package-lock.json b/ggsql-vscode/package-lock.json index 7e5b7c3f..d61cfc95 100644 --- a/ggsql-vscode/package-lock.json +++ b/ggsql-vscode/package-lock.json @@ -8,6 +8,9 @@ "name": "ggsql", "version": "0.1.9", "license": "MIT", + "dependencies": { + "toml": "^3.0.0" + }, "devDependencies": { "@posit-dev/positron": "^0.2.2", "@types/node": "^18.x", @@ -3477,6 +3480,12 @@ "url": "https://github.com/sponsors/SuperchupuDev" } }, + "node_modules/toml": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/toml/-/toml-3.0.0.tgz", + "integrity": "sha512-y/mWCZinnvxjTKYhJ+pYxwD0mRLVvOtdS2Awbgxln6iEnt4rk0yBxeSBHkGJcPucRiG0e55mwWp+g/05rsrd6w==", + "license": "MIT" + }, "node_modules/ts-api-utils": { "version": "2.5.0", "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-2.5.0.tgz", diff --git a/ggsql-vscode/package.json b/ggsql-vscode/package.json index e9847c02..7032bf45 100644 --- a/ggsql-vscode/package.json +++ b/ggsql-vscode/package.json @@ -77,6 +77,9 @@ "check-types": "tsc --noEmit", "lint": "eslint src --ext ts" }, + "dependencies": { + "toml": "^3.0.0" + }, "devDependencies": { "@posit-dev/positron": "^0.2.2", "@types/node": "^18.x", diff --git a/ggsql-vscode/src/connections.ts b/ggsql-vscode/src/connections.ts new file mode 100644 index 00000000..0df0641f --- /dev/null +++ b/ggsql-vscode/src/connections.ts @@ -0,0 +1,397 @@ +/* + * Connection Drivers for Positron's Connections pane + * + * Registers drivers that let users create database connections via the + * "New Connection" dialog. Each driver generates a `-- @connect:` meta-command + * that the ggsql-jupyter kernel interprets to switch readers. + */ + +import * as os from 'os'; +import * as path from 'path'; +import * as fs from 'fs'; +import * as toml from 'toml'; +import type * as positron from '@posit-dev/positron'; + +type PositronApi = positron.PositronApi; +type ConnectionsDriverMetadata = positron.ConnectionsDriverMetadata & { description?: string }; + +/** + * Create the set of ggsql connection drivers to register with Positron. + */ +export function createConnectionDrivers( + positronApi: PositronApi +): positron.ConnectionsDriver[] { + return [ + createDuckDBDriver(positronApi), + createSnowflakeDefaultDriver(positronApi), + createSnowflakePasswordDriver(positronApi), + createSnowflakeSSODriver(positronApi), + createSnowflakePATDriver(positronApi), + createOdbcDriver(positronApi), + ]; +} + +// ============================================================================ +// DuckDB +// ============================================================================ + +/** + * DuckDB connection driver. + * + * Inputs: optional database file path (empty = in-memory). + */ +function createDuckDBDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-duckdb', + metadata: { + languageId: 'ggsql', + name: 'DuckDB', + inputs: [ + { + id: 'database', + label: 'Database', + type: 'string', + value: '', + }, + ], + }, + generateCode: (inputs) => { + const db = inputs.find((i) => i.id === 'database')?.value?.trim(); + if (!db) { + return '-- @connect: duckdb://memory'; + } + return `-- @connect: duckdb://${db}`; + }, + connect: async (code: string) => { + await positronApi.runtime.executeCode('ggsql', code, false); + }, + }; +} + +// ============================================================================ +// Snowflake — shared helpers +// ============================================================================ + +interface SnowflakeConnectionEntry { + name: string; + account?: string; +} + +/** + * Find the Snowflake connections.toml file, checking standard locations. + */ +function findSnowflakeConnectionsToml(): string | undefined { + const candidates: string[] = []; + + // 1. $SNOWFLAKE_HOME/connections.toml + const snowflakeHome = process.env.SNOWFLAKE_HOME; + if (snowflakeHome) { + candidates.push(path.join(snowflakeHome, 'connections.toml')); + } + + // 2. ~/.snowflake/connections.toml + const home = os.homedir(); + candidates.push(path.join(home, '.snowflake', 'connections.toml')); + + // 3. Platform-specific paths + if (process.platform === 'darwin') { + candidates.push( + path.join(home, 'Library', 'Application Support', 'snowflake', 'connections.toml') + ); + } else if (process.platform === 'linux') { + const xdgConfig = process.env.XDG_CONFIG_HOME || path.join(home, '.config'); + candidates.push(path.join(xdgConfig, 'snowflake', 'connections.toml')); + } else if (process.platform === 'win32') { + candidates.push( + path.join(home, 'AppData', 'Local', 'snowflake', 'connections.toml') + ); + } + + for (const candidate of candidates) { + if (fs.existsSync(candidate)) { + return candidate; + } + } + return undefined; +} + +/** + * Read Snowflake connection entries from connections.toml. + */ +function readSnowflakeConnections(): { + connections: SnowflakeConnectionEntry[]; + defaultConnection?: string; +} { + const tomlPath = findSnowflakeConnectionsToml(); + if (!tomlPath) { + return { connections: [] }; + } + + try { + const content = fs.readFileSync(tomlPath, 'utf-8'); + const parsed = toml.parse(content); + + const defaultConnection = + process.env.SNOWFLAKE_DEFAULT_CONNECTION_NAME || + parsed.default_connection_name || + undefined; + + const connections: SnowflakeConnectionEntry[] = Object.keys(parsed) + .filter( + (key) => + key !== 'default_connection_name' && + typeof parsed[key] === 'object' && + parsed[key] !== null + ) + .map((name) => ({ + name, + account: parsed[name].account as string | undefined, + })); + + return { connections, defaultConnection }; + } catch { + return { connections: [] }; + } +} + +/** + * Build an ODBC connection string for Snowflake with the given parts. + */ +function buildSnowflakeOdbc(parts: Record): string { + let connStr = `Driver=Snowflake;Server=${parts.account}.snowflakecomputing.com`; + if (parts.uid) { + connStr += `;UID=${parts.uid}`; + } + if (parts.pwd) { + connStr += `;PWD=${parts.pwd}`; + } + if (parts.authenticator) { + connStr += `;Authenticator=${parts.authenticator}`; + } + if (parts.token) { + connStr += `;Token=${parts.token}`; + } + if (parts.warehouse) { + connStr += `;Warehouse=${parts.warehouse}`; + } + if (parts.database) { + connStr += `;Database=${parts.database}`; + } + if (parts.schema) { + connStr += `;Schema=${parts.schema}`; + } + return `-- @connect: odbc://${connStr}`; +} + +function snowflakeConnect(positronApi: PositronApi) { + return async (code: string) => { + await positronApi.runtime.executeCode('ggsql', code, false); + }; +} + +// ============================================================================ +// Snowflake — Default Connection (connections.toml) +// ============================================================================ + +function createSnowflakeDefaultDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + const { connections, defaultConnection } = readSnowflakeConnections(); + + let inputs: positron.ConnectionsInput[]; + if (connections.length > 0) { + const defaultValue = + defaultConnection || + (connections.find((c) => c.name === 'default')?.name ?? connections[0].name); + + inputs = [ + { + id: 'connection_name', + label: 'Connection Name', + type: 'option', + options: connections.map((conn) => ({ + identifier: conn.name, + title: conn.account + ? `${conn.name} (${conn.account})` + : conn.name, + })), + value: defaultValue, + }, + ]; + } else { + inputs = [ + { + id: 'connection_name', + label: 'Connection Name', + type: 'string', + value: 'default', + }, + ]; + } + + return { + driverId: 'ggsql-snowflake-default', + metadata: { + languageId: 'ggsql', + name: 'Snowflake', + description: 'Default Connection (connections.toml)', + inputs, + } as ConnectionsDriverMetadata, + generateCode: (inputs) => { + const name = + inputs.find((i) => i.id === 'connection_name')?.value?.trim() || 'default'; + return `-- @connect: odbc://Driver=Snowflake;ConnectionName=${name}`; + }, + connect: snowflakeConnect(positronApi), + }; +} + +// ============================================================================ +// Snowflake — Username/Password +// ============================================================================ + +function createSnowflakePasswordDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-snowflake-password', + metadata: { + languageId: 'ggsql', + name: 'Snowflake', + description: 'Username/Password', + inputs: [ + { id: 'account', label: 'Account', type: 'string' }, + { id: 'user', label: 'User', type: 'string' }, + { id: 'password', label: 'Password', type: 'string' }, + { id: 'warehouse', label: 'Warehouse', type: 'string' }, + { id: 'database', label: 'Database', type: 'string', value: '' }, + { id: 'schema', label: 'Schema', type: 'string', value: '' }, + ], + } as ConnectionsDriverMetadata, + generateCode: (inputs) => { + const get = (id: string) => + inputs.find((i) => i.id === id)?.value?.trim() || ''; + return buildSnowflakeOdbc({ + account: get('account'), + uid: get('user'), + pwd: get('password'), + warehouse: get('warehouse'), + database: get('database') || undefined, + schema: get('schema') || undefined, + }); + }, + connect: snowflakeConnect(positronApi), + }; +} + +// ============================================================================ +// Snowflake — External Browser (SSO) +// ============================================================================ + +function createSnowflakeSSODriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-snowflake-sso', + metadata: { + languageId: 'ggsql', + name: 'Snowflake', + description: 'External Browser (SSO)', + inputs: [ + { id: 'account', label: 'Account', type: 'string' }, + { id: 'user', label: 'User', type: 'string', value: '' }, + { id: 'warehouse', label: 'Warehouse', type: 'string' }, + { id: 'database', label: 'Database', type: 'string', value: '' }, + { id: 'schema', label: 'Schema', type: 'string', value: '' }, + ], + } as ConnectionsDriverMetadata, + generateCode: (inputs) => { + const get = (id: string) => + inputs.find((i) => i.id === id)?.value?.trim() || ''; + return buildSnowflakeOdbc({ + account: get('account'), + uid: get('user') || undefined, + authenticator: 'externalbrowser', + warehouse: get('warehouse'), + database: get('database') || undefined, + schema: get('schema') || undefined, + }); + }, + connect: snowflakeConnect(positronApi), + }; +} + +// ============================================================================ +// Snowflake — Programmatic Access Token (PAT) +// ============================================================================ + +function createSnowflakePATDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-snowflake-pat', + metadata: { + languageId: 'ggsql', + name: 'Snowflake', + description: 'Programmatic Access Token (PAT)', + inputs: [ + { id: 'account', label: 'Account', type: 'string' }, + { id: 'token', label: 'Token', type: 'string' }, + { id: 'warehouse', label: 'Warehouse', type: 'string' }, + { id: 'database', label: 'Database', type: 'string', value: '' }, + { id: 'schema', label: 'Schema', type: 'string', value: '' }, + ], + } as ConnectionsDriverMetadata, + generateCode: (inputs) => { + const get = (id: string) => + inputs.find((i) => i.id === id)?.value?.trim() || ''; + return buildSnowflakeOdbc({ + account: get('account'), + authenticator: 'programmatic_access_token', + token: get('token'), + warehouse: get('warehouse'), + database: get('database') || undefined, + schema: get('schema') || undefined, + }); + }, + connect: snowflakeConnect(positronApi), + }; +} + +// ============================================================================ +// Generic ODBC +// ============================================================================ + +/** + * Generic ODBC connection driver. + * + * Lets users paste a raw ODBC connection string. + */ +function createOdbcDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-odbc', + metadata: { + languageId: 'ggsql', + name: 'ODBC', + inputs: [ + { + id: 'connection_string', + label: 'Connection String', + type: 'string', + }, + ], + }, + generateCode: (inputs) => { + const connStr = + inputs.find((i) => i.id === 'connection_string')?.value ?? ''; + return `-- @connect: odbc://${connStr}`; + }, + connect: async (code: string) => { + await positronApi.runtime.executeCode('ggsql', code, false); + }, + }; +} diff --git a/ggsql-vscode/src/extension.ts b/ggsql-vscode/src/extension.ts index 6d6d1ec5..54b76edf 100644 --- a/ggsql-vscode/src/extension.ts +++ b/ggsql-vscode/src/extension.ts @@ -8,6 +8,7 @@ import * as vscode from 'vscode'; import { tryAcquirePositronApi } from '@posit-dev/positron'; import { GgsqlRuntimeManager } from './manager'; +import { createConnectionDrivers } from './connections'; // Output channel for logging const outputChannel = vscode.window.createOutputChannel('ggsql'); @@ -42,6 +43,15 @@ export function activate(context: vscode.ExtensionContext): void { context.subscriptions.push(disposable); log('ggsql runtime manager registered successfully'); + + // Register connection drivers for the Connections pane + const drivers = createConnectionDrivers(positronApi); + for (const driver of drivers) { + const driverDisposable = positronApi.connections.registerConnectionDriver(driver); + context.subscriptions.push(driverDisposable); + } + + log(`Registered ${drivers.length} connection drivers`); } /** diff --git a/ggsql-vscode/src/manager.ts b/ggsql-vscode/src/manager.ts index 0e4c454b..e8254044 100644 --- a/ggsql-vscode/src/manager.ts +++ b/ggsql-vscode/src/manager.ts @@ -199,9 +199,14 @@ function generateMetadata( * * @param kernelPath - Path to the ggsql-jupyter executable */ -function createKernelSpec(kernelPath: string): JupyterKernelSpec { +function createKernelSpec(kernelPath: string, readerUri?: string): JupyterKernelSpec { + const argv = [kernelPath, '-f', '{connection_file}']; + if (readerUri) { + argv.push('--reader', readerUri); + } + return { - argv: [kernelPath, '-f', '{connection_file}'], + argv, display_name: 'ggsql', language: 'ggsql', interrupt_mode: 'signal', diff --git a/src/Cargo.toml b/src/Cargo.toml index 38dc04e6..c24259cc 100644 --- a/src/Cargo.toml +++ b/src/Cargo.toml @@ -34,6 +34,8 @@ duckdb = { workspace = true, optional = true } arrow = { workspace = true, optional = true } postgres = { workspace = true, optional = true } rusqlite = { workspace = true, optional = true } +odbc-api = { workspace = true, optional = true } +toml_edit = { workspace = true, optional = true } # Writers plotters = { workspace = true, optional = true } @@ -61,20 +63,22 @@ pyo3 = { workspace = true, optional = true } [dev-dependencies] jsonschema = "0.44" proptest.workspace = true +tempfile = "3.8" ureq = "3" [features] -default = ["duckdb", "sqlite", "vegalite", "ipc", "parquet", "builtin-data"] +default = ["duckdb", "sqlite", "vegalite", "ipc", "parquet", "builtin-data", "odbc"] ipc = ["polars/ipc"] duckdb = ["dep:duckdb", "dep:arrow"] parquet = ["polars/parquet"] postgres = ["dep:postgres"] sqlite = ["dep:rusqlite"] +odbc = ["dep:odbc-api", "dep:toml_edit"] vegalite = [] ggplot2 = [] builtin-data = [] python = ["dep:pyo3"] -all-readers = ["duckdb", "postgres", "sqlite"] +all-readers = ["duckdb", "postgres", "sqlite", "odbc"] all-writers = ["vegalite", "ggplot2", "plotters"] # cargo-packager configuration for cross-platform installers diff --git a/src/execute/cte.rs b/src/execute/cte.rs index 5a6b665f..041bb298 100644 --- a/src/execute/cte.rs +++ b/src/execute/cte.rs @@ -94,7 +94,7 @@ pub fn transform_cte_references(sql: &str, cte_names: &HashSet) -> Strin let mut result = sql.to_string(); for cte_name in cte_names { - let temp_table_name = naming::cte_table(cte_name); + let temp_table_name = naming::quote_ident(&naming::cte_table(cte_name)); // Replace table references: FROM cte_name, JOIN cte_name, cte_name.column // Use word boundary matching to avoid replacing substrings @@ -360,7 +360,7 @@ mod tests { ( "SELECT * FROM sales WHERE year = 2024", vec!["sales"], - vec!["FROM __ggsql_cte_sales_", "__ WHERE year = 2024"], + vec!["FROM \"__ggsql_cte_sales_", "__\" WHERE year = 2024"], None, ), // Multiple CTE references with qualified columns @@ -368,8 +368,8 @@ mod tests { "SELECT sales.date, targets.revenue FROM sales JOIN targets ON sales.id = targets.id", vec!["sales", "targets"], vec![ - "FROM __ggsql_cte_sales_", - "JOIN __ggsql_cte_targets_", + "FROM \"__ggsql_cte_sales_", + "JOIN \"__ggsql_cte_targets_", "__ggsql_cte_sales_", // qualified reference sales.date "__ggsql_cte_targets_", // qualified reference targets.revenue ], diff --git a/src/execute/layer.rs b/src/execute/layer.rs index a645bd29..9e2d2fc8 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -54,7 +54,10 @@ pub fn layer_source_query( None => { // Layer uses global data debug_assert!(has_global, "Layer has no source and no global data"); - Ok(format!("SELECT * FROM {}", naming::global_table())) + Ok(format!( + "SELECT * FROM {}", + naming::quote_ident(&naming::global_table()) + )) } } } @@ -109,17 +112,27 @@ pub fn build_layer_select_list( if let Some(req) = cast_map.get(name.as_str()) { // Cast and rename to prefixed aesthetic name format!( - "CAST(\"{}\" AS {}) AS \"{}\"", - name, req.sql_type_name, aes_col_name + "CAST({} AS {}) AS {}", + naming::quote_ident(name), + req.sql_type_name, + naming::quote_ident(&aes_col_name) ) } else { // Just rename to prefixed aesthetic name - format!("\"{}\" AS \"{}\"", name, aes_col_name) + format!( + "{} AS {}", + naming::quote_ident(name), + naming::quote_ident(&aes_col_name) + ) } } AestheticValue::Literal(lit) => { // Literals become columns with prefixed aesthetic name - format!("{} AS \"{}\"", lit.to_sql(dialect), aes_col_name) + format!( + "{} AS {}", + lit.to_sql(dialect), + naming::quote_ident(&aes_col_name) + ) } }; @@ -314,15 +327,15 @@ pub fn apply_pre_stat_transform( .filter(|col| seen.insert(&col.name)) .map(|col| { if let Some((_, sql)) = transform_exprs.iter().find(|(c, _)| c == &col.name) { - format!("{} AS \"{}\"", sql, col.name) + format!("{} AS {}", sql, naming::quote_ident(&col.name)) } else { - format!("\"{}\"", col.name) + naming::quote_ident(&col.name) } }) .collect(); format!( - "SELECT {} FROM ({}) AS __ggsql_pre__", + "SELECT {} FROM ({}) AS \"__ggsql_pre__\"", select_exprs.join(", "), query ) @@ -374,14 +387,14 @@ pub fn build_layer_base_query( // Build query with optional WHERE clause if let Some(ref f) = layer.filter { format!( - "SELECT {} FROM ({}) AS __ggsql_src__ WHERE {}", + "SELECT {} FROM ({}) AS \"__ggsql_src__\" WHERE {}", select_clause, source_query, f.as_str() ) } else { format!( - "SELECT {} FROM ({}) AS __ggsql_src__", + "SELECT {} FROM ({}) AS \"__ggsql_src__\"", select_clause, source_query ) } @@ -611,7 +624,11 @@ where final_remappings.get(stat).map(|aes| { let stat_col = naming::stat_column(stat); let prefixed_aes = naming::aesthetic_column(aes); - format!("\"{}\" AS \"{}\"", stat_col, prefixed_aes) + format!( + "{} AS {}", + naming::quote_ident(&stat_col), + naming::quote_ident(&prefixed_aes) + ) }) }) .collect(); @@ -620,7 +637,7 @@ where transformed_query } else { format!( - "SELECT *, {} FROM ({}) AS __ggsql_stat__", + "SELECT *, {} FROM ({}) AS \"__ggsql_stat__\"", stat_rename_exprs.join(", "), transformed_query ) @@ -809,7 +826,7 @@ fn process_annotation_layer(layer: &mut Layer, dialect: &dyn SqlDialect) -> Resu // Step 6: Build complete SQL query let column_list = column_names .iter() - .map(|c| format!("\"{}\"", c)) + .map(|c| naming::quote_ident(c)) .collect::>() .join(", "); diff --git a/src/execute/mod.rs b/src/execute/mod.rs index 3a6aea9b..709d9060 100644 --- a/src/execute/mod.rs +++ b/src/execute/mod.rs @@ -929,7 +929,7 @@ pub struct PreparedData { /// # Arguments /// * `query` - The full ggsql query string /// * `reader` - A Reader implementation for executing SQL -pub fn prepare_data_with_reader(query: &str, reader: &R) -> Result { +pub fn prepare_data_with_reader(query: &str, reader: &dyn Reader) -> Result { let execute_query = |sql: &str| reader.execute_sql(sql); let dialect = reader.dialect(); diff --git a/src/execute/schema.rs b/src/execute/schema.rs index ad80aa7f..3df3c55f 100644 --- a/src/execute/schema.rs +++ b/src/execute/schema.rs @@ -21,16 +21,22 @@ pub type TypeInfo = (String, DataType, bool); pub fn build_minmax_query(source_query: &str, column_names: &[&str]) -> String { let min_exprs: Vec = column_names .iter() - .map(|name| format!("MIN(\"{}\") AS \"{}\"", name, name)) + .map(|name| { + let q = naming::quote_ident(name); + format!("MIN({q}) AS {q}") + }) .collect(); let max_exprs: Vec = column_names .iter() - .map(|name| format!("MAX(\"{}\") AS \"{}\"", name, name)) + .map(|name| { + let q = naming::quote_ident(name); + format!("MAX({q}) AS {q}") + }) .collect(); format!( - "WITH __ggsql_source__ AS ({}) SELECT {} FROM __ggsql_source__ UNION ALL SELECT {} FROM __ggsql_source__", + "WITH \"__ggsql_source__\" AS ({}) SELECT {} FROM \"__ggsql_source__\" UNION ALL SELECT {} FROM \"__ggsql_source__\"", source_query, min_exprs.join(", "), max_exprs.join(", ") diff --git a/src/naming.rs b/src/naming.rs index bb2773c8..a25cbdc4 100644 --- a/src/naming.rs +++ b/src/naming.rs @@ -224,6 +224,22 @@ pub fn aesthetic_column(aesthetic: &str) -> String { format!("{}{}{}", AES_PREFIX, aesthetic, GGSQL_SUFFIX) } +// ============================================================================ +// SQL Quoting +// ============================================================================ + +/// Double-quote a SQL identifier for case-preserving databases (e.g. Snowflake). +/// +/// # Example +/// ``` +/// use ggsql::naming; +/// assert_eq!(naming::quote_ident("__ggsql_aes_x__"), "\"__ggsql_aes_x__\""); +/// assert_eq!(naming::quote_ident("has\"quote"), "\"has\"\"quote\""); +/// ``` +pub fn quote_ident(name: &str) -> String { + format!("\"{}\"", name.replace('"', "\"\"")) +} + // ============================================================================ // Detection Functions // ============================================================================ diff --git a/src/plot/layer/geom/area.rs b/src/plot/layer/geom/area.rs index c16c5e63..7427e0e2 100644 --- a/src/plot/layer/geom/area.rs +++ b/src/plot/layer/geom/area.rs @@ -71,7 +71,7 @@ impl GeomTrait for Area { // Area geom needs ordering by pos1 (domain axis) for proper rendering let order_col = naming::aesthetic_column("pos1"); Ok(StatResult::Transformed { - query: format!("{} ORDER BY \"{}\"", query, order_col), + query: format!("{} ORDER BY {}", query, naming::quote_ident(&order_col)), stat_columns: vec![], dummy_columns: vec![], consumed_aesthetics: vec![], diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index cddda1d4..d64bce9f 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -178,37 +178,44 @@ fn stat_bar_count( if let Some(weight_col) = weight_value.column_name() { if schema_columns.contains(weight_col) { // weight column exists - use SUM (but still call it "count") - format!("SUM({}) AS {}", weight_col, stat_count) + format!( + "SUM({}) AS {}", + naming::quote_ident(weight_col), + naming::quote_ident(&stat_count) + ) } else { // weight mapped but column doesn't exist - fall back to COUNT // (this shouldn't happen with upfront validation, but handle gracefully) - format!("COUNT(*) AS {}", stat_count) + format!("COUNT(*) AS {}", naming::quote_ident(&stat_count)) } } else { // Shouldn't happen (not literal, not column), fall back to COUNT - format!("COUNT(*) AS {}", stat_count) + format!("COUNT(*) AS {}", naming::quote_ident(&stat_count)) } } else { // weight not mapped - use COUNT - format!("COUNT(*) AS {}", stat_count) + format!("COUNT(*) AS {}", naming::quote_ident(&stat_count)) }; // Build the query based on whether x is mapped or not // Use two-stage query: first GROUP BY, then calculate proportion with window function let (transformed_query, stat_columns, dummy_columns, consumed_aesthetics) = if use_dummy_x { // x is not mapped - use dummy constant, no GROUP BY on x + let q_x = naming::quote_ident(&stat_x); + let q_count = naming::quote_ident(&stat_count); + let q_prop = naming::quote_ident(&stat_proportion); let (grouped_select, final_select) = if group_by.is_empty() { ( format!( "'{dummy}' AS {x}, {agg}", dummy = stat_dummy_value, - x = stat_x, + x = q_x, agg = agg_expr ), format!( "*, {count} * 1.0 / SUM({count}) OVER () AS {prop}", - count = stat_count, - prop = stat_proportion + count = q_count, + prop = q_prop ), ) } else { @@ -218,14 +225,14 @@ fn stat_bar_count( "{g}, '{dummy}' AS {x}, {agg}", g = grp_cols, dummy = stat_dummy_value, - x = stat_x, + x = q_x, agg = agg_expr ), format!( "*, {count} * 1.0 / SUM({count}) OVER (PARTITION BY {grp}) AS {prop}", - count = stat_count, + count = q_count, grp = grp_cols, - prop = stat_proportion + prop = q_prop ), ) }; @@ -233,7 +240,7 @@ fn stat_bar_count( let query_str = if group_by.is_empty() { // No grouping at all - single aggregate format!( - "WITH __stat_src__ AS ({query}), __grouped__ AS (SELECT {grouped} FROM __stat_src__) SELECT {final} FROM __grouped__", + "WITH \"__stat_src__\" AS ({query}), \"__grouped__\" AS (SELECT {grouped} FROM \"__stat_src__\") SELECT {final} FROM \"__grouped__\"", query = query, grouped = grouped_select, final = final_select @@ -242,7 +249,7 @@ fn stat_bar_count( // Group by partition/facet variables only let group_cols = group_by.join(", "); format!( - "WITH __stat_src__ AS ({query}), __grouped__ AS (SELECT {grouped} FROM __stat_src__ GROUP BY {group}) SELECT {final} FROM __grouped__", + "WITH \"__stat_src__\" AS ({query}), \"__grouped__\" AS (SELECT {grouped} FROM \"__stat_src__\" GROUP BY {group}) SELECT {final} FROM \"__grouped__\"", query = query, grouped = grouped_select, group = group_cols, @@ -264,7 +271,7 @@ fn stat_bar_count( ) } else { // x is mapped - use existing logic with two-stage query - let x_col = x_col.unwrap(); + let x_col = naming::quote_ident(&x_col.unwrap()); // Build grouped columns (group_by includes partition_by + facet variables + x) let group_cols = if group_by.is_empty() { @@ -276,13 +283,15 @@ fn stat_bar_count( }; // Keep original x column name, only add the aggregated stat column + let q_count = naming::quote_ident(&stat_count); + let q_prop = naming::quote_ident(&stat_proportion); let (grouped_select, final_select) = if group_by.is_empty() { ( format!("{x}, {agg}", x = x_col, agg = agg_expr), format!( "*, {count} * 1.0 / SUM({count}) OVER () AS {prop}", - count = stat_count, - prop = stat_proportion + count = q_count, + prop = q_prop ), ) } else { @@ -291,15 +300,15 @@ fn stat_bar_count( format!("{g}, {x}, {agg}", g = grp_cols, x = x_col, agg = agg_expr), format!( "*, {count} * 1.0 / SUM({count}) OVER (PARTITION BY {grp}) AS {prop}", - count = stat_count, + count = q_count, grp = grp_cols, - prop = stat_proportion + prop = q_prop ), ) }; let query_str = format!( - "WITH __stat_src__ AS ({query}), __grouped__ AS (SELECT {grouped} FROM __stat_src__ GROUP BY {group}) SELECT {final} FROM __grouped__", + "WITH \"__stat_src__\" AS ({query}), \"__grouped__\" AS (SELECT {grouped} FROM \"__stat_src__\" GROUP BY {group}) SELECT {final} FROM \"__grouped__\"", query = query, grouped = grouped_select, group = group_cols, diff --git a/src/plot/layer/geom/boxplot.rs b/src/plot/layer/geom/boxplot.rs index 89dcc921..fdc7bae6 100644 --- a/src/plot/layer/geom/boxplot.rs +++ b/src/plot/layer/geom/boxplot.rs @@ -169,12 +169,16 @@ fn boxplot_sql_compute_summary( coef: &f64, dialect: &dyn SqlDialect, ) -> String { - let groups_str = groups.join(", "); + let quoted_groups: Vec = groups.iter().map(|g| naming::quote_ident(g)).collect(); + let groups_str = quoted_groups.join(", "); let lower_expr = dialect.sql_greatest(&[&format!("q1 - {coef} * (q3 - q1)"), "min"]); let upper_expr = dialect.sql_least(&[&format!("q3 + {coef} * (q3 - q1)"), "max"]); let q1 = dialect.sql_percentile(value, 0.25, from, groups); let median = dialect.sql_percentile(value, 0.50, from, groups); let q3 = dialect.sql_percentile(value, 0.75, from, groups); + let qt = "\"__ggsql_qt__\""; + let fn_alias = "\"__ggsql_fn__\""; + let quoted_value = naming::quote_ident(value); format!( "SELECT *, @@ -188,14 +192,14 @@ fn boxplot_sql_compute_summary( {q1} AS q1, {median} AS median, {q3} AS q3 - FROM ({from}) AS __ggsql_qt__ + FROM ({from}) AS {qt} WHERE {value} IS NOT NULL GROUP BY {groups} - ) AS __ggsql_fn__", + ) AS {fn_alias}", lower_expr = lower_expr, upper_expr = upper_expr, groups = groups_str, - value = value, + value = quoted_value, from = from, q1 = q1, median = median, @@ -207,10 +211,12 @@ fn boxplot_sql_filter_outliers(groups: &[String], value: &str, from: &str) -> St let mut join_pairs = Vec::new(); let mut keep_columns = Vec::new(); for column in groups { - join_pairs.push(format!("raw.{} = summary.{}", column, column)); - keep_columns.push(format!("raw.{}", column)); + let quoted = naming::quote_ident(column); + join_pairs.push(format!("raw.{} = summary.{}", quoted, quoted)); + keep_columns.push(format!("raw.{}", quoted)); } + let quoted_value = naming::quote_ident(value); // We're joining outliers with the summary to use the lower/upper whisker // values as a filter format!( @@ -221,7 +227,7 @@ fn boxplot_sql_filter_outliers(groups: &[String], value: &str, from: &str) -> St FROM ({from}) raw JOIN summary ON {pairs} WHERE raw.{value} NOT BETWEEN summary.lower AND summary.upper", - value = value, + value = quoted_value, groups = keep_columns.join(", "), pairs = join_pairs.join(" AND "), from = from @@ -235,11 +241,12 @@ fn boxplot_sql_append_outliers( raw_query: &str, draw_outliers: &bool, ) -> String { - let value_name = naming::stat_column("value"); - let value2_name = naming::stat_column("value2"); - let type_name = naming::stat_column("type"); + let value_name = naming::quote_ident(&naming::stat_column("value")); + let value2_name = naming::quote_ident(&naming::stat_column("value2")); + let type_name = naming::quote_ident(&naming::stat_column("type")); - let groups_str = groups.join(", "); + let quoted_groups: Vec = groups.iter().map(|g| naming::quote_ident(g)).collect(); + let groups_str = quoted_groups.join(", "); // Helper to build visual-element rows from summary table // Each row type maps to one visual element with y and yend where needed @@ -306,14 +313,14 @@ mod tests { fn test_sql_compute_summary_basic() { let groups = vec!["category".to_string()]; let result = boxplot_sql_compute_summary("data", &groups, "value", &1.5, &AnsiDialect); - assert!(result.contains("NTILE(4) OVER (ORDER BY value)")); + assert!(result.contains("NTILE(4) OVER (ORDER BY \"value\")")); assert!(result.contains("AS q1")); assert!(result.contains("AS median")); assert!(result.contains("AS q3")); - assert!(result.contains("MIN(value) AS min")); - assert!(result.contains("MAX(value) AS max")); - assert!(result.contains("WHERE value IS NOT NULL")); - assert!(result.contains("GROUP BY category")); + assert!(result.contains("MIN(\"value\") AS min")); + assert!(result.contains("MAX(\"value\") AS max")); + assert!(result.contains("WHERE \"value\" IS NOT NULL")); + assert!(result.contains("GROUP BY \"category\"")); assert!(result.contains("CASE WHEN (q1 - 1.5")); assert!(result.contains("CASE WHEN (q3 + 1.5")); } @@ -322,8 +329,8 @@ mod tests { fn test_sql_compute_summary_multiple_groups() { let groups = vec!["cat".to_string(), "region".to_string()]; let result = boxplot_sql_compute_summary("tbl", &groups, "val", &1.5, &AnsiDialect); - assert!(result.contains("GROUP BY cat, region")); - assert!(result.contains("NTILE(4) OVER (ORDER BY val)")); + assert!(result.contains("GROUP BY \"cat\", \"region\"")); + assert!(result.contains("NTILE(4) OVER (ORDER BY \"val\")")); } #[test] @@ -344,8 +351,8 @@ mod tests { let groups = vec!["cat".to_string(), "region".to_string()]; let result = boxplot_sql_filter_outliers(&groups, "value", "raw_data"); assert!(result.contains("JOIN summary ON")); - assert!(result.contains("raw.cat = summary.cat")); - assert!(result.contains("raw.region = summary.region")); + assert!(result.contains("raw.\"cat\" = summary.\"cat\"")); + assert!(result.contains("raw.\"region\" = summary.\"region\"")); assert!(result.contains("NOT BETWEEN summary.lower AND summary.upper")); assert!(result.contains("'outlier' AS type")); } @@ -373,16 +380,16 @@ mod tests { (CASE WHEN (q3 + 1.5 * (q3 - q1)) <= (max) THEN (q3 + 1.5 * (q3 - q1)) ELSE (max) END) AS upper FROM ( SELECT - category, - MIN(price) AS min, - MAX(price) AS max, + "category", + MIN("price") AS min, + MAX("price") AS max, {q1} AS q1, {median} AS median, {q3} AS q3 - FROM (SELECT * FROM sales) AS __ggsql_qt__ - WHERE price IS NOT NULL - GROUP BY category - ) AS __ggsql_fn__"# + FROM (SELECT * FROM sales) AS "__ggsql_qt__" + WHERE "price" IS NOT NULL + GROUP BY "category" + ) AS "__ggsql_fn__""# ); assert_eq!(result, expected); @@ -409,16 +416,16 @@ mod tests { (CASE WHEN (q3 + 1.5 * (q3 - q1)) <= (max) THEN (q3 + 1.5 * (q3 - q1)) ELSE (max) END) AS upper FROM ( SELECT - region, product, - MIN(revenue) AS min, - MAX(revenue) AS max, + "region", "product", + MIN("revenue") AS min, + MAX("revenue") AS max, {q1} AS q1, {median} AS median, {q3} AS q3 - FROM (SELECT * FROM data) AS __ggsql_qt__ - WHERE revenue IS NOT NULL - GROUP BY region, product - ) AS __ggsql_fn__"# + FROM (SELECT * FROM data) AS "__ggsql_qt__" + WHERE "revenue" IS NOT NULL + GROUP BY "region", "product" + ) AS "__ggsql_fn__""# ); assert_eq!(result, expected); @@ -445,9 +452,9 @@ mod tests { assert!(result.contains("'median'")); // Check column names - assert!(result.contains(&format!("AS {}", naming::stat_column("value")))); - assert!(result.contains(&format!("AS {}", naming::stat_column("value2")))); - assert!(result.contains(&format!("AS {}", naming::stat_column("type")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("value")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("value2")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("type")))); } #[test] @@ -469,9 +476,9 @@ mod tests { assert!(result.contains("'median'")); // Check column names - assert!(result.contains(&format!("AS {}", naming::stat_column("value")))); - assert!(result.contains(&format!("AS {}", naming::stat_column("value2")))); - assert!(result.contains(&format!("AS {}", naming::stat_column("type")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("value")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("value2")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("type")))); } #[test] @@ -481,8 +488,8 @@ mod tests { let raw = "(SELECT * FROM raw_data)"; let result = boxplot_sql_append_outliers(summary, &groups, "val", raw, &true); - // Verify all groups are present - assert!(result.contains("cat, region, year")); + // Verify all groups are present (quoted) + assert!(result.contains("\"cat\", \"region\", \"year\"")); // Check structure assert!(result.contains("WITH")); @@ -491,9 +498,9 @@ mod tests { // Verify outlier join conditions for all groups let outlier_section = result.split("outliers AS").nth(1).unwrap(); - assert!(outlier_section.contains("raw.cat = summary.cat")); - assert!(outlier_section.contains("raw.region = summary.region")); - assert!(outlier_section.contains("raw.year = summary.year")); + assert!(outlier_section.contains("raw.\"cat\" = summary.\"cat\"")); + assert!(outlier_section.contains("raw.\"region\" = summary.\"region\"")); + assert!(outlier_section.contains("raw.\"year\" = summary.\"year\"")); } // ==================== GeomTrait Implementation Tests ==================== diff --git a/src/plot/layer/geom/density.rs b/src/plot/layer/geom/density.rs index e67e0cbe..a52e777f 100644 --- a/src/plot/layer/geom/density.rs +++ b/src/plot/layer/geom/density.rs @@ -229,13 +229,15 @@ fn density_sql_bandwidth( let (groups_select, group_by) = if groups.is_empty() { (String::new(), String::new()) } else { - let groups_str = groups.join(", "); + let quoted_groups: Vec = groups.iter().map(|g| naming::quote_ident(g)).collect(); + let groups_str = quoted_groups.join(", "); ( format!("\n {},", groups_str), format!("\n GROUP BY {}", groups_str), ) }; + let quoted_value = naming::quote_ident(value); format!( "WITH RECURSIVE bandwidth AS ( @@ -243,12 +245,12 @@ fn density_sql_bandwidth( {bw_expr} AS bw,{groups_select} MIN({value}) AS x_min, MAX({value}) AS x_max - FROM ({from}) AS __ggsql_qt__ + FROM ({from}) AS \"__ggsql_qt__\" WHERE {value} IS NOT NULL{group_by} )", bw_expr = bw_expr, groups_select = groups_select, - value = value, + value = quoted_value, from = from, group_by = group_by ) @@ -264,7 +266,8 @@ fn silverman_rule( // The query computes Silverman's rule of thumb (R's `stats::bw.nrd0()`). // We absorb the adjustment in the 0.9 multiplier of the rule let adjust = 0.9 * adjust; - let stddev = format!("SQRT(AVG({v}*{v}) - AVG({v})*AVG({v}))", v = value_column); + let v = naming::quote_ident(value_column); + let stddev = format!("SQRT(AVG({v}*{v}) - AVG({v})*AVG({v}))", v = v); let q75 = dialect.sql_percentile(value_column, 0.75, from, groups); let q25 = dialect.sql_percentile(value_column, 0.25, from, groups); let iqr = format!("({q75} - {q25}) / 1.34"); @@ -351,34 +354,36 @@ fn build_data_cte( ) -> String { // Include weight column if provided, otherwise default to 1.0 let weight_col = if let Some(w) = weight { - format!(", {} AS weight", w) + format!(", {} AS weight", naming::quote_ident(w)) } else { ", 1.0 AS weight".to_string() }; let smooth_col = if let Some(s) = smooth { - format!(", {}", s) + format!(", {}", naming::quote_ident(s)) } else { "".to_string() }; + let quoted_value = naming::quote_ident(value); // Only filter out nulls in value column, keep NULLs in group columns - let mut filter_valid = format!("{} IS NOT NULL", value); + let mut filter_valid = format!("{} IS NOT NULL", quoted_value); if let Some(s) = smooth { filter_valid = format!( - "{filter} AND {smth} IS NOT NULL", + "{filter} AND {} IS NOT NULL", + naming::quote_ident(s), filter = filter_valid, - smth = s ); } + let quoted_groups: Vec = group_by.iter().map(|g| naming::quote_ident(g)).collect(); format!( "data AS ( SELECT {groups}{value} AS val{weight_col}{smooth_col} FROM ({from}) WHERE {filter_valid} )", - groups = with_trailing_comma(&group_by.join(", ")), - value = value, + groups = with_trailing_comma("ed_groups.join(", ")), + value = quoted_value, weight_col = weight_col, smooth_col = smooth_col, from = from, @@ -420,12 +425,13 @@ fn build_grid_cte( "grid AS ( SELECT {x_formula} AS x FROM global_range AS global - CROSS JOIN __ggsql_seq__ AS seq + CROSS JOIN \"__ggsql_seq__\" AS seq )", x_formula = x_formula ) } else { - let groups_str = groups.join(", "); + let quoted_groups: Vec = groups.iter().map(|g| naming::quote_ident(g)).collect(); + let groups_str = quoted_groups.join(", "); // When tails is specified, create full_grid; otherwise create grid directly let cte_name = if tails.is_some() { "full_grid" } else { "grid" }; format!( @@ -434,7 +440,7 @@ fn build_grid_cte( {groups}, {x_formula} AS x FROM global_range AS global - CROSS JOIN __ggsql_seq__ AS seq + CROSS JOIN \"__ggsql_seq__\" AS seq CROSS JOIN (SELECT DISTINCT {groups} FROM bandwidth) AS groups )", cte_name = cte_name, @@ -449,14 +455,14 @@ fn build_grid_cte( let bandwidth_join_conds: Vec = groups .iter() .map(|g| { - format!( - "full_grid.{col} IS NOT DISTINCT FROM bandwidth.{col}", - col = g - ) + let q = naming::quote_ident(g); + format!("full_grid.{q} IS NOT DISTINCT FROM bandwidth.{q}") }) .collect(); - let grid_groups_select: Vec = - groups.iter().map(|g| format!("full_grid.{}", g)).collect(); + let grid_groups_select: Vec = groups + .iter() + .map(|g| format!("full_grid.{}", naming::quote_ident(g))) + .collect(); format!( "{seq_cte}, @@ -513,7 +519,10 @@ fn compute_density( } else { group_by .iter() - .map(|g| format!("data.{col} IS NOT DISTINCT FROM bandwidth.{col}", col = g)) + .map(|g| { + let q = naming::quote_ident(g); + format!("data.{q} IS NOT DISTINCT FROM bandwidth.{q}") + }) .collect::>() .join(" AND ") }; @@ -524,7 +533,10 @@ fn compute_density( } else { let grid_data_conds: Vec = group_by .iter() - .map(|g| format!("grid.{col} IS NOT DISTINCT FROM data.{col}", col = g)) + .map(|g| { + let q = naming::quote_ident(g); + format!("grid.{q} IS NOT DISTINCT FROM data.{q}") + }) .collect(); format!("WHERE {}", grid_data_conds.join(" AND ")) }; @@ -538,7 +550,10 @@ fn compute_density( ); // Build group-related SQL fragments - let grid_groups: Vec = group_by.iter().map(|g| format!("grid.{}", g)).collect(); + let grid_groups: Vec = group_by + .iter() + .map(|g| format!("grid.{}", naming::quote_ident(g))) + .collect(); let aggregation = format!( "GROUP BY grid.x{grid_group_by} ORDER BY grid.x{grid_group_by}", @@ -548,9 +563,14 @@ fn compute_density( let groups = if group_by.is_empty() { String::new() } else { - format!("{},", group_by.join(", ")) + let quoted: Vec = group_by.iter().map(|g| naming::quote_ident(g)).collect(); + format!("{},", quoted.join(", ")) }; + let x_column = naming::quote_ident(&naming::stat_column(value_aesthetic)); + let intensity_column = naming::quote_ident(&naming::stat_column("intensity")); + let density_column = naming::quote_ident(&naming::stat_column("density")); + // Generate the density computation query format!( "{bandwidth_cte}, @@ -560,23 +580,23 @@ fn compute_density( {x_column}, {groups} {intensity_column}, - {intensity_column} / __norm AS {density_column} + {intensity_column} / \"__norm\" AS {density_column} FROM ( SELECT grid.x AS {x_column}, {grid_groups} {kernel} AS {intensity_column}, - SUM(data.weight) AS __norm + SUM(data.weight) AS \"__norm\" {join_logic} {aggregation} )", bandwidth_cte = bandwidth_cte, data_cte = data_cte, grid_cte = grid_cte, - x_column = naming::stat_column(value_aesthetic), + x_column = x_column, groups = groups, - intensity_column = naming::stat_column("intensity"), - density_column = naming::stat_column("density"), + intensity_column = intensity_column, + density_column = density_column, aggregation = aggregation, grid_groups = with_trailing_comma(&grid_groups.join(", ")) ) @@ -606,21 +626,21 @@ mod tests { let kernel = choose_kde_kernel(¶meters, None).expect("kernel should be valid"); let sql = compute_density("x", &groups, kernel, &bw_cte, &data_cte, &grid_cte); - let expected = "WITH RECURSIVE + let expected = r#"WITH RECURSIVE bandwidth AS ( SELECT 0.5 AS bw, - MIN(x) AS x_min, - MAX(x) AS x_max - FROM (SELECT x FROM (VALUES (1.0), (2.0), (3.0)) AS t(x)) AS __ggsql_qt__ - WHERE x IS NOT NULL + MIN("x") AS x_min, + MAX("x") AS x_max + FROM (SELECT x FROM (VALUES (1.0), (2.0), (3.0)) AS t(x)) AS "__ggsql_qt__" + WHERE "x" IS NOT NULL ), data AS ( - SELECT x AS val, 1.0 AS weight + SELECT "x" AS val, 1.0 AS weight FROM (SELECT x FROM (VALUES (1.0), (2.0), (3.0)) AS t(x)) - WHERE x IS NOT NULL + WHERE "x" IS NOT NULL ), - __ggsql_base__(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM __ggsql_base__ WHERE n < 7),__ggsql_seq__(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM __ggsql_base__ a, __ggsql_base__ b, __ggsql_base__ c WHERE a.n * 64 + b.n * 8 + c.n < 512), + "__ggsql_base__"(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM "__ggsql_base__" WHERE n < 7),"__ggsql_seq__"(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM "__ggsql_base__" a, "__ggsql_base__" b, "__ggsql_base__" c WHERE a.n * 64 + b.n * 8 + c.n < 512), global_range AS ( SELECT MIN(x_min) AS min, MAX(x_max) AS max, 3 * MAX(bw) AS expansion FROM bandwidth @@ -628,23 +648,23 @@ mod tests { grid AS ( SELECT (global.min - global.expansion) + (seq.n * ((global.max - global.min) + 2 * global.expansion) / 511) AS x FROM global_range AS global - CROSS JOIN __ggsql_seq__ AS seq + CROSS JOIN "__ggsql_seq__" AS seq ) SELECT - __ggsql_stat_x, - __ggsql_stat_intensity, - __ggsql_stat_intensity / __norm AS __ggsql_stat_density + "__ggsql_stat_x", + "__ggsql_stat_intensity", + "__ggsql_stat_intensity" / "__norm" AS "__ggsql_stat_density" FROM ( SELECT - grid.x AS __ggsql_stat_x, - SUM(data.weight * ((EXP(-0.5 * (grid.x - data.val) * (grid.x - data.val) / (bandwidth.bw * bandwidth.bw))) * 0.3989422804014327)) / MIN(bandwidth.bw) AS __ggsql_stat_intensity, - SUM(data.weight) AS __norm + grid.x AS "__ggsql_stat_x", + SUM(data.weight * ((EXP(-0.5 * (grid.x - data.val) * (grid.x - data.val) / (bandwidth.bw * bandwidth.bw))) * 0.3989422804014327)) / MIN(bandwidth.bw) AS "__ggsql_stat_intensity", + SUM(data.weight) AS "__norm" FROM data INNER JOIN bandwidth ON true CROSS JOIN grid GROUP BY grid.x ORDER BY grid.x - )"; + )"#; // Normalize whitespace for comparison let normalize = |s: &str| s.split_whitespace().collect::>().join(" "); @@ -682,53 +702,53 @@ mod tests { let kernel = choose_kde_kernel(¶meters, None).expect("kernel should be valid"); let sql = compute_density("x", &groups, kernel, &bw_cte, &data_cte, &grid_cte); - let expected = "WITH RECURSIVE + let expected = r#"WITH RECURSIVE bandwidth AS ( SELECT 0.5 AS bw, - region, category, - MIN(x) AS x_min, - MAX(x) AS x_max - FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) AS __ggsql_qt__ - WHERE x IS NOT NULL - GROUP BY region, category + "region", "category", + MIN("x") AS x_min, + MAX("x") AS x_max + FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) AS "__ggsql_qt__" + WHERE "x" IS NOT NULL + GROUP BY "region", "category" ), data AS ( - SELECT region, category, x AS val, 1.0 AS weight + SELECT "region", "category", "x" AS val, 1.0 AS weight FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) - WHERE x IS NOT NULL + WHERE "x" IS NOT NULL ), - __ggsql_base__(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM __ggsql_base__ WHERE n < 7),__ggsql_seq__(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM __ggsql_base__ a, __ggsql_base__ b, __ggsql_base__ c WHERE a.n * 64 + b.n * 8 + c.n < 512), + "__ggsql_base__"(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM "__ggsql_base__" WHERE n < 7),"__ggsql_seq__"(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM "__ggsql_base__" a, "__ggsql_base__" b, "__ggsql_base__" c WHERE a.n * 64 + b.n * 8 + c.n < 512), global_range AS ( SELECT MIN(x_min) AS min, MAX(x_max) AS max, 3 * MAX(bw) AS expansion FROM bandwidth ), grid AS ( SELECT - region, category, + "region", "category", (global.min - global.expansion) + (seq.n * ((global.max - global.min) + 2 * global.expansion) / 511) AS x FROM global_range AS global - CROSS JOIN __ggsql_seq__ AS seq - CROSS JOIN (SELECT DISTINCT region, category FROM bandwidth) AS groups + CROSS JOIN "__ggsql_seq__" AS seq + CROSS JOIN (SELECT DISTINCT "region", "category" FROM bandwidth) AS groups ) SELECT - __ggsql_stat_x, - region, category, - __ggsql_stat_intensity, - __ggsql_stat_intensity / __norm AS __ggsql_stat_density + "__ggsql_stat_x", + "region", "category", + "__ggsql_stat_intensity", + "__ggsql_stat_intensity" / "__norm" AS "__ggsql_stat_density" FROM ( SELECT - grid.x AS __ggsql_stat_x, - grid.region, grid.category, - SUM(data.weight * ((EXP(-0.5 * (grid.x - data.val) * (grid.x - data.val) / (bandwidth.bw * bandwidth.bw))) * 0.3989422804014327)) / MIN(bandwidth.bw) AS __ggsql_stat_intensity, - SUM(data.weight) AS __norm + grid.x AS "__ggsql_stat_x", + grid."region", grid."category", + SUM(data.weight * ((EXP(-0.5 * (grid.x - data.val) * (grid.x - data.val) / (bandwidth.bw * bandwidth.bw))) * 0.3989422804014327)) / MIN(bandwidth.bw) AS "__ggsql_stat_intensity", + SUM(data.weight) AS "__norm" FROM data - INNER JOIN bandwidth ON data.region IS NOT DISTINCT FROM bandwidth.region AND data.category IS NOT DISTINCT FROM bandwidth.category + INNER JOIN bandwidth ON data."region" IS NOT DISTINCT FROM bandwidth."region" AND data."category" IS NOT DISTINCT FROM bandwidth."category" CROSS JOIN grid - WHERE grid.region IS NOT DISTINCT FROM data.region AND grid.category IS NOT DISTINCT FROM data.category - GROUP BY grid.x, grid.region, grid.category - ORDER BY grid.x, grid.region, grid.category - )"; + WHERE grid."region" IS NOT DISTINCT FROM data."region" AND grid."category" IS NOT DISTINCT FROM data."category" + GROUP BY grid.x, grid."region", grid."category" + ORDER BY grid.x, grid."region", grid."category" + )"#; // Normalize whitespace for comparison let normalize = |s: &str| s.split_whitespace().collect::>().join(" "); @@ -822,7 +842,7 @@ mod tests { // Verify SQL uses NTILE-based percentile subqueries with grouping assert!(bw_cte.contains("NTILE(4)")); - assert!(bw_cte.contains("GROUP BY region")); + assert!(bw_cte.contains("GROUP BY \"region\"")); let expected_rule = silverman_rule(1.0, "x", query, &groups, &AnsiDialect); assert!(normalize(&bw_cte).contains(&normalize(&expected_rule))); diff --git a/src/plot/layer/geom/histogram.rs b/src/plot/layer/geom/histogram.rs index fef6fb2e..4a4165ad 100644 --- a/src/plot/layer/geom/histogram.rs +++ b/src/plot/layer/geom/histogram.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; -use super::types::{get_column_name, CLOSED_VALUES, POSITION_VALUES}; +use super::types::{get_quoted_column_name, CLOSED_VALUES, POSITION_VALUES}; use super::{ DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, StatResult, @@ -125,7 +125,7 @@ fn stat_histogram( dialect: &dyn SqlDialect, ) -> Result { // Get x column name from aesthetics - let x_col = get_column_name(aesthetics, "pos1").ok_or_else(|| { + let x_col = get_quoted_column_name(aesthetics, "pos1").ok_or_else(|| { GgsqlError::ValidationError("Histogram requires 'x' aesthetic mapping".to_string()) })?; @@ -149,7 +149,7 @@ fn stat_histogram( // Query min/max to compute bin width let stats_query = format!( - "SELECT MIN({x}) as min_val, MAX({x}) as max_val FROM ({query}) AS __ggsql_stats__", + "SELECT MIN({x}) as min_val, MAX({x}) as max_val FROM ({query}) AS \"__ggsql_stats__\"", x = x_col, query = query ); @@ -213,7 +213,7 @@ fn stat_histogram( )); } if let Some(weight_col) = weight_value.column_name() { - format!("SUM({})", weight_col) + format!("SUM({})", naming::quote_ident(weight_col)) } else { "COUNT(*)".to_string() } @@ -229,16 +229,20 @@ fn stat_histogram( let stat_count = naming::stat_column("count"); let stat_density = naming::stat_column("density"); + let q_bin = naming::quote_ident(&stat_bin); + let q_bin_end = naming::quote_ident(&stat_bin_end); + let q_count = naming::quote_ident(&stat_count); + let q_density = naming::quote_ident(&stat_density); let (binned_select, final_select) = if group_by.is_empty() { ( format!( "{} AS {}, {} AS {}, {} AS {}", - bin_expr, stat_bin, bin_end_expr, stat_bin_end, agg_expr, stat_count + bin_expr, q_bin, bin_end_expr, q_bin_end, agg_expr, q_count ), format!( "*, {count} * 1.0 / SUM({count}) OVER () AS {density}", - count = stat_count, - density = stat_density + count = q_count, + density = q_density ), ) } else { @@ -246,19 +250,19 @@ fn stat_histogram( ( format!( "{}, {} AS {}, {} AS {}, {} AS {}", - grp_cols, bin_expr, stat_bin, bin_end_expr, stat_bin_end, agg_expr, stat_count + grp_cols, bin_expr, q_bin, bin_end_expr, q_bin_end, agg_expr, q_count ), format!( "*, {count} * 1.0 / SUM({count}) OVER (PARTITION BY {grp}) AS {density}", - count = stat_count, + count = q_count, grp = grp_cols, - density = stat_density + density = q_density ), ) }; let transformed_query = format!( - "WITH __stat_src__ AS ({query}), __binned__ AS (SELECT {binned} FROM __stat_src__ GROUP BY {group}) SELECT {final} FROM __binned__", + "WITH \"__stat_src__\" AS ({query}), \"__binned__\" AS (SELECT {binned} FROM \"__stat_src__\" GROUP BY {group}) SELECT {final} FROM \"__binned__\"", query = query, binned = binned_select, group = group_cols, diff --git a/src/plot/layer/geom/line.rs b/src/plot/layer/geom/line.rs index c61a458d..e0009961 100644 --- a/src/plot/layer/geom/line.rs +++ b/src/plot/layer/geom/line.rs @@ -56,7 +56,7 @@ impl GeomTrait for Line { // Line geom needs ordering by pos1 (domain axis) for proper rendering let order_col = naming::aesthetic_column("pos1"); Ok(StatResult::Transformed { - query: format!("{} ORDER BY \"{}\"", query, order_col), + query: format!("{} ORDER BY {}", query, naming::quote_ident(&order_col)), stat_columns: vec![], dummy_columns: vec![], consumed_aesthetics: vec![], diff --git a/src/plot/layer/geom/rect.rs b/src/plot/layer/geom/rect.rs index 384bcc08..6eb11df5 100644 --- a/src/plot/layer/geom/rect.rs +++ b/src/plot/layer/geom/rect.rs @@ -2,8 +2,8 @@ use std::collections::HashMap; -use super::types::get_column_name; use super::types::POSITION_VALUES; +use super::types::{get_column_name, get_quoted_column_name}; use super::{DefaultAesthetics, GeomTrait, GeomType, ParamConstraint, StatResult}; use crate::naming; use crate::plot::types::{DefaultAestheticValue, ParameterValue}; @@ -130,15 +130,17 @@ fn process_direction( _ => unreachable!("axis must be 'x' or 'y'"), }; - // Get column names from MAPPING, with SETTING fallback for size - let center = get_column_name(aesthetics, center_aes); - let min = get_column_name(aesthetics, min_aes); - let max = get_column_name(aesthetics, max_aes); - let size = get_column_name(aesthetics, size_aes) + // Get unquoted center name for schema lookup + let center_unquoted = get_column_name(aesthetics, center_aes); + let center = center_unquoted.as_deref().map(naming::quote_ident); + let min = get_quoted_column_name(aesthetics, min_aes); + let max = get_quoted_column_name(aesthetics, max_aes); + // SETTING fallback for size is a literal value, no quoting needed. + let size = get_quoted_column_name(aesthetics, size_aes) .or_else(|| parameters.get(size_aes).map(|v| v.to_string())); // Detect if discrete by checking schema - let is_discrete = center + let is_discrete = center_unquoted .as_ref() .and_then(|col| schema.iter().find(|c| &c.name == col)) .map(|c| c.is_discrete) @@ -172,8 +174,16 @@ fn process_direction( // Build SELECT parts using the stat columns let select_parts = vec![ - format!("{} AS {}", expr_1, naming::stat_column(&stat_cols[0])), - format!("{} AS {}", expr_2, naming::stat_column(&stat_cols[1])), + format!( + "{} AS {}", + expr_1, + naming::quote_ident(&naming::stat_column(&stat_cols[0])) + ), + format!( + "{} AS {}", + expr_2, + naming::quote_ident(&naming::stat_column(&stat_cols[1])) + ), ]; Ok((select_parts, stat_cols)) @@ -208,7 +218,7 @@ fn stat_rect( let mut select_parts: Vec = schema .iter() .filter(|col| !consumed_columns.contains(&col.name)) - .map(|col| col.name.clone()) + .map(|col| naming::quote_ident(&col.name)) .collect(); // Add X direction SELECT parts and collect stat columns @@ -223,7 +233,7 @@ fn stat_rect( // Build transformed query let transformed_query = format!( - "SELECT {} FROM ({}) AS __ggsql_rect_stat__", + "SELECT {} FROM ({}) AS \"__ggsql_rect_stat__\"", select_list, query ); @@ -446,44 +456,44 @@ mod tests { ( "xmin + xmax", vec!["pos1min", "pos1max"], - "__ggsql_aes_pos1min__", - "__ggsql_aes_pos1max__", + "\"__ggsql_aes_pos1min__\"", + "\"__ggsql_aes_pos1max__\"", ), ( "x + width", vec!["pos1", "width"], - "(__ggsql_aes_pos1__ - __ggsql_aes_width__ / 2.0)", - "(__ggsql_aes_pos1__ + __ggsql_aes_width__ / 2.0)", + "(\"__ggsql_aes_pos1__\" - \"__ggsql_aes_width__\" / 2.0)", + "(\"__ggsql_aes_pos1__\" + \"__ggsql_aes_width__\" / 2.0)", ), ( "x only (default width 1.0)", vec!["pos1"], - "(__ggsql_aes_pos1__ - 0.5)", - "(__ggsql_aes_pos1__ + 0.5)", + "(\"__ggsql_aes_pos1__\" - 0.5)", + "(\"__ggsql_aes_pos1__\" + 0.5)", ), ( "x + xmin", vec!["pos1", "pos1min"], - "__ggsql_aes_pos1min__", - "(2 * __ggsql_aes_pos1__ - __ggsql_aes_pos1min__)", + "\"__ggsql_aes_pos1min__\"", + "(2 * \"__ggsql_aes_pos1__\" - \"__ggsql_aes_pos1min__\")", ), ( "x + xmax", vec!["pos1", "pos1max"], - "(2 * __ggsql_aes_pos1__ - __ggsql_aes_pos1max__)", - "__ggsql_aes_pos1max__", + "(2 * \"__ggsql_aes_pos1__\" - \"__ggsql_aes_pos1max__\")", + "\"__ggsql_aes_pos1max__\"", ), ( "xmin + width", vec!["pos1min", "width"], - "__ggsql_aes_pos1min__", - "(__ggsql_aes_pos1min__ + __ggsql_aes_width__)", + "\"__ggsql_aes_pos1min__\"", + "(\"__ggsql_aes_pos1min__\" + \"__ggsql_aes_width__\")", ), ( "xmax + width", vec!["pos1max", "width"], - "(__ggsql_aes_pos1max__ - __ggsql_aes_width__)", - "__ggsql_aes_pos1max__", + "(\"__ggsql_aes_pos1max__\" - \"__ggsql_aes_width__\")", + "\"__ggsql_aes_pos1max__\"", ), ]; @@ -522,7 +532,7 @@ mod tests { let stat_pos1min = naming::stat_column("pos1min"); let stat_pos1max = naming::stat_column("pos1max"); assert!( - query.contains(&format!("{} AS {}", expected_min, stat_pos1min)), + query.contains(&format!("{} AS \"{}\"", expected_min, stat_pos1min)), "{}: Expected '{} AS {}' in query, got: {}", name, expected_min, @@ -530,7 +540,7 @@ mod tests { query ); assert!( - query.contains(&format!("{} AS {}", expected_max, stat_pos1max)), + query.contains(&format!("{} AS \"{}\"", expected_max, stat_pos1max)), "{}: Expected '{} AS {}' in query, got: {}", name, expected_max, @@ -562,38 +572,38 @@ mod tests { ( "ymin + ymax", vec!["pos2min", "pos2max"], - "__ggsql_aes_pos2min__", - "__ggsql_aes_pos2max__", + "\"__ggsql_aes_pos2min__\"", + "\"__ggsql_aes_pos2max__\"", ), ( "y + height", vec!["pos2", "height"], - "(__ggsql_aes_pos2__ - __ggsql_aes_height__ / 2.0)", - "(__ggsql_aes_pos2__ + __ggsql_aes_height__ / 2.0)", + "(\"__ggsql_aes_pos2__\" - \"__ggsql_aes_height__\" / 2.0)", + "(\"__ggsql_aes_pos2__\" + \"__ggsql_aes_height__\" / 2.0)", ), ( "y + ymin", vec!["pos2", "pos2min"], - "__ggsql_aes_pos2min__", - "(2 * __ggsql_aes_pos2__ - __ggsql_aes_pos2min__)", + "\"__ggsql_aes_pos2min__\"", + "(2 * \"__ggsql_aes_pos2__\" - \"__ggsql_aes_pos2min__\")", ), ( "y + ymax", vec!["pos2", "pos2max"], - "(2 * __ggsql_aes_pos2__ - __ggsql_aes_pos2max__)", - "__ggsql_aes_pos2max__", + "(2 * \"__ggsql_aes_pos2__\" - \"__ggsql_aes_pos2max__\")", + "\"__ggsql_aes_pos2max__\"", ), ( "ymin + height", vec!["pos2min", "height"], - "__ggsql_aes_pos2min__", - "(__ggsql_aes_pos2min__ + __ggsql_aes_height__)", + "\"__ggsql_aes_pos2min__\"", + "(\"__ggsql_aes_pos2min__\" + \"__ggsql_aes_height__\")", ), ( "ymax + height", vec!["pos2max", "height"], - "(__ggsql_aes_pos2max__ - __ggsql_aes_height__)", - "__ggsql_aes_pos2max__", + "(\"__ggsql_aes_pos2max__\" - \"__ggsql_aes_height__\")", + "\"__ggsql_aes_pos2max__\"", ), ]; @@ -632,7 +642,7 @@ mod tests { let stat_pos2min = naming::stat_column("pos2min"); let stat_pos2max = naming::stat_column("pos2max"); assert!( - query.contains(&format!("{} AS {}", expected_min, stat_pos2min)), + query.contains(&format!("{} AS \"{}\"", expected_min, stat_pos2min)), "{}: Expected '{} AS {}' in query, got: {}", name, expected_min, @@ -640,7 +650,7 @@ mod tests { query ); assert!( - query.contains(&format!("{} AS {}", expected_max, stat_pos2max)), + query.contains(&format!("{} AS \"{}\"", expected_max, stat_pos2max)), "{}: Expected '{} AS {}' in query, got: {}", name, expected_max, @@ -687,8 +697,8 @@ mod tests { .. }) = result { - assert!(query.contains("__ggsql_aes_pos1__ AS __ggsql_stat_pos1")); - assert!(query.contains("__ggsql_aes_width__ AS __ggsql_stat_width")); + assert!(query.contains("\"__ggsql_aes_pos1__\" AS \"__ggsql_stat_pos1")); + assert!(query.contains("\"__ggsql_aes_width__\" AS \"__ggsql_stat_width")); assert!(stat_columns.contains(&"pos1".to_string())); assert!(stat_columns.contains(&"width".to_string())); assert!(stat_columns.contains(&"pos2min".to_string())); @@ -718,8 +728,8 @@ mod tests { .. }) = result { - assert!(query.contains("__ggsql_aes_pos2__ AS __ggsql_stat_pos2")); - assert!(query.contains("__ggsql_aes_height__ AS __ggsql_stat_height")); + assert!(query.contains("\"__ggsql_aes_pos2__\" AS \"__ggsql_stat_pos2")); + assert!(query.contains("\"__ggsql_aes_height__\" AS \"__ggsql_stat_height")); assert!(stat_columns.contains(&"pos1min".to_string())); assert!(stat_columns.contains(&"pos1max".to_string())); assert!(stat_columns.contains(&"pos2".to_string())); @@ -749,10 +759,10 @@ mod tests { .. }) = result { - assert!(query.contains("__ggsql_aes_pos1__ AS __ggsql_stat_pos1")); - assert!(query.contains("__ggsql_aes_width__ AS __ggsql_stat_width")); - assert!(query.contains("__ggsql_aes_pos2__ AS __ggsql_stat_pos2")); - assert!(query.contains("__ggsql_aes_height__ AS __ggsql_stat_height")); + assert!(query.contains("\"__ggsql_aes_pos1__\" AS \"__ggsql_stat_pos1")); + assert!(query.contains("\"__ggsql_aes_width__\" AS \"__ggsql_stat_width")); + assert!(query.contains("\"__ggsql_aes_pos2__\" AS \"__ggsql_stat_pos2")); + assert!(query.contains("\"__ggsql_aes_height__\" AS \"__ggsql_stat_height")); assert_eq!(stat_columns.len(), 4); } } @@ -782,8 +792,8 @@ mod tests { stat_columns, .. } => { - assert!(query.contains("(__ggsql_aes_pos1__ - 0.5)")); - assert!(query.contains("(__ggsql_aes_pos1__ + 0.5)")); + assert!(query.contains("(\"__ggsql_aes_pos1__\" - 0.5)")); + assert!(query.contains("(\"__ggsql_aes_pos1__\" + 0.5)")); assert!(stat_columns.contains(&"pos1min".to_string())); assert!(stat_columns.contains(&"pos1max".to_string())); } @@ -852,7 +862,7 @@ mod tests { stat_columns, .. } => { - assert!(query.contains("1.0 AS __ggsql_stat_width")); + assert!(query.contains("1.0 AS \"__ggsql_stat_width")); assert!(stat_columns.contains(&"width".to_string())); } _ => panic!("Expected Transformed"), @@ -879,12 +889,12 @@ mod tests { assert!(result.is_ok()); if let Ok(StatResult::Transformed { query, .. }) = result { - // Should include fill column (non-consumed aesthetic from schema) - assert!(query.contains("__ggsql_aes_fill__")); + // Should include fill column (non-consumed aesthetic from schema, quoted) + assert!(query.contains("\"__ggsql_aes_fill__\"")); // Should NOT include width/height as pass-through (they're consumed) // They should only appear as stat columns - assert!(query.contains("__ggsql_aes_width__ AS __ggsql_stat_width")); - assert!(query.contains("__ggsql_aes_height__ AS __ggsql_stat_height")); + assert!(query.contains("\"__ggsql_aes_width__\" AS \"__ggsql_stat_width")); + assert!(query.contains("\"__ggsql_aes_height__\" AS \"__ggsql_stat_height")); } } @@ -909,8 +919,8 @@ mod tests { if let Ok(StatResult::Transformed { query, .. }) = result { // Should use SETTING values as SQL literals - assert!(query.contains("0.7 AS __ggsql_stat_width")); - assert!(query.contains("0.9 AS __ggsql_stat_height")); + assert!(query.contains("0.7 AS \"__ggsql_stat_width")); + assert!(query.contains("0.9 AS \"__ggsql_stat_height")); } } } diff --git a/src/plot/layer/geom/ribbon.rs b/src/plot/layer/geom/ribbon.rs index d3615d12..46470c72 100644 --- a/src/plot/layer/geom/ribbon.rs +++ b/src/plot/layer/geom/ribbon.rs @@ -56,7 +56,7 @@ impl GeomTrait for Ribbon { // Ribbon geom needs ordering by pos1 (domain axis) for proper rendering let order_col = naming::aesthetic_column("pos1"); Ok(StatResult::Transformed { - query: format!("{} ORDER BY \"{}\"", query, order_col), + query: format!("{} ORDER BY {}", query, naming::quote_ident(&order_col)), stat_columns: vec![], dummy_columns: vec![], consumed_aesthetics: vec![], diff --git a/src/plot/layer/geom/smooth.rs b/src/plot/layer/geom/smooth.rs index fad14432..4509fc5c 100644 --- a/src/plot/layer/geom/smooth.rs +++ b/src/plot/layer/geom/smooth.rs @@ -4,7 +4,7 @@ use super::types::POSITION_VALUES; use super::{ DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, }; -use crate::plot::geom::types::get_column_name; +use crate::plot::geom::types::get_quoted_column_name; use crate::plot::types::DefaultAestheticValue; use crate::plot::{ParameterValue, StatResult}; use crate::reader::SqlDialect; @@ -136,10 +136,10 @@ impl std::fmt::Display for Smooth { } fn stat_ols(query: &str, aesthetics: &Mappings, group_by: &[String]) -> Result { - let x_col = get_column_name(aesthetics, "pos1").ok_or_else(|| { + let x_col = get_quoted_column_name(aesthetics, "pos1").ok_or_else(|| { GgsqlError::ValidationError("Smooth requires 'pos1' aesthetic".to_string()) })?; - let y_col = get_column_name(aesthetics, "pos2").ok_or_else(|| { + let y_col = get_quoted_column_name(aesthetics, "pos2").ok_or_else(|| { GgsqlError::ValidationError("Smooth requires 'pos2' aesthetic".to_string()) })?; @@ -184,8 +184,8 @@ fn stat_ols(query: &str, aesthetics: &Mappings, group_by: &[String]) -> Result Result Result { - let x_col = get_column_name(aesthetics, "pos1").ok_or_else(|| { + let x_col = get_quoted_column_name(aesthetics, "pos1").ok_or_else(|| { GgsqlError::ValidationError("Smooth requires 'pos1' aesthetic".to_string()) })?; - let y_col = get_column_name(aesthetics, "pos2").ok_or_else(|| { + let y_col = get_quoted_column_name(aesthetics, "pos2").ok_or_else(|| { GgsqlError::ValidationError("Smooth requires 'pos2' aesthetic".to_string()) })?; @@ -257,8 +257,8 @@ fn stat_tls(query: &str, aesthetics: &Mappings, group_by: &[String]) -> Result Option }) } +/// Helper to extract a double-quoted column name for use in SQL expressions. +pub fn get_quoted_column_name(aesthetics: &Mappings, aesthetic: &str) -> Option { + get_column_name(aesthetics, aesthetic).map(|n| naming::quote_ident(&n)) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/plot/scale/scale_type/binned.rs b/src/plot/scale/scale_type/binned.rs index cf1a77d4..dd4b028e 100644 --- a/src/plot/scale/scale_type/binned.rs +++ b/src/plot/scale/scale_type/binned.rs @@ -8,6 +8,7 @@ use super::{ expand_numeric_range, resolve_common_steps, ScaleDataContext, ScaleTypeKind, ScaleTypeTrait, TransformKind, CLOSED_VALUES, OOB_CENSOR, OOB_SQUISH, OOB_VALUES_BINNED, }; +use crate::naming; use crate::plot::types::{ ArrayConstraint, DefaultParamValue, NumberConstraint, ParamConstraint, ParamDefinition, }; @@ -727,20 +728,21 @@ fn build_bin_condition( (if is_first { ">=" } else { ">" }, "<=") }; + let quoted = naming::quote_ident(column_name); if oob_squish && is_first && is_last { // Single bin with squish: capture everything "TRUE".to_string() } else if oob_squish && is_first { // First bin with squish: no lower bound, extends to -∞ - format!("{} {} {}", column_name, upper_op, upper_expr) + format!("{} {} {}", quoted, upper_op, upper_expr) } else if oob_squish && is_last { // Last bin with squish: no upper bound, extends to +∞ - format!("{} {} {}", column_name, lower_op, lower_expr) + format!("{} {} {}", quoted, lower_op, lower_expr) } else { // Normal bin with both bounds format!( "{} {} {} AND {} {} {}", - column_name, lower_op, lower_expr, column_name, upper_op, upper_expr + quoted, lower_op, lower_expr, quoted, upper_op, upper_expr ) } } @@ -855,10 +857,10 @@ mod tests { // Should produce CASE WHEN with bin centers 5, 15, 25 assert!(sql.contains("CASE")); - assert!(sql.contains("WHEN value >= 0 AND value < 10 THEN 5")); - assert!(sql.contains("WHEN value >= 10 AND value < 20 THEN 15")); + assert!(sql.contains("WHEN \"value\" >= 0 AND \"value\" < 10 THEN 5")); + assert!(sql.contains("WHEN \"value\" >= 10 AND \"value\" < 20 THEN 15")); // Last bin should be inclusive on both ends - assert!(sql.contains("WHEN value >= 20 AND value <= 30 THEN 25")); + assert!(sql.contains("WHEN \"value\" >= 20 AND \"value\" <= 30 THEN 25")); assert!(sql.contains("ELSE NULL END")); } @@ -906,8 +908,8 @@ mod tests { .unwrap(); // closed="left": [lower, upper) except last which is [lower, upper] - assert!(sql.contains("col >= 0 AND col < 10")); - assert!(sql.contains("col >= 10 AND col <= 20")); // last bin inclusive + assert!(sql.contains("\"col\" >= 0 AND \"col\" < 10")); + assert!(sql.contains("\"col\" >= 10 AND \"col\" <= 20")); // last bin inclusive } #[test] @@ -932,8 +934,8 @@ mod tests { .unwrap(); // closed="right": first bin is [lower, upper], rest are (lower, upper] - assert!(sql.contains("col >= 0 AND col <= 10")); // first bin inclusive - assert!(sql.contains("col > 10 AND col <= 20")); + assert!(sql.contains("\"col\" >= 0 AND \"col\" <= 10")); // first bin inclusive + assert!(sql.contains("\"col\" > 10 AND \"col\" <= 20")); } #[test] @@ -1191,8 +1193,8 @@ mod tests { sql ); assert!( - sql.contains("value >= 0"), - "SQL should use raw column name. Got: {}", + sql.contains("\"value\" >= 0"), + "SQL should use quoted column name. Got: {}", sql ); assert!( @@ -1227,7 +1229,10 @@ mod tests { !sql.contains("CAST("), "SQL should not contain CAST when column is numeric" ); - assert!(sql.contains("value >= 0"), "SQL should use raw column name"); + assert!( + sql.contains("\"value\" >= 0"), + "SQL should use quoted column name" + ); } #[test] @@ -1503,9 +1508,9 @@ mod tests { "left", vec![0.0, 10.0, 20.0, 30.0], vec![ - "WHEN value < 10 THEN 5", // First bin extends to -∞ - "WHEN value >= 10 AND value < 20 THEN 15", // Middle bin - "WHEN value >= 20 THEN 25", // Last bin extends to +∞ + "WHEN \"value\" < 10 THEN 5", // First bin extends to -∞ + "WHEN \"value\" >= 10 AND \"value\" < 20 THEN 15", // Middle bin + "WHEN \"value\" >= 20 THEN 25", // Last bin extends to +∞ ], ), // closed="right" with 3 bins (4 breaks) @@ -1513,9 +1518,9 @@ mod tests { "right", vec![0.0, 10.0, 20.0, 30.0], vec![ - "WHEN value <= 10 THEN 5", // First bin extends to -∞ - "WHEN value > 10 AND value <= 20 THEN 15", // Middle bin - "WHEN value > 20 THEN 25", // Last bin extends to +∞ + "WHEN \"value\" <= 10 THEN 5", // First bin extends to -∞ + "WHEN \"value\" > 10 AND \"value\" <= 20 THEN 15", // Middle bin + "WHEN \"value\" > 20 THEN 25", // Last bin extends to +∞ ], ), ]; @@ -1576,11 +1581,11 @@ mod tests { .pre_stat_transform_sql("x", &DataType::Float64, &scale, &AnsiDialect) .unwrap(); assert!( - sql.contains("WHEN x < 50 THEN 25"), + sql.contains("WHEN \"x\" < 50 THEN 25"), "Two bins: first should extend to -∞" ); assert!( - sql.contains("WHEN x >= 50 THEN 75"), + sql.contains("WHEN \"x\" >= 50 THEN 75"), "Two bins: last should extend to +∞" ); } @@ -1625,11 +1630,11 @@ mod tests { .pre_stat_transform_sql("x", &DataType::Float64, &scale, &AnsiDialect) .unwrap(); assert!( - sql.contains("x >= 0 AND x < 10"), + sql.contains("\"x\" >= 0 AND \"x\" < 10"), "First bin should have lower bound with censor" ); assert!( - sql.contains("x >= 10 AND x <= 20"), + sql.contains("\"x\" >= 10 AND \"x\" <= 20"), "Last bin should have upper bound with censor" ); } @@ -1642,14 +1647,17 @@ mod tests { ( true, vec![ - "WHEN col < 10 THEN 5", - "WHEN col >= 10 AND col < 20 THEN 15", - "WHEN col >= 20 THEN 25", + "WHEN \"col\" < 10 THEN 5", + "WHEN \"col\" >= 10 AND \"col\" < 20 THEN 15", + "WHEN \"col\" >= 20 THEN 25", ], ), ( false, - vec!["col >= 0 AND col < 10", "col >= 10 AND col <= 20"], + vec![ + "\"col\" >= 0 AND \"col\" < 10", + "\"col\" >= 10 AND \"col\" <= 20", + ], ), ]; diff --git a/src/plot/scale/scale_type/continuous.rs b/src/plot/scale/scale_type/continuous.rs index 5f98af40..099ba168 100644 --- a/src/plot/scale/scale_type/continuous.rs +++ b/src/plot/scale/scale_type/continuous.rs @@ -5,6 +5,7 @@ use polars::prelude::DataType; use super::{ ScaleTypeKind, ScaleTypeTrait, TransformKind, OOB_CENSOR, OOB_SQUISH, OOB_VALUES_CONTINUOUS, }; +use crate::naming; use crate::plot::types::{ ArrayConstraint, DefaultParamValue, NumberConstraint, ParamConstraint, ParamDefinition, }; @@ -214,14 +215,18 @@ impl ScaleTypeTrait for Continuous { .unwrap_or(super::default_oob(&scale.aesthetic)); match oob { - OOB_CENSOR => Some(format!( - "(CASE WHEN {} >= {} AND {} <= {} THEN {} ELSE NULL END)", - column_name, min, column_name, max, column_name - )), + OOB_CENSOR => { + let quoted = naming::quote_ident(column_name); + Some(format!( + "(CASE WHEN {} >= {} AND {} <= {} THEN {} ELSE NULL END)", + quoted, min, quoted, max, quoted + )) + } OOB_SQUISH => { let min_s = min.to_string(); let max_s = max.to_string(); - let inner = dialect.sql_least(&[&max_s, column_name]); + let quoted = naming::quote_ident(column_name); + let inner = dialect.sql_least(&[&max_s, "ed]); Some(dialect.sql_greatest(&[&min_s, &inner])) } _ => None, // "keep" = no transformation @@ -259,8 +264,8 @@ mod tests { let sql = sql.unwrap(); // Should generate CASE WHEN for censor assert!(sql.contains("CASE WHEN")); - assert!(sql.contains("value >= 0")); - assert!(sql.contains("value <= 100")); + assert!(sql.contains("\"value\" >= 0")); + assert!(sql.contains("\"value\" <= 100")); assert!(sql.contains("ELSE NULL")); } diff --git a/src/plot/scale/scale_type/discrete.rs b/src/plot/scale/scale_type/discrete.rs index ca31edfa..c483cb3e 100644 --- a/src/plot/scale/scale_type/discrete.rs +++ b/src/plot/scale/scale_type/discrete.rs @@ -4,6 +4,7 @@ use polars::prelude::DataType; use super::super::transform::{Transform, TransformKind}; use super::{ScaleTypeKind, ScaleTypeTrait}; +use crate::naming; use crate::plot::types::{DefaultParamValue, ParamConstraint, ParamDefinition}; use crate::plot::ArrayElement; @@ -259,11 +260,12 @@ impl ScaleTypeTrait for Discrete { } // Always censor - discrete scales have no other valid OOB behavior + let quoted = naming::quote_ident(column_name); Some(format!( "(CASE WHEN {} IN ({}) THEN {} ELSE NULL END)", - column_name, + quoted, allowed_values.join(", "), - column_name + quoted )) } } diff --git a/src/plot/scale/scale_type/mod.rs b/src/plot/scale/scale_type/mod.rs index 932353c2..479f1d5e 100644 --- a/src/plot/scale/scale_type/mod.rs +++ b/src/plot/scale/scale_type/mod.rs @@ -3394,7 +3394,7 @@ mod tests { let dialect = AnsiDialect; assert_eq!( dialect.type_name_for(CastTargetType::Number), - Some("DOUBLE") + Some("DOUBLE PRECISION") ); assert_eq!( dialect.type_name_for(CastTargetType::Integer), diff --git a/src/plot/scale/scale_type/ordinal.rs b/src/plot/scale/scale_type/ordinal.rs index bc3c0d5e..3b83b46e 100644 --- a/src/plot/scale/scale_type/ordinal.rs +++ b/src/plot/scale/scale_type/ordinal.rs @@ -8,6 +8,7 @@ use polars::prelude::DataType; use super::super::transform::{Transform, TransformKind}; use super::{ScaleTypeKind, ScaleTypeTrait}; +use crate::naming; use crate::plot::types::{DefaultParamValue, ParamConstraint, ParamDefinition}; use crate::plot::ArrayElement; @@ -291,11 +292,12 @@ impl ScaleTypeTrait for Ordinal { } // Always censor - ordinal scales have no other valid OOB behavior + let quoted = naming::quote_ident(column_name); Some(format!( "(CASE WHEN {} IN ({}) THEN {} ELSE NULL END)", - column_name, + quoted, allowed_values.join(", "), - column_name + quoted )) } } diff --git a/src/reader/connection.rs b/src/reader/connection.rs index 63f90cf7..b97bd553 100644 --- a/src/reader/connection.rs +++ b/src/reader/connection.rs @@ -17,6 +17,9 @@ pub enum ConnectionInfo { /// SQLite file-based database #[allow(dead_code)] SQLite(String), + /// Generic ODBC connection (raw connection string after `odbc://` prefix) + #[allow(dead_code)] + ODBC(String), } /// Parse a connection string into connection information @@ -70,8 +73,17 @@ pub fn parse_connection_string(uri: &str) -> Result { return Ok(ConnectionInfo::SQLite(cleaned_path.to_string())); } + if let Some(conn_str) = uri.strip_prefix("odbc://") { + if conn_str.is_empty() { + return Err(GgsqlError::ReaderError( + "ODBC connection string cannot be empty".to_string(), + )); + } + return Ok(ConnectionInfo::ODBC(conn_str.to_string())); + } + Err(GgsqlError::ReaderError(format!( - "Unsupported connection string format: {}. Supported: duckdb://, postgres://, sqlite://", + "Unsupported connection string format: {}. Supported: duckdb://, postgres://, sqlite://, odbc://", uri ))) } @@ -133,6 +145,26 @@ mod tests { assert!(result.is_err()); } + #[test] + fn test_odbc() { + let info = parse_connection_string( + "odbc://Driver=Snowflake;Server=myaccount.snowflakecomputing.com", + ) + .unwrap(); + assert_eq!( + info, + ConnectionInfo::ODBC( + "Driver=Snowflake;Server=myaccount.snowflakecomputing.com".to_string() + ) + ); + } + + #[test] + fn test_odbc_empty() { + let result = parse_connection_string("odbc://"); + assert!(result.is_err()); + } + #[test] fn test_unsupported_scheme() { let result = parse_connection_string("mysql://localhost/db"); diff --git a/src/reader/data.rs b/src/reader/data.rs index 720ccac7..b143a4b2 100644 --- a/src/reader/data.rs +++ b/src/reader/data.rs @@ -67,8 +67,8 @@ pub fn register_builtin_datasets_duckdb( } let create_sql = format!( - "CREATE TABLE IF NOT EXISTS \"{}\" AS SELECT * FROM read_parquet('{}')", - table_name, + "CREATE TABLE IF NOT EXISTS {} AS SELECT * FROM read_parquet('{}')", + naming::quote_ident(&table_name), tmp_path.display() ); @@ -185,7 +185,7 @@ pub fn rewrite_namespaced_sql(sql: &str) -> Result { replacements.push(( node.start_byte(), node.end_byte(), - naming::builtin_data_table(name), + naming::quote_ident(&naming::builtin_data_table(name)), )); } } @@ -315,7 +315,7 @@ mod tests { fn test_rewrite_namespaced_sql_simple() { let sql = "SELECT * FROM ggsql:penguins"; let rewritten = rewrite_namespaced_sql(sql).unwrap(); - assert_eq!(rewritten, "SELECT * FROM __ggsql_data_penguins__"); + assert_eq!(rewritten, "SELECT * FROM \"__ggsql_data_penguins__\""); } #[test] @@ -324,7 +324,7 @@ mod tests { let rewritten = rewrite_namespaced_sql(sql).unwrap(); assert_eq!( rewritten, - "SELECT * FROM __ggsql_data_penguins__ p, __ggsql_data_airquality__ a WHERE p.id = a.id" + "SELECT * FROM \"__ggsql_data_penguins__\" p, \"__ggsql_data_airquality__\" a WHERE p.id = a.id" ); } @@ -339,7 +339,7 @@ mod tests { fn test_rewrite_namespaced_sql_with_visualise() { let sql = "SELECT * FROM ggsql:penguins VISUALISE DRAW point MAPPING bill_len AS x, bill_dep AS y"; let rewritten = rewrite_namespaced_sql(sql).unwrap(); - assert!(rewritten.starts_with("SELECT * FROM __ggsql_data_penguins__")); + assert!(rewritten.starts_with("SELECT * FROM \"__ggsql_data_penguins__\"")); assert!(!rewritten.contains("ggsql:")); } } diff --git a/src/reader/duckdb.rs b/src/reader/duckdb.rs index c35dc05f..53dfb288 100644 --- a/src/reader/duckdb.rs +++ b/src/reader/duckdb.rs @@ -3,7 +3,7 @@ //! Provides a reader for DuckDB databases with direct Polars DataFrame integration. use crate::reader::{connection::ConnectionInfo, Reader}; -use crate::{DataFrame, GgsqlError, Result}; +use crate::{naming, DataFrame, GgsqlError, Result}; use arrow::ipc::reader::FileReader; use duckdb::vtab::arrow::{arrow_recordbatch_to_query_params, ArrowVTab}; use duckdb::{params, Connection}; @@ -36,7 +36,7 @@ impl super::SqlDialect for DuckDbDialect { fn sql_generate_series(&self, n: usize) -> String { format!( - "__ggsql_seq__(n) AS (SELECT generate_series FROM GENERATE_SERIES(0, {}))", + "\"__ggsql_seq__\"(n) AS (SELECT generate_series FROM GENERATE_SERIES(0, {}))", n - 1 ) } @@ -44,14 +44,23 @@ impl super::SqlDialect for DuckDbDialect { fn sql_percentile(&self, column: &str, fraction: f64, from: &str, groups: &[String]) -> String { let group_filter = groups .iter() - .map(|g| format!("AND __ggsql_pct__.{g} IS NOT DISTINCT FROM __ggsql_qt__.{g}")) + .map(|g| { + let q = naming::quote_ident(g); + format!( + "AND {pct}.{q} IS NOT DISTINCT FROM {qt}.{q}", + pct = naming::quote_ident("__ggsql_pct__"), + qt = naming::quote_ident("__ggsql_qt__") + ) + }) .collect::>() .join(" "); + let quoted_column = naming::quote_ident(column); format!( "(SELECT QUANTILE_CONT({column}, {fraction}) \ - FROM ({from}) AS __ggsql_pct__ \ - WHERE {column} IS NOT NULL {group_filter})" + FROM ({from}) AS \"__ggsql_pct__\" \ + WHERE {column} IS NOT NULL {group_filter})", + column = quoted_column ) } } @@ -144,34 +153,7 @@ impl DuckDBReader { } } -/// Validate a table name -fn validate_table_name(name: &str) -> Result<()> { - if name.is_empty() { - return Err(GgsqlError::ReaderError("Table name cannot be empty".into())); - } - - // Reject characters that could break double-quoted identifiers or cause issues - let forbidden = ['"', '\0', '\n', '\r']; - for ch in forbidden { - if name.contains(ch) { - return Err(GgsqlError::ReaderError(format!( - "Table name '{}' contains invalid character '{}'", - name, - ch.escape_default() - ))); - } - } - - // Reasonable length limit - if name.len() > 128 { - return Err(GgsqlError::ReaderError(format!( - "Table name '{}' exceeds maximum length of 128 characters", - name - ))); - } - - Ok(()) -} +use super::validate_table_name; /// Convert a Polars DataFrame to DuckDB Arrow query parameters via IPC serialization fn dataframe_to_arrow_params(df: DataFrame) -> Result<[usize; 2]> { @@ -579,8 +561,9 @@ impl Reader for DuckDBReader { // Small DataFrame: register in a single batch let params = dataframe_to_arrow_params(df)?; let sql = format!( - "{} TEMP TABLE \"{}\" AS SELECT * FROM arrow(?, ?)", - create_or_replace, name + "{} TEMP TABLE {} AS SELECT * FROM arrow(?, ?)", + create_or_replace, + naming::quote_ident(name) ); self.conn.execute(&sql, params).map_err(|e| { GgsqlError::ReaderError(format!("Failed to register table '{}': {}", name, e)) @@ -590,8 +573,9 @@ impl Reader for DuckDBReader { let first_chunk = df.slice(0, MAX_ARROW_BATCH_ROWS); let params = dataframe_to_arrow_params(first_chunk)?; let create_sql = format!( - "{} TEMP TABLE \"{}\" AS SELECT * FROM arrow(?, ?)", - create_or_replace, name + "{} TEMP TABLE {} AS SELECT * FROM arrow(?, ?)", + create_or_replace, + naming::quote_ident(name) ); self.conn.execute(&create_sql, params).map_err(|e| { GgsqlError::ReaderError(format!("Failed to register table '{}': {}", name, e)) @@ -602,7 +586,10 @@ impl Reader for DuckDBReader { let chunk_size = std::cmp::min(MAX_ARROW_BATCH_ROWS, total_rows - offset); let chunk = df.slice(offset as i64, chunk_size); let params = dataframe_to_arrow_params(chunk)?; - let insert_sql = format!("INSERT INTO \"{}\" SELECT * FROM arrow(?, ?)", name); + let insert_sql = format!( + "INSERT INTO {} SELECT * FROM arrow(?, ?)", + naming::quote_ident(name) + ); self.conn.execute(&insert_sql, params).map_err(|e| { GgsqlError::ReaderError(format!( "Failed to insert chunk into table '{}': {}", @@ -628,7 +615,7 @@ impl Reader for DuckDBReader { } // Drop the temp table - let sql = format!("DROP TABLE IF EXISTS \"{}\"", name); + let sql = format!("DROP TABLE IF EXISTS {}", naming::quote_ident(name)); self.conn.execute(&sql, []).map_err(|e| { GgsqlError::ReaderError(format!("Failed to unregister table '{}': {}", name, e)) })?; @@ -639,6 +626,10 @@ impl Reader for DuckDBReader { Ok(()) } + fn execute(&self, query: &str) -> Result { + super::execute_with_reader(self, query) + } + fn dialect(&self) -> &dyn super::SqlDialect { &DuckDbDialect } @@ -771,13 +762,10 @@ mod tests { assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("cannot be empty")); - // Name with double quote + // Name with double quote should succeed (quote_ident escapes it) let result = reader.register("bad\"name", df.clone(), false); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("invalid character")); + assert!(result.is_ok()); + reader.unregister("bad\"name").unwrap(); // Name with null byte let result = reader.register("bad\0name", df.clone(), false); @@ -786,15 +774,6 @@ mod tests { .unwrap_err() .to_string() .contains("invalid character")); - - // Name too long - let long_name = "a".repeat(200); - let result = reader.register(&long_name, df, false); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("exceeds maximum length")); } #[test] diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 06457f85..6a67f20b 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -36,7 +36,7 @@ use std::collections::HashMap; use crate::execute::prepare_data_with_reader; use crate::plot::{CastTargetType, Plot}; use crate::validate::{validate, ValidationWarning}; -use crate::{DataFrame, GgsqlError, Result}; +use crate::{naming, DataFrame, GgsqlError, Result}; // ============================================================================= // SQL Dialect @@ -46,9 +46,9 @@ use crate::{DataFrame, GgsqlError, Result}; /// /// Default implementations produce portable ANSI SQL. pub trait SqlDialect { - /// SQL type name for numeric columns (e.g., "DOUBLE") + /// SQL type name for numeric columns (e.g., "DOUBLE PRECISION") fn number_type_name(&self) -> Option<&str> { - Some("DOUBLE") + Some("DOUBLE PRECISION") } /// SQL type name for integer columns (e.g., "BIGINT") @@ -94,6 +94,48 @@ pub trait SqlDialect { } } + // ========================================================================= + // Schema introspection queries (for Connections pane) + // ========================================================================= + + /// SQL to list catalog names. Returns rows with column `catalog_name`. + fn sql_list_catalogs(&self) -> String { + "SELECT DISTINCT catalog_name FROM information_schema.schemata ORDER BY catalog_name".into() + } + + /// SQL to list schema names within a catalog. Returns rows with column `schema_name`. + fn sql_list_schemas(&self, catalog: &str) -> String { + format!( + "SELECT DISTINCT schema_name FROM information_schema.schemata \ + WHERE catalog_name = '{}' ORDER BY schema_name", + catalog.replace('\'', "''") + ) + } + + /// SQL to list tables/views within a catalog and schema. + /// Returns rows with columns `table_name` and `table_type`. + fn sql_list_tables(&self, catalog: &str, schema: &str) -> String { + format!( + "SELECT DISTINCT table_name, table_type FROM information_schema.tables \ + WHERE table_catalog = '{}' AND table_schema = '{}' ORDER BY table_name", + catalog.replace('\'', "''"), + schema.replace('\'', "''") + ) + } + + /// SQL to list columns in a table. + /// Returns rows with columns `column_name` and `data_type`. + fn sql_list_columns(&self, catalog: &str, schema: &str, table: &str) -> String { + format!( + "SELECT column_name, data_type FROM information_schema.columns \ + WHERE table_catalog = '{}' AND table_schema = '{}' AND table_name = '{}' \ + ORDER BY ordinal_position", + catalog.replace('\'', "''"), + schema.replace('\'', "''"), + table.replace('\'', "''") + ) + } + /// Scalar MAX across any number of SQL expressions. fn sql_greatest(&self, exprs: &[&str]) -> String { let mut result = exprs[0].to_string(); @@ -124,12 +166,12 @@ pub trait SqlDialect { let base_sq = base_size * base_size; let base_max = base_size - 1; format!( - "__ggsql_base__(n) AS (\ - SELECT 0 UNION ALL SELECT n + 1 FROM __ggsql_base__ WHERE n < {base_max}\ + "\"__ggsql_base__\"(n) AS (\ + SELECT 0 UNION ALL SELECT n + 1 FROM \"__ggsql_base__\" WHERE n < {base_max}\ ),\ - __ggsql_seq__(n) AS (\ + \"__ggsql_seq__\"(n) AS (\ SELECT CAST(a.n * {base_sq} + b.n * {base_size} + c.n AS REAL) AS n \ - FROM __ggsql_base__ a, __ggsql_base__ b, __ggsql_base__ c \ + FROM \"__ggsql_base__\" a, \"__ggsql_base__\" b, \"__ggsql_base__\" c \ WHERE a.n * {base_sq} + b.n * {base_size} + c.n < {n}\ )" ) @@ -143,12 +185,20 @@ pub trait SqlDialect { // Uses NTILE(4) to divide data into quartiles, then interpolates between boundaries. let group_filter = groups .iter() - .map(|g| format!("AND __ggsql_pct__.{g} IS NOT DISTINCT FROM __ggsql_qt__.{g}")) + .map(|g| { + let q = naming::quote_ident(g); + format!( + "AND {pct}.{q} IS NOT DISTINCT FROM {qt}.{q}", + pct = naming::quote_ident("__ggsql_pct__"), + qt = naming::quote_ident("__ggsql_qt__") + ) + }) .collect::>() .join(" "); let lo_tile = (fraction * 4.0).ceil() as usize; let hi_tile = lo_tile + 1; + let quoted_column = naming::quote_ident(column); format!( "(SELECT (\ @@ -158,9 +208,10 @@ pub trait SqlDialect { FROM (\ SELECT {column} AS __val, \ NTILE(4) OVER (ORDER BY {column}) AS __tile \ - FROM ({from}) AS __ggsql_pct__ \ + FROM ({from}) AS \"__ggsql_pct__\" \ WHERE {column} IS NOT NULL {group_filter}\ - ))" + ))", + column = quoted_column ) } @@ -209,6 +260,12 @@ pub mod duckdb; #[cfg(feature = "sqlite")] pub mod sqlite; +#[cfg(feature = "odbc")] +pub mod odbc; + +#[cfg(feature = "odbc")] +pub mod snowflake; + pub mod connection; pub mod data; mod spec; @@ -219,6 +276,35 @@ pub use duckdb::DuckDBReader; #[cfg(feature = "sqlite")] pub use sqlite::SqliteReader; +#[cfg(feature = "odbc")] +pub use odbc::OdbcReader; + +// ============================================================================ +// Shared utilities +// ============================================================================ + +/// Validate a table name for use in SQL statements. +/// +/// Rejects empty names and names containing null bytes or newlines. +pub(crate) fn validate_table_name(name: &str) -> Result<()> { + if name.is_empty() { + return Err(GgsqlError::ReaderError("Table name cannot be empty".into())); + } + + let forbidden = ['\0', '\n', '\r']; + for ch in forbidden { + if name.contains(ch) { + return Err(GgsqlError::ReaderError(format!( + "Table name '{}' contains invalid character '{}'", + name, + ch.escape_default() + ))); + } + } + + Ok(()) +} + // ============================================================================ // Spec - Result of reader.execute() // ============================================================================ @@ -363,37 +449,7 @@ pub trait Reader { /// let writer = VegaLiteWriter::new(); /// let json = writer.render(&spec)?; /// ``` - fn execute(&self, query: &str) -> Result - where - Self: Sized, - { - // Run validation first to capture warnings - let validated = validate(query)?; - let warnings: Vec = validated.warnings().to_vec(); - - // Prepare data with type names for this reader - let prepared_data = prepare_data_with_reader(query, self)?; - - // Get the first (and typically only) spec - let plot = prepared_data.specs.into_iter().next().ok_or_else(|| { - GgsqlError::ValidationError("No visualization spec found".to_string()) - })?; - - // For now, layer_sql and stat_sql are not tracked in PreparedData - // (they were part of main's version but not HEAD's) - let layer_sql = vec![None; plot.layers.len()]; - let stat_sql = vec![None; plot.layers.len()]; - - Ok(Spec::new( - plot, - prepared_data.data, - prepared_data.sql, - prepared_data.visual, - layer_sql, - stat_sql, - warnings, - )) - } + fn execute(&self, query: &str) -> Result; /// Get the SQL dialect for this reader. /// @@ -403,6 +459,36 @@ pub trait Reader { } } +/// Execute a ggsql query using any reader +/// +/// This is the shared implementation behind `Reader::execute()`. Concrete +/// readers delegate to this so the trait stays object-safe (no `Self: Sized` +/// bound on `execute`). +pub fn execute_with_reader(reader: &dyn Reader, query: &str) -> Result { + let validated = validate(query)?; + let warnings: Vec = validated.warnings().to_vec(); + + let prepared_data = prepare_data_with_reader(query, reader)?; + + let plot = + prepared_data.specs.into_iter().next().ok_or_else(|| { + GgsqlError::ValidationError("No visualization spec found".to_string()) + })?; + + let layer_sql = vec![None; plot.layers.len()]; + let stat_sql = vec![None; plot.layers.len()]; + + Ok(Spec::new( + plot, + prepared_data.data, + prepared_data.sql, + prepared_data.visual, + layer_sql, + stat_sql, + warnings, + )) +} + #[cfg(test)] #[cfg(all(feature = "duckdb", feature = "vegalite"))] mod tests { diff --git a/src/reader/odbc.rs b/src/reader/odbc.rs new file mode 100644 index 00000000..6f474da7 --- /dev/null +++ b/src/reader/odbc.rs @@ -0,0 +1,854 @@ +//! Generic ODBC data source implementation +//! +//! Provides a reader for any ODBC-compatible database (Snowflake, PostgreSQL, +//! SQL Server, etc.) using the `odbc-api` crate. + +use crate::reader::Reader; +use crate::{naming, DataFrame, GgsqlError, Result}; +use odbc_api::sys::{Date as OdbcDate, Time as OdbcTime, Timestamp as OdbcTimestamp}; +use odbc_api::{ + buffers::{AnyBuffer, AnySlice, BufferDesc, ColumnarBuffer}, + ConnectionOptions, Cursor, DataType as OdbcDataType, Environment, +}; +use polars::prelude::*; +use std::cell::RefCell; +use std::collections::HashSet; +use std::sync::OnceLock; + +/// Global ODBC environment (must be a singleton per process). +fn odbc_env() -> &'static Environment { + static ENV: OnceLock = OnceLock::new(); + ENV.get_or_init(|| Environment::new().expect("Failed to create ODBC environment")) +} + +/// Detect the backend SQL dialect from an ODBC connection string. +/// +/// Returns a dialect matching the detected backend (e.g. Snowflake, SQLite, +/// DuckDB, or ANSI for generic/unknown backends). +fn detect_dialect(conn_str: &str) -> Box { + let lower = conn_str.to_lowercase(); + if lower.contains("driver=snowflake") { + Box::new(super::snowflake::SnowflakeDialect) + } else if lower.contains("driver=sqlite") || lower.contains("driver={sqlite") { + #[cfg(feature = "sqlite")] + { + Box::new(super::sqlite::SqliteDialect) + } + #[cfg(not(feature = "sqlite"))] + { + Box::new(super::AnsiDialect) + } + } else if lower.contains("driver=duckdb") || lower.contains("driver={duckdb") { + #[cfg(feature = "duckdb")] + { + Box::new(super::duckdb::DuckDbDialect) + } + #[cfg(not(feature = "duckdb"))] + { + Box::new(super::AnsiDialect) + } + } else { + Box::new(super::AnsiDialect) + } +} + +/// Generic ODBC reader implementing the `Reader` trait. +pub struct OdbcReader { + connection: odbc_api::Connection<'static>, + dialect: Box, + registered_tables: RefCell>, +} + +// Safety: odbc_api::Connection is Send when we ensure single-threaded access. +// The Reader trait requires &self (immutable) for execute_sql, and ODBC +// connections are safe to use from one thread at a time. +unsafe impl Send for OdbcReader {} + +impl OdbcReader { + /// Create a new ODBC reader from a `odbc://` connection URI. + /// + /// The URI format is `odbc://` followed by the raw ODBC connection string. + pub fn from_connection_string(uri: &str) -> Result { + let conn_str = uri + .strip_prefix("odbc://") + .ok_or_else(|| GgsqlError::ReaderError("ODBC URI must start with odbc://".into()))?; + + let mut conn_str = conn_str.to_string(); + + // Snowflake ConnectionName resolution from connections.toml + if is_snowflake(&conn_str) { + if let Some(resolved) = resolve_connection_name(&conn_str) { + conn_str = resolved; + } + } + + // Snowflake Workbench credential detection + if is_snowflake(&conn_str) && !has_token(&conn_str) { + if let Some(token) = detect_workbench_token() { + conn_str = inject_snowflake_token(&conn_str, &token); + } + } + + // Detect backend dialect from connection string + let dialect = detect_dialect(&conn_str); + + let env = odbc_env(); + let connection = env + .connect_with_connection_string(&conn_str, ConnectionOptions::default()) + .map_err(|e| GgsqlError::ReaderError(format!("ODBC connection failed: {}", e)))?; + + Ok(Self { + connection, + dialect, + registered_tables: RefCell::new(HashSet::new()), + }) + } +} + +impl Reader for OdbcReader { + fn execute_sql(&self, sql: &str) -> Result { + // Execute the query (3rd arg = query timeout, None = no timeout) + let cursor = self + .connection + .execute(sql, (), None) + .map_err(|e| GgsqlError::ReaderError(format!("ODBC execute failed: {}", e)))?; + + let Some(cursor) = cursor else { + // DDL or non-query statement — return empty DataFrame + return DataFrame::new(Vec::::new()) + .map_err(|e| GgsqlError::ReaderError(format!("Empty DataFrame error: {}", e))); + }; + + cursor_to_dataframe(cursor) + } + + fn register(&self, name: &str, df: DataFrame, replace: bool) -> Result<()> { + super::validate_table_name(name)?; + + if replace { + let drop_sql = format!("DROP TABLE IF EXISTS {}", naming::quote_ident(name)); + // Ignore errors from DROP — table may not exist + let _ = self.connection.execute(&drop_sql, (), None); + } + + // Build CREATE TEMP TABLE with typed columns + let schema = df.schema(); + let col_defs: Vec = schema + .iter() + .map(|(col_name, dtype)| { + format!( + "{} {}", + naming::quote_ident(col_name), + polars_dtype_to_sql(dtype) + ) + }) + .collect(); + let create_sql = format!( + "CREATE TEMPORARY TABLE {} ({})", + naming::quote_ident(name), + col_defs.join(", ") + ); + self.connection + .execute(&create_sql, (), None) + .map_err(|e| { + GgsqlError::ReaderError(format!("Failed to create temp table '{}': {}", name, e)) + })?; + + // Insert data using ODBC bulk text inserter + let num_rows = df.height(); + if num_rows > 0 { + let num_cols = df.width(); + let placeholders: Vec<&str> = vec!["?"; num_cols]; + let insert_sql = format!( + "INSERT INTO {} VALUES ({})", + naming::quote_ident(name), + placeholders.join(", ") + ); + + // Convert all columns to string representation for text insertion + let string_columns: Vec>> = df + .get_columns() + .iter() + .map(|col| { + (0..num_rows) + .map(|row| { + let val = col.get(row).ok()?; + if val == AnyValue::Null { + None + } else { + Some(format!("{}", val)) + } + }) + .collect() + }) + .collect(); + + // Determine max string length per column for buffer allocation + let max_str_lens: Vec = string_columns + .iter() + .map(|col| { + col.iter() + .filter_map(|v| v.as_ref().map(|s| s.len())) + .max() + .unwrap_or(1) + .max(1) // minimum buffer size of 1 + }) + .collect(); + + const BATCH_SIZE: usize = 1024; + let prepared = self.connection.prepare(&insert_sql).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to prepare INSERT for '{}': {}", name, e)) + })?; + + let batch_capacity = num_rows.min(BATCH_SIZE); + let mut inserter = prepared + .into_text_inserter(batch_capacity, max_str_lens) + .map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to create bulk inserter for '{}': {}", + name, e + )) + })?; + + let mut rows_in_batch = 0; + for row_idx in 0..num_rows { + let row_values: Vec> = string_columns + .iter() + .map(|col| col[row_idx].as_ref().map(|s| s.as_bytes())) + .collect(); + + inserter.append(row_values.into_iter()).map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to append row {} to '{}': {}", + row_idx, name, e + )) + })?; + rows_in_batch += 1; + + if rows_in_batch >= BATCH_SIZE { + inserter.execute().map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to execute batch insert into '{}': {}", + name, e + )) + })?; + inserter.clear(); + rows_in_batch = 0; + } + } + + // Execute final partial batch + if rows_in_batch > 0 { + inserter.execute().map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to execute final batch insert into '{}': {}", + name, e + )) + })?; + } + } + + self.registered_tables.borrow_mut().insert(name.to_string()); + Ok(()) + } + + fn unregister(&self, name: &str) -> Result<()> { + if !self.registered_tables.borrow().contains(name) { + return Err(GgsqlError::ReaderError(format!( + "Table '{}' was not registered via this reader", + name + ))); + } + + let sql = format!("DROP TABLE IF EXISTS {}", naming::quote_ident(name)); + self.connection.execute(&sql, (), None).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to unregister table '{}': {}", name, e)) + })?; + + self.registered_tables.borrow_mut().remove(name); + Ok(()) + } + + fn execute(&self, query: &str) -> Result { + super::execute_with_reader(self, query) + } + + fn dialect(&self) -> &dyn super::SqlDialect { + &*self.dialect + } +} + +/// Map a Polars data type to a SQL column type string. +fn polars_dtype_to_sql(dtype: &DataType) -> &'static str { + match dtype { + DataType::Boolean => "BOOLEAN", + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => "BIGINT", + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => "BIGINT", + DataType::Float32 | DataType::Float64 => "DOUBLE PRECISION", + DataType::Date => "DATE", + DataType::Datetime(_, _) => "TIMESTAMP", + DataType::Time => "TIME", + _ => "TEXT", + } +} + +/// Column builder that accumulates typed values across batches. +enum ColumnBuilder { + Int8(Vec>), + Int16(Vec>), + Int32(Vec>), + Int64(Vec>), + Float32(Vec>), + Float64(Vec>), + Boolean(Vec>), + Date(Vec>), + Time(Vec>), + Timestamp(Vec>), + Text(Vec>), +} + +impl ColumnBuilder { + fn from_odbc_type(data_type: &OdbcDataType) -> Self { + match data_type { + OdbcDataType::TinyInt => Self::Int8(Vec::new()), + OdbcDataType::SmallInt => Self::Int16(Vec::new()), + OdbcDataType::Integer => Self::Int32(Vec::new()), + OdbcDataType::BigInt => Self::Int64(Vec::new()), + OdbcDataType::Real | OdbcDataType::Float { precision: 0..=24 } => { + Self::Float32(Vec::new()) + } + OdbcDataType::Double | OdbcDataType::Float { .. } => Self::Float64(Vec::new()), + OdbcDataType::Numeric { + scale: 0, + precision, + } + | OdbcDataType::Decimal { + scale: 0, + precision, + } => { + if *precision < 10 { + Self::Int32(Vec::new()) + } else if *precision < 19 { + Self::Int64(Vec::new()) + } else { + Self::Float64(Vec::new()) + } + } + OdbcDataType::Numeric { .. } | OdbcDataType::Decimal { .. } => { + Self::Float64(Vec::new()) + } + OdbcDataType::Bit => Self::Boolean(Vec::new()), + OdbcDataType::Date => Self::Date(Vec::new()), + OdbcDataType::Time { .. } => Self::Time(Vec::new()), + OdbcDataType::Timestamp { .. } => Self::Timestamp(Vec::new()), + _ => Self::Text(Vec::new()), + } + } + + fn append_from_slice(&mut self, slice: AnySlice<'_>) -> std::result::Result<(), String> { + match (self, slice) { + (Self::Int8(v), AnySlice::NullableI8(s)) => { + v.extend(s.map(|opt| opt.copied())); + } + (Self::Int16(v), AnySlice::NullableI16(s)) => { + v.extend(s.map(|opt| opt.copied())); + } + (Self::Int32(v), AnySlice::NullableI32(s)) => { + v.extend(s.map(|opt| opt.copied())); + } + (Self::Int64(v), AnySlice::NullableI64(s)) => { + v.extend(s.map(|opt| opt.copied())); + } + (Self::Float32(v), AnySlice::NullableF32(s)) => { + v.extend(s.map(|opt| opt.copied())); + } + (Self::Float64(v), AnySlice::NullableF64(s)) => { + v.extend(s.map(|opt| opt.copied())); + } + (Self::Boolean(v), AnySlice::NullableBit(s)) => { + v.extend(s.map(|opt| opt.map(|b| b.as_bool()))); + } + (Self::Date(v), AnySlice::NullableDate(s)) => { + v.extend(s.map(|opt| opt.and_then(odbc_date_to_days))); + } + (Self::Time(v), AnySlice::NullableTime(s)) => { + v.extend(s.map(|opt| opt.map(odbc_time_to_nanos))); + } + (Self::Timestamp(v), AnySlice::NullableTimestamp(s)) => { + v.extend(s.map(|opt| opt.and_then(odbc_timestamp_to_micros))); + } + (Self::Text(v), AnySlice::Text(view)) => { + v.extend(view.iter().map(|opt| { + opt.and_then(|bytes| std::str::from_utf8(bytes).ok().map(|s| s.to_string())) + })); + } + (Self::Text(v), AnySlice::WText(view)) => { + v.extend( + view.iter() + .map(|opt| opt.map(|chars| String::from_utf16_lossy(chars.into()))), + ); + } + // Decimal/Numeric with scale > 0 bound as text → parse to f64 + (Self::Float64(v), AnySlice::Text(view)) => { + v.extend(view.iter().map(|opt| { + opt.and_then(|bytes| { + std::str::from_utf8(bytes) + .ok() + .and_then(|s| s.parse::().ok()) + }) + })); + } + // Decimal with scale=0 bound as i32/i64 text fallback + (Self::Int32(v), AnySlice::Text(view)) => { + v.extend(view.iter().map(|opt| { + opt.and_then(|bytes| { + std::str::from_utf8(bytes) + .ok() + .and_then(|s| s.parse::().ok()) + }) + })); + } + (Self::Int64(v), AnySlice::Text(view)) => { + v.extend(view.iter().map(|opt| { + opt.and_then(|bytes| { + std::str::from_utf8(bytes) + .ok() + .and_then(|s| s.parse::().ok()) + }) + })); + } + (builder, _slice) => { + let builder_type = match builder { + Self::Int8(_) => "Int8", + Self::Int16(_) => "Int16", + Self::Int32(_) => "Int32", + Self::Int64(_) => "Int64", + Self::Float32(_) => "Float32", + Self::Float64(_) => "Float64", + Self::Boolean(_) => "Boolean", + Self::Date(_) => "Date", + Self::Time(_) => "Time", + Self::Timestamp(_) => "Timestamp", + Self::Text(_) => "Text", + }; + return Err(format!( + "ODBC type mismatch: expected {builder_type} buffer but driver returned a different type" + )); + } + } + Ok(()) + } + + fn into_series(self, name: &str) -> Series { + match self { + Self::Int8(v) => Series::new(name.into(), v), + Self::Int16(v) => Series::new(name.into(), v), + Self::Int32(v) => Series::new(name.into(), v), + Self::Int64(v) => Series::new(name.into(), v), + Self::Float32(v) => Series::new(name.into(), v), + Self::Float64(v) => Series::new(name.into(), v), + Self::Boolean(v) => Series::new(name.into(), v), + Self::Date(v) => { + let ca = Int32Chunked::new(name.into(), &v); + ca.into_date().into_series() + } + Self::Time(v) => { + let ca = Int64Chunked::new(name.into(), &v); + ca.into_time().into_series() + } + Self::Timestamp(v) => { + let ca = Int64Chunked::new(name.into(), &v); + ca.into_datetime(TimeUnit::Microseconds, None).into_series() + } + Self::Text(v) => Series::new(name.into(), v), + } + } +} + +fn odbc_date_to_days(d: &OdbcDate) -> Option { + chrono::NaiveDate::from_ymd_opt(d.year as i32, d.month as u32, d.day as u32).map(|date| { + let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + (date - epoch).num_days() as i32 + }) +} + +fn odbc_time_to_nanos(t: &OdbcTime) -> i64 { + let h = t.hour as i64; + let m = t.minute as i64; + let s = t.second as i64; + (h * 3600 + m * 60 + s) * 1_000_000_000 +} + +fn odbc_timestamp_to_micros(ts: &OdbcTimestamp) -> Option { + chrono::NaiveDate::from_ymd_opt(ts.year as i32, ts.month as u32, ts.day as u32) + .and_then(|date| { + date.and_hms_nano_opt( + ts.hour as u32, + ts.minute as u32, + ts.second as u32, + ts.fraction, + ) + }) + .map(|dt| dt.and_utc().timestamp_micros()) +} + +/// Convert an ODBC cursor to a Polars DataFrame using typed buffers. +fn cursor_to_dataframe(mut cursor: impl Cursor) -> Result { + let col_count = cursor + .num_result_cols() + .map_err(|e| GgsqlError::ReaderError(format!("Failed to get column count: {}", e)))? + as usize; + + if col_count == 0 { + return DataFrame::new(Vec::::new()) + .map_err(|e| GgsqlError::ReaderError(e.to_string())); + } + + // Collect column names and types, build buffer descriptors + let mut col_names = Vec::with_capacity(col_count); + let mut col_types = Vec::with_capacity(col_count); + let mut descs = Vec::with_capacity(col_count); + + let text_fallback = BufferDesc::Text { max_str_len: 65536 }; + + for i in 1..=col_count as u16 { + let name = cursor.col_name(i).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to get column {} name: {}", i, e)) + })?; + let data_type = cursor.col_data_type(i).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to get column {} type: {}", i, e)) + })?; + + let desc = BufferDesc::from_data_type(data_type, true).unwrap_or(text_fallback); + + col_names.push(name); + col_types.push(data_type); + descs.push(desc); + } + + // Create typed columnar buffer and column builders + let batch_size = 1000; + let mut builders: Vec = col_types + .iter() + .map(ColumnBuilder::from_odbc_type) + .collect(); + + let mut buffer = ColumnarBuffer::::from_descs(batch_size, descs); + + let mut block_cursor = cursor + .bind_buffer(&mut buffer) + .map_err(|e| GgsqlError::ReaderError(format!("Failed to bind buffer: {}", e)))?; + + while let Some(batch) = block_cursor + .fetch() + .map_err(|e| GgsqlError::ReaderError(format!("Failed to fetch batch: {}", e)))? + { + for (col_idx, builder) in builders.iter_mut().enumerate() { + let slice = batch.column(col_idx); + builder.append_from_slice(slice).map_err(|e| { + GgsqlError::ReaderError(format!("Column '{}': {}", col_names[col_idx], e)) + })?; + } + } + + // Convert builders to Polars Series + let series: Vec = col_names + .iter() + .zip(builders) + .map(|(name, builder)| Column::from(builder.into_series(name))) + .collect(); + + DataFrame::new(series).map_err(|e| GgsqlError::ReaderError(e.to_string())) +} + +// ============================================================================ +// Snowflake Workbench credential detection +// ============================================================================ + +fn is_snowflake(conn_str: &str) -> bool { + conn_str.to_lowercase().contains("driver=snowflake") +} + +fn has_token(conn_str: &str) -> bool { + conn_str.to_lowercase().contains("token=") +} + +fn home_dir() -> Option { + #[cfg(target_os = "windows")] + { + std::env::var("USERPROFILE") + .ok() + .map(std::path::PathBuf::from) + } + #[cfg(not(target_os = "windows"))] + { + std::env::var("HOME").ok().map(std::path::PathBuf::from) + } +} + +/// Find the Snowflake connections.toml file, checking standard locations. +fn find_snowflake_connections_toml() -> Option { + use std::path::PathBuf; + + // 1. $SNOWFLAKE_HOME/connections.toml + if let Ok(snowflake_home) = std::env::var("SNOWFLAKE_HOME") { + let p = PathBuf::from(&snowflake_home).join("connections.toml"); + if p.exists() { + return Some(p); + } + } + + // 2. ~/.snowflake/connections.toml + if let Some(home) = home_dir() { + let p = home.join(".snowflake").join("connections.toml"); + if p.exists() { + return Some(p); + } + } + + // 3. Platform-specific paths + if let Some(home) = home_dir() { + #[cfg(target_os = "macos")] + { + let p = home.join("Library/Application Support/snowflake/connections.toml"); + if p.exists() { + return Some(p); + } + } + + #[cfg(target_os = "linux")] + { + let xdg = std::env::var("XDG_CONFIG_HOME") + .map(PathBuf::from) + .unwrap_or_else(|_| home.join(".config")); + let p = xdg.join("snowflake").join("connections.toml"); + if p.exists() { + return Some(p); + } + } + + #[cfg(target_os = "windows")] + { + let p = home.join("AppData/Local/snowflake/connections.toml"); + if p.exists() { + return Some(p); + } + } + } + + None +} + +/// Resolve a `ConnectionName=` parameter in a Snowflake ODBC connection +/// string by reading the named entry from `~/.snowflake/connections.toml` and +/// building a full ODBC connection string from it. +fn resolve_connection_name(conn_str: &str) -> Option { + // Extract ConnectionName value (case-insensitive) + let lower = conn_str.to_lowercase(); + let cn_key = "connectionname="; + let cn_start = lower.find(cn_key)?; + let value_start = cn_start + cn_key.len(); + + let rest = &conn_str[value_start..]; + let value_end = rest.find(';').unwrap_or(rest.len()); + let connection_name = rest[..value_end].trim(); + + if connection_name.is_empty() { + return None; + } + + // Read and parse connections.toml + let toml_path = find_snowflake_connections_toml()?; + let content = std::fs::read_to_string(&toml_path).ok()?; + let doc = content.parse::().ok()?; + + let entry = doc.get(connection_name)?; + if !entry.is_table() && !entry.is_inline_table() { + return None; + } + + // Build ODBC connection string from TOML entry fields + let get_str = |key: &str| -> Option { entry.get(key)?.as_str().map(|s| s.to_string()) }; + + let account = get_str("account")?; + let mut parts = vec![ + "Driver=Snowflake".to_string(), + format!("Server={}.snowflakecomputing.com", account), + ]; + + if let Some(user) = get_str("user") { + parts.push(format!("UID={}", user)); + } + if let Some(password) = get_str("password") { + parts.push(format!("PWD={}", password)); + } + if let Some(authenticator) = get_str("authenticator") { + parts.push(format!("Authenticator={}", authenticator)); + } + if let Some(token) = get_str("token") { + parts.push(format!("Token={}", token)); + } + if let Some(warehouse) = get_str("warehouse") { + parts.push(format!("Warehouse={}", warehouse)); + } + if let Some(database) = get_str("database") { + parts.push(format!("Database={}", database)); + } + if let Some(schema) = get_str("schema") { + parts.push(format!("Schema={}", schema)); + } + if let Some(role) = get_str("role") { + parts.push(format!("Role={}", role)); + } + + Some(parts.join(";")) +} + +/// Detect Posit Workbench Snowflake OAuth token. +/// +/// Checks `SNOWFLAKE_HOME` for a Workbench-managed `connections.toml` file +/// containing OAuth credentials. +fn detect_workbench_token() -> Option { + let snowflake_home = std::env::var("SNOWFLAKE_HOME").ok()?; + + // Only use Workbench credentials if the path indicates Workbench management + if !snowflake_home.contains("posit-workbench") { + return None; + } + + let toml_path = std::path::Path::new(&snowflake_home).join("connections.toml"); + let content = std::fs::read_to_string(&toml_path).ok()?; + + let doc = content.parse::().ok()?; + let token = doc.get("workbench")?.get("token")?.as_str()?.to_string(); + + if token.is_empty() { + None + } else { + Some(token) + } +} + +/// Inject OAuth token into a Snowflake ODBC connection string. +fn inject_snowflake_token(conn_str: &str, token: &str) -> String { + // Append authenticator and token parameters + let mut result = conn_str.trim_end_matches(';').to_string(); + result.push_str(";Authenticator=oauth;Token="); + result.push_str(token); + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_snowflake() { + assert!(is_snowflake( + "Driver=Snowflake;Server=foo.snowflakecomputing.com" + )); + assert!(!is_snowflake("Driver={PostgreSQL};Server=localhost")); + } + + #[test] + fn test_has_token() { + assert!(has_token("Driver=Snowflake;Token=abc123")); + assert!(!has_token("Driver=Snowflake;Server=foo")); + } + + #[test] + fn test_detect_dialect() { + // Snowflake uses SHOW commands + let dialect = detect_dialect("Driver=Snowflake;Server=foo"); + assert!(dialect.sql_list_catalogs().contains("SHOW")); + + // PostgreSQL uses information_schema (ANSI default) + let dialect = detect_dialect("Driver={PostgreSQL};Server=localhost"); + assert!(dialect.sql_list_catalogs().contains("information_schema")); + + // Generic uses information_schema (ANSI default) + let dialect = detect_dialect("Driver=SomeOther;Server=localhost"); + assert!(dialect.sql_list_catalogs().contains("information_schema")); + } + + #[test] + fn test_inject_snowflake_token() { + let result = inject_snowflake_token( + "Driver=Snowflake;Server=foo.snowflakecomputing.com", + "mytoken", + ); + assert!(result.contains("Authenticator=oauth")); + assert!(result.contains("Token=mytoken")); + } + + #[test] + fn test_resolve_connection_name_with_toml() { + use std::io::Write; + + // Create a temp dir with a connections.toml + let dir = tempfile::tempdir().unwrap(); + let toml_path = dir.path().join("connections.toml"); + let mut f = std::fs::File::create(&toml_path).unwrap(); + writeln!( + f, + r#" +default_connection_name = "myconn" + +[myconn] +account = "myaccount" +user = "myuser" +password = "mypass" +warehouse = "mywh" +database = "mydb" +schema = "public" +role = "myrole" + +[other] +account = "otheraccount" +"# + ) + .unwrap(); + + // Point SNOWFLAKE_HOME at our temp dir + std::env::set_var("SNOWFLAKE_HOME", dir.path()); + + let result = resolve_connection_name("Driver=Snowflake;ConnectionName=myconn"); + assert!(result.is_some()); + let conn = result.unwrap(); + assert!(conn.contains("Driver=Snowflake")); + assert!(conn.contains("Server=myaccount.snowflakecomputing.com")); + assert!(conn.contains("UID=myuser")); + assert!(conn.contains("PWD=mypass")); + assert!(conn.contains("Warehouse=mywh")); + assert!(conn.contains("Database=mydb")); + assert!(conn.contains("Schema=public")); + assert!(conn.contains("Role=myrole")); + + // Test with a connection that has fewer fields + let result2 = resolve_connection_name("Driver=Snowflake;ConnectionName=other"); + assert!(result2.is_some()); + let conn2 = result2.unwrap(); + assert!(conn2.contains("Server=otheraccount.snowflakecomputing.com")); + assert!(!conn2.contains("UID=")); + + // Test with non-existent connection name + let result3 = resolve_connection_name("Driver=Snowflake;ConnectionName=nonexistent"); + assert!(result3.is_none()); + + // No ConnectionName param → None + let result4 = resolve_connection_name("Driver=Snowflake;Server=foo"); + assert!(result4.is_none()); + + // Clean up env + std::env::remove_var("SNOWFLAKE_HOME"); + } + + #[test] + fn test_polars_dtype_to_sql() { + assert_eq!(polars_dtype_to_sql(&DataType::Int64), "BIGINT"); + assert_eq!(polars_dtype_to_sql(&DataType::Float64), "DOUBLE PRECISION"); + assert_eq!(polars_dtype_to_sql(&DataType::Boolean), "BOOLEAN"); + assert_eq!(polars_dtype_to_sql(&DataType::Date), "DATE"); + assert_eq!(polars_dtype_to_sql(&DataType::String), "TEXT"); + } +} diff --git a/src/reader/snowflake.rs b/src/reader/snowflake.rs new file mode 100644 index 00000000..9257052f --- /dev/null +++ b/src/reader/snowflake.rs @@ -0,0 +1,35 @@ +//! Snowflake-specific SQL dialect. +//! +//! Overrides schema introspection to use Snowflake's SHOW commands +//! instead of information_schema queries. + +use crate::naming; + +pub struct SnowflakeDialect; + +impl super::SqlDialect for SnowflakeDialect { + fn sql_list_catalogs(&self) -> String { + "SHOW DATABASES".into() + } + + fn sql_list_schemas(&self, catalog: &str) -> String { + format!("SHOW SCHEMAS IN DATABASE {}", naming::quote_ident(catalog)) + } + + fn sql_list_tables(&self, catalog: &str, schema: &str) -> String { + format!( + "SHOW OBJECTS IN SCHEMA {}.{}", + naming::quote_ident(catalog), + naming::quote_ident(schema) + ) + } + + fn sql_list_columns(&self, catalog: &str, schema: &str, table: &str) -> String { + format!( + "SHOW COLUMNS IN TABLE {}.{}.{}", + naming::quote_ident(catalog), + naming::quote_ident(schema), + naming::quote_ident(table) + ) + } +} diff --git a/src/reader/sqlite.rs b/src/reader/sqlite.rs index 793ec928..c6fd92d5 100644 --- a/src/reader/sqlite.rs +++ b/src/reader/sqlite.rs @@ -4,7 +4,7 @@ //! Works on both native targets and wasm32-unknown-unknown (via sqlite-wasm-rs). use crate::reader::Reader; -use crate::{DataFrame, GgsqlError, Result}; +use crate::{naming, DataFrame, GgsqlError, Result}; use chrono::Datelike; use polars::prelude::*; use rusqlite::Connection; @@ -67,6 +67,29 @@ impl super::SqlDialect for SqliteDialect { "0".to_string() } } + + fn sql_list_catalogs(&self) -> String { + "SELECT name AS catalog_name FROM pragma_database_list ORDER BY name".into() + } + + fn sql_list_schemas(&self, _catalog: &str) -> String { + "SELECT 'main' AS schema_name".into() + } + + fn sql_list_tables(&self, catalog: &str, _schema: &str) -> String { + format!( + "SELECT name AS table_name, type AS table_type FROM {}.sqlite_master \ + WHERE type IN ('table', 'view') ORDER BY name", + naming::quote_ident(catalog) + ) + } + + fn sql_list_columns(&self, _catalog: &str, _schema: &str, table: &str) -> String { + format!( + "SELECT name AS column_name, type AS data_type FROM pragma_table_info('{}') ORDER BY cid", + table.replace('\'', "''") + ) + } } /// SQLite database reader @@ -153,7 +176,7 @@ fn validate_table_name(name: &str) -> Result<()> { return Err(GgsqlError::ReaderError("Table name cannot be empty".into())); } - let forbidden = ['"', '\0', '\n', '\r']; + let forbidden = ['\0', '\n', '\r']; for ch in forbidden { if name.contains(ch) { return Err(GgsqlError::ReaderError(format!( @@ -164,13 +187,6 @@ fn validate_table_name(name: &str) -> Result<()> { } } - if name.len() > 128 { - return Err(GgsqlError::ReaderError(format!( - "Table name '{}' exceeds maximum length of 128 characters", - name - ))); - } - Ok(()) } @@ -248,7 +264,7 @@ impl Reader for SqliteReader { { let dataset_names = super::data::extract_builtin_dataset_names(sql)?; for name in &dataset_names { - let table_name = crate::naming::builtin_data_table(name); + let table_name = naming::builtin_data_table(name); if !self.table_exists(&table_name) { let df = super::data::load_builtin_dataframe(name)?; self.register(&table_name, df, true)?; @@ -331,7 +347,7 @@ impl Reader for SqliteReader { if self.table_exists(name) { if replace { - let sql = format!("DROP TABLE IF EXISTS \"{}\"", name); + let sql = format!("DROP TABLE IF EXISTS {}", naming::quote_ident(name)); self.conn.execute(&sql, []).map_err(|e| { GgsqlError::ReaderError(format!("Failed to drop table '{}': {}", name, e)) })?; @@ -351,11 +367,15 @@ impl Reader for SqliteReader { .map(|col| { let col_name = col.name().to_string(); let col_type = polars_type_to_sqlite(col.dtype()); - format!("\"{}\" {}", col_name, col_type) + format!("{} {}", naming::quote_ident(&col_name), col_type) }) .collect(); - let create_sql = format!("CREATE TABLE \"{}\" ({})", name, col_defs.join(", ")); + let create_sql = format!( + "CREATE TABLE {} ({})", + naming::quote_ident(name), + col_defs.join(", ") + ); self.conn.execute(&create_sql, []).map_err(|e| { GgsqlError::ReaderError(format!("Failed to create table '{}': {}", name, e)) })?; @@ -364,8 +384,8 @@ impl Reader for SqliteReader { if df.height() > 0 { let placeholders: Vec<&str> = vec!["?"; df.width()]; let insert_sql = format!( - "INSERT INTO \"{}\" VALUES ({})", - name, + "INSERT INTO {} VALUES ({})", + naming::quote_ident(name), placeholders.join(", ") ); @@ -433,7 +453,7 @@ impl Reader for SqliteReader { ))); } - let sql = format!("DROP TABLE IF EXISTS \"{}\"", name); + let sql = format!("DROP TABLE IF EXISTS {}", naming::quote_ident(name)); self.conn.execute(&sql, []).map_err(|e| { GgsqlError::ReaderError(format!("Failed to unregister table '{}': {}", name, e)) })?; @@ -442,6 +462,10 @@ impl Reader for SqliteReader { Ok(()) } + fn execute(&self, query: &str) -> Result { + super::execute_with_reader(self, query) + } + fn dialect(&self) -> &dyn super::SqlDialect { &SqliteDialect } @@ -716,12 +740,10 @@ mod tests { assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("cannot be empty")); + // Name with double quote should succeed (quote_ident escapes it) let result = reader.register("bad\"name", df.clone(), false); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("invalid character")); + assert!(result.is_ok()); + reader.unregister("bad\"name").unwrap(); let result = reader.register("bad\0name", df.clone(), false); assert!(result.is_err()); @@ -729,14 +751,6 @@ mod tests { .unwrap_err() .to_string() .contains("invalid character")); - - let long_name = "a".repeat(200); - let result = reader.register(&long_name, df, false); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("exceeds maximum length")); } #[test]